diff --git a/.gitignore b/.gitignore index a3c97ff992..33eaaa82dd 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ spark/benchmarks comet-event-trace.json __pycache__ output +.claude/ docs/comet-*/ docs/build/ docs/temp/ diff --git a/dev/diffs/3.4.3.diff b/dev/diffs/3.4.3.diff index 3f04dd43b8..3746d085c3 100644 --- a/dev/diffs/3.4.3.diff +++ b/dev/diffs/3.4.3.diff @@ -520,7 +520,17 @@ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/s index a6b295578d6..91acca4306f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala -@@ -463,7 +463,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite +@@ -260,7 +260,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite + } + } + +- test("SPARK-33853: explain codegen - check presence of subquery") { ++ test("SPARK-33853: explain codegen - check presence of subquery", ++ IgnoreComet("Comet changes the WholeStageCodegen subtree count")) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { + withTempView("df") { + val df1 = spark.range(1, 100) +@@ -463,7 +464,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite } } @@ -530,7 +540,7 @@ index a6b295578d6..91acca4306f 100644 withTempDir { dir => Seq("parquet", "orc", "csv", "json").foreach { fmt => val basePath = dir.getCanonicalPath + "/" + fmt -@@ -541,7 +542,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite +@@ -541,7 +543,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite } } diff --git a/dev/diffs/3.5.8.diff b/dev/diffs/3.5.8.diff index 00a6963dd9..b79da5296f 100644 --- a/dev/diffs/3.5.8.diff +++ b/dev/diffs/3.5.8.diff @@ -498,10 +498,20 @@ index f33432ddb6f..914afa6b01d 100644 } assert(scanOption.isDefined) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala -index a206e97c353..fea1149b67d 100644 +index a206e97c353..8bd3ab5985a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala -@@ -467,7 +467,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite +@@ -264,7 +264,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite + } + } + +- test("SPARK-33853: explain codegen - check presence of subquery") { ++ test("SPARK-33853: explain codegen - check presence of subquery", ++ IgnoreComet("Comet changes the WholeStageCodegen subtree count")) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { + withTempView("df") { + val df1 = spark.range(1, 100) +@@ -467,7 +468,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite } } @@ -511,7 +521,7 @@ index a206e97c353..fea1149b67d 100644 withTempDir { dir => Seq("parquet", "orc", "csv", "json").foreach { fmt => val basePath = dir.getCanonicalPath + "/" + fmt -@@ -545,7 +546,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite +@@ -545,7 +547,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite } } diff --git a/dev/diffs/4.0.1.diff b/dev/diffs/4.0.1.diff index 0da594b58d..9db5a5b6e8 100644 --- a/dev/diffs/4.0.1.diff +++ b/dev/diffs/4.0.1.diff @@ -643,10 +643,20 @@ index 2c24cc7d570..5a1fe7017c3 100644 } assert(scanOption.isDefined) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala -index 9c90e0105a4..fadf2f0f698 100644 +index 9c90e0105a4..ed6d4887b13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala -@@ -470,7 +470,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite +@@ -267,7 +267,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite + } + } + +- test("SPARK-33853: explain codegen - check presence of subquery") { ++ test("SPARK-33853: explain codegen - check presence of subquery", ++ IgnoreComet("Comet changes the WholeStageCodegen subtree count")) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { + withTempView("df") { + val df1 = spark.range(1, 100) +@@ -470,7 +471,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite } } @@ -656,7 +666,7 @@ index 9c90e0105a4..fadf2f0f698 100644 withTempDir { dir => Seq("parquet", "orc", "csv", "json").foreach { fmt => val basePath = dir.getCanonicalPath + "/" + fmt -@@ -548,7 +549,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite +@@ -548,7 +550,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite } } diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 4b15c26d27..d7a123e11d 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -23,15 +23,17 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Final, Partial} import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} import org.apache.spark.sql.execution.datasources.WriteFilesExec import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -57,6 +59,14 @@ import org.apache.comet.shims.ShimSubqueryBroadcast object CometExecRule { + /** + * Tag applied to Partial-mode aggregate operators that must NOT be converted to Comet because + * the corresponding Final-mode aggregate cannot be converted, and the aggregate functions have + * incompatible intermediate buffer formats between Spark and Comet. + */ + val COMET_UNSAFE_PARTIAL: TreeNodeTag[String] = + TreeNodeTag[String]("comet.unsafePartialAgg") + /** * Fully native operators. */ @@ -512,6 +522,12 @@ case class CometExecRule(session: SparkSession) normalizedPlan } + // Tag Partial aggregates that must not be converted to Comet because the + // corresponding Final aggregate cannot be converted and the intermediate buffer + // formats are incompatible. This runs before transform() so the tags are checked + // during the bottom-up conversion. Tags persist through AQE stage creation. + tagUnsafePartialAggregates(planWithJoinRewritten) + var newPlan = transform(planWithJoinRewritten) // if the plan cannot be run fully natively then explain why (when appropriate @@ -732,4 +748,127 @@ case class CometExecRule(session: SparkSession) } } + /** + * Walk the plan to find Final-mode aggregates that cannot be converted to Comet. For each such + * Final, if the aggregate functions have incompatible intermediate buffer formats, tag the + * corresponding Partial-mode aggregate so it will also be skipped during conversion. + * + * This prevents the crash described in issue #1389 where a Comet Partial produces intermediate + * data in a format that the Spark Final cannot interpret. + */ + private def tagUnsafePartialAggregates(plan: SparkPlan): Unit = { + plan.foreach { + case agg: BaseAggregateExec => + // Only consider single-mode Final aggregates. Multi-mode Finals come from Spark's + // distinct-aggregate rewrite, where the Comet partial (if any) feeds into a Spark + // PartialMerge rather than directly into a Final, which is a different code path + // than the Comet-Partial → Spark-Final crash scenario from issue #1389. + val modes = agg.aggregateExpressions.map(_.mode).distinct + if (modes == Seq(Final) && + !QueryPlanSerde.allAggsSupportMixedExecution(agg.aggregateExpressions) && + !canAggregateBeConverted(agg, Final)) { + findPartialAggInPlan(agg.child).foreach { partial => + // Only tag if the Partial would otherwise have been converted. If the Partial + // itself cannot be converted (e.g. the aggregate function is incompatible for the + // input type), there is no buffer-format mismatch to guard against, and tagging + // would mask the natural, more specific fallback reason. + if (canAggregateBeConverted(partial, Partial)) { + partial.setTagValue( + CometExecRule.COMET_UNSAFE_PARTIAL, + "Partial aggregate disabled: corresponding final aggregate " + + "cannot be converted to Comet and intermediate buffer formats are incompatible") + } + } + } + case _ => + } + } + + /** + * Conservative check for whether an aggregate could be converted to Comet. Checks operator + * enablement, grouping expressions, aggregate expressions, and result expressions. + * Intentionally skips the sparkFinalMode / child-native checks since those depend on + * transformation state. + * + * WARNING: this intentionally mirrors the predicate checks in `CometBaseAggregate.doConvert` + * (operators.scala). Any change to the convertibility rules there must be reflected here or + * this tagging pass will drift and either crash (missed tag) or over-disable (spurious tag). A + * shared predicate helper would be preferable. + */ + private def canAggregateBeConverted( + agg: BaseAggregateExec, + expectedMode: AggregateMode): Boolean = { + val handler = allExecs.get(agg.getClass) + if (handler.isEmpty) return false + val serde = handler.get.asInstanceOf[CometOperatorSerde[SparkPlan]] + if (!isOperatorEnabled(serde, agg.asInstanceOf[SparkPlan])) return false + + // ObjectHashAggregate has an extra shuffle-enabled guard in its convert method + agg match { + case _: ObjectHashAggregateExec if !isCometShuffleEnabled(agg.conf) => return false + case _ => + } + + val aggregateExpressions = agg.aggregateExpressions + val groupingExpressions = agg.groupingExpressions + + if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) return false + + if (groupingExpressions.exists(e => QueryPlanSerde.containsMapType(e.dataType))) return false + + if (!groupingExpressions.forall(e => + QueryPlanSerde.exprToProto(e, agg.child.output).isDefined)) { + return false + } + + if (aggregateExpressions.isEmpty) { + // Result expressions always checked when there are no aggregate expressions + val attributes = + groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes + return agg.resultExpressions.forall(e => + QueryPlanSerde.exprToProto(e, attributes).isDefined) + } + + val modes = aggregateExpressions.map(_.mode).distinct + if (modes.size != 1 || modes.head != expectedMode) return false + + // In Final mode, exprToProto resolves against the child's output; in Partial/non-Final mode + // it must bind to input attributes. This mirrors the `binding` calculation in + // `CometBaseAggregate.doConvert`. + val binding = expectedMode != Final + if (!aggregateExpressions.forall(e => + QueryPlanSerde.aggExprToProto(e, agg.child.output, binding, agg.conf).isDefined)) { + return false + } + + // doConvert only checks resultExpressions in Final mode when aggregate expressions exist + // (Partial emits the buffer directly). Mirror that here to avoid false negatives. + if (expectedMode == Final) { + val attributes = + groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes + agg.resultExpressions.forall(e => QueryPlanSerde.exprToProto(e, attributes).isDefined) + } else { + true + } + } + + /** + * Look for a Partial-mode aggregate that feeds directly into the given plan (the child of a + * Final). Walks through exchanges and AQE stages only, stopping at anything else including + * other aggregate stages. This avoids tagging unrelated Partials found deeper in the plan (e.g. + * the non-distinct Partial in a distinct-aggregate rewrite, which is separated from the Final + * by intermediate PartialMerge stages). Requires `aggregateExpressions.nonEmpty` so that + * group-by-only dedup stages are not mistaken for the partial we want to tag. + */ + private def findPartialAggInPlan(plan: SparkPlan): Option[BaseAggregateExec] = plan match { + case agg: BaseAggregateExec + if agg.aggregateExpressions.nonEmpty && + agg.aggregateExpressions.forall(e => e.mode == Partial) => + Some(agg) + case a: AQEShuffleReadExec => findPartialAggInPlan(a.child) + case s: ShuffleQueryStageExec => findPartialAggInPlan(s.plan) + case e: ShuffleExchangeExec => findPartialAggInPlan(e.child) + case _ => None + } + } diff --git a/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala b/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala index 316510400b..9a83152168 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala @@ -81,6 +81,17 @@ trait CometAggregateExpressionSerde[T <: AggregateFunction] { */ def getSupportLevel(expr: T): SupportLevel = Compatible(None) + /** + * Whether this aggregate's intermediate buffer format is compatible between Spark and Comet, + * making it safe to run the Partial in one engine and the Final in the other. Aggregates with + * simple single-value buffers (MIN, MAX, bitwise) are safe; those with complex or + * differently-encoded buffers (AVG, SUM with decimals, CollectSet, Variance) are not. COUNT is + * intentionally excluded: mixed COUNT partial/final regressed AQE's + * PropagateEmptyRelationAfterAQE pattern (which matches BaseAggregateExec only) and the Spark + * 4.0 count-bug decorrelation for correlated IN subqueries. + */ + def supportsMixedPartialFinal: Boolean = false + /** * Convert a Spark expression into a protocol buffer representation that can be passed into * native code. diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index c3dc6dcfd5..b55cc3bc86 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -282,6 +282,24 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[VariancePop] -> CometVariancePop, classOf[VarianceSamp] -> CometVarianceSamp) + /** + * Returns true if all aggregate expressions in the list have intermediate buffer formats that + * are compatible between Spark and Comet, making it safe to run Partial in one engine and Final + * in the other. + */ + def allAggsSupportMixedExecution(aggExprs: Seq[AggregateExpression]): Boolean = { + aggExprs.forall { aggExpr => + val fn = aggExpr.aggregateFunction + aggrSerdeMap.get(fn.getClass) match { + case Some(handler) => + handler + .asInstanceOf[CometAggregateExpressionSerde[AggregateFunction]] + .supportsMixedPartialFinal + case None => false + } + } + } + // A unique id for each expression. ~used to look up QueryContext during error creation. private val exprIdCounter = new AtomicLong(0) @@ -359,6 +377,19 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { false } + /** + * Returns true if the given data type is or contains a `MapType` at any nesting level. Arrow's + * row format (used by DataFusion's grouped hash aggregate for composite group keys) does not + * support `Map`, so grouping on any type that transitively contains a map would crash in native + * execution. + */ + def containsMapType(dt: DataType): Boolean = dt match { + case _: MapType => true + case a: ArrayType => containsMapType(a.elementType) + case s: StructType => s.fields.exists(f => containsMapType(f.dataType)) + case _ => false + } + /** * Serializes Spark datatype to protobuf. Note that, a datatype can be serialized by this method * doesn't mean it is supported by Comet native execution, i.e., `supportedDataType` may return diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 6889fc02d9..1a536f04e6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -34,6 +34,8 @@ import org.apache.comet.shims.CometEvalModeUtil object CometMin extends CometAggregateExpressionSerde[Min] { + override def supportsMixedPartialFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, expr: Min, @@ -81,6 +83,8 @@ object CometMin extends CometAggregateExpressionSerde[Min] { object CometMax extends CometAggregateExpressionSerde[Max] { + override def supportsMixedPartialFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, expr: Max, @@ -319,6 +323,8 @@ object CometLast extends CometAggregateExpressionSerde[Last] { } object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] { + override def supportsMixedPartialFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, bitAnd: BitAndAgg, @@ -353,6 +359,8 @@ object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] { } object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] { + override def supportsMixedPartialFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, bitOr: BitOrAgg, @@ -387,6 +395,8 @@ object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] { } object CometBitXOrAgg extends CometAggregateExpressionSerde[BitXorAgg] { + override def supportsMixedPartialFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, bitXor: BitXorAgg, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 5f7e91529d..16c25a2d63 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -54,8 +54,10 @@ import com.google.protobuf.CodedOutputStream import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException, ConfigEntry} import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, withInfo} import org.apache.comet.parquet.CometParquetUtils +import org.apache.comet.rules.CometExecRule import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, SupportLevel, Unsupported} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator} +import org.apache.comet.serde.QueryPlanSerde import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto, isStringCollationType, supportedSortType} import org.apache.comet.serde.operator.CometSink @@ -1359,10 +1361,24 @@ trait CometBaseAggregate { // In distinct aggregates there can be a combination of modes val multiMode = modes.size > 1 // For a final mode HashAggregate, we only need to transform the HashAggregate - // if there is Comet partial aggregation. + // if there is Comet partial aggregation, unless all aggregates have compatible + // intermediate buffer formats (safe for mixed Spark/Comet execution). val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty - if (multiMode || sparkFinalMode) { + if (multiMode) { + return None + } + + if (sparkFinalMode && + !QueryPlanSerde.allAggsSupportMixedExecution(aggregate.aggregateExpressions)) { + return None + } + + // Check if this aggregate has been tagged as unsafe for mixed execution + // (Comet partial + Spark final with incompatible intermediate buffers) + val unsafeReason = aggregate.getTagValue(CometExecRule.COMET_UNSAFE_PARTIAL) + if (unsafeReason.isDefined) { + withInfo(aggregate, unsafeReason.get) return None } @@ -1377,11 +1393,7 @@ trait CometBaseAggregate { return None } - if (groupingExpressions.exists(expr => - expr.dataType match { - case _: MapType => true - case _ => false - })) { + if (groupingExpressions.exists(expr => QueryPlanSerde.containsMapType(expr.dataType))) { withInfo(aggregate, "Grouping on map types is not supported") return None } diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index 63714383ae..7f353c36e2 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx import org.apache.spark.sql.types.{DataTypes, StructField, StructType} import org.apache.comet.CometConf +import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} /** @@ -131,9 +132,8 @@ class CometExecRuleSuite extends CometTestBase { } } - // TODO this test exposes the bug described in - // https://github.com/apache/datafusion-comet/issues/1389 - ignore("CometExecRule should not allow Comet partial and Spark final hash aggregate") { + // Regression test for https://github.com/apache/datafusion-comet/issues/1389 + test("CometExecRule should not allow Comet partial and Spark final hash aggregate") { withTempView("test_data") { createTestDataFrame.createOrReplaceTempView("test_data") @@ -149,7 +149,8 @@ class CometExecRuleSuite extends CometTestBase { CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { val transformedPlan = applyCometExecRule(sparkPlan) - // if the final aggregate cannot be converted to Comet, then neither should be + // SUM has incompatible intermediate buffers, so if the final aggregate cannot + // be converted to Comet, neither should be assert( countOperators(transformedPlan, classOf[HashAggregateExec]) == originalHashAggCount) assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0) @@ -181,6 +182,78 @@ class CometExecRuleSuite extends CometTestBase { } } + test("CometExecRule should allow safe Comet partial and Spark final hash aggregate") { + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + // Query uses only safe aggregates (MIN, MAX) with compatible intermediate buffers + val sparkPlan = + createSparkPlan(spark, "SELECT MIN(id), MAX(id) FROM test_data GROUP BY (id % 3)") + + val originalHashAggCount = countOperators(sparkPlan, classOf[HashAggregateExec]) + assert(originalHashAggCount == 2) + + withSQLConf( + CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val transformedPlan = applyCometExecRule(sparkPlan) + + // Safe aggregates allow mixed execution: partial can be Comet, final stays Spark + assert(countOperators(transformedPlan, classOf[HashAggregateExec]) == 1) // final only + assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 1) // partial + } + } + } + + test("CometExecRule should allow safe Spark partial and Comet final hash aggregate") { + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + // Query uses only safe aggregates (MIN, MAX) with compatible intermediate buffers + val sparkPlan = + createSparkPlan(spark, "SELECT MIN(id), MAX(id) FROM test_data GROUP BY (id % 3)") + + val originalHashAggCount = countOperators(sparkPlan, classOf[HashAggregateExec]) + assert(originalHashAggCount == 2) + + withSQLConf( + CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val transformedPlan = applyCometExecRule(sparkPlan) + + // Safe aggregates allow mixed execution: partial stays Spark, final can be Comet + assert(countOperators(transformedPlan, classOf[HashAggregateExec]) == 1) // partial only + assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 1) // final + } + } + } + + test("CometExecRule should not convert hash aggregate when grouping key contains map type") { + // Spark 3.4/3.5 reject `array>` as a grouping key in the analyzer (not orderable), + // so the plan never reaches CometExecRule on those versions. The guard we're exercising + // (containsMapType) only matters on Spark 4.0+, which permits the GROUP BY to be analyzed. + assume(isSpark40Plus) + // Arrow's row format, used by DataFusion's grouped hash aggregate for composite keys, does + // not support Map at any nesting level. Grouping by a type that transitively contains a map + // (e.g. array>) must stay on Spark to avoid a native row-encoding crash. + val sparkPlan = createSparkPlan( + spark, + """SELECT count(*) + |FROM VALUES (ARRAY(MAP(1, 2), MAP(1, 3))), + | (ARRAY(MAP(2, 3), MAP(1, 3))) AS t(a) + |GROUP BY a""".stripMargin) + + val originalHashAggCount = countOperators(sparkPlan, classOf[HashAggregateExec]) + assert(originalHashAggCount == 2) + + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val transformedPlan = applyCometExecRule(sparkPlan) + + assert(countOperators(transformedPlan, classOf[HashAggregateExec]) == originalHashAggCount) + assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0) + } + } + test("CometExecRule should apply broadcast exchange transformations") { withTempView("test_data") { createTestDataFrame.createOrReplaceTempView("test_data")