From 141a946ca4728396e45124f7b0d178c6229ba84d Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Tue, 23 Jun 2026 23:17:09 +0000 Subject: [PATCH] [SPARK-57644][SQL] Support generated column values on V2 table writes Add support for auto-filling generated column values and enforcing generated column constraints during V2 table writes (INSERT). When a catalog declares SUPPORT_GENERATED_COLUMN_ON_WRITE: - Missing generated columns are auto-filled using the generation expression - User-provided generated column values are validated against the generation expression via CheckInvariant(EqualNullSafe(col, genExpr)) - MERGE, UPDATE, and streaming writes with generated columns are blocked until support is implemented for those operations Key changes: - TableCatalogCapability: Add SUPPORT_GENERATED_COLUMN_ON_WRITE - CatalogV2Util: Encode generation expression in V2-to-StructField roundtrip - TableOutputResolver: Fill missing generated columns in by-name and by-position write paths, checking generation expression before null defaults - ResolveTableConstraints: Add generated column constraints using V2 expression with SQL parser fallback, only for user-provided columns - Analyzer (ResolveOutputRelation): Gate on capability, strip generation expression metadata from auto-filled columns so constraints are scoped correctly - RewriteRowLevelCommand: Block MERGE/UPDATE with generated columns - ResolveWriteToStream: Block streaming writes with generated columns - SQLConf: Add generatedColumn.allowNullableIngest.enabled config Co-authored-by: Isaac --- .../catalog/TableCatalogCapability.java | 20 +- .../sql/catalyst/analysis/Analyzer.scala | 35 +- .../analysis/ResolveTableConstraints.scala | 97 +++- .../analysis/RewriteMergeIntoTable.scala | 3 + .../analysis/RewriteRowLevelCommand.scala | 23 +- .../analysis/RewriteUpdateTable.scala | 1 + .../analysis/TableOutputResolver.scala | 169 ++++-- .../sql/catalyst/util/GeneratedColumn.scala | 39 +- .../sql/connector/catalog/CatalogV2Util.scala | 6 + .../apache/spark/sql/internal/SQLConf.scala | 12 + .../catalog/InMemoryTableCatalog.scala | 3 +- .../runtime/ResolveWriteToStream.scala | 15 +- .../connector/GeneratedColumnWriteSuite.scala | 544 ++++++++++++++++++ 13 files changed, 904 insertions(+), 63 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/GeneratedColumnWriteSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java index a60c827d5ace1..ac829395b9bcd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java @@ -92,5 +92,23 @@ public enum TableCatalogCapability { * {@link TableCatalog#createTable}. * See {@link Column#identityColumnSpec()}. */ - SUPPORTS_CREATE_TABLE_WITH_IDENTITY_COLUMNS + SUPPORTS_CREATE_TABLE_WITH_IDENTITY_COLUMNS, + + /** + * Signals that the TableCatalog supports Spark auto-filling generated column values and + * enforcing generated column constraints during writes. + *

+ * When this capability is present, Spark will: + *

+ *

+ * Without this capability, the connector is responsible for handling generated column values + * during writes. + * + * @since 4.3.0 + */ + SUPPORT_GENERATED_COLUMN_ON_WRITE } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7a5048ed0a41d..ff21bd4b66f21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -49,7 +49,7 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.catalyst.util.{toPrettySQL, trimTempResolvedColumn, CharVarcharUtils} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, trimTempResolvedColumn, CharVarcharUtils, GeneratedColumn} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -3853,13 +3853,29 @@ class Analyzer( val defaultValueFillMode = if (conf.coerceInsertNestedTypes && v2Write.schemaEvolutionEnabled) RECURSE else FILL - val projection = TableOutputResolver.resolveOutputColumns( - v2Write.table.name, v2Write.table.output, v2Write.query, v2Write.isByName, conf, - defaultValueFillMode) + // Only let TableOutputResolver see generation expression metadata if the catalog + // supports auto-filling generated columns on write. + val expected = v2Write.table match { + case r: DataSourceV2Relation if !supportsGeneratedColumnOnWrite(r) => + r.output.map(GeneratedColumn.removeGenerationExpressionMetadata) + case _ => v2Write.table.output + } + val (projection, autoFilledGenCols) = + TableOutputResolver.resolveOutputColumnsWithGeneratedInfo( + v2Write.table.name, expected, v2Write.query, v2Write.isByName, conf, + defaultValueFillMode) if (projection != v2Write.query) { val cleanedTable = v2Write.table match { case r: DataSourceV2Relation => - r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata)) + r.copy(output = r.output.map { attr => + val cleaned = CharVarcharUtils.cleanAttrMetadata(attr) + if (autoFilledGenCols.contains(attr.name)) { + GeneratedColumn.removeGenerationExpressionMetadata(cleaned) + .asInstanceOf[AttributeReference] + } else { + cleaned + } + }) case other => other } v2Write.withNewQuery(projection).withNewTable(cleanedTable) @@ -3869,6 +3885,15 @@ class Analyzer( } } + private def supportsGeneratedColumnOnWrite(r: DataSourceV2Relation): Boolean = { + r.catalog.exists { + case tc: TableCatalog => + tc.capabilities().contains( + TableCatalogCapability.SUPPORT_GENERATED_COLUMN_ON_WRITE) + case _ => false + } + } + private def validateStoreAssignmentPolicy(): Unit = { // SPARK-28730: LEGACY store assignment policy is disallowed in data source v2. if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableConstraints.scala index 41631b24a83ed..933901791ba25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableConstraints.scala @@ -18,11 +18,14 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{And, CheckInvariant, Expression, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.expressions.{And, CheckInvariant, EqualNullSafe, Expression, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, V2WriteCommand, WriteDelta} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND -import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.catalyst.util.GeneratedColumn +import org.apache.spark.sql.connector.catalog.{CatalogManager, GenerationExpression, TableCatalog, + TableCatalogCapability} import org.apache.spark.sql.connector.catalog.constraints.Check import org.apache.spark.sql.connector.write.RowLevelOperation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -39,18 +42,22 @@ class ResolveTableConstraints(val catalogManager: CatalogManager) extends Rule[L if v2Write.table.resolved && v2Write.query.resolved && !containsCheckInvariant(v2Write.query) && v2Write.outputResolved => v2Write.table match { - case r: DataSourceV2Relation - if r.table.constraints != null && r.table.constraints.nonEmpty => - // Check constraint is the only enforced constraint for DSV2 tables. - val checkInvariants = r.table.constraints.collect { - case c: Check => - val unresolvedExpr = buildCatalystExpression(c) - val columnExtractors = mutable.Map[String, Expression]() - buildColumnExtractors(unresolvedExpr, columnExtractors) - CheckInvariant(unresolvedExpr, columnExtractors.toSeq, c.name, c.predicateSql) - } - // Combine the check invariants into a single expression using conjunctive AND. - checkInvariants.reduceOption(And).fold(v2Write)( + case r: DataSourceV2Relation => + val tableCheckInvariants = + if (r.table.constraints != null && r.table.constraints.nonEmpty) { + r.table.constraints.collect { + case c: Check => + val unresolvedExpr = buildCatalystExpression(c) + val columnExtractors = mutable.Map[String, Expression]() + buildColumnExtractors(unresolvedExpr, columnExtractors) + CheckInvariant(unresolvedExpr, columnExtractors.toSeq, c.name, c.predicateSql) + }.toSeq + } else { + Seq.empty + } + val genColInvariants = buildGeneratedColumnConstraints(r, v2Write) + val allInvariants = tableCheckInvariants ++ genColInvariants + allInvariants.reduceOption(And).fold(v2Write)( condition => v2Write.withNewQuery(Filter(condition, v2Write.query))) case _ => v2Write @@ -72,6 +79,68 @@ class ResolveTableConstraints(val catalogManager: CatalogManager) extends Rule[L .getOrElse(catalogManager.v1SessionCatalog.parser.parseExpression(c.predicateSql())) } + /** + * For each user-provided generated column, add a CheckInvariant that validates the column + * value matches the generation expression. Auto-filled generated columns are excluded: + * ResolveOutputRelation strips their GENERATION_EXPRESSION metadata from the table output, + * so they won't appear here. See the comment in ResolveOutputRelation for the full contract. + */ + private def hasGeneratedColumnsOnWrite(r: DataSourceV2Relation): Boolean = { + r.catalog.exists { + case tc: TableCatalog => + tc.capabilities().contains( + TableCatalogCapability.SUPPORT_GENERATED_COLUMN_ON_WRITE) + case _ => false + } && r.table.columns().exists(c => c.columnGenerationExpression() != null) + } + + private def buildGeneratedColumnConstraints( + r: DataSourceV2Relation, + v2Write: V2WriteCommand): Seq[Expression] = { + if (!hasGeneratedColumnsOnWrite(r)) return Seq.empty + + // Use V2 columns from the table to access both V2 expressions and SQL strings. + // Only add constraints for generated columns whose GENERATION_EXPRESSION metadata + // is still present in the table output -- ResolveOutputRelation strips the metadata + // from auto-filled columns so they are excluded here. + val v2Columns = r.table.columns() + val resolver = catalogManager.v1SessionCatalog.conf.resolver + val userProvidedGenCols = r.output + .filter(attr => GeneratedColumn.isGeneratedColumn(attr.metadata)) + .map(_.name) + .toSet + val parser = catalogManager.v1SessionCatalog.parser + + v2Columns.flatMap { col => + Option(col.columnGenerationExpression()) + .filter(_ => userProvidedGenCols.exists(n => resolver(n, col.name))) + .map { genExpr => + val catalystExpr = buildGenerationCatalystExpression(genExpr, parser) + val colRef = UnresolvedAttribute(col.name) + val columnExtractors = Seq(col.name -> colRef) + val genExprSql = Option(genExpr.getSql()).getOrElse(catalystExpr.sql) + CheckInvariant( + EqualNullSafe(colRef, catalystExpr), + columnExtractors, + "Generated Column", + s"${col.name} <=> $genExprSql" + ) + } + }.toSeq + } + + /** + * Convert a V2 GenerationExpression to a Catalyst expression. + * Try the V2 expression first, fall back to parsing the SQL string. + */ + private def buildGenerationCatalystExpression( + genExpr: GenerationExpression, + parser: ParserInterface): Expression = { + Option(genExpr.getExpression) + .flatMap(V2ExpressionUtils.toCatalyst) + .getOrElse(parser.parseExpression(genExpr.getSql)) + } + private def buildColumnExtractors( expr: Expression, columnExtractors: mutable.Map[String, Expression]): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala index 8281f89bd2e8e..e11a985d85ddb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -52,6 +52,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper EliminateSubqueryAliases(aliasedTable) match { case r: DataSourceV2Relation => + checkNoGeneratedColumns(r, MERGE) validateMergeIntoConditions(m) // NOT MATCHED conditions may only refer to columns in source so they can be pushed down @@ -85,6 +86,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper EliminateSubqueryAliases(aliasedTable) match { case r: DataSourceV2Relation => + checkNoGeneratedColumns(r, MERGE) validateMergeIntoConditions(m) // there are only NOT MATCHED actions, use a left anti join to remove any matching rows @@ -124,6 +126,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper notMatchedBySourceActions, _) if m.resolved && m.rewritable && m.aligned => EliminateSubqueryAliases(aliasedTable) match { case r @ ExtractV2Table(tbl: SupportsRowLevelOperations) => + checkNoGeneratedColumns(r, MERGE) validateMergeIntoConditions(m) val table = buildOperationTable(tbl, MERGE, CaseInsensitiveStringMap.empty()) table.operation match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index 48c48eb323bd7..efdc6b667b6e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, LogicalP import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{ReplaceDataProjections, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ -import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations +import org.apache.spark.sql.connector.catalog.{SupportsRowLevelOperations, TableCatalog, TableCatalogCapability} import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.write.{RowLevelOperation, RowLevelOperationInfoImpl, RowLevelOperationTable, SupportsDelta} import org.apache.spark.sql.connector.write.RowLevelOperation.Command @@ -47,6 +47,27 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { protected def groupFilterEnabled: Boolean = conf.runtimeRowLevelOperationGroupFilterEnabled + /** + * Throws if the catalog supports auto-filling generated columns on write and the table + * has generated columns. MERGE and UPDATE with generated columns are not yet supported. + */ + protected def checkNoGeneratedColumns( + relation: DataSourceV2Relation, + command: Command): Unit = { + val supportsGenCol = relation.catalog.exists { + case tc: TableCatalog => + tc.capabilities().contains( + TableCatalogCapability.SUPPORT_GENERATED_COLUMN_ON_WRITE) + case _ => false + } + if (supportsGenCol && + relation.table.columns().exists(_.columnGenerationExpression() != null)) { + throw QueryCompilationErrors.unsupportedTableOperationError( + relation.catalog.get, relation.identifier.get, + s"${command.toString} with generated columns") + } + } + protected def buildOperationTable( table: SupportsRowLevelOperations, command: Command, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index f235374bd5d6f..5664e4b3a24c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -41,6 +41,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { EliminateSubqueryAliases(aliasedTable) match { case r @ ExtractV2Table(tbl: SupportsRowLevelOperations) => + checkNoGeneratedColumns(r, UPDATE) val table = buildOperationTable(tbl, UPDATE, CaseInsensitiveStringMap.empty()) val updateCond = cond.getOrElse(TrueLiteral) table.operation match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index aa379a63a2af9..f77076b3263f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes -import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, GeneratedColumn} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLit import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -94,12 +95,44 @@ object TableOutputResolver extends SQLConfHelper with Logging { byName: Boolean, conf: SQLConf, defaultValueFillMode: DefaultValueFillMode.Value = NONE): LogicalPlan = { + resolveOutputColumnsInternal( + tableName, expected, query, byName, conf, defaultValueFillMode)._1 + } + + /** + * Same as [[resolveOutputColumns]], but also returns the names of generated columns that + * were auto-filled (not provided by the user). This allows callers to distinguish + * user-provided generated column values from auto-filled ones. + */ + def resolveOutputColumnsWithGeneratedInfo( + tableName: String, + expected: Seq[Attribute], + query: LogicalPlan, + byName: Boolean, + conf: SQLConf, + defaultValueFillMode: DefaultValueFillMode.Value = NONE + ): (LogicalPlan, Set[String]) = { + resolveOutputColumnsInternal( + tableName, expected, query, byName, conf, defaultValueFillMode) + } + + private def resolveOutputColumnsInternal( + tableName: String, + expected: Seq[Attribute], + query: LogicalPlan, + byName: Boolean, + conf: SQLConf, + defaultValueFillMode: DefaultValueFillMode.Value + ): (LogicalPlan, Set[String]) = { if (expected.size < query.output.size) { throw QueryCompilationErrors.cannotWriteTooManyColumnsToTableError( tableName, expected.map(_.name), query.output) } + // Track which generated columns are auto-filled (not provided by the user). + val autoFilledGenCols = mutable.Set[String]() + // In RECURSE mode, allow fewer source columns than target by filling trailing columns // with defaults. In other modes, a column count mismatch in by-position resolution is // an error. @@ -117,14 +150,20 @@ object TableOutputResolver extends SQLConfHelper with Logging { errors += _, Nil, defaultValueFillMode, - enforceFullOutput = true) + enforceFullOutput = true, + autoFilledGenCols) } else { if (expected.size > query.output.size && !fillDefaultValue) { - throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError( - tableName, expected.map(_.name), query.output) + // Allow if all missing trailing columns are generated columns + val missingCols = expected.drop(query.output.size) + if (!missingCols.forall(col => GeneratedColumn.isGeneratedColumn(col.metadata))) { + throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError( + tableName, expected.map(_.name), query.output) + } } resolveColumnsByPosition( - tableName, query.output, expected, conf, errors += _, fillDefaultValue = fillDefaultValue) + tableName, query.output, expected, conf, errors += _, + fillDefaultValue = fillDefaultValue, autoFilledGenCols = autoFilledGenCols) } if (errors.nonEmpty) { @@ -132,11 +171,12 @@ object TableOutputResolver extends SQLConfHelper with Logging { tableName, expected.map(_.name).map(toSQLId).mkString(", ")) } - if (resolved == query.output) { + val plan = if (resolved == query.output) { query } else { Project(resolved, query) } + (plan, autoFilledGenCols.toSet) } def resolveUpdate( @@ -337,23 +377,39 @@ object TableOutputResolver extends SQLConfHelper with Logging { addError: String => Unit, colPath: Seq[String] = Nil, defaultValueFillMode: DefaultValueFillMode.Value, - enforceFullOutput: Boolean = false): Seq[NamedExpression] = { + enforceFullOutput: Boolean = false, + autoFilledGenCols: mutable.Set[String] = mutable.Set.empty): Seq[NamedExpression] = { val matchedCols = mutable.HashSet.empty[String] val reordered = expectedCols.flatMap { expectedCol => val matched = inputCols.filter(col => conf.resolver(col.name, expectedCol.name)) val newColPath = colPath :+ expectedCol.name if (matched.isEmpty) { - val defaultExpr = if (Set(FILL, RECURSE).contains(defaultValueFillMode)) { - getDefaultValueExprOrNullLit(expectedCol, conf.useNullsForMissingDefaultColumnValues) - } else { - None - } - if (defaultExpr.isEmpty) { - throw QueryCompilationErrors.incompatibleDataToTableCannotFindDataError( - tableName, newColPath.quoted - ) + // Check for generated column expression first, before falling back to defaults, + // since getDefaultValueExprOrNullLit may return null for nullable columns. + GeneratedColumn.getGenerationExpression(expectedCol.metadata) match { + case Some(genExprSql) => + autoFilledGenCols += expectedCol.name + val genExpr = CatalystSqlParser.parseExpression(genExprSql) + Some(applyColumnMetadata(genExpr, expectedCol)) + case None => + // When a table has generated columns and the config is off, do not + // fill missing nullable non-generated columns with null. + val useNullAsDefault = conf.useNullsForMissingDefaultColumnValues && + (conf.generatedColumnAllowNullableIngest || + !expectedCols.exists(c => GeneratedColumn.isGeneratedColumn(c.metadata))) + val defaultExpr = if (Set(FILL, RECURSE).contains(defaultValueFillMode)) { + getDefaultValueExprOrNullLit(expectedCol, useNullAsDefault) + } else { + None + } + if (defaultExpr.isDefined) { + Some(applyDefaultWithLengthCheck(defaultExpr.get, expectedCol, conf)) + } else { + throw QueryCompilationErrors.incompatibleDataToTableCannotFindDataError( + tableName, newColPath.quoted + ) + } } - Some(applyDefaultWithLengthCheck(defaultExpr.get, expectedCol, conf)) } else if (matched.length > 1) { throw QueryCompilationErrors.incompatibleDataToTableAmbiguousColumnNameError( tableName, newColPath.quoted @@ -416,7 +472,8 @@ object TableOutputResolver extends SQLConfHelper with Logging { conf: SQLConf, addError: String => Unit, colPath: Seq[String] = Nil, - fillDefaultValue: Boolean = false): Seq[NamedExpression] = { + fillDefaultValue: Boolean = false, + autoFilledGenCols: mutable.Set[String] = mutable.Set.empty): Seq[NamedExpression] = { val actualExpectedCols = expectedCols.map { attr => attr.withDataType { CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType) } } @@ -433,16 +490,20 @@ object TableOutputResolver extends SQLConfHelper with Logging { ) } } else if (inputCols.size < actualExpectedCols.size && !fillDefaultValue) { - val missingColsStr = actualExpectedCols.takeRight(actualExpectedCols.size - inputCols.size) - .map(col => toSQLId(col.name)) - .mkString(", ") - if (colPath.isEmpty) { - throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError(tableName, - actualExpectedCols.map(_.name), inputCols.map(_.toAttribute)) - } else { - throw QueryCompilationErrors.incompatibleDataToTableStructMissingFieldsError( - tableName, colPath.quoted, missingColsStr - ) + // Allow if all missing trailing columns are generated columns + val missingCols = actualExpectedCols.drop(inputCols.size) + if (!missingCols.forall(col => GeneratedColumn.isGeneratedColumn(col.metadata))) { + val missingColsStr = missingCols + .map(col => toSQLId(col.name)) + .mkString(", ") + if (colPath.isEmpty) { + throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError(tableName, + actualExpectedCols.map(_.name), inputCols.map(_.toAttribute)) + } else { + throw QueryCompilationErrors.incompatibleDataToTableStructMissingFieldsError( + tableName, colPath.quoted, missingColsStr + ) + } } } @@ -466,15 +527,35 @@ object TableOutputResolver extends SQLConfHelper with Logging { } } - val defaults = if (fillDefaultValue) { - actualExpectedCols.drop(inputCols.size).map { expectedCol => - val defaultExpr = getDefaultValueExprOrNullLit( - expectedCol, conf.useNullsForMissingDefaultColumnValues) - if (defaultExpr.isEmpty) { - throw QueryCompilationErrors.incompatibleDataToTableCannotFindDataError( - tableName, (colPath :+ expectedCol.name).quoted) + val trailingCols = actualExpectedCols.drop(inputCols.size) + val defaults = if (fillDefaultValue || trailingCols.nonEmpty) { + trailingCols.map { expectedCol => + // Check for generated column expression first, before falling back to defaults. + GeneratedColumn.getGenerationExpression(expectedCol.metadata) match { + case Some(genExprSql) => + autoFilledGenCols += expectedCol.name + val genExpr = CatalystSqlParser.parseExpression(genExprSql) + // For by-position, manually resolve references against matched columns + // since the query may use different column names than the table. + val resolvedGenExpr = resolveGenerationExprReferences(genExpr, matched, conf) + applyColumnMetadata(resolvedGenExpr, expectedCol) + case None => + val useNullAsDefault = conf.useNullsForMissingDefaultColumnValues && + (conf.generatedColumnAllowNullableIngest || + !actualExpectedCols.exists(c => GeneratedColumn.isGeneratedColumn(c.metadata))) + val defaultExpr = if (fillDefaultValue) { + getDefaultValueExprOrNullLit( + expectedCol, useNullAsDefault) + } else { + None + } + if (defaultExpr.isDefined) { + applyDefaultWithLengthCheck(defaultExpr.get, expectedCol, conf) + } else { + throw QueryCompilationErrors.incompatibleDataToTableCannotFindDataError( + tableName, (colPath :+ expectedCol.name).quoted) + } } - applyDefaultWithLengthCheck(defaultExpr.get, expectedCol, conf) } } else { Nil @@ -662,6 +743,22 @@ object TableOutputResolver extends SQLConfHelper with Logging { } // scalastyle:on argcount + /** + * For by-position writes, resolve references in a generation expression against + * the already-matched output columns (which have table column names). + */ + private def resolveGenerationExprReferences( + genExpr: Expression, + matched: Seq[NamedExpression], + conf: SQLConf): Expression = { + genExpr.transform { + case u: UnresolvedAttribute if u.nameParts.size == 1 => + matched.collectFirst { + case alias: Alias if conf.resolver(alias.name, u.nameParts.head) => alias.child + }.getOrElse(u) + } + } + // For table insertions, capture the overflow errors and show proper message. // Without this method, the overflow errors of castings will show hints for turning off ANSI SQL // mode, which are not helpful since the behavior is controlled by the store assignment policy. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala index 2f92428cc53e2..fbb8d9a9d8f5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.ColumnDefinition import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog, TableCatalogCapability} import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField, StructType} /** * This object contains utility methods and values for Generated Columns @@ -37,15 +38,29 @@ object GeneratedColumn { * Whether the given `field` is a generated column */ def isGeneratedColumn(field: StructField): Boolean = { - field.metadata.contains(GENERATION_EXPRESSION_METADATA_KEY) + isGeneratedColumn(field.metadata) + } + + /** + * Whether the given metadata indicates a generated column + */ + def isGeneratedColumn(metadata: Metadata): Boolean = { + metadata.contains(GENERATION_EXPRESSION_METADATA_KEY) } /** * Returns the generation expression stored in the column metadata if it exists */ def getGenerationExpression(field: StructField): Option[String] = { - if (isGeneratedColumn(field)) { - Some(field.metadata.getString(GENERATION_EXPRESSION_METADATA_KEY)) + getGenerationExpression(field.metadata) + } + + /** + * Returns the generation expression stored in the metadata if it exists + */ + def getGenerationExpression(metadata: Metadata): Option[String] = { + if (isGeneratedColumn(metadata)) { + Some(metadata.getString(GENERATION_EXPRESSION_METADATA_KEY)) } else { None } @@ -74,4 +89,20 @@ object GeneratedColumn { } } } + + /** + * Returns an attribute with the generation expression metadata removed. + * Used when the catalog does not support auto-filling generated columns on write. + */ + def removeGenerationExpressionMetadata(attr: Attribute): Attribute = { + if (isGeneratedColumn(attr.metadata)) { + val cleaned = new MetadataBuilder() + .withMetadata(attr.metadata) + .remove(GENERATION_EXPRESSION_METADATA_KEY) + .build() + attr.withMetadata(cleaned) + } else { + attr + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index e34a21eeec762..0823fb403b703 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -653,6 +653,12 @@ private[sql] object CatalogV2Util { Option(col.defaultValue()).foreach { default => f = encodeDefaultValue(default, f) } + Option(col.generationExpression()).foreach { genExpr => + f = f.copy(metadata = new MetadataBuilder() + .withMetadata(f.metadata) + .putString(GeneratedColumn.GENERATION_EXPRESSION_METADATA_KEY, genExpr) + .build()) + } f } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index db74d0378fc19..6855140dbce69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5399,6 +5399,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val GENERATED_COLUMN_ALLOW_NULLABLE_INGEST = + buildConf("spark.sql.generatedColumn.allowNullableIngest.enabled") + .doc("When true, writing to a table with generated columns allows omitting nullable " + + "non-generated columns from the input. Missing nullable columns are filled with null. " + + "When false, all non-generated columns must be provided.") + .version("4.3.0") + .booleanConf + .createWithDefault(true) + val SKIP_TYPE_VALIDATION_ON_ALTER_PARTITION = buildConf("spark.sql.legacy.skipTypeValidationOnAlterPartition") .internal() @@ -8556,6 +8565,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def useNullsForMissingDefaultColumnValues: Boolean = getConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES) + def generatedColumnAllowNullableIngest: Boolean = + getConf(SQLConf.GENERATED_COLUMN_ALLOW_NULLABLE_INGEST) + def unionIsResolvedWhenDuplicatesPerChildResolved: Boolean = getConf(SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index c9a6c4acfa014..1a3de97602be8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -287,7 +287,8 @@ class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamesp TableCatalogCapability.SUPPORT_COLUMN_DEFAULT_VALUE, TableCatalogCapability.SUPPORT_TABLE_CONSTRAINT, TableCatalogCapability.SUPPORTS_CREATE_TABLE_WITH_GENERATED_COLUMNS, - TableCatalogCapability.SUPPORTS_CREATE_TABLE_WITH_IDENTITY_COLUMNS + TableCatalogCapability.SUPPORTS_CREATE_TABLE_WITH_IDENTITY_COLUMNS, + TableCatalogCapability.SUPPORT_GENERATED_COLUMN_ON_WRITE ).asJava } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ResolveWriteToStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ResolveWriteToStream.scala index 0be430591dbd8..8db40b26ca18a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ResolveWriteToStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ResolveWriteToStream.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.{FlowAssigned, StreamingRelationV2, UserProvided, WriteToStream, WriteToStreamStatement} -import org.apache.spark.sql.connector.catalog.SupportsWrite +import org.apache.spark.sql.connector.catalog.{SupportsWrite, TableCatalog, TableCatalogCapability} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.streaming.{ContinuousTrigger, RealTimeTrigger} import org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager @@ -64,6 +64,19 @@ object ResolveWriteToStream extends Rule[LogicalPlan] { } } + // Streaming writes with generated columns are not yet supported. + s.catalogAndIdent.foreach { case (catalog, ident) => + catalog match { + case tc: TableCatalog + if tc.capabilities().contains( + TableCatalogCapability.SUPPORT_GENERATED_COLUMN_ON_WRITE) && + s.sink.columns().exists(_.columnGenerationExpression() != null) => + throw QueryCompilationErrors.unsupportedTableOperationError( + catalog, ident, "streaming write with generated columns") + case _ => + } + } + WriteToStream( s.userSpecifiedName.orNull, s.userSpecifiedSinkName, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GeneratedColumnWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GeneratedColumnWriteSuite.scala new file mode 100644 index 0000000000000..b82464503dbcc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GeneratedColumnWriteSuite.scala @@ -0,0 +1,544 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.SparkRuntimeException +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.QueryPlanningTracker +import org.apache.spark.sql.catalyst.expressions.CheckInvariant +import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.internal.SQLConf + +/** + * Tests for generated column auto-fill and constraint enforcement during writes. + */ +class GeneratedColumnWriteSuite extends QueryTest with DatasourceV2SQLBase { + + test("INSERT by name auto-fills missing generated column") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | eventDate DATE, + | eventYear INT GENERATED ALWAYS AS (year(eventDate)) + |) USING foo""".stripMargin) + sql(s"INSERT INTO testcat.$tblName(eventDate) VALUES (DATE'2024-06-15')") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(java.sql.Date.valueOf("2024-06-15"), 2024)) + } + } + + test("INSERT by name with matching explicit value succeeds") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | eventDate DATE, + | eventYear INT GENERATED ALWAYS AS (year(eventDate)) + |) USING foo""".stripMargin) + sql(s"INSERT INTO testcat.$tblName(eventDate, eventYear) VALUES (DATE'2024-06-15', 2024)") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(java.sql.Date.valueOf("2024-06-15"), 2024)) + } + } + + test("INSERT by name with non-matching explicit value fails") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | eventDate DATE, + | eventYear INT GENERATED ALWAYS AS (year(eventDate)) + |) USING foo""".stripMargin) + val ex = intercept[SparkRuntimeException] { + sql(s"INSERT INTO testcat.$tblName(eventDate, eventYear) VALUES (DATE'2024-06-15', 2025)") + } + assert(ex.getCondition == "CHECK_CONSTRAINT_VIOLATION") + } + } + + test("INSERT by position auto-fills trailing generated column") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | eventDate DATE, + | eventYear INT GENERATED ALWAYS AS (year(eventDate)) + |) USING foo""".stripMargin) + // Insert by position without column list — only provide the non-generated column + sql(s"INSERT INTO testcat.$tblName VALUES (DATE'2024-06-15')") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(java.sql.Date.valueOf("2024-06-15"), 2024)) + } + } + + test("INSERT by position with matching explicit value succeeds") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | eventDate DATE, + | eventYear INT GENERATED ALWAYS AS (year(eventDate)) + |) USING foo""".stripMargin) + sql(s"INSERT INTO testcat.$tblName VALUES (DATE'2024-06-15', 2024)") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(java.sql.Date.valueOf("2024-06-15"), 2024)) + } + } + + test("INSERT by position with non-matching explicit value fails") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | eventDate DATE, + | eventYear INT GENERATED ALWAYS AS (year(eventDate)) + |) USING foo""".stripMargin) + val ex = intercept[SparkRuntimeException] { + sql(s"INSERT INTO testcat.$tblName VALUES (DATE'2024-06-15', 2025)") + } + assert(ex.getCondition == "CHECK_CONSTRAINT_VIOLATION") + } + } + + test("INSERT auto-fills multiple generated columns") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | eventDate DATE, + | eventYear INT GENERATED ALWAYS AS (year(eventDate)), + | eventMonth INT GENERATED ALWAYS AS (month(eventDate)) + |) USING foo""".stripMargin) + sql(s"INSERT INTO testcat.$tblName(eventDate) VALUES (DATE'2024-06-15')") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(java.sql.Date.valueOf("2024-06-15"), 2024, 6)) + } + } + + test("INSERT with expression referencing multiple columns") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | a INT, + | b INT, + | c INT GENERATED ALWAYS AS (a + b) + |) USING foo""".stripMargin) + sql(s"INSERT INTO testcat.$tblName(a, b) VALUES (3, 5)") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(3, 5, 8)) + } + } + + test("allowNullableIngest config controls missing non-generated columns") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | a INT, + | b STRING, + | c INT GENERATED ALWAYS AS (a + 1) + |) USING foo""".stripMargin) + // Config ON (default): missing nullable column 'b' is filled with null + sql(s"INSERT INTO testcat.$tblName(a) VALUES (1)") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(1, null, 2)) + // Config OFF: missing nullable column 'b' causes an error + withSQLConf(SQLConf.GENERATED_COLUMN_ALLOW_NULLABLE_INGEST.key -> "false") { + val ex = intercept[AnalysisException] { + sql(s"INSERT INTO testcat.$tblName(a) VALUES (2)") + } + assert(ex.getMessage.contains("b")) + } + } + } + + test("works alongside table CHECK constraints") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | a INT, + | b INT GENERATED ALWAYS AS (a + 1) + |) USING foo""".stripMargin) + sql(s"ALTER TABLE testcat.$tblName ADD CONSTRAINT positive_a CHECK (a > 0)") + + // Both constraints pass: a > 0 and b = a + 1 (auto-filled) + sql(s"INSERT INTO testcat.$tblName(a) VALUES (5)") + checkAnswer(spark.table(s"testcat.$tblName"), Row(5, 6)) + + // Table CHECK constraint fails: a <= 0 + val ex1 = intercept[SparkRuntimeException] { + sql(s"INSERT INTO testcat.$tblName(a) VALUES (-1)") + } + assert(ex1.getCondition == "CHECK_CONSTRAINT_VIOLATION") + assert(ex1.getMessage.contains("positive_a")) + + // Generated column constraint fails: user provides wrong b + val ex2 = intercept[SparkRuntimeException] { + sql(s"INSERT INTO testcat.$tblName(a, b) VALUES (5, 999)") + } + assert(ex2.getCondition == "CHECK_CONSTRAINT_VIOLATION") + assert(ex2.getMessage.contains("Generated Column")) + } + } + + test("NULL input produces NULL generated column value") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | eventDate DATE, + | eventYear INT GENERATED ALWAYS AS (year(eventDate)) + |) USING foo""".stripMargin) + sql(s"INSERT INTO testcat.$tblName(eventDate) VALUES (NULL)") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(null, null)) + } + } + + test("type coercion in generated column expression") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | a INT, + | b LONG GENERATED ALWAYS AS (a + 1) + |) USING foo""".stripMargin) + sql(s"INSERT INTO testcat.$tblName(a) VALUES (100)") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(100, 101L)) + } + } + + test("multiple rows each get their own generated value") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | a INT, + | b INT GENERATED ALWAYS AS (a * 10) + |) USING foo""".stripMargin) + sql(s"INSERT INTO testcat.$tblName(a) VALUES (1), (2), (3)") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(1, 10) :: Row(2, 20) :: Row(3, 30) :: Nil) + } + } + + test("generated column in the middle of schema") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | a INT, + | b INT GENERATED ALWAYS AS (a + 1), + | c STRING + |) USING foo""".stripMargin) + sql(s"INSERT INTO testcat.$tblName(a, c) VALUES (5, 'hello')") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(5, 6, "hello")) + } + } + + test("NULL explicit value matching NULL generation result succeeds") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | eventDate DATE, + | eventYear INT GENERATED ALWAYS AS (year(eventDate)) + |) USING foo""".stripMargin) + // NULL <=> year(NULL) → NULL <=> NULL → true (EqualNullSafe) + sql(s"INSERT INTO testcat.$tblName(eventDate, eventYear) VALUES (NULL, NULL)") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(null, null)) + } + } + + test("NULL explicit value for generated column") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | a INT, + | b INT GENERATED ALWAYS AS (a + 1) + |) USING foo""".stripMargin) + // NULL <=> (a + 1) where a=5 → NULL <=> 6 → false → violation + val ex = intercept[SparkRuntimeException] { + sql(s"INSERT INTO testcat.$tblName(a, b) VALUES (5, NULL)") + } + assert(ex.getCondition == "CHECK_CONSTRAINT_VIOLATION") + } + } + + test("INSERT OVERWRITE with generated columns") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | a INT, + | b INT GENERATED ALWAYS AS (a + 1) + |) USING foo""".stripMargin) + sql(s"INSERT INTO testcat.$tblName(a) VALUES (1)") + sql(s"INSERT OVERWRITE testcat.$tblName(a) VALUES (10)") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(10, 11)) + } + } + + test("DataFrame writeTo append with missing generated column") { + import testImplicits._ + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | a INT, + | b INT GENERATED ALWAYS AS (a * 3) + |) USING foo""".stripMargin) + Seq(4, 5).toDF("a").writeTo(s"testcat.$tblName").append() + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(4, 12) :: Row(5, 15) :: Nil) + } + } + + test("non-trailing generated column by position is not auto-filled") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | a INT GENERATED ALWAYS AS (b + 1), + | b INT + |) USING foo""".stripMargin) + // By position with 1 value: the missing column (a) is not trailing, + // so it cannot be auto-filled by position + val ex = intercept[AnalysisException] { + sql(s"INSERT INTO testcat.$tblName VALUES (10)") + } + assert(ex.getMessage.contains("not enough data columns")) + } + } + + test("INSERT by name with columns in different order") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | a INT, + | b INT GENERATED ALWAYS AS (a + 1), + | c STRING + |) USING foo""".stripMargin) + // Provide columns in reverse order + sql(s"INSERT INTO testcat.$tblName(c, a) VALUES ('hello', 7)") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(7, 8, "hello")) + } + } + + test("INSERT SELECT auto-fills generated columns") { + val tblName = "my_tab" + val srcName = "src_tab" + withTable(s"testcat.$tblName", s"testcat.$srcName") { + sql(s"""CREATE TABLE testcat.$tblName( + | a INT, + | b INT GENERATED ALWAYS AS (a * 10) + |) USING foo""".stripMargin) + sql(s"CREATE TABLE testcat.$srcName(a INT) USING foo") + sql(s"INSERT INTO testcat.$srcName VALUES (1), (2), (3)") + sql(s"INSERT INTO testcat.$tblName(a) SELECT a FROM testcat.$srcName") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(1, 10) :: Row(2, 20) :: Row(3, 30) :: Nil) + } + } + + test("INSERT with complex generation expression using CAST") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | ts TIMESTAMP, + | ts_date DATE GENERATED ALWAYS AS (CAST(ts AS DATE)) + |) USING foo""".stripMargin) + sql(s"INSERT INTO testcat.$tblName(ts) VALUES (TIMESTAMP'2024-06-15 10:30:00')") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(java.sql.Timestamp.valueOf("2024-06-15 10:30:00"), + java.sql.Date.valueOf("2024-06-15"))) + } + } + + test("INSERT with case-insensitive column matching") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | EventDate DATE, + | EventYear INT GENERATED ALWAYS AS (year(EventDate)) + |) USING foo""".stripMargin) + // Use different case in INSERT column list + sql(s"INSERT INTO testcat.$tblName(eventdate) VALUES (DATE'2024-06-15')") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(java.sql.Date.valueOf("2024-06-15"), 2024)) + } + } + + test("INSERT missing required non-generated column fails") { + val tblName = "my_tab" + withSQLConf(SQLConf.GENERATED_COLUMN_ALLOW_NULLABLE_INGEST.key -> "false") { + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | a INT, + | b STRING, + | c INT GENERATED ALWAYS AS (a + 1) + |) USING foo""".stripMargin) + // Missing non-generated, non-nullable column 'b' should fail + // when allowNullableIngest is off + val ex = intercept[AnalysisException] { + sql(s"INSERT INTO testcat.$tblName(a) VALUES (1)") + } + assert(ex.getMessage.contains("b")) + } + } + } + + test("INSERT with generated partition column") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | eventDate DATE, + | eventYear INT GENERATED ALWAYS AS (year(eventDate)) + |) USING foo PARTITIONED BY (eventYear)""".stripMargin) + sql(s"INSERT INTO testcat.$tblName(eventDate) VALUES (DATE'2024-06-15')") + sql(s"INSERT INTO testcat.$tblName(eventDate) VALUES (DATE'2023-01-01')") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(java.sql.Date.valueOf("2024-06-15"), 2024) :: + Row(java.sql.Date.valueOf("2023-01-01"), 2023) :: Nil) + } + } + + test("CTAS auto-fills generated columns") { + val src = "src_tab" + val tgt = "tgt_tab" + withTable(s"testcat.$src", s"testcat.$tgt") { + sql(s"CREATE TABLE testcat.$src(a INT, b INT) USING foo") + sql(s"INSERT INTO testcat.$src VALUES (3, 5)") + sql(s"""CREATE TABLE testcat.$tgt( + | a INT, + | b INT, + | c INT GENERATED ALWAYS AS (a + b) + |) USING foo""".stripMargin) + sql(s"INSERT INTO testcat.$tgt(a, b) SELECT a, b FROM testcat.$src") + checkAnswer( + spark.table(s"testcat.$tgt"), + Row(3, 5, 8)) + } + } + + test("streaming write with generated columns is blocked") { + import testImplicits._ + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + withTempDir { checkpointDir => + sql(s"""CREATE TABLE testcat.$tblName( + | id INT, + | doubled INT GENERATED ALWAYS AS (id * 2) + |) USING foo""".stripMargin) + val inputData = MemoryStream[Int] + val df = inputData.toDF().toDF("id") + val ex = intercept[AnalysisException] { + df.writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .toTable(s"testcat.$tblName") + } + assert(ex.getCondition == "UNSUPPORTED_FEATURE.TABLE_OPERATION") + assert(ex.getMessage.contains("streaming")) + assert(ex.getMessage.contains("generated columns")) + } + } + } + + test("by-position auto-fill with type cast on source column") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + // Table expects LONG for 'a', generation expression is a + 1 + sql(s"""CREATE TABLE testcat.$tblName( + | a LONG, + | b LONG GENERATED ALWAYS AS (a + 1) + |) USING foo""".stripMargin) + // Insert INT value by position — gets cast to LONG, then generation expression uses + // the cast value + sql(s"INSERT INTO testcat.$tblName VALUES (42)") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(42L, 43L)) + } + } + + test("mix of auto-filled and user-provided generated columns") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | a INT, + | b INT GENERATED ALWAYS AS (a + 1), + | c INT GENERATED ALWAYS AS (a * 10) + |) USING foo""".stripMargin) + // Provide 'a' and 'c' (user-provided), omit 'b' (auto-filled) + // b is auto-filled with a + 1 = 6 + // c is user-provided with correct value a * 10 = 50 → passes constraint + sql(s"INSERT INTO testcat.$tblName(a, c) VALUES (5, 50)") + checkAnswer( + spark.table(s"testcat.$tblName"), + Row(5, 6, 50)) + + // Now provide wrong value for c → constraint violation on c only + val ex = intercept[SparkRuntimeException] { + sql(s"INSERT INTO testcat.$tblName(a, c) VALUES (5, 999)") + } + assert(ex.getCondition == "CHECK_CONSTRAINT_VIOLATION") + } + } + + test("plan has constraint for user-provided but not auto-filled generated columns") { + val tblName = "my_tab" + withTable(s"testcat.$tblName") { + sql(s"""CREATE TABLE testcat.$tblName( + | a INT, + | b INT GENERATED ALWAYS AS (a + 1), + | c INT GENERATED ALWAYS AS (a * 10) + |) USING foo""".stripMargin) + + def hasCheckInvariant(sqlText: String): Boolean = { + val parsed = spark.sessionState.sqlParser.parsePlan(sqlText) + val analyzed = spark.sessionState.analyzer.executeAndCheck( + parsed, new QueryPlanningTracker) + analyzed.exists { node => + node.expressions.exists(_.exists(_.isInstanceOf[CheckInvariant])) + } + } + + // Auto-filled: no CheckInvariant in analyzed plan + assert(!hasCheckInvariant(s"INSERT INTO testcat.$tblName(a) VALUES (5)"), + "Auto-filled generated columns should not have CheckInvariant in plan") + + // User-provided: CheckInvariant should appear + assert(hasCheckInvariant(s"INSERT INTO testcat.$tblName(a, b) VALUES (5, 6)"), + "User-provided generated columns should have CheckInvariant in plan") + } + } + + // TODO: Add tests for MERGE and UPDATE blocking once TxnTableCatalog delegates + // capabilities() to its underlying catalog. The check is in RewriteRowLevelCommand + // but can't be tested until the transactional catalog wrapper propagates + // SUPPORT_GENERATED_COLUMN_ON_WRITE. +}