Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,9 @@ class OracleIntegrationSuite extends SharedJDBCIntegrationSuite
case LogicalRelationWithTable(JDBCRelation(_, parts, _, _), _) =>
val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet
assert(whereClauses === Set(
""""D" < '2018-07-11' or "D" is null""",
""""D" >= '2018-07-11' AND "D" < '2018-07-15'""",
""""D" >= '2018-07-15'"""))
""""D" < {d '2018-07-11'} or "D" is null""",
""""D" >= {d '2018-07-11'} AND "D" < {d '2018-07-15'}""",
""""D" >= {d '2018-07-15'}"""))
}
assert(df1.collect().toSet === expectedResult)

Expand All @@ -483,8 +483,8 @@ class OracleIntegrationSuite extends SharedJDBCIntegrationSuite
case LogicalRelationWithTable(JDBCRelation(_, parts, _, _), _) =>
val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet
assert(whereClauses === Set(
""""T" < '2018-07-15 20:50:32.5' or "T" is null""",
""""T" >= '2018-07-15 20:50:32.5'"""))
""""T" < {ts '2018-07-15 20:50:32.5'} or "T" is null""",
""""T" >= {ts '2018-07-15 20:50:32.5'}"""))
}
assert(df2.collect().toSet === expectedResult)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, DateType, NumericType, StructType, TimestampType}
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -112,8 +112,9 @@ private[sql] object JDBCRelation extends Logging {
"Operation not allowed: the lower bound of partitioning column is larger than the upper " +
s"bound. Lower bound: $lowerBound; Upper bound: $upperBound")

val dialect = JdbcDialects.get(jdbcOptions.url)
val boundValueToString: Long => String =
toBoundValueInWhereClause(_, partitioning.columnType, timeZoneId)
toBoundValueInWhereClause(_, partitioning.columnType, timeZoneId, dialect)
val numPartitions =
if ((upperBound - lowerBound) >= partitioning.numPartitions || /* check for overflow */
(upperBound - lowerBound) < 0) {
Expand Down Expand Up @@ -216,21 +217,18 @@ private[sql] object JDBCRelation extends Logging {
private def toBoundValueInWhereClause(
value: Long,
columnType: DataType,
timeZoneId: String): String = {
def dateTimeToString(): String = {
val dateTimeStr = columnType match {
case DateType =>
DateFormatter().format(value.toInt)
case TimestampType =>
val timestampFormatter = TimestampFormatter.getFractionFormatter(
DateTimeUtils.getZoneId(timeZoneId))
timestampFormatter.format(value)
}
s"'$dateTimeStr'"
}
timeZoneId: String,
dialect: JdbcDialect): String = {
columnType match {
case _: NumericType => value.toString
case DateType | TimestampType => dateTimeToString()
case DateType =>
val date = DateFormatter().format(value.toInt)
dialect.compileValue(java.sql.Date.valueOf(date)).toString
case TimestampType =>
val timestampFormatter = TimestampFormatter.getFractionFormatter(
DateTimeUtils.getZoneId(timeZoneId))
val ts = timestampFormatter.format(value)
dialect.compileValue(java.sql.Timestamp.valueOf(ts)).toString
}
}

Expand Down
60 changes: 60 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,66 @@ class JDBCSuite extends SharedSparkSession {
assert(lastPredicate == """"PartitionColumn" >= '2020-08-02'""")
}

test("SPARK-28587: columnPartition should use dialect.compileValue for date/timestamp bounds") {
// Dialects for strict-typing engines (e.g. Athena, Phoenix) override compileValue to emit
// typed literals (DATE '...', TIMESTAMP '...') instead of bare quoted strings. Verify that
// toBoundValueInWhereClause routes through the dialect so the generated WHERE clauses are
// valid SQL for those engines.
//
// Use a unique URL prefix (jdbc:typed:) so no built-in dialect matches and our test dialect
// is the sole match — avoiding AggregatedDialect which does not delegate compileValue.
// Pass "driver" -> H2 driver class to suppress DriverManager.getDriver() lookup;
// columnPartition never opens a connection so the driver mismatch is harmless.
val typedLiteralDialect = new JdbcDialect {
override def canHandle(url: String): Boolean = url.startsWith("jdbc:typed:")
override def compileValue(value: Any): Any = value match {
case d: java.sql.Date => s"DATE '$d'"
case t: java.sql.Timestamp => s"TIMESTAMP '$t'"
case other => super.compileValue(other)
}
}
val typedUrl = "jdbc:typed:test"
val driverOpt = "driver" -> "org.h2.Driver"
JdbcDialects.registerDialect(typedLiteralDialect)
try {
val dateSchema = StructType(Seq(StructField("d", DateType)))
val dateParts = JDBCRelation.columnPartition(
dateSchema,
analysis.caseInsensitiveResolution,
TimeZone.getDefault.toZoneId.toString,
new JDBCOptions(typedUrl, "t", Map(
driverOpt,
"partitionColumn" -> "d",
"lowerBound" -> "2024-01-01",
"upperBound" -> "2024-03-31",
"numPartitions" -> "2"
))
)
val dateClauses = dateParts.map(_.asInstanceOf[JDBCPartition].whereClause)
assert(dateClauses.forall(c => c == null || c.contains("DATE '")),
s"Expected DATE '...' typed literals but got: ${dateClauses.mkString(", ")}")

val tsSchema = StructType(Seq(StructField("ts", TimestampType)))
val tsParts = JDBCRelation.columnPartition(
tsSchema,
analysis.caseInsensitiveResolution,
TimeZone.getDefault.toZoneId.toString,
new JDBCOptions(typedUrl, "t", Map(
driverOpt,
"partitionColumn" -> "ts",
"lowerBound" -> "2024-01-01 00:00:00.0",
"upperBound" -> "2024-03-31 00:00:00.0",
"numPartitions" -> "2"
))
)
val tsClauses = tsParts.map(_.asInstanceOf[JDBCPartition].whereClause)
assert(tsClauses.forall(c => c == null || c.contains("TIMESTAMP '")),
s"Expected TIMESTAMP '...' typed literals but got: ${tsClauses.mkString(", ")}")
} finally {
JdbcDialects.unregisterDialect(typedLiteralDialect)
}
}

test("overflow of partition bound difference does not give negative stride") {
val df = sql("SELECT * FROM partsoverflow")
checkNumPartitions(df, expectedNumPartitions = 3)
Expand Down