Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ spark/benchmarks
comet-event-trace.json
__pycache__
output
.claude/
docs/comet-*/
docs/build/
docs/temp/
14 changes: 12 additions & 2 deletions dev/diffs/3.4.3.diff
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,17 @@ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/s
index a6b295578d6..91acca4306f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
@@ -463,7 +463,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
@@ -260,7 +260,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}

- test("SPARK-33853: explain codegen - check presence of subquery") {
+ test("SPARK-33853: explain codegen - check presence of subquery",
+ IgnoreComet("Comet changes the WholeStageCodegen subtree count")) {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") {
withTempView("df") {
val df1 = spark.range(1, 100)
@@ -463,7 +464,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}

Expand All @@ -530,7 +540,7 @@ index a6b295578d6..91acca4306f 100644
withTempDir { dir =>
Seq("parquet", "orc", "csv", "json").foreach { fmt =>
val basePath = dir.getCanonicalPath + "/" + fmt
@@ -541,7 +542,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
@@ -541,7 +543,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}

Expand Down
16 changes: 13 additions & 3 deletions dev/diffs/3.5.8.diff
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,20 @@ index f33432ddb6f..914afa6b01d 100644
}
assert(scanOption.isDefined)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
index a206e97c353..fea1149b67d 100644
index a206e97c353..8bd3ab5985a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
@@ -467,7 +467,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
@@ -264,7 +264,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}

- test("SPARK-33853: explain codegen - check presence of subquery") {
+ test("SPARK-33853: explain codegen - check presence of subquery",
+ IgnoreComet("Comet changes the WholeStageCodegen subtree count")) {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") {
withTempView("df") {
val df1 = spark.range(1, 100)
@@ -467,7 +468,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}

Expand All @@ -511,7 +521,7 @@ index a206e97c353..fea1149b67d 100644
withTempDir { dir =>
Seq("parquet", "orc", "csv", "json").foreach { fmt =>
val basePath = dir.getCanonicalPath + "/" + fmt
@@ -545,7 +546,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
@@ -545,7 +547,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}

Expand Down
16 changes: 13 additions & 3 deletions dev/diffs/4.0.1.diff
Original file line number Diff line number Diff line change
Expand Up @@ -643,10 +643,20 @@ index 2c24cc7d570..5a1fe7017c3 100644
}
assert(scanOption.isDefined)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
index 9c90e0105a4..fadf2f0f698 100644
index 9c90e0105a4..ed6d4887b13 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
@@ -470,7 +470,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
@@ -267,7 +267,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}

- test("SPARK-33853: explain codegen - check presence of subquery") {
+ test("SPARK-33853: explain codegen - check presence of subquery",
+ IgnoreComet("Comet changes the WholeStageCodegen subtree count")) {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") {
withTempView("df") {
val df1 = spark.range(1, 100)
@@ -470,7 +471,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}

Expand All @@ -656,7 +666,7 @@ index 9c90e0105a4..fadf2f0f698 100644
withTempDir { dir =>
Seq("parquet", "orc", "csv", "json").foreach { fmt =>
val basePath = dir.getCanonicalPath + "/" + fmt
@@ -548,7 +549,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
@@ -548,7 +550,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}

Expand Down
141 changes: 140 additions & 1 deletion spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ import scala.collection.mutable.ListBuffer

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Final, Partial}
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
import org.apache.spark.sql.execution.datasources.WriteFilesExec
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
Expand All @@ -57,6 +59,14 @@ import org.apache.comet.shims.ShimSubqueryBroadcast

object CometExecRule {

/**
* Tag applied to Partial-mode aggregate operators that must NOT be converted to Comet because
* the corresponding Final-mode aggregate cannot be converted, and the aggregate functions have
* incompatible intermediate buffer formats between Spark and Comet.
*/
val COMET_UNSAFE_PARTIAL: TreeNodeTag[String] =
TreeNodeTag[String]("comet.unsafePartialAgg")

/**
* Fully native operators.
*/
Expand Down Expand Up @@ -512,6 +522,12 @@ case class CometExecRule(session: SparkSession)
normalizedPlan
}

// Tag Partial aggregates that must not be converted to Comet because the
// corresponding Final aggregate cannot be converted and the intermediate buffer
// formats are incompatible. This runs before transform() so the tags are checked
// during the bottom-up conversion. Tags persist through AQE stage creation.
tagUnsafePartialAggregates(planWithJoinRewritten)

var newPlan = transform(planWithJoinRewritten)

// if the plan cannot be run fully natively then explain why (when appropriate
Expand Down Expand Up @@ -732,4 +748,127 @@ case class CometExecRule(session: SparkSession)
}
}

/**
* Walk the plan to find Final-mode aggregates that cannot be converted to Comet. For each such
* Final, if the aggregate functions have incompatible intermediate buffer formats, tag the
* corresponding Partial-mode aggregate so it will also be skipped during conversion.
*
* This prevents the crash described in issue #1389 where a Comet Partial produces intermediate
* data in a format that the Spark Final cannot interpret.
*/
private def tagUnsafePartialAggregates(plan: SparkPlan): Unit = {
plan.foreach {
case agg: BaseAggregateExec =>
// Only consider single-mode Final aggregates. Multi-mode Finals come from Spark's
// distinct-aggregate rewrite, where the Comet partial (if any) feeds into a Spark
// PartialMerge rather than directly into a Final, which is a different code path
// than the Comet-Partial → Spark-Final crash scenario from issue #1389.
val modes = agg.aggregateExpressions.map(_.mode).distinct
if (modes == Seq(Final) &&
!QueryPlanSerde.allAggsSupportMixedExecution(agg.aggregateExpressions) &&
!canAggregateBeConverted(agg, Final)) {
findPartialAggInPlan(agg.child).foreach { partial =>
// Only tag if the Partial would otherwise have been converted. If the Partial
// itself cannot be converted (e.g. the aggregate function is incompatible for the
// input type), there is no buffer-format mismatch to guard against, and tagging
// would mask the natural, more specific fallback reason.
if (canAggregateBeConverted(partial, Partial)) {
partial.setTagValue(
CometExecRule.COMET_UNSAFE_PARTIAL,
"Partial aggregate disabled: corresponding final aggregate " +
"cannot be converted to Comet and intermediate buffer formats are incompatible")
}
}
}
case _ =>
}
}

/**
* Conservative check for whether an aggregate could be converted to Comet. Checks operator
* enablement, grouping expressions, aggregate expressions, and result expressions.
* Intentionally skips the sparkFinalMode / child-native checks since those depend on
* transformation state.
*
* WARNING: this intentionally mirrors the predicate checks in `CometBaseAggregate.doConvert`
* (operators.scala). Any change to the convertibility rules there must be reflected here or
* this tagging pass will drift and either crash (missed tag) or over-disable (spurious tag). A
* shared predicate helper would be preferable.
*/
private def canAggregateBeConverted(
agg: BaseAggregateExec,
expectedMode: AggregateMode): Boolean = {
val handler = allExecs.get(agg.getClass)
if (handler.isEmpty) return false
val serde = handler.get.asInstanceOf[CometOperatorSerde[SparkPlan]]
if (!isOperatorEnabled(serde, agg.asInstanceOf[SparkPlan])) return false

// ObjectHashAggregate has an extra shuffle-enabled guard in its convert method
agg match {
case _: ObjectHashAggregateExec if !isCometShuffleEnabled(agg.conf) => return false
case _ =>
}

val aggregateExpressions = agg.aggregateExpressions
val groupingExpressions = agg.groupingExpressions

if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) return false

if (groupingExpressions.exists(e => QueryPlanSerde.containsMapType(e.dataType))) return false

if (!groupingExpressions.forall(e =>
QueryPlanSerde.exprToProto(e, agg.child.output).isDefined)) {
return false
}

if (aggregateExpressions.isEmpty) {
// Result expressions always checked when there are no aggregate expressions
val attributes =
groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
return agg.resultExpressions.forall(e =>
QueryPlanSerde.exprToProto(e, attributes).isDefined)
}

val modes = aggregateExpressions.map(_.mode).distinct
if (modes.size != 1 || modes.head != expectedMode) return false

// In Final mode, exprToProto resolves against the child's output; in Partial/non-Final mode
// it must bind to input attributes. This mirrors the `binding` calculation in
// `CometBaseAggregate.doConvert`.
val binding = expectedMode != Final
if (!aggregateExpressions.forall(e =>
QueryPlanSerde.aggExprToProto(e, agg.child.output, binding, agg.conf).isDefined)) {
return false
}

// doConvert only checks resultExpressions in Final mode when aggregate expressions exist
// (Partial emits the buffer directly). Mirror that here to avoid false negatives.
if (expectedMode == Final) {
val attributes =
groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
agg.resultExpressions.forall(e => QueryPlanSerde.exprToProto(e, attributes).isDefined)
} else {
true
}
}

/**
* Look for a Partial-mode aggregate that feeds directly into the given plan (the child of a
* Final). Walks through exchanges and AQE stages only, stopping at anything else including
* other aggregate stages. This avoids tagging unrelated Partials found deeper in the plan (e.g.
* the non-distinct Partial in a distinct-aggregate rewrite, which is separated from the Final
* by intermediate PartialMerge stages). Requires `aggregateExpressions.nonEmpty` so that
* group-by-only dedup stages are not mistaken for the partial we want to tag.
*/
private def findPartialAggInPlan(plan: SparkPlan): Option[BaseAggregateExec] = plan match {
case agg: BaseAggregateExec
if agg.aggregateExpressions.nonEmpty &&
agg.aggregateExpressions.forall(e => e.mode == Partial) =>
Some(agg)
case a: AQEShuffleReadExec => findPartialAggInPlan(a.child)
case s: ShuffleQueryStageExec => findPartialAggInPlan(s.plan)
case e: ShuffleExchangeExec => findPartialAggInPlan(e.child)
case _ => None
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ trait CometAggregateExpressionSerde[T <: AggregateFunction] {
*/
def getSupportLevel(expr: T): SupportLevel = Compatible(None)

/**
* Whether this aggregate's intermediate buffer format is compatible between Spark and Comet,
* making it safe to run the Partial in one engine and the Final in the other. Aggregates with
* simple single-value buffers (MIN, MAX, bitwise) are safe; those with complex or
* differently-encoded buffers (AVG, SUM with decimals, CollectSet, Variance) are not. COUNT is
* intentionally excluded: mixed COUNT partial/final regressed AQE's
* PropagateEmptyRelationAfterAQE pattern (which matches BaseAggregateExec only) and the Spark
* 4.0 count-bug decorrelation for correlated IN subqueries.
*/
def supportsMixedPartialFinal: Boolean = false

/**
* Convert a Spark expression into a protocol buffer representation that can be passed into
* native code.
Expand Down
31 changes: 31 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,24 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
classOf[VariancePop] -> CometVariancePop,
classOf[VarianceSamp] -> CometVarianceSamp)

/**
* Returns true if all aggregate expressions in the list have intermediate buffer formats that
* are compatible between Spark and Comet, making it safe to run Partial in one engine and Final
* in the other.
*/
def allAggsSupportMixedExecution(aggExprs: Seq[AggregateExpression]): Boolean = {
aggExprs.forall { aggExpr =>
val fn = aggExpr.aggregateFunction
aggrSerdeMap.get(fn.getClass) match {
case Some(handler) =>
handler
.asInstanceOf[CometAggregateExpressionSerde[AggregateFunction]]
.supportsMixedPartialFinal
case None => false
}
}
}

// A unique id for each expression. ~used to look up QueryContext during error creation.
private val exprIdCounter = new AtomicLong(0)

Expand Down Expand Up @@ -359,6 +377,19 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
false
}

/**
* Returns true if the given data type is or contains a `MapType` at any nesting level. Arrow's
* row format (used by DataFusion's grouped hash aggregate for composite group keys) does not
* support `Map`, so grouping on any type that transitively contains a map would crash in native
* execution.
*/
def containsMapType(dt: DataType): Boolean = dt match {
case _: MapType => true
case a: ArrayType => containsMapType(a.elementType)
case s: StructType => s.fields.exists(f => containsMapType(f.dataType))
case _ => false
}

/**
* Serializes Spark datatype to protobuf. Note that, a datatype can be serialized by this method
* doesn't mean it is supported by Comet native execution, i.e., `supportedDataType` may return
Expand Down
10 changes: 10 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import org.apache.comet.shims.CometEvalModeUtil

object CometMin extends CometAggregateExpressionSerde[Min] {

override def supportsMixedPartialFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
expr: Min,
Expand Down Expand Up @@ -81,6 +83,8 @@ object CometMin extends CometAggregateExpressionSerde[Min] {

object CometMax extends CometAggregateExpressionSerde[Max] {

override def supportsMixedPartialFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
expr: Max,
Expand Down Expand Up @@ -319,6 +323,8 @@ object CometLast extends CometAggregateExpressionSerde[Last] {
}

object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] {
override def supportsMixedPartialFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
bitAnd: BitAndAgg,
Expand Down Expand Up @@ -353,6 +359,8 @@ object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] {
}

object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] {
override def supportsMixedPartialFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
bitOr: BitOrAgg,
Expand Down Expand Up @@ -387,6 +395,8 @@ object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] {
}

object CometBitXOrAgg extends CometAggregateExpressionSerde[BitXorAgg] {
override def supportsMixedPartialFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
bitXor: BitXorAgg,
Expand Down
Loading
Loading