Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,23 @@ public enum TableCatalogCapability {
* {@link TableCatalog#createTable}.
* See {@link Column#identityColumnSpec()}.
*/
SUPPORTS_CREATE_TABLE_WITH_IDENTITY_COLUMNS
SUPPORTS_CREATE_TABLE_WITH_IDENTITY_COLUMNS,

/**
* Signals that the TableCatalog supports Spark auto-filling generated column values and
* enforcing generated column constraints during writes.
* <p>
* When this capability is present, Spark will:
* <ul>
* <li>Auto-compute missing generated column values using the generation expression.</li>
* <li>Validate explicitly-provided generated column values against the generation
* expression.</li>
* </ul>
* <p>
* Without this capability, the connector is responsible for handling generated column values
* during writes.
*
* @since 4.3.0
*/
SUPPORT_GENERATED_COLUMN_ON_WRITE
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{toPrettySQL, trimTempResolvedColumn, CharVarcharUtils}
import org.apache.spark.sql.catalyst.util.{toPrettySQL, trimTempResolvedColumn, CharVarcharUtils, GeneratedColumn}
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
Expand Down Expand Up @@ -3853,13 +3853,29 @@ class Analyzer(
val defaultValueFillMode =
if (conf.coerceInsertNestedTypes && v2Write.schemaEvolutionEnabled) RECURSE
else FILL
val projection = TableOutputResolver.resolveOutputColumns(
v2Write.table.name, v2Write.table.output, v2Write.query, v2Write.isByName, conf,
defaultValueFillMode)
// Only let TableOutputResolver see generation expression metadata if the catalog
// supports auto-filling generated columns on write.
val expected = v2Write.table match {
case r: DataSourceV2Relation if !supportsGeneratedColumnOnWrite(r) =>
r.output.map(GeneratedColumn.removeGenerationExpressionMetadata)
case _ => v2Write.table.output
}
val (projection, autoFilledGenCols) =
TableOutputResolver.resolveOutputColumnsWithGeneratedInfo(
v2Write.table.name, expected, v2Write.query, v2Write.isByName, conf,
defaultValueFillMode)
if (projection != v2Write.query) {
val cleanedTable = v2Write.table match {
case r: DataSourceV2Relation =>
r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata))
r.copy(output = r.output.map { attr =>
val cleaned = CharVarcharUtils.cleanAttrMetadata(attr)
if (autoFilledGenCols.contains(attr.name)) {
GeneratedColumn.removeGenerationExpressionMetadata(cleaned)
.asInstanceOf[AttributeReference]
} else {
cleaned
}
})
case other => other
}
v2Write.withNewQuery(projection).withNewTable(cleanedTable)
Expand All @@ -3869,6 +3885,15 @@ class Analyzer(
}
}

private def supportsGeneratedColumnOnWrite(r: DataSourceV2Relation): Boolean = {
r.catalog.exists {
case tc: TableCatalog =>
tc.capabilities().contains(
TableCatalogCapability.SUPPORT_GENERATED_COLUMN_ON_WRITE)
case _ => false
}
}

private def validateStoreAssignmentPolicy(): Unit = {
// SPARK-28730: LEGACY store assignment policy is disallowed in data source v2.
if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{And, CheckInvariant, Expression, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.expressions.{And, CheckInvariant, EqualNullSafe, Expression, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, V2WriteCommand, WriteDelta}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.catalyst.util.GeneratedColumn
import org.apache.spark.sql.connector.catalog.{CatalogManager, GenerationExpression, TableCatalog,
TableCatalogCapability}
import org.apache.spark.sql.connector.catalog.constraints.Check
import org.apache.spark.sql.connector.write.RowLevelOperation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
Expand All @@ -39,18 +42,22 @@ class ResolveTableConstraints(val catalogManager: CatalogManager) extends Rule[L
if v2Write.table.resolved && v2Write.query.resolved &&
!containsCheckInvariant(v2Write.query) && v2Write.outputResolved =>
v2Write.table match {
case r: DataSourceV2Relation
if r.table.constraints != null && r.table.constraints.nonEmpty =>
// Check constraint is the only enforced constraint for DSV2 tables.
val checkInvariants = r.table.constraints.collect {
case c: Check =>
val unresolvedExpr = buildCatalystExpression(c)
val columnExtractors = mutable.Map[String, Expression]()
buildColumnExtractors(unresolvedExpr, columnExtractors)
CheckInvariant(unresolvedExpr, columnExtractors.toSeq, c.name, c.predicateSql)
}
// Combine the check invariants into a single expression using conjunctive AND.
checkInvariants.reduceOption(And).fold(v2Write)(
case r: DataSourceV2Relation =>
val tableCheckInvariants =
if (r.table.constraints != null && r.table.constraints.nonEmpty) {
r.table.constraints.collect {
case c: Check =>
val unresolvedExpr = buildCatalystExpression(c)
val columnExtractors = mutable.Map[String, Expression]()
buildColumnExtractors(unresolvedExpr, columnExtractors)
CheckInvariant(unresolvedExpr, columnExtractors.toSeq, c.name, c.predicateSql)
}.toSeq
} else {
Seq.empty
}
val genColInvariants = buildGeneratedColumnConstraints(r, v2Write)
val allInvariants = tableCheckInvariants ++ genColInvariants
allInvariants.reduceOption(And).fold(v2Write)(
condition => v2Write.withNewQuery(Filter(condition, v2Write.query)))
case _ =>
v2Write
Expand All @@ -72,6 +79,68 @@ class ResolveTableConstraints(val catalogManager: CatalogManager) extends Rule[L
.getOrElse(catalogManager.v1SessionCatalog.parser.parseExpression(c.predicateSql()))
}

/**
* For each user-provided generated column, add a CheckInvariant that validates the column
* value matches the generation expression. Auto-filled generated columns are excluded:
* ResolveOutputRelation strips their GENERATION_EXPRESSION metadata from the table output,
* so they won't appear here. See the comment in ResolveOutputRelation for the full contract.
*/
private def hasGeneratedColumnsOnWrite(r: DataSourceV2Relation): Boolean = {
r.catalog.exists {
case tc: TableCatalog =>
tc.capabilities().contains(
TableCatalogCapability.SUPPORT_GENERATED_COLUMN_ON_WRITE)
case _ => false
} && r.table.columns().exists(c => c.columnGenerationExpression() != null)
}

private def buildGeneratedColumnConstraints(
r: DataSourceV2Relation,
v2Write: V2WriteCommand): Seq[Expression] = {
if (!hasGeneratedColumnsOnWrite(r)) return Seq.empty

// Use V2 columns from the table to access both V2 expressions and SQL strings.
// Only add constraints for generated columns whose GENERATION_EXPRESSION metadata
// is still present in the table output -- ResolveOutputRelation strips the metadata
// from auto-filled columns so they are excluded here.
val v2Columns = r.table.columns()
val resolver = catalogManager.v1SessionCatalog.conf.resolver
val userProvidedGenCols = r.output
.filter(attr => GeneratedColumn.isGeneratedColumn(attr.metadata))
.map(_.name)
.toSet
val parser = catalogManager.v1SessionCatalog.parser

v2Columns.flatMap { col =>
Option(col.columnGenerationExpression())
.filter(_ => userProvidedGenCols.exists(n => resolver(n, col.name)))
.map { genExpr =>
val catalystExpr = buildGenerationCatalystExpression(genExpr, parser)
val colRef = UnresolvedAttribute(col.name)
val columnExtractors = Seq(col.name -> colRef)
val genExprSql = Option(genExpr.getSql()).getOrElse(catalystExpr.sql)
CheckInvariant(
EqualNullSafe(colRef, catalystExpr),
columnExtractors,
"Generated Column",
s"${col.name} <=> $genExprSql"
)
}
}.toSeq
}

/**
* Convert a V2 GenerationExpression to a Catalyst expression.
* Try the V2 expression first, fall back to parsing the SQL string.
*/
private def buildGenerationCatalystExpression(
genExpr: GenerationExpression,
parser: ParserInterface): Expression = {
Option(genExpr.getExpression)
.flatMap(V2ExpressionUtils.toCatalyst)
.getOrElse(parser.parseExpression(genExpr.getSql))
}

private def buildColumnExtractors(
expr: Expression,
columnExtractors: mutable.Map[String, Expression]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper

EliminateSubqueryAliases(aliasedTable) match {
case r: DataSourceV2Relation =>
checkNoGeneratedColumns(r, MERGE)
validateMergeIntoConditions(m)

// NOT MATCHED conditions may only refer to columns in source so they can be pushed down
Expand Down Expand Up @@ -85,6 +86,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper

EliminateSubqueryAliases(aliasedTable) match {
case r: DataSourceV2Relation =>
checkNoGeneratedColumns(r, MERGE)
validateMergeIntoConditions(m)

// there are only NOT MATCHED actions, use a left anti join to remove any matching rows
Expand Down Expand Up @@ -124,6 +126,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
notMatchedBySourceActions, _) if m.resolved && m.rewritable && m.aligned =>
EliminateSubqueryAliases(aliasedTable) match {
case r @ ExtractV2Table(tbl: SupportsRowLevelOperations) =>
checkNoGeneratedColumns(r, MERGE)
validateMergeIntoConditions(m)
val table = buildOperationTable(tbl, MERGE, CaseInsensitiveStringMap.empty())
table.operation match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, LogicalP
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.{ReplaceDataProjections, WriteDeltaProjections}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils._
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
import org.apache.spark.sql.connector.catalog.{SupportsRowLevelOperations, TableCatalog, TableCatalogCapability}
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.write.{RowLevelOperation, RowLevelOperationInfoImpl, RowLevelOperationTable, SupportsDelta}
import org.apache.spark.sql.connector.write.RowLevelOperation.Command
Expand All @@ -47,6 +47,27 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {

protected def groupFilterEnabled: Boolean = conf.runtimeRowLevelOperationGroupFilterEnabled

/**
* Throws if the catalog supports auto-filling generated columns on write and the table
* has generated columns. MERGE and UPDATE with generated columns are not yet supported.
*/
protected def checkNoGeneratedColumns(
relation: DataSourceV2Relation,
command: Command): Unit = {
val supportsGenCol = relation.catalog.exists {
case tc: TableCatalog =>
tc.capabilities().contains(
TableCatalogCapability.SUPPORT_GENERATED_COLUMN_ON_WRITE)
case _ => false
}
if (supportsGenCol &&
relation.table.columns().exists(_.columnGenerationExpression() != null)) {
throw QueryCompilationErrors.unsupportedTableOperationError(
relation.catalog.get, relation.identifier.get,
s"${command.toString} with generated columns")
}
}

protected def buildOperationTable(
table: SupportsRowLevelOperations,
command: Command,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {

EliminateSubqueryAliases(aliasedTable) match {
case r @ ExtractV2Table(tbl: SupportsRowLevelOperations) =>
checkNoGeneratedColumns(r, UPDATE)
val table = buildOperationTable(tbl, UPDATE, CaseInsensitiveStringMap.empty())
val updateCond = cond.getOrElse(TrueLiteral)
table.operation match {
Expand Down
Loading