diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml
index 2eb182bd7b..d39734e5d3 100644
--- a/.github/workflows/pr_build_linux.yml
+++ b/.github/workflows/pr_build_linux.yml
@@ -140,6 +140,38 @@ jobs:
run: |
./dev/ci/check-working-tree-clean.sh
+ # Compile-only verification for Spark 4.1. Tests are intentionally skipped: the spark-4.1
+ # profile is currently a build target only, and several runtime/test failures are tracked
+ # in follow-up PRs. Excluded from lint-java because semanticdb-scalac_2.13.17 is not yet
+ # published and the lint job activates -Psemanticdb.
+ build-spark-4-1:
+ needs: lint
+ name: Build Spark 4.1, JDK 17
+ runs-on: ubuntu-latest
+ container:
+ image: amd64/rust
+ steps:
+ - uses: actions/checkout@v6
+
+ - name: Setup Rust & Java toolchain
+ uses: ./.github/actions/setup-builder
+ with:
+ rust-version: ${{ env.RUST_VERSION }}
+ jdk-version: 17
+
+ - name: Cache Maven dependencies
+ uses: actions/cache@v5
+ with:
+ path: |
+ ~/.m2/repository
+ /root/.m2/repository
+ key: ${{ runner.os }}-java-maven-${{ hashFiles('**/pom.xml') }}-spark-4.1-build
+ restore-keys: |
+ ${{ runner.os }}-java-maven-
+
+ - name: Compile (skip tests)
+ run: ./mvnw -B install -DskipTests -Dmaven.test.skip=true -Pspark-4.1
+
# Build native library once and share with all test jobs
build-native:
needs: lint
diff --git a/common/src/main/spark-4.1/org/apache/comet/shims/CometTypeShim.scala b/common/src/main/spark-4.1/org/apache/comet/shims/CometTypeShim.scala
new file mode 100644
index 0000000000..f75fde7de9
--- /dev/null
+++ b/common/src/main/spark-4.1/org/apache/comet/shims/CometTypeShim.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.execution.datasources.VariantMetadata
+import org.apache.spark.sql.types.{DataType, StringType, StructType}
+
+trait CometTypeShim {
+ // A `StringType` carries collation metadata in Spark 4.0. Only non-default (non-UTF8_BINARY)
+ // collations have semantics Comet's byte-level hashing/sorting/equality cannot honor. The
+ // default `StringType` object is `StringType(UTF8_BINARY_COLLATION_ID)`, so comparing
+ // `collationId` against that instance's id picks out non-default collations without needing
+ // `private[sql]` helpers on `StringType`.
+ def isStringCollationType(dt: DataType): Boolean = dt match {
+ case st: StringType => st.collationId != StringType.collationId
+ case _ => false
+ }
+
+ // Spark 4.0's `PushVariantIntoScan` rewrites `VariantType` columns into a `StructType` whose
+ // fields each carry `__VARIANT_METADATA_KEY` metadata, then pushes `variant_get` paths down as
+ // ordinary struct field accesses. Comet's native scans don't understand the on-disk Parquet
+ // variant shredding layout, so reading such a struct natively returns nulls. Detect the marker
+ // and force scan fallback.
+ def isVariantStruct(s: StructType): Boolean = VariantMetadata.isVariantStruct(s)
+}
diff --git a/common/src/main/spark-4.1/org/apache/comet/shims/ShimBatchReader.scala b/common/src/main/spark-4.1/org/apache/comet/shims/ShimBatchReader.scala
new file mode 100644
index 0000000000..8ce51b31b6
--- /dev/null
+++ b/common/src/main/spark-4.1/org/apache/comet/shims/ShimBatchReader.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.paths.SparkPath
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources.PartitionedFile
+
+object ShimBatchReader {
+ def newPartitionedFile(partitionValues: InternalRow, file: String): PartitionedFile =
+ PartitionedFile(
+ partitionValues,
+ SparkPath.fromUrlString(file),
+ -1, // -1 means we read the entire file
+ -1,
+ Array.empty[String],
+ 0,
+ 0)
+}
diff --git a/common/src/main/spark-4.1/org/apache/comet/shims/ShimCometConf.scala b/common/src/main/spark-4.1/org/apache/comet/shims/ShimCometConf.scala
new file mode 100644
index 0000000000..0eb57c52b4
--- /dev/null
+++ b/common/src/main/spark-4.1/org/apache/comet/shims/ShimCometConf.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+trait ShimCometConf {
+ protected val COMET_SCHEMA_EVOLUTION_ENABLED_DEFAULT = true
+}
diff --git a/common/src/main/spark-4.1/org/apache/comet/shims/ShimFileFormat.scala b/common/src/main/spark-4.1/org/apache/comet/shims/ShimFileFormat.scala
new file mode 100644
index 0000000000..1702db135a
--- /dev/null
+++ b/common/src/main/spark-4.1/org/apache/comet/shims/ShimFileFormat.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
+import org.apache.spark.sql.execution.datasources.parquet.ParquetRowIndexUtil
+import org.apache.spark.sql.types.StructType
+
+object ShimFileFormat {
+ // A name for a temporary column that holds row indexes computed by the file format reader
+ // until they can be placed in the _metadata struct.
+ val ROW_INDEX_TEMPORARY_COLUMN_NAME = ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME
+
+ def findRowIndexColumnIndexInSchema(sparkSchema: StructType): Int =
+ ParquetRowIndexUtil.findRowIndexColumnIndexInSchema(sparkSchema)
+}
diff --git a/common/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala b/common/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
new file mode 100644
index 0000000000..b6a1b56d97
--- /dev/null
+++ b/common/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.AccumulatorV2
+
+object ShimTaskMetrics {
+
+ def getTaskAccumulator(taskMetrics: TaskMetrics): Option[AccumulatorV2[_, _]] =
+ taskMetrics._externalAccums.lastOption
+}
diff --git a/pom.xml b/pom.xml
index d1c81b8352..644556c1e1 100644
--- a/pom.xml
+++ b/pom.xml
@@ -682,6 +682,30 @@ under the License.
+
+
+ spark-4.1
+
+
+ 2.13.17
+ 2.13
+ 4.1.1
+ 4.1
+ 1.16.0
+ 4.13.6
+ 2.0.17
+ spark-4.1
+ not-needed-yet
+
+ 17
+ ${java.version}
+ ${java.version}
+
+
+
scala-2.12
diff --git a/spark/pom.xml b/spark/pom.xml
index 0d679ace62..baac3a2557 100644
--- a/spark/pom.xml
+++ b/spark/pom.xml
@@ -273,6 +273,31 @@ under the License.
+
+ spark-4.1
+
+
+
+ org.apache.iceberg
+ iceberg-spark-runtime-4.0_${scala.binary.version}
+ 1.10.0
+ test
+
+
+
+ org.eclipse.jetty
+ jetty-server
+ 11.0.26
+ test
+
+
+ org.eclipse.jetty
+ jetty-servlet
+ 11.0.26
+ test
+
+
+
generate-docs
diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java
index 70721366c7..15755248ba 100644
--- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java
+++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java
@@ -41,7 +41,6 @@
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper;
import org.apache.spark.scheduler.MapStatus;
-import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
@@ -171,8 +170,7 @@ public void write(Iterator> records) throws IOException {
mapOutputWriter
.commitAllPartitions(ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE)
.getPartitionLengths();
- mapStatus =
- MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, mapId);
+ mapStatus = MapStatusHelper.apply(blockManager.shuffleServerId(), partitionLengths, mapId);
return;
}
final long openStartTime = System.nanoTime();
@@ -262,7 +260,7 @@ public void write(Iterator> records) throws IOException {
// TODO: We probably can move checksum generation here when concatenating partition files
partitionLengths = writePartitionedData(mapOutputWriter);
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, mapId);
+ mapStatus = MapStatusHelper.apply(blockManager.shuffleServerId(), partitionLengths, mapId);
} catch (Exception e) {
try {
mapOutputWriter.abort(e);
diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java
index 8930a52884..995facaef5 100644
--- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java
+++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java
@@ -50,7 +50,6 @@
import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
-import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.BaseShuffleHandle;
@@ -288,7 +287,7 @@ void closeAndWriteOutput() throws IOException {
}
}
}
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, mapId);
+ mapStatus = MapStatusHelper.apply(blockManager.shuffleServerId(), partitionLengths, mapId);
}
@VisibleForTesting
diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 1ef901d065..d67530ac41 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -165,6 +165,10 @@ object CometSparkSessionExtensions extends Logging {
org.apache.spark.SPARK_VERSION >= "4.0"
}
+ def isSpark41Plus: Boolean = {
+ org.apache.spark.SPARK_VERSION >= "4.1"
+ }
+
/**
* Whether we should override Spark memory configuration for Comet. This only returns true when
* Comet native execution is enabled and/or Comet shuffle is enabled and Comet doesn't use
diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
index 6889fc02d9..c87c7ae00d 100644
--- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
@@ -219,7 +219,7 @@ object CometSum extends CometAggregateExpressionSerde[Sum] {
return None
}
- val evalMode = sum.evalMode
+ val evalMode = CometEvalModeUtil.sumEvalMode(sum)
val childExpr = exprToProto(sum.child, inputs, binding)
val dataType = serializeDataType(sum.dataType)
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala
index aa47dfa166..ed3e029e23 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala
@@ -19,7 +19,6 @@
package org.apache.spark.sql.comet.execution.shuffle
-import java.util.Collections
import java.util.concurrent.ConcurrentHashMap
import scala.jdk.CollectionConverters._
@@ -101,7 +100,9 @@ class CometShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
case 2 =>
c.newInstance(conf, null).asInstanceOf[IndexShuffleBlockResolver]
case 3 =>
- c.newInstance(conf, null, Collections.emptyMap())
+ // Spark 4.1 changed the third parameter type from java.util.Map to
+ // java.util.concurrent.ConcurrentMap. ConcurrentHashMap satisfies both.
+ c.newInstance(conf, null, new ConcurrentHashMap[Int, OpenHashSet[Long]]())
.asInstanceOf[IndexShuffleBlockResolver]
}
}
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/MapStatusHelper.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/MapStatusHelper.scala
new file mode 100644
index 0000000000..4689b28626
--- /dev/null
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/MapStatusHelper.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.execution.shuffle
+
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage.BlockManagerId
+
+// Spark 4.1 added a `checksumVal` parameter (with a default of 0) to MapStatus.apply.
+// Java callers can't use Scala default parameters, so we wrap the call here. The Scala
+// compiler fills in the default per Spark version.
+object MapStatusHelper {
+ def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskId: Long): MapStatus =
+ MapStatus(loc, uncompressedSizes, mapTaskId)
+}
diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
index 600931c346..f80a8909f6 100644
--- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
@@ -20,6 +20,7 @@
package org.apache.comet.shims
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.comet.expressions.CometEvalMode
import org.apache.comet.serde.CommonStringExprs
@@ -54,4 +55,6 @@ object CometEvalModeUtil {
case EvalMode.TRY => CometEvalMode.TRY
case EvalMode.ANSI => CometEvalMode.ANSI
}
+
+ def sumEvalMode(s: Sum): EvalMode.Value = s.evalMode
}
diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala
index d9b80ab488..d3e3270700 100644
--- a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala
@@ -20,6 +20,7 @@
package org.apache.comet.shims
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.types.DataTypes
import org.apache.comet.CometSparkSessionExtensions.withInfo
@@ -97,4 +98,6 @@ object CometEvalModeUtil {
case EvalMode.TRY => CometEvalMode.TRY
case EvalMode.ANSI => CometEvalMode.ANSI
}
+
+ def sumEvalMode(s: Sum): EvalMode.Value = s.evalMode
}
diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
index 38a3b9d726..86b28b715b 100644
--- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
@@ -20,6 +20,7 @@
package org.apache.comet.shims
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.internal.SQLConf
@@ -155,4 +156,6 @@ object CometEvalModeUtil {
case EvalMode.TRY => CometEvalMode.TRY
case EvalMode.ANSI => CometEvalMode.ANSI
}
+
+ def sumEvalMode(s: Sum): EvalMode.Value = s.evalMode
}
diff --git a/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala
new file mode 100644
index 0000000000..4f21e5eafa
--- /dev/null
+++ b/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
+import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator
+import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.StringTypeWithCollation
+import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes, StringType}
+
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.expressions.{CometCast, CometEvalMode}
+import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible}
+import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr}
+import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType}
+
+/**
+ * `CometExprShim` acts as a shim for parsing expressions from different Spark versions.
+ */
+trait CometExprShim extends CommonStringExprs {
+ protected def evalMode(c: Cast): CometEvalMode.Value =
+ CometEvalModeUtil.fromSparkEvalMode(c.evalMode)
+
+ protected def binaryOutputStyle: BinaryOutputStyle = {
+ // In Spark 4.1, BINARY_OUTPUT_STYLE is an enumConf so getConf already returns the enum value.
+ SQLConf.get.getConf(SQLConf.BINARY_OUTPUT_STYLE) match {
+ case Some(SQLConf.BinaryOutputStyle.UTF8) => BinaryOutputStyle.UTF8
+ case Some(SQLConf.BinaryOutputStyle.BASIC) => BinaryOutputStyle.BASIC
+ case Some(SQLConf.BinaryOutputStyle.BASE64) => BinaryOutputStyle.BASE64
+ case Some(SQLConf.BinaryOutputStyle.HEX) => BinaryOutputStyle.HEX
+ case _ => BinaryOutputStyle.HEX_DISCRETE
+ }
+ }
+
+ def versionSpecificExprToProtoInternal(
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[Expr] = {
+ expr match {
+ case knc: KnownNotContainsNull =>
+ // On Spark 4.0, array_compact rewrites to KnownNotContainsNull(ArrayFilter(IsNotNull)).
+ // Strip the wrapper and serialize the inner ArrayFilter as spark_array_compact.
+ knc.child match {
+ case filter: ArrayFilter =>
+ filter.function.children.headOption match {
+ case Some(_: IsNotNull) =>
+ val arrayChild = filter.left
+ val elementType = arrayChild.dataType.asInstanceOf[ArrayType].elementType
+ val arrayExprProto = exprToProtoInternal(arrayChild, inputs, binding)
+ val returnType = ArrayType(elementType)
+ val scalarExpr = scalarFunctionExprToProtoWithReturnType(
+ "spark_array_compact",
+ returnType,
+ false,
+ arrayExprProto)
+ optExprWithInfo(scalarExpr, knc, arrayChild)
+ case _ => exprToProtoInternal(knc.child, inputs, binding)
+ }
+ case _ => exprToProtoInternal(knc.child, inputs, binding)
+ }
+
+ case s: StaticInvoke
+ if s.staticObject == classOf[StringDecode] &&
+ s.dataType.isInstanceOf[StringType] &&
+ s.functionName == "decode" &&
+ s.arguments.size == 4 &&
+ s.inputTypes == Seq(
+ BinaryType,
+ StringTypeWithCollation(supportsTrimCollation = true),
+ BooleanType,
+ BooleanType) =>
+ val Seq(bin, charset, _, _) = s.arguments
+ stringDecode(expr, charset, bin, inputs, binding)
+
+ case expr @ ToPrettyString(child, timeZoneId) =>
+ val castSupported = CometCast.isSupported(
+ child.dataType,
+ DataTypes.StringType,
+ timeZoneId,
+ CometEvalMode.TRY)
+
+ val isCastSupported = castSupported match {
+ case Compatible(_) => true
+ case Incompatible(_) => true
+ case _ => false
+ }
+
+ if (isCastSupported) {
+ exprToProtoInternal(child, inputs, binding) match {
+ case Some(p) =>
+ val toPrettyString = ExprOuterClass.ToPrettyString
+ .newBuilder()
+ .setChild(p)
+ .setTimezone(timeZoneId.getOrElse("UTC"))
+ .setBinaryOutputStyle(binaryOutputStyle)
+ .build()
+ Some(
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setToPrettyString(toPrettyString)
+ .build())
+ case _ =>
+ withInfo(expr, child)
+ None
+ }
+ } else {
+ None
+ }
+
+ case wb: WidthBucket =>
+ val childExprs = wb.children.map(exprToProtoInternal(_, inputs, binding))
+ val optExpr = scalarFunctionExprToProto("width_bucket", childExprs: _*)
+ optExprWithInfo(optExpr, wb, wb.children: _*)
+
+ // In Spark 4.0, StructsToJson is a RuntimeReplaceable whose replacement is
+ // Invoke(Literal(StructsToJsonEvaluator), "evaluate", ...). Reconstruct the
+ // original StructsToJson and recurse so support-level checks apply.
+ case i: Invoke =>
+ (i.targetObject, i.functionName, i.arguments) match {
+ case (Literal(evaluator: StructsToJsonEvaluator, _), "evaluate", Seq(child)) =>
+ exprToProtoInternal(
+ StructsToJson(evaluator.options, child, evaluator.timeZoneId),
+ inputs,
+ binding)
+ case _ => None
+ }
+
+ case _ => None
+ }
+ }
+}
+
+object CometEvalModeUtil {
+ def fromSparkEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value = evalMode match {
+ case EvalMode.LEGACY => CometEvalMode.LEGACY
+ case EvalMode.TRY => CometEvalMode.TRY
+ case EvalMode.ANSI => CometEvalMode.ANSI
+ }
+
+ // In Spark 4.1, Sum carries a NumericEvalContext rather than a direct EvalMode.
+ def sumEvalMode(s: Sum): EvalMode.Value = s.evalContext.evalMode
+}
diff --git a/spark/src/main/spark-4.1/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala
new file mode 100644
index 0000000000..e88cbf1f53
--- /dev/null
+++ b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.SparkContext
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
+import org.apache.spark.sql.internal.SQLConf
+
+import org.apache.comet.shims.ShimCometBroadcastExchangeExec.SPARK_MAX_BROADCAST_TABLE_SIZE
+
+trait ShimCometBroadcastExchangeExec {
+
+ def setJobGroupOrTag(sc: SparkContext, broadcastExchange: BroadcastExchangeLike): Unit = {
+ // Setup a job tag here so later it may get cancelled by tag if necessary.
+ sc.addJobTag(broadcastExchange.jobTag)
+ sc.setInterruptOnCancel(true)
+ }
+
+ def cancelJobGroup(sc: SparkContext, broadcastExchange: BroadcastExchangeLike): Unit = {
+ sc.cancelJobsWithTag(broadcastExchange.jobTag)
+ }
+
+ def maxBroadcastTableBytes(conf: SQLConf): Long = {
+ JavaUtils.byteStringAsBytes(conf.getConfString(SPARK_MAX_BROADCAST_TABLE_SIZE, "8GB"))
+ }
+
+}
+
+object ShimCometBroadcastExchangeExec {
+ val SPARK_MAX_BROADCAST_TABLE_SIZE = "spark.sql.maxBroadcastTableSize"
+}
diff --git a/spark/src/main/spark-4.1/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
new file mode 100644
index 0000000000..a8ae94cd7c
--- /dev/null
+++ b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.ShuffleDependency
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleExchangeExec, ShuffleType}
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.types.StructType
+
+trait ShimCometShuffleExchangeExec {
+ def apply(s: ShuffleExchangeExec, shuffleType: ShuffleType): CometShuffleExchangeExec = {
+ CometShuffleExchangeExec(
+ s.outputPartitioning,
+ s.child,
+ s,
+ s.shuffleOrigin,
+ shuffleType,
+ s.advisoryPartitionSize)
+ }
+
+ protected def fromAttributes(attributes: Seq[Attribute]): StructType =
+ DataTypeUtils.fromAttributes(attributes)
+
+ protected def getShuffleId(shuffleDependency: ShuffleDependency[Int, _, _]): Int =
+ shuffleDependency.shuffleId
+}
diff --git a/spark/src/main/spark-4.1/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
new file mode 100644
index 0000000000..cac636c45c
--- /dev/null
+++ b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.SparkSessionExtensions
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlan}
+import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
+import org.apache.spark.sql.internal.SQLConf
+
+trait ShimCometSparkSessionExtensions {
+ protected def getPushedAggregate(scan: ParquetScan): Option[Aggregation] = scan.pushedAggregate
+
+ protected def supportsExtendedExplainInfo(qe: QueryExecution): Boolean = true
+
+ protected val EXTENDED_EXPLAIN_PROVIDERS_KEY = SQLConf.EXTENDED_EXPLAIN_PROVIDERS.key
+
+ def injectQueryStageOptimizerRuleShim(
+ extensions: SparkSessionExtensions,
+ rule: Rule[SparkPlan]): Unit = {
+ extensions.injectQueryStageOptimizerRule(_ => rule)
+ }
+}
diff --git a/spark/src/main/spark-4.1/org/apache/comet/shims/ShimSQLConf.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimSQLConf.scala
new file mode 100644
index 0000000000..bdb2739460
--- /dev/null
+++ b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimSQLConf.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.internal.LegacyBehaviorPolicy
+
+trait ShimSQLConf {
+ protected val LEGACY = LegacyBehaviorPolicy.LEGACY
+ protected val CORRECTED = LegacyBehaviorPolicy.CORRECTED
+}
diff --git a/spark/src/main/spark-4.1/org/apache/comet/shims/ShimSubqueryBroadcast.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimSubqueryBroadcast.scala
new file mode 100644
index 0000000000..73d9e53c4a
--- /dev/null
+++ b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimSubqueryBroadcast.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.execution.{SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec}
+
+trait ShimSubqueryBroadcast {
+
+ /**
+ * Gets the build key indices from SubqueryAdaptiveBroadcastExec. Spark 3.x has `index: Int`,
+ * Spark 4.x has `indices: Seq[Int]`.
+ */
+ def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec): Seq[Int] = {
+ sab.indices
+ }
+
+ /** Same version shim for SubqueryBroadcastExec. */
+ def getSubqueryBroadcastExecIndices(sub: SubqueryBroadcastExec): Seq[Int] = {
+ sub.indices
+ }
+}
diff --git a/spark/src/main/spark-4.1/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala b/spark/src/main/spark-4.1/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala
new file mode 100644
index 0000000000..4e48744fc4
--- /dev/null
+++ b/spark/src/main/spark-4.1/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.comet.shims
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.config.EXECUTOR_MIN_MEMORY_OVERHEAD
+
+trait ShimCometDriverPlugin {
+ protected def getMemoryOverheadMinMib(sparkConf: SparkConf): Long =
+ sparkConf.get(EXECUTOR_MIN_MEMORY_OVERHEAD)
+}
diff --git a/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
new file mode 100644
index 0000000000..3d9b963a93
--- /dev/null
+++ b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.hadoop.fs.Path
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, FileSourceConstantMetadataAttribute, Literal}
+import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, ScalarSubquery}
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.StructType
+
+trait ShimCometScanExec extends ShimStreamSourceAwareSparkPlan {
+ def wrapped: FileSourceScanExec
+
+ lazy val fileConstantMetadataColumns: Seq[AttributeReference] =
+ wrapped.fileConstantMetadataColumns
+
+ protected def newFileScanRDD(
+ fsRelation: HadoopFsRelation,
+ readFunction: PartitionedFile => Iterator[InternalRow],
+ filePartitions: Seq[FilePartition],
+ readSchema: StructType,
+ options: ParquetOptions): FileScanRDD = {
+ new FileScanRDD(
+ fsRelation.sparkSession,
+ readFunction,
+ filePartitions,
+ readSchema,
+ fileConstantMetadataColumns,
+ fsRelation.fileFormat.fileConstantMetadataExtractors,
+ options)
+ }
+
+ // see SPARK-39634
+ protected def isNeededForSchema(sparkSchema: StructType): Boolean = false
+
+ protected def getPartitionedFile(
+ f: FileStatusWithMetadata,
+ p: PartitionDirectory): PartitionedFile =
+ PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values, 0, f.getLen)
+
+ protected def splitFiles(
+ sparkSession: SparkSession,
+ file: FileStatusWithMetadata,
+ filePath: Path,
+ isSplitable: Boolean,
+ maxSplitBytes: Long,
+ partitionValues: InternalRow): Seq[PartitionedFile] =
+ PartitionedFileUtil.splitFiles(file, filePath, isSplitable, maxSplitBytes, partitionValues)
+
+ protected def getPushedDownFilters(
+ relation: HadoopFsRelation,
+ dataFilters: Seq[Expression]): Seq[Filter] = {
+ translateToV1Filters(relation, dataFilters, _.toLiteral)
+ }
+
+ // From Spark FileSourceScanLike
+ private def translateToV1Filters(
+ relation: HadoopFsRelation,
+ dataFilters: Seq[Expression],
+ scalarSubqueryToLiteral: ScalarSubquery => Literal): Seq[Filter] = {
+ val scalarSubqueryReplaced = dataFilters.map(_.transform {
+ // Replace scalar subquery to literal so that `DataSourceStrategy.translateFilter` can
+ // support translating it.
+ case scalarSubquery: ScalarSubquery => scalarSubqueryToLiteral(scalarSubquery)
+ })
+
+ val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation)
+ // `dataFilters` should not include any constant metadata col filters
+ // because the metadata struct has been flatted in FileSourceStrategy
+ // and thus metadata col filters are invalid to be pushed down. Metadata that is generated
+ // during the scan can be used for filters.
+ scalarSubqueryReplaced
+ .filterNot(_.references.exists {
+ case FileSourceConstantMetadataAttribute(_) => true
+ case _ => false
+ })
+ .flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown))
+ }
+
+}
diff --git a/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala
new file mode 100644
index 0000000000..41ccdd0402
--- /dev/null
+++ b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.shuffle.ShuffleWriteProcessor
+
+trait ShimCometShuffleWriteProcessor extends ShuffleWriteProcessor {}
diff --git a/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
new file mode 100644
index 0000000000..46bdd2ec03
--- /dev/null
+++ b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -0,0 +1,346 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import java.io.FileNotFoundException
+
+import scala.util.matching.Regex
+
+import org.apache.spark.QueryContext
+import org.apache.spark.SparkException
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+object ShimSparkErrorConverter {
+ val ObjectLocationPattern: Regex = "Object at location (.+?) not found".r
+}
+
+/**
+ * Spark 4.0-specific implementation for converting error types to proper Spark exceptions.
+ */
+trait ShimSparkErrorConverter {
+
+ private def parseFloatLiteral(value: String): Float = {
+ value.toLowerCase match {
+ case "inf" | "+inf" | "infinity" | "+infinity" => Float.PositiveInfinity
+ case "-inf" | "-infinity" => Float.NegativeInfinity
+ case "nan" | "+nan" | "-nan" => Float.NaN
+ case _ => value.toFloat
+ }
+ }
+
+ private def parseDoubleLiteral(value: String): Double = {
+ val normalized = value.toLowerCase.stripSuffix("d")
+ normalized match {
+ case "inf" | "+inf" | "infinity" | "+infinity" => Double.PositiveInfinity
+ case "-inf" | "-infinity" => Double.NegativeInfinity
+ case "nan" | "+nan" | "-nan" => Double.NaN
+ case _ => normalized.toDouble
+ }
+ }
+
+ /**
+ * Convert error type string and parameters to appropriate Spark exception. Version-specific
+ * implementations call the correct QueryExecutionErrors.* methods.
+ *
+ * @param errorType
+ * The error type from JSON (e.g., "DivideByZero")
+ * @param errorClass
+ * The Spark error class (e.g., "DIVIDE_BY_ZERO")
+ * @param params
+ * Error parameters from JSON
+ * @param context
+ * QueryContext array with SQL text and position information
+ * @param summary
+ * Formatted summary string showing error location
+ * @return
+ * Throwable (specific exception type from QueryExecutionErrors), or None if unknown
+ */
+ def convertErrorType(
+ errorType: String,
+ _errorClass: String,
+ params: Map[String, Any],
+ context: Array[QueryContext],
+ _summary: String): Option[Throwable] = {
+
+ errorType match {
+
+ case "DivideByZero" =>
+ Some(QueryExecutionErrors.divideByZeroError(context.headOption.orNull))
+
+ case "RemainderByZero" =>
+ // SPARK 4.0 REMOVED remainderByZeroError so we use generic arithmetic exception
+ Some(
+ new SparkException(
+ errorClass = "REMAINDER_BY_ZERO",
+ messageParameters = params.map { case (k, v) => (k, v.toString) },
+ cause = null))
+
+ case "IntervalDividedByZero" =>
+ Some(QueryExecutionErrors.intervalDividedByZeroError(context.headOption.orNull))
+
+ case "BinaryArithmeticOverflow" =>
+ Some(
+ QueryExecutionErrors.binaryArithmeticCauseOverflowError(
+ params("value1").toString.toShort,
+ params("symbol").toString,
+ params("value2").toString.toShort,
+ params("functionName").toString))
+
+ case "ArithmeticOverflow" =>
+ val fromType = params("fromType").toString
+ Some(
+ QueryExecutionErrors
+ .arithmeticOverflowError(fromType + " overflow", "", context.headOption.orNull))
+
+ case "IntegralDivideOverflow" =>
+ Some(QueryExecutionErrors.overflowInIntegralDivideError(context.headOption.orNull))
+
+ case "DecimalSumOverflow" =>
+ Some(QueryExecutionErrors.overflowInSumOfDecimalError(context.headOption.orNull, ""))
+
+ case "NumericValueOutOfRange" =>
+ val decimal = Decimal(params("value").toString)
+ Some(
+ QueryExecutionErrors.cannotChangeDecimalPrecisionError(
+ decimal,
+ params("precision").toString.toInt,
+ params("scale").toString.toInt,
+ context.headOption.orNull))
+
+ case "DatetimeOverflow" =>
+ // Spark 4.0 doesn't have datetimeOverflowError
+ Some(
+ new SparkException(
+ errorClass = "DATETIME_OVERFLOW",
+ messageParameters = params.map { case (k, v) => (k, v.toString) },
+ cause = null))
+
+ case "InvalidArrayIndex" =>
+ Some(
+ QueryExecutionErrors.invalidArrayIndexError(
+ params("indexValue").toString.toInt,
+ params("arraySize").toString.toInt,
+ context.headOption.orNull))
+
+ case "InvalidElementAtIndex" =>
+ Some(
+ QueryExecutionErrors.invalidElementAtIndexError(
+ params("indexValue").toString.toInt,
+ params("arraySize").toString.toInt,
+ context.headOption.orNull))
+
+ case "InvalidIndexOfZero" =>
+ Some(QueryExecutionErrors.invalidIndexOfZeroError(context.headOption.orNull))
+
+ case "InvalidBitmapPosition" =>
+ Some(
+ QueryExecutionErrors.invalidBitmapPositionError(
+ params("bitPosition").toString.toLong,
+ params("bitmapNumBytes").toString.toLong))
+
+ case "DuplicatedMapKey" =>
+ Some(QueryExecutionErrors.duplicateMapKeyFoundError(params("key")))
+
+ case "NullMapKey" =>
+ Some(QueryExecutionErrors.nullAsMapKeyNotAllowedError())
+
+ case "MapKeyValueDiffSizes" =>
+ Some(QueryExecutionErrors.mapDataKeyArrayLengthDiffersFromValueArrayLengthError())
+
+ case "ExceedMapSizeLimit" =>
+ Some(QueryExecutionErrors.exceedMapSizeLimitError(params("size").toString.toInt))
+
+ case "CollectionSizeLimitExceeded" =>
+ Some(
+ QueryExecutionErrors.createArrayWithElementsExceedLimitError(
+ "array",
+ params("numElements").toString.toLong))
+
+ case "NotNullAssertViolation" =>
+ Some(
+ QueryExecutionErrors.foundNullValueForNotNullableFieldError(
+ params("fieldName").toString))
+
+ case "ValueIsNull" =>
+ Some(
+ QueryExecutionErrors.fieldCannotBeNullError(
+ params.getOrElse("rowIndex", 0).toString.toInt,
+ params("fieldName").toString))
+
+ case "CannotParseTimestamp" =>
+ Some(
+ QueryExecutionErrors.ansiDateTimeParseError(
+ new Exception(params("message").toString),
+ params("suggestedFunc").toString))
+
+ case "InvalidFractionOfSecond" =>
+ Some(QueryExecutionErrors.invalidFractionOfSecondError(params("value").toString.toDouble))
+
+ case "CastInvalidValue" =>
+ val str = UTF8String.fromString(params("value").toString)
+ val targetType = getDataType(params("toType").toString)
+ Some(
+ QueryExecutionErrors
+ .invalidInputInCastToNumberError(targetType, str, context.headOption.orNull))
+
+ case "InvalidInputInCastToDatetime" =>
+ val str = UTF8String.fromString(params("value").toString)
+ val targetType = getDataType(params("toType").toString)
+ Some(
+ QueryExecutionErrors
+ .invalidInputInCastToDatetimeError(str, targetType, context.headOption.orNull))
+
+ case "CastOverFlow" =>
+ val fromType = getDataType(params("fromType").toString)
+ val toType = getDataType(params("toType").toString)
+ val valueStr = params("value").toString
+
+ // Convert string value to appropriate type for toSQLValue
+ val typedValue: Any = fromType match {
+ case _: DecimalType =>
+ // Parse decimal string (may have "BD" suffix from BigDecimal.toString)
+ val cleanStr = if (valueStr.endsWith("BD")) valueStr.dropRight(2) else valueStr
+ Decimal(cleanStr)
+ case ByteType =>
+ // Strip "T" suffix for TINYINT literals
+ val cleanStr = if (valueStr.endsWith("T")) valueStr.dropRight(1) else valueStr
+ cleanStr.toByte
+ case ShortType =>
+ // Strip "S" suffix for SMALLINT literals
+ val cleanStr = if (valueStr.endsWith("S")) valueStr.dropRight(1) else valueStr
+ cleanStr.toShort
+ case IntegerType => valueStr.toInt
+ case LongType =>
+ // Strip "L" suffix for BIGINT literals
+ val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1) else valueStr
+ cleanStr.toLong
+ case FloatType => parseFloatLiteral(valueStr)
+ case DoubleType => parseDoubleLiteral(valueStr)
+ case StringType => UTF8String.fromString(valueStr)
+ case _ => valueStr // Fallback to string
+ }
+
+ Some(QueryExecutionErrors.castingCauseOverflowError(typedValue, fromType, toType))
+
+ case "CannotParseDecimal" =>
+ Some(QueryExecutionErrors.cannotParseDecimalError())
+
+ case "InvalidUtf8String" =>
+ val hexStr = UTF8String.fromString(params("hexString").toString)
+ Some(QueryExecutionErrors.invalidUTF8StringError(hexStr))
+
+ case "UnexpectedPositiveValue" =>
+ Some(
+ QueryExecutionErrors.unexpectedValueForStartInFunctionError(
+ params("parameterName").toString))
+
+ case "UnexpectedNegativeValue" =>
+ Some(
+ QueryExecutionErrors.unexpectedValueForLengthInFunctionError(
+ params("parameterName").toString,
+ params("actualValue").toString.toInt))
+
+ case "InvalidRegexGroupIndex" =>
+ Some(
+ QueryExecutionErrors.invalidRegexGroupIndexError(
+ params("functionName").toString,
+ params("groupCount").toString.toInt,
+ params("groupIndex").toString.toInt))
+
+ case "DatatypeCannotOrder" =>
+ Some(
+ QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError(
+ params("dataType").toString))
+
+ case "ScalarSubqueryTooManyRows" =>
+ Some(QueryExecutionErrors.multipleRowScalarSubqueryError(context.headOption.orNull))
+
+ case "IntervalArithmeticOverflowWithSuggestion" =>
+ Some(
+ QueryExecutionErrors.withSuggestionIntervalArithmeticOverflowError(
+ params.get("functionName").map(_.toString).getOrElse(""),
+ context.headOption.orNull))
+
+ case "IntervalArithmeticOverflowWithoutSuggestion" =>
+ Some(
+ QueryExecutionErrors.withoutSuggestionIntervalArithmeticOverflowError(
+ context.headOption.orNull))
+
+ case "DuplicateFieldCaseInsensitive" =>
+ Some(
+ QueryExecutionErrors.foundDuplicateFieldInCaseInsensitiveModeError(
+ params("requiredFieldName").toString,
+ params("matchedOrcFields").toString))
+
+ case "FileNotFound" =>
+ val msg = params("message").toString
+ val path = ShimSparkErrorConverter.ObjectLocationPattern
+ .findFirstMatchIn(msg)
+ .map(_.group(1))
+ .getOrElse(msg)
+ Some(
+ QueryExecutionErrors
+ .fileNotExistError(path, new FileNotFoundException(s"File $path does not exist")))
+
+ case _ =>
+ // Unknown error type - return None to trigger fallback
+ None
+ }
+ }
+
+ private def getDataType(typeName: String): DataType = {
+ typeName.toUpperCase match {
+ case "BYTE" | "TINYINT" => ByteType
+ case "SHORT" | "SMALLINT" => ShortType
+ case "INT" | "INTEGER" => IntegerType
+ case "LONG" | "BIGINT" => LongType
+ case "FLOAT" | "REAL" => FloatType
+ case "DOUBLE" => DoubleType
+ case "DECIMAL" => DecimalType.SYSTEM_DEFAULT
+ case "STRING" | "VARCHAR" => StringType
+ case "BINARY" => BinaryType
+ case "BOOLEAN" => BooleanType
+ case "DATE" => DateType
+ case "TIMESTAMP" => TimestampType
+ case _ =>
+ try {
+ DataType.fromDDL(typeName)
+ } catch {
+ case _: Exception =>
+ // fromDDL rejects types that are syntactically invalid in SQL DDL, such as
+ // DECIMAL(p,s) with a negative scale (valid when allowNegativeScaleOfDecimal=true).
+ // Parse those manually rather than silently falling back to StringType.
+ if (typeName.toUpperCase.startsWith("DECIMAL(") && typeName.endsWith(")")) {
+ val inner = typeName.substring("DECIMAL(".length, typeName.length - 1)
+ val parts = inner.split(",")
+ if (parts.length == 2) {
+ try {
+ DataTypes.createDecimalType(parts(0).trim.toInt, parts(1).trim.toInt)
+ } catch {
+ case _: Exception => StringType
+ }
+ } else StringType
+ } else StringType
+ }
+ }
+ }
+}
diff --git a/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimStreamSourceAwareSparkPlan.scala b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimStreamSourceAwareSparkPlan.scala
new file mode 100644
index 0000000000..93552fc00f
--- /dev/null
+++ b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimStreamSourceAwareSparkPlan.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.sql.connector.read.streaming.SparkDataStream
+import org.apache.spark.sql.execution.StreamSourceAwareSparkPlan
+
+trait ShimStreamSourceAwareSparkPlan extends StreamSourceAwareSparkPlan {
+ override def getStream: Option[SparkDataStream] = None
+}
diff --git a/spark/src/test/spark-4.1/org/apache/comet/exec/CometShuffle4_0Suite.scala b/spark/src/test/spark-4.1/org/apache/comet/exec/CometShuffle4_0Suite.scala
new file mode 100644
index 0000000000..3d0ec5006d
--- /dev/null
+++ b/spark/src/test/spark-4.1/org/apache/comet/exec/CometShuffle4_0Suite.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.exec
+
+import java.util.Collections
+
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryCatalog, InMemoryTableCatalog}
+import org.apache.spark.sql.connector.expressions.Expressions.identity
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{FloatType, LongType, StringType, TimestampType}
+
+class CometShuffle4_0Suite extends CometColumnarShuffleSuite {
+ override protected val asyncShuffleEnable: Boolean = false
+
+ protected val adaptiveExecutionEnabled: Boolean = true
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName)
+ }
+
+ override def afterAll(): Unit = {
+ spark.sessionState.conf.unsetConf("spark.sql.catalog.testcat")
+ super.afterAll()
+ }
+
+ private val emptyProps: java.util.Map[String, String] = {
+ Collections.emptyMap[String, String]
+ }
+ private val items: String = "items"
+ private val itemsColumns: Array[Column] = Array(
+ Column.create("id", LongType),
+ Column.create("name", StringType),
+ Column.create("price", FloatType),
+ Column.create("arrive_time", TimestampType))
+
+ private val purchases: String = "purchases"
+ private val purchasesColumns: Array[Column] = Array(
+ Column.create("item_id", LongType),
+ Column.create("price", FloatType),
+ Column.create("time", TimestampType))
+
+ protected def catalog: InMemoryCatalog = {
+ val catalog = spark.sessionState.catalogManager.catalog("testcat")
+ catalog.asInstanceOf[InMemoryCatalog]
+ }
+
+ private def createTable(
+ table: String,
+ columns: Array[Column],
+ partitions: Array[Transform],
+ catalog: InMemoryTableCatalog = catalog): Unit = {
+ catalog.createTable(Identifier.of(Array("ns"), table), columns, partitions, emptyProps)
+ }
+
+ private def selectWithMergeJoinHint(t1: String, t2: String): String = {
+ s"SELECT /*+ MERGE($t1, $t2) */ "
+ }
+
+ private def createJoinTestDF(
+ keys: Seq[(String, String)],
+ extraColumns: Seq[String] = Nil,
+ joinType: String = ""): DataFrame = {
+ val extraColList = if (extraColumns.isEmpty) "" else extraColumns.mkString(", ", ", ", "")
+ sql(s"""
+ |${selectWithMergeJoinHint("i", "p")}
+ |id, name, i.price as purchase_price, p.price as sale_price $extraColList
+ |FROM testcat.ns.$items i $joinType JOIN testcat.ns.$purchases p
+ |ON ${keys.map(k => s"i.${k._1} = p.${k._2}").mkString(" AND ")}
+ |ORDER BY id, purchase_price, sale_price $extraColList
+ |""".stripMargin)
+ }
+
+ test("Fallback to Spark for unsupported partitioning") {
+ val items_partitions = Array(identity("id"))
+ createTable(items, itemsColumns, items_partitions)
+
+ sql(
+ s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+ createTable(purchases, purchasesColumns, Array.empty)
+ sql(
+ s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 19.5, cast('2020-02-01' as timestamp)), " +
+ "(5, 26.0, cast('2023-01-01' as timestamp)), " +
+ "(6, 50.0, cast('2023-02-01' as timestamp))")
+
+ Seq(true, false).foreach { shuffle =>
+ withSQLConf(
+ SQLConf.V2_BUCKETING_ENABLED.key -> "true",
+ "spark.sql.sources.v2.bucketing.shuffle.enabled" -> shuffle.toString,
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "true") {
+ val df = createJoinTestDF(Seq("id" -> "item_id"))
+ checkSparkAnswer(df)
+ }
+ }
+ }
+}
diff --git a/spark/src/test/spark-4.1/org/apache/comet/iceberg/RESTCatalogHelper.scala b/spark/src/test/spark-4.1/org/apache/comet/iceberg/RESTCatalogHelper.scala
new file mode 100644
index 0000000000..bd53804b8d
--- /dev/null
+++ b/spark/src/test/spark-4.1/org/apache/comet/iceberg/RESTCatalogHelper.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.iceberg
+
+import java.io.File
+import java.nio.file.Files
+
+/** Helper trait for setting up REST catalog with Jetty 11 (jakarta.servlet) for Spark 4.0 */
+trait RESTCatalogHelper {
+
+ /** Helper to set up REST catalog with embedded Jetty server (Spark 4.0 / Jetty 11) */
+ def withRESTCatalog(f: (String, org.eclipse.jetty.server.Server, File) => Unit): Unit =
+ withRESTCatalog()(f)
+
+ /**
+ * Helper to set up REST catalog with optional credential vending.
+ *
+ * @param vendedCredentials
+ * Storage credentials to inject into loadTable responses, simulating REST catalog credential
+ * vending. When non-empty, these are added to every LoadTableResponse.config().
+ * @param warehouseLocation
+ * Override the warehouse location (e.g., for S3). Defaults to a local temp directory.
+ */
+ def withRESTCatalog(
+ vendedCredentials: Map[String, String] = Map.empty,
+ warehouseLocation: Option[String] = None)(
+ f: (String, org.eclipse.jetty.server.Server, File) => Unit): Unit = {
+ import org.apache.iceberg.inmemory.InMemoryCatalog
+ import org.apache.iceberg.CatalogProperties
+ import org.apache.iceberg.rest.{RESTCatalogAdapter, RESTCatalogServlet}
+ import org.eclipse.jetty.server.Server
+ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
+ import org.eclipse.jetty.server.handler.gzip.GzipHandler
+
+ val warehouseDir = Files.createTempDirectory("comet-rest-catalog-test").toFile
+ val effectiveWarehouse = warehouseLocation.getOrElse(warehouseDir.getAbsolutePath)
+
+ val backendCatalog = new InMemoryCatalog()
+ backendCatalog.initialize(
+ "in-memory",
+ java.util.Map.of(CatalogProperties.WAREHOUSE_LOCATION, effectiveWarehouse))
+
+ val adapter = new RESTCatalogAdapter(backendCatalog)
+ if (vendedCredentials.nonEmpty) {
+ import scala.jdk.CollectionConverters._
+ adapter.setVendedCredentials(vendedCredentials.asJava)
+ }
+ val servlet = new RESTCatalogServlet(adapter)
+
+ val servletContext = new ServletContextHandler(ServletContextHandler.NO_SESSIONS)
+ servletContext.setContextPath("/")
+ val servletHolder = new ServletHolder(servlet.asInstanceOf[jakarta.servlet.Servlet])
+ servletHolder.setInitParameter("jakarta.ws.rs.Application", "ServiceListPublic")
+ servletContext.addServlet(servletHolder, "/*")
+ servletContext.setVirtualHosts(null)
+ servletContext.insertHandler(new GzipHandler())
+
+ val httpServer = new Server(0) // random port
+ httpServer.setHandler(servletContext)
+
+ try {
+ httpServer.start()
+ val restUri = httpServer.getURI.toString.stripSuffix("/")
+ f(restUri, httpServer, warehouseDir)
+ } finally {
+ try {
+ httpServer.stop()
+ httpServer.join()
+ } catch {
+ case _: Exception => // ignore cleanup errors
+ }
+ try {
+ backendCatalog.close()
+ } catch {
+ case _: Exception => // ignore cleanup errors
+ }
+ def deleteRecursively(file: File): Unit = {
+ if (file.isDirectory) {
+ file.listFiles().foreach(deleteRecursively)
+ }
+ file.delete()
+ }
+ deleteRecursively(warehouseDir)
+ }
+ }
+}
diff --git a/spark/src/test/spark-4.1/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala b/spark/src/test/spark-4.1/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala
new file mode 100644
index 0000000000..80811d701f
--- /dev/null
+++ b/spark/src/test/spark-4.1/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.SQLQueryTestHelper
+
+trait ShimCometTPCHQuerySuite extends SQLQueryTestHelper {}
diff --git a/spark/src/test/spark-4.1/org/apache/iceberg/rest/RESTCatalogServlet.java b/spark/src/test/spark-4.1/org/apache/iceberg/rest/RESTCatalogServlet.java
new file mode 100644
index 0000000000..b54dacac48
--- /dev/null
+++ b/spark/src/test/spark-4.1/org/apache/iceberg/rest/RESTCatalogServlet.java
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.rest;
+
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.Reader;
+import java.io.UncheckedIOException;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import jakarta.servlet.http.HttpServlet;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.iceberg.exceptions.RESTException;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.relocated.com.google.common.io.CharStreams;
+import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod;
+import org.apache.iceberg.rest.RESTCatalogAdapter.Route;
+import org.apache.iceberg.rest.responses.ErrorResponse;
+import org.apache.iceberg.util.Pair;
+
+import static java.lang.String.format;
+
+/**
+ * The RESTCatalogServlet provides a servlet implementation used in combination with a
+ * RESTCatalogAdaptor to proxy the REST Spec to any Catalog implementation.
+ * Modified version of Iceberg's org/apache/iceberg/rest/RESTCatalogServlet.java
+ */
+public class RESTCatalogServlet extends HttpServlet {
+ private static final Logger LOG = LoggerFactory.getLogger(RESTCatalogServlet.class);
+
+ private final RESTCatalogAdapter restCatalogAdapter;
+ private final Map responseHeaders =
+ ImmutableMap.of("Content-Type", "application/json");
+
+ public RESTCatalogServlet(RESTCatalogAdapter restCatalogAdapter) {
+ this.restCatalogAdapter = restCatalogAdapter;
+ }
+
+ @Override
+ protected void doGet(HttpServletRequest request, HttpServletResponse response)
+ throws IOException {
+ execute(ServletRequestContext.from(request), response);
+ }
+
+ @Override
+ protected void doHead(HttpServletRequest request, HttpServletResponse response)
+ throws IOException {
+ execute(ServletRequestContext.from(request), response);
+ }
+
+ @Override
+ protected void doPost(HttpServletRequest request, HttpServletResponse response)
+ throws IOException {
+ execute(ServletRequestContext.from(request), response);
+ }
+
+ @Override
+ protected void doDelete(HttpServletRequest request, HttpServletResponse response)
+ throws IOException {
+ execute(ServletRequestContext.from(request), response);
+ }
+
+ protected void execute(ServletRequestContext context, HttpServletResponse response)
+ throws IOException {
+ response.setStatus(HttpServletResponse.SC_OK);
+ responseHeaders.forEach(response::setHeader);
+
+ if (context.error().isPresent()) {
+ response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
+ RESTObjectMapper.mapper().writeValue(response.getWriter(), context.error().get());
+ return;
+ }
+
+ try {
+ Object responseBody =
+ restCatalogAdapter.execute(
+ context.method(),
+ context.path(),
+ context.queryParams(),
+ context.body(),
+ context.route().responseClass(),
+ context.headers(),
+ handle(response));
+
+ if (responseBody != null) {
+ RESTObjectMapper.mapper().writeValue(response.getWriter(), responseBody);
+ }
+ } catch (RESTException e) {
+ LOG.error("Error processing REST request", e);
+ response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
+ } catch (Exception e) {
+ LOG.error("Unexpected exception when processing REST request", e);
+ response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
+ }
+ }
+
+ protected Consumer handle(HttpServletResponse response) {
+ return (errorResponse) -> {
+ response.setStatus(errorResponse.code());
+ try {
+ RESTObjectMapper.mapper().writeValue(response.getWriter(), errorResponse);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ };
+ }
+
+ public static class ServletRequestContext {
+ private HTTPMethod method;
+ private Route route;
+ private String path;
+ private Map headers;
+ private Map queryParams;
+ private Object body;
+
+ private ErrorResponse errorResponse;
+
+ private ServletRequestContext(ErrorResponse errorResponse) {
+ this.errorResponse = errorResponse;
+ }
+
+ private ServletRequestContext(
+ HTTPMethod method,
+ Route route,
+ String path,
+ Map headers,
+ Map queryParams,
+ Object body) {
+ this.method = method;
+ this.route = route;
+ this.path = path;
+ this.headers = headers;
+ this.queryParams = queryParams;
+ this.body = body;
+ }
+
+ static ServletRequestContext from(HttpServletRequest request) throws IOException {
+ HTTPMethod method = HTTPMethod.valueOf(request.getMethod());
+ String path = request.getRequestURI().substring(1);
+ Pair> routeContext = Route.from(method, path);
+
+ if (routeContext == null) {
+ return new ServletRequestContext(
+ ErrorResponse.builder()
+ .responseCode(400)
+ .withType("BadRequestException")
+ .withMessage(format("No route for request: %s %s", method, path))
+ .build());
+ }
+
+ Route route = routeContext.first();
+ Object requestBody = null;
+ if (route.requestClass() != null) {
+ requestBody =
+ RESTObjectMapper.mapper().readValue(request.getReader(), route.requestClass());
+ } else if (route == Route.TOKENS) {
+ try (Reader reader = new InputStreamReader(request.getInputStream())) {
+ requestBody = RESTUtil.decodeFormData(CharStreams.toString(reader));
+ }
+ }
+
+ Map queryParams =
+ request.getParameterMap().entrySet().stream()
+ .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue()[0]));
+ Map headers =
+ Collections.list(request.getHeaderNames()).stream()
+ .collect(Collectors.toMap(Function.identity(), request::getHeader));
+
+ return new ServletRequestContext(method, route, path, headers, queryParams, requestBody);
+ }
+
+ public HTTPMethod method() {
+ return method;
+ }
+
+ public Route route() {
+ return route;
+ }
+
+ public String path() {
+ return path;
+ }
+
+ public Map headers() {
+ return headers;
+ }
+
+ public Map queryParams() {
+ return queryParams;
+ }
+
+ public Object body() {
+ return body;
+ }
+
+ public Optional error() {
+ return Optional.ofNullable(errorResponse);
+ }
+ }
+}
diff --git a/spark/src/test/spark-4.1/org/apache/spark/comet/shims/ShimTestUtils.scala b/spark/src/test/spark-4.1/org/apache/spark/comet/shims/ShimTestUtils.scala
new file mode 100644
index 0000000000..923ae68f2e
--- /dev/null
+++ b/spark/src/test/spark-4.1/org/apache/spark/comet/shims/ShimTestUtils.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.comet.shims
+
+import java.io.File
+
+object ShimTestUtils {
+ def listDirectory(path: File): Array[String] =
+ org.apache.spark.TestUtils.listDirectory(path)
+}
diff --git a/spark/src/test/spark-4.1/org/apache/spark/sql/CometCollationSuite.scala b/spark/src/test/spark-4.1/org/apache/spark/sql/CometCollationSuite.scala
new file mode 100644
index 0000000000..463e169b66
--- /dev/null
+++ b/spark/src/test/spark-4.1/org/apache/spark/sql/CometCollationSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql
+
+class CometCollationSuite extends CometTestBase {
+
+ // Queries that group, sort, or shuffle on a non-default collated string must fall back to
+ // Spark because Comet's shuffle/sort/aggregate compare raw bytes rather than collation-aware
+ // keys. The shuffle-exchange rule is the primary line of defense (see #1947), so these tests
+ // pin down the fallback reason it emits.
+ private val hashShuffleCollationReason =
+ "unsupported hash partitioning data type for columnar shuffle"
+ private val rangeShuffleCollationReason =
+ "unsupported range partitioning data type for columnar shuffle"
+
+ test("listagg DISTINCT with utf8_lcase collation (issue #1947)") {
+ checkSparkAnswerAndFallbackReason(
+ "SELECT lower(listagg(DISTINCT c1 COLLATE utf8_lcase) " +
+ "WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) " +
+ "FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1)",
+ hashShuffleCollationReason)
+ }
+
+ test("DISTINCT on utf8_lcase collated string groups case-insensitively") {
+ checkSparkAnswerAndFallbackReason(
+ "SELECT DISTINCT c1 COLLATE utf8_lcase AS c " +
+ "FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) ORDER BY c",
+ hashShuffleCollationReason)
+ }
+
+ test("GROUP BY utf8_lcase collated string groups case-insensitively") {
+ checkSparkAnswerAndFallbackReason(
+ "SELECT lower(c1 COLLATE utf8_lcase) AS k, count(*) " +
+ "FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) " +
+ "GROUP BY c1 COLLATE utf8_lcase ORDER BY k",
+ hashShuffleCollationReason)
+ }
+
+ test("ORDER BY utf8_lcase collated string sorts case-insensitively") {
+ checkSparkAnswerAndFallbackReason(
+ "SELECT c1 COLLATE utf8_lcase AS c " +
+ "FROM (VALUES ('A'), ('b'), ('a'), ('B')) AS t(c1) ORDER BY c",
+ rangeShuffleCollationReason)
+ }
+
+ test("default UTF8_BINARY string still runs through Comet") {
+ // Sanity check that the collation fallback does not over-block the default string type.
+ withParquetTable(Seq(("a", 1), ("b", 2), ("a", 3)), "tbl") {
+ checkSparkAnswerAndOperator("SELECT DISTINCT _1 FROM tbl ORDER BY _1")
+ }
+ }
+}
diff --git a/spark/src/test/spark-4.1/org/apache/spark/sql/CometToPrettyStringSuite.scala b/spark/src/test/spark-4.1/org/apache/spark/sql/CometToPrettyStringSuite.scala
new file mode 100644
index 0000000000..e7f1757bf6
--- /dev/null
+++ b/spark/src/test/spark-4.1/org/apache/spark/sql/CometToPrettyStringSuite.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.{Alias, ToPrettyString}
+import org.apache.spark.sql.catalyst.plans.logical.Project
+import org.apache.spark.sql.classic.Dataset
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.BinaryOutputStyle
+import org.apache.spark.sql.types.DataTypes
+
+import org.apache.comet.CometFuzzTestBase
+import org.apache.comet.expressions.{CometCast, CometEvalMode}
+import org.apache.comet.serde.Compatible
+
+class CometToPrettyStringSuite extends CometFuzzTestBase {
+
+ test("ToPrettyString") {
+ val style = List(
+ BinaryOutputStyle.UTF8,
+ BinaryOutputStyle.BASIC,
+ BinaryOutputStyle.BASE64,
+ BinaryOutputStyle.HEX,
+ BinaryOutputStyle.HEX_DISCRETE)
+ style.foreach(s =>
+ withSQLConf(SQLConf.BINARY_OUTPUT_STYLE.key -> s.toString) {
+ val df = spark.read.parquet(filename)
+ df.createOrReplaceTempView("t1")
+ val table = spark.sessionState.catalog.lookupRelation(TableIdentifier("t1"))
+
+ for (field <- df.schema.fields) {
+ val col = field.name
+ val prettyExpr = Alias(ToPrettyString(UnresolvedAttribute(col)), s"pretty_$col")()
+ val plan = Project(Seq(prettyExpr), table)
+ val analyzed = spark.sessionState.analyzer.execute(plan)
+ val result: DataFrame = Dataset.ofRows(spark, analyzed)
+ val supportLevel = CometCast.isSupported(
+ field.dataType,
+ DataTypes.StringType,
+ Some(spark.sessionState.conf.sessionLocalTimeZone),
+ CometEvalMode.TRY)
+ supportLevel match {
+ case _: Compatible => checkSparkAnswerAndOperator(result)
+ case _ => checkSparkAnswer(result)
+ }
+ }
+ })
+ }
+}
diff --git a/spark/src/test/spark-4.1/org/apache/spark/sql/ShimCometTestBase.scala b/spark/src/test/spark-4.1/org/apache/spark/sql/ShimCometTestBase.scala
new file mode 100644
index 0000000000..5ad4543220
--- /dev/null
+++ b/spark/src/test/spark-4.1/org/apache/spark/sql/ShimCometTestBase.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.classic.{Dataset, ExpressionColumnNode, SparkSession}
+
+trait ShimCometTestBase {
+ type SparkSessionType = SparkSession
+
+ def createSparkSessionWithExtensions(conf: SparkConf): SparkSessionType = {
+ SparkSession
+ .builder()
+ .config(conf)
+ .master("local[1]")
+ .withExtensions(new org.apache.comet.CometSparkSessionExtensions)
+ .getOrCreate()
+ }
+
+ def datasetOfRows(spark: SparkSession, plan: LogicalPlan): DataFrame = {
+ Dataset.ofRows(spark, plan)
+ }
+
+ def getColumnFromExpression(expr: Expression): Column = {
+ new Column(ExpressionColumnNode.apply(expr))
+ }
+
+ def extractLogicalPlan(df: DataFrame): LogicalPlan = {
+ df.queryExecution.analyzed
+ }
+
+ def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int): Column = {
+ new Column(ExpressionColumnNode.apply(MakeDecimal(child, precision, scale, true)))
+ }
+}
diff --git a/spark/src/test/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala b/spark/src/test/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala
new file mode 100644
index 0000000000..f03608d3e3
--- /dev/null
+++ b/spark/src/test/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+trait ShimCometTPCDSQuerySuite {}