Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .claude/skills/audit-comet-expression/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,16 @@ Read the serde implementation and check:
- Whether `getSupportLevel` is implemented and accurate
- Whether all input types are handled
- Whether any types are explicitly marked `Unsupported`
- Whether `getIncompatibleReasons()` and `getUnsupportedReasons()` are overridden.
`getSupportLevel` controls runtime fallback, but `GenerateDocs` reads these two
methods to build the Compatibility Guide. If `getSupportLevel` returns
`Incompatible(Some(...))` or `Unsupported(Some(...))` but the corresponding
`get*Reasons()` method is not overridden, the reason will not appear in the
published docs. Expect both to return a `Seq[String]` containing the same
reason text used in `getSupportLevel`. Follow the pattern in
`spark/src/main/scala/org/apache/comet/serde/structs.scala::CometStructsToJson`
or `spark/src/main/scala/org/apache/comet/serde/datetime.scala::CometHour`:
extract the reason as a `private val` and reference it from both places.

### Shims

Expand Down Expand Up @@ -227,6 +237,7 @@ Also review the Comet implementation (Step 3) against the Spark behavior (Step 1
- Are there behavioral differences that are NOT marked `Incompatible` but should be?
- Are there behavioral differences between Spark versions that the Comet implementation does not account for (missing shim)?
- Does the Rust implementation match the Spark behavior for all edge cases?
- If `getSupportLevel` returns `Incompatible` or `Unsupported` with a reason, are `getIncompatibleReasons()` / `getUnsupportedReasons()` also overridden so the reason is picked up by `GenerateDocs` and appears in the Compatibility Guide?

---

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,7 @@ import org.apache.spark.sql.types.DataType
trait CometTypeShim {
@nowarn // Spark 4 feature; stubbed to false in Spark 3.x for compatibility.
def isStringCollationType(dt: DataType): Boolean = false

@nowarn // Spark 4 feature; stubbed to false in Spark 3.x for compatibility.
def hasNonDefaultStringCollation(dt: DataType): Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

package org.apache.comet.shims

import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType, StructType}

trait CometTypeShim {
// A `StringType` carries collation metadata in Spark 4.0. Only non-default (non-UTF8_BINARY)
Expand All @@ -31,4 +31,18 @@ trait CometTypeShim {
case st: StringType => st.collationId != StringType.collationId
case _ => false
}

/**
* Returns true if `dt`, or any nested element/field/key/value type, is a `StringType` with a
* non-default (non-UTF8_BINARY) collation. Expression serdes can use this to fall back to Spark
* when they cannot honour collation semantics. Stubbed to `false` in Spark 3.x.
*/
def hasNonDefaultStringCollation(dt: DataType): Boolean = dt match {
case _: StringType => isStringCollationType(dt)
case ArrayType(elementType, _) => hasNonDefaultStringCollation(elementType)
case MapType(kt, vt, _) =>
hasNonDefaultStringCollation(kt) || hasNonDefaultStringCollation(vt)
case StructType(fields) => fields.exists(f => hasNonDefaultStringCollation(f.dataType))
case _ => false
}
}
3 changes: 3 additions & 0 deletions docs/source/contributor-guide/spark_expressions_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@
- Spark 3.5.8 audited 2026-04-02
- Spark 4.0.1 audited 2026-04-02 (pos=0 error message differs from Spark)
- [x] array_intersect
- Spark 3.4.3 audited 2026-04-24 (result element order may differ from Spark when the right array is longer than the left; DataFusion probes the longer side)
- Spark 3.5.8 audited 2026-04-24 (same ordering incompatibility as 3.4.3)
- Spark 4.0.1 audited 2026-04-24 (ordering incompatibility as above; collated strings now fall back to Spark)
- [x] array_join
- [x] array_max
- [ ] array_min
Expand Down
24 changes: 19 additions & 5 deletions spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.types._
import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde._
import org.apache.comet.shims.CometExprShim
import org.apache.comet.shims.{CometExprShim, CometTypeShim}

object CometArrayRemove
extends CometExpressionSerde[ArrayRemove]
Expand Down Expand Up @@ -191,12 +191,26 @@ object CometSortArray extends CometExpressionSerde[SortArray] {
}
}

object CometArrayIntersect extends CometExpressionSerde[ArrayIntersect] {
object CometArrayIntersect extends CometExpressionSerde[ArrayIntersect] with CometTypeShim {

override def getIncompatibleReasons(): Seq[String] = Seq(
"Null handling and ordering may differ from Spark")
private val incompatReason: String =
"Result array element order may differ from Spark when the right array is longer " +
"than the left (DataFusion probes the longer side)."

private val unsupportedCollationReason: String =
"array_intersect on collated strings is not supported."

override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason)

override def getSupportLevel(expr: ArrayIntersect): SupportLevel = Incompatible(None)
override def getUnsupportedReasons(): Seq[String] = Seq(unsupportedCollationReason)

override def getSupportLevel(expr: ArrayIntersect): SupportLevel = {
if (hasNonDefaultStringCollation(expr.dataType)) {
Unsupported(Some(unsupportedCollationReason))
} else {
Incompatible(Some(incompatReason))
}
}

override def convert(
expr: ArrayIntersect,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@

-- Config: spark.comet.expression.ArrayIntersect.allowIncompatible=true

-- DataFusion's array_intersect emits elements in the order of the longer input
-- (it uses the shorter side as the hash lookup), while Spark emits elements in
-- the order they appear in the left argument. Most cases below have
-- left.length >= right.length so the element order matches Spark; cases that
-- intentionally flip that relation use sort_array on the result so the
-- assertion stays stable.

statement
CREATE TABLE test_array_intersect(a array<int>, b array<int>) USING parquet

Expand All @@ -37,3 +44,169 @@ SELECT array_intersect(array(1, 2, 3), b) FROM test_array_intersect
-- literal + literal
query
SELECT array_intersect(array(1, 2, 3), array(2, 3, 4)), array_intersect(array(1, 2), array(3, 4)), array_intersect(array(), array(1)), array_intersect(cast(NULL as array<int>), array(1))

-- right longer than left: element order diverges from Spark, so sort the result
query
SELECT sort_array(array_intersect(array(2, 1), array(3, 1, 2))), sort_array(array_intersect(array(3, 1), array(1, 2, 3, 4)))

-- duplicate elements within and across arrays
statement
CREATE TABLE test_intersect_dups(a array<int>, b array<int>) USING parquet

statement
INSERT INTO test_intersect_dups VALUES (array(1, 1, 1), array(1, 1)), (array(1, 2, 1, 2), array(2, 1, 2, 1)), (array(1, 2, 3), array(1, 2, 3)), (array(1, 1, 2, 2, 3, 3), array(2, 2))

query
SELECT a, b, array_intersect(a, b) FROM test_intersect_dups

-- both-NULL arrays and all-NULL element arrays
statement
CREATE TABLE test_intersect_nulls(a array<int>, b array<int>) USING parquet

statement
INSERT INTO test_intersect_nulls VALUES (array(NULL), array(NULL)), (array(NULL, NULL), array(NULL)), (array(NULL, NULL), array(NULL, NULL)), (array(1, NULL), array(1, NULL)), (array(1, NULL), array(NULL)), (array(1), array(NULL))

query
SELECT a, b, array_intersect(a, b) FROM test_intersect_nulls

query
SELECT array_intersect(cast(NULL as array<int>), cast(NULL as array<int>))

-- self-intersection (deduplication)
query
SELECT a, array_intersect(a, a) FROM test_intersect_dups

-- empty array combinations
query
SELECT array_intersect(array(), array()), array_intersect(array(), array(1, 2)), array_intersect(array(1, 2), array())

-- boolean arrays
statement
CREATE TABLE test_intersect_bool(a array<boolean>, b array<boolean>) USING parquet

statement
INSERT INTO test_intersect_bool VALUES (array(true, false), array(true)), (array(true, true), array(false)), (array(true, false, NULL), array(false, NULL))

query
SELECT a, b, array_intersect(a, b) FROM test_intersect_bool

-- tinyint arrays
statement
CREATE TABLE test_intersect_byte(a array<tinyint>, b array<tinyint>) USING parquet

statement
INSERT INTO test_intersect_byte VALUES (array(cast(127 as tinyint), cast(-128 as tinyint), cast(0 as tinyint)), array(cast(127 as tinyint), cast(1 as tinyint))), (array(cast(1 as tinyint), cast(2 as tinyint)), array(cast(3 as tinyint), cast(4 as tinyint)))

query
SELECT a, b, array_intersect(a, b) FROM test_intersect_byte

-- smallint arrays
statement
CREATE TABLE test_intersect_short(a array<smallint>, b array<smallint>) USING parquet

statement
INSERT INTO test_intersect_short VALUES (array(cast(32767 as smallint), cast(-32768 as smallint)), array(cast(32767 as smallint))), (array(cast(1 as smallint), cast(2 as smallint)), array(cast(2 as smallint), cast(3 as smallint)))

query
SELECT a, b, array_intersect(a, b) FROM test_intersect_short

-- bigint arrays with boundary values
statement
CREATE TABLE test_intersect_long(a array<bigint>, b array<bigint>) USING parquet

statement
INSERT INTO test_intersect_long VALUES (array(9223372036854775807, 1, -9223372036854775808), array(9223372036854775807, -9223372036854775808)), (array(0, 1, 2), array(3, 4, 5))

query
SELECT a, b, array_intersect(a, b) FROM test_intersect_long

-- float arrays with NaN, Infinity, -Infinity
statement
CREATE TABLE test_intersect_float(a array<float>, b array<float>) USING parquet

statement
INSERT INTO test_intersect_float VALUES (array(1.5, 2.5, float('NaN')), array(2.5, float('NaN'))), (array(float('Infinity'), 1.0, float('-Infinity')), array(float('Infinity'), float('-Infinity'))), (array(cast(0.0 as float), cast(-0.0 as float)), array(cast(0.0 as float)))

query
SELECT a, b, array_intersect(a, b) FROM test_intersect_float

-- double arrays with NaN, Infinity, -Infinity
statement
CREATE TABLE test_intersect_dbl(a array<double>, b array<double>) USING parquet

statement
INSERT INTO test_intersect_dbl VALUES (array(1.0, 2.0, double('NaN')), array(2.0, double('NaN'))), (array(double('Infinity'), 1.0, double('-Infinity')), array(double('Infinity'), double('-Infinity'))), (array(0.0, -0.0), array(0.0)), (array(1.0, 2.0, NULL), array(1.0, NULL))

query
SELECT a, b, array_intersect(a, b) FROM test_intersect_dbl

-- decimal arrays
statement
CREATE TABLE test_intersect_dec(a array<decimal(10,2)>, b array<decimal(10,2)>) USING parquet

statement
INSERT INTO test_intersect_dec VALUES (array(1.00, 2.50, 3.00), array(2.50, 3.00, 4.00)), (array(1.00, 2.00), array(3.00, 4.00)), (array(1.10, NULL), array(1.10, NULL))

query
SELECT a, b, array_intersect(a, b) FROM test_intersect_dec

-- date arrays
statement
CREATE TABLE test_intersect_date(a array<date>, b array<date>) USING parquet

statement
INSERT INTO test_intersect_date VALUES (array(date '2024-01-01', date '2024-06-15', date '2024-12-31'), array(date '2024-06-15', date '2024-12-31')), (array(date '2024-01-01'), array(date '2024-12-31')), (array(date '2024-01-01', NULL), array(date '2024-01-01', NULL))

query
SELECT a, b, array_intersect(a, b) FROM test_intersect_date

-- timestamp arrays
statement
CREATE TABLE test_intersect_ts(a array<timestamp>, b array<timestamp>) USING parquet

statement
INSERT INTO test_intersect_ts VALUES (array(timestamp '2024-01-01 00:00:00', timestamp '2024-06-15 12:00:00', timestamp '2024-12-31 23:59:59'), array(timestamp '2024-06-15 12:00:00', timestamp '2024-12-31 23:59:59')), (array(timestamp '2024-01-01 00:00:00'), array(timestamp '2024-12-31 23:59:59'))

query
SELECT a, b, array_intersect(a, b) FROM test_intersect_ts

-- string arrays including empty strings and multibyte UTF-8
statement
CREATE TABLE test_intersect_str(a array<string>, b array<string>) USING parquet

statement
INSERT INTO test_intersect_str VALUES (array('a', 'b', 'c'), array('b', 'c', 'd')), (array('a', NULL, 'b'), array('a', NULL)), (array('', 'a'), array('', 'b')), (array('café', '中文', 'test'), array('café', 'other')), (NULL, array('a'))

query
SELECT a, b, array_intersect(a, b) FROM test_intersect_str

-- binary arrays
statement
CREATE TABLE test_intersect_bin(a array<binary>, b array<binary>) USING parquet

statement
INSERT INTO test_intersect_bin VALUES (array(binary('abc'), binary('def'), binary('ghi')), array(binary('def'), binary('ghi'))), (array(binary('abc')), array(binary('xyz'))), (array(binary('abc'), NULL), array(binary('abc'), NULL))

query
SELECT a, b, array_intersect(a, b) FROM test_intersect_bin

-- nested array<array<int>>
statement
CREATE TABLE test_intersect_nested(a array<array<int>>, b array<array<int>>) USING parquet

statement
INSERT INTO test_intersect_nested VALUES (array(array(1, 2), array(3, 4), array(5, 6)), array(array(3, 4), array(5, 6))), (array(array(1, 2)), array(array(3, 4))), (array(array(1, 2), cast(NULL as array<int>)), array(array(1, 2), cast(NULL as array<int>)))

query
SELECT a, b, array_intersect(a, b) FROM test_intersect_nested

-- mixed column and literal with NULL elements
query
SELECT array_intersect(a, array(1, 2, NULL)) FROM test_array_intersect

query
SELECT array_intersect(array(1, NULL, 3), b) FROM test_array_intersect

-- conditional (CASE WHEN) arrays
query
SELECT array_intersect(CASE WHEN a IS NOT NULL THEN a ELSE array(0) END, b) FROM test_array_intersect
43 changes: 38 additions & 5 deletions spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,9 @@ abstract class CometTestBase
|${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
""".stripMargin)
}
if (!QueryTest.compare(
QueryTest.prepareAnswer(sparkAnswer, isSorted),
QueryTest.prepareAnswer(cometAnswer, isSorted))) {
val preparedSpark = prepareCometAnswer(sparkAnswer, isSorted)
val preparedComet = prepareCometAnswer(cometAnswer, isSorted)
if (!QueryTest.compare(preparedSpark, preparedComet)) {
val getRowType: Option[Row] => String = row =>
row
.map(r => if (r.schema == null) "struct<>" else r.schema.catalogString)
Expand All @@ -394,14 +394,47 @@ abstract class CometTestBase
|${sideBySide(
s"== Spark Answer - ${sparkAnswer.size} ==" +:
getRowType(sparkAnswer.headOption) +:
QueryTest.prepareAnswer(sparkAnswer, isSorted).map(_.toString()),
preparedSpark.map(_.toString()),
s"== Comet Answer - ${cometAnswer.size} ==" +:
getRowType(cometAnswer.headOption) +:
QueryTest.prepareAnswer(cometAnswer, isSorted).map(_.toString())).mkString("\n")}
preparedComet.map(_.toString())).mkString("\n")}
""".stripMargin)
}
}

/**
* Like `QueryTest.prepareAnswer` but recursively converts nested arrays to seqs. Spark's
* version only normalizes top-level `Array[_]`, leaving inner arrays (e.g. `Array[Byte]` from
* `array<binary>`) intact. Their default `toString` is the JVM identity (`[B@<hex>`), which
* makes the toString-based sort in `prepareAnswer` non-deterministic and causes spurious
* mismatches between the two sides.
*/
private def prepareCometAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = {
val converted = answer.map(prepareCometRow)
if (isSorted) converted else converted.sortBy(_.toString())
}

private def prepareCometRow(row: Row): Row = {
Row.fromSeq(row.toSeq.map(normalizeForComparison))
}

private def normalizeForComparison(value: Any): Any = value match {
case null => null
case bd: java.math.BigDecimal => BigDecimal(bd)
case row: Row => prepareCometRow(row)
case arr: Array[_] => arr.toSeq.map(normalizeForComparison)
case map: scala.collection.Map[_, _] =>
map.map { case (k, v) => normalizeForComparison(k) -> normalizeForComparison(v) }
case seq: scala.collection.Iterable[_] => seq.map(normalizeForComparison).toSeq
case b: java.lang.Byte => b.byteValue
case s: java.lang.Short => s.shortValue
case i: java.lang.Integer => i.intValue
case l: java.lang.Long => l.longValue
case f: java.lang.Float => f.floatValue
case d: java.lang.Double => d.doubleValue
case x => x
}

/**
* A helper function for comparing Comet DataFrame with Spark result using absolute tolerance.
*/
Expand Down
Loading