diff --git a/.claude/skills/audit-comet-expression/SKILL.md b/.claude/skills/audit-comet-expression/SKILL.md index 18b3400fe6..d13bf21166 100644 --- a/.claude/skills/audit-comet-expression/SKILL.md +++ b/.claude/skills/audit-comet-expression/SKILL.md @@ -122,6 +122,16 @@ Read the serde implementation and check: - Whether `getSupportLevel` is implemented and accurate - Whether all input types are handled - Whether any types are explicitly marked `Unsupported` +- Whether `getIncompatibleReasons()` and `getUnsupportedReasons()` are overridden. + `getSupportLevel` controls runtime fallback, but `GenerateDocs` reads these two + methods to build the Compatibility Guide. If `getSupportLevel` returns + `Incompatible(Some(...))` or `Unsupported(Some(...))` but the corresponding + `get*Reasons()` method is not overridden, the reason will not appear in the + published docs. Expect both to return a `Seq[String]` containing the same + reason text used in `getSupportLevel`. Follow the pattern in + `spark/src/main/scala/org/apache/comet/serde/structs.scala::CometStructsToJson` + or `spark/src/main/scala/org/apache/comet/serde/datetime.scala::CometHour`: + extract the reason as a `private val` and reference it from both places. ### Shims @@ -227,6 +237,7 @@ Also review the Comet implementation (Step 3) against the Spark behavior (Step 1 - Are there behavioral differences that are NOT marked `Incompatible` but should be? - Are there behavioral differences between Spark versions that the Comet implementation does not account for (missing shim)? - Does the Rust implementation match the Spark behavior for all edge cases? +- If `getSupportLevel` returns `Incompatible` or `Unsupported` with a reason, are `getIncompatibleReasons()` / `getUnsupportedReasons()` also overridden so the reason is picked up by `GenerateDocs` and appears in the Compatibility Guide? --- diff --git a/common/src/main/spark-3.x/org/apache/comet/shims/CometTypeShim.scala b/common/src/main/spark-3.x/org/apache/comet/shims/CometTypeShim.scala index 46a81818a0..a46e093534 100644 --- a/common/src/main/spark-3.x/org/apache/comet/shims/CometTypeShim.scala +++ b/common/src/main/spark-3.x/org/apache/comet/shims/CometTypeShim.scala @@ -26,4 +26,7 @@ import org.apache.spark.sql.types.DataType trait CometTypeShim { @nowarn // Spark 4 feature; stubbed to false in Spark 3.x for compatibility. def isStringCollationType(dt: DataType): Boolean = false + + @nowarn // Spark 4 feature; stubbed to false in Spark 3.x for compatibility. + def hasNonDefaultStringCollation(dt: DataType): Boolean = false } diff --git a/common/src/main/spark-4.0/org/apache/comet/shims/CometTypeShim.scala b/common/src/main/spark-4.0/org/apache/comet/shims/CometTypeShim.scala index 1b82c04b20..2eb52c3b79 100644 --- a/common/src/main/spark-4.0/org/apache/comet/shims/CometTypeShim.scala +++ b/common/src/main/spark-4.0/org/apache/comet/shims/CometTypeShim.scala @@ -19,7 +19,7 @@ package org.apache.comet.shims -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType, StructType} trait CometTypeShim { // A `StringType` carries collation metadata in Spark 4.0. Only non-default (non-UTF8_BINARY) @@ -31,4 +31,18 @@ trait CometTypeShim { case st: StringType => st.collationId != StringType.collationId case _ => false } + + /** + * Returns true if `dt`, or any nested element/field/key/value type, is a `StringType` with a + * non-default (non-UTF8_BINARY) collation. Expression serdes can use this to fall back to Spark + * when they cannot honour collation semantics. Stubbed to `false` in Spark 3.x. + */ + def hasNonDefaultStringCollation(dt: DataType): Boolean = dt match { + case _: StringType => isStringCollationType(dt) + case ArrayType(elementType, _) => hasNonDefaultStringCollation(elementType) + case MapType(kt, vt, _) => + hasNonDefaultStringCollation(kt) || hasNonDefaultStringCollation(vt) + case StructType(fields) => fields.exists(f => hasNonDefaultStringCollation(f.dataType)) + case _ => false + } } diff --git a/docs/source/contributor-guide/spark_expressions_support.md b/docs/source/contributor-guide/spark_expressions_support.md index 2caa0db86d..1e4b4e34bc 100644 --- a/docs/source/contributor-guide/spark_expressions_support.md +++ b/docs/source/contributor-guide/spark_expressions_support.md @@ -93,6 +93,9 @@ - Spark 3.5.8 audited 2026-04-02 - Spark 4.0.1 audited 2026-04-02 (pos=0 error message differs from Spark) - [x] array_intersect + - Spark 3.4.3 audited 2026-04-24 (result element order may differ from Spark when the right array is longer than the left; DataFusion probes the longer side) + - Spark 3.5.8 audited 2026-04-24 (same ordering incompatibility as 3.4.3) + - Spark 4.0.1 audited 2026-04-24 (ordering incompatibility as above; collated strings now fall back to Spark) - [x] array_join - [x] array_max - [ ] array_min diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 50833c00ca..6b19038dfb 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types._ import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.QueryPlanSerde._ -import org.apache.comet.shims.CometExprShim +import org.apache.comet.shims.{CometExprShim, CometTypeShim} object CometArrayRemove extends CometExpressionSerde[ArrayRemove] @@ -191,12 +191,26 @@ object CometSortArray extends CometExpressionSerde[SortArray] { } } -object CometArrayIntersect extends CometExpressionSerde[ArrayIntersect] { +object CometArrayIntersect extends CometExpressionSerde[ArrayIntersect] with CometTypeShim { - override def getIncompatibleReasons(): Seq[String] = Seq( - "Null handling and ordering may differ from Spark") + private val incompatReason: String = + "Result array element order may differ from Spark when the right array is longer " + + "than the left (DataFusion probes the longer side)." + + private val unsupportedCollationReason: String = + "array_intersect on collated strings is not supported." + + override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason) - override def getSupportLevel(expr: ArrayIntersect): SupportLevel = Incompatible(None) + override def getUnsupportedReasons(): Seq[String] = Seq(unsupportedCollationReason) + + override def getSupportLevel(expr: ArrayIntersect): SupportLevel = { + if (hasNonDefaultStringCollation(expr.dataType)) { + Unsupported(Some(unsupportedCollationReason)) + } else { + Incompatible(Some(incompatReason)) + } + } override def convert( expr: ArrayIntersect, diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_intersect.sql b/spark/src/test/resources/sql-tests/expressions/array/array_intersect.sql index dc92f41edc..e3554a580e 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_intersect.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_intersect.sql @@ -17,6 +17,13 @@ -- Config: spark.comet.expression.ArrayIntersect.allowIncompatible=true +-- DataFusion's array_intersect emits elements in the order of the longer input +-- (it uses the shorter side as the hash lookup), while Spark emits elements in +-- the order they appear in the left argument. Most cases below have +-- left.length >= right.length so the element order matches Spark; cases that +-- intentionally flip that relation use sort_array on the result so the +-- assertion stays stable. + statement CREATE TABLE test_array_intersect(a array, b array) USING parquet @@ -37,3 +44,169 @@ SELECT array_intersect(array(1, 2, 3), b) FROM test_array_intersect -- literal + literal query SELECT array_intersect(array(1, 2, 3), array(2, 3, 4)), array_intersect(array(1, 2), array(3, 4)), array_intersect(array(), array(1)), array_intersect(cast(NULL as array), array(1)) + +-- right longer than left: element order diverges from Spark, so sort the result +query +SELECT sort_array(array_intersect(array(2, 1), array(3, 1, 2))), sort_array(array_intersect(array(3, 1), array(1, 2, 3, 4))) + +-- duplicate elements within and across arrays +statement +CREATE TABLE test_intersect_dups(a array, b array) USING parquet + +statement +INSERT INTO test_intersect_dups VALUES (array(1, 1, 1), array(1, 1)), (array(1, 2, 1, 2), array(2, 1, 2, 1)), (array(1, 2, 3), array(1, 2, 3)), (array(1, 1, 2, 2, 3, 3), array(2, 2)) + +query +SELECT a, b, array_intersect(a, b) FROM test_intersect_dups + +-- both-NULL arrays and all-NULL element arrays +statement +CREATE TABLE test_intersect_nulls(a array, b array) USING parquet + +statement +INSERT INTO test_intersect_nulls VALUES (array(NULL), array(NULL)), (array(NULL, NULL), array(NULL)), (array(NULL, NULL), array(NULL, NULL)), (array(1, NULL), array(1, NULL)), (array(1, NULL), array(NULL)), (array(1), array(NULL)) + +query +SELECT a, b, array_intersect(a, b) FROM test_intersect_nulls + +query +SELECT array_intersect(cast(NULL as array), cast(NULL as array)) + +-- self-intersection (deduplication) +query +SELECT a, array_intersect(a, a) FROM test_intersect_dups + +-- empty array combinations +query +SELECT array_intersect(array(), array()), array_intersect(array(), array(1, 2)), array_intersect(array(1, 2), array()) + +-- boolean arrays +statement +CREATE TABLE test_intersect_bool(a array, b array) USING parquet + +statement +INSERT INTO test_intersect_bool VALUES (array(true, false), array(true)), (array(true, true), array(false)), (array(true, false, NULL), array(false, NULL)) + +query +SELECT a, b, array_intersect(a, b) FROM test_intersect_bool + +-- tinyint arrays +statement +CREATE TABLE test_intersect_byte(a array, b array) USING parquet + +statement +INSERT INTO test_intersect_byte VALUES (array(cast(127 as tinyint), cast(-128 as tinyint), cast(0 as tinyint)), array(cast(127 as tinyint), cast(1 as tinyint))), (array(cast(1 as tinyint), cast(2 as tinyint)), array(cast(3 as tinyint), cast(4 as tinyint))) + +query +SELECT a, b, array_intersect(a, b) FROM test_intersect_byte + +-- smallint arrays +statement +CREATE TABLE test_intersect_short(a array, b array) USING parquet + +statement +INSERT INTO test_intersect_short VALUES (array(cast(32767 as smallint), cast(-32768 as smallint)), array(cast(32767 as smallint))), (array(cast(1 as smallint), cast(2 as smallint)), array(cast(2 as smallint), cast(3 as smallint))) + +query +SELECT a, b, array_intersect(a, b) FROM test_intersect_short + +-- bigint arrays with boundary values +statement +CREATE TABLE test_intersect_long(a array, b array) USING parquet + +statement +INSERT INTO test_intersect_long VALUES (array(9223372036854775807, 1, -9223372036854775808), array(9223372036854775807, -9223372036854775808)), (array(0, 1, 2), array(3, 4, 5)) + +query +SELECT a, b, array_intersect(a, b) FROM test_intersect_long + +-- float arrays with NaN, Infinity, -Infinity +statement +CREATE TABLE test_intersect_float(a array, b array) USING parquet + +statement +INSERT INTO test_intersect_float VALUES (array(1.5, 2.5, float('NaN')), array(2.5, float('NaN'))), (array(float('Infinity'), 1.0, float('-Infinity')), array(float('Infinity'), float('-Infinity'))), (array(cast(0.0 as float), cast(-0.0 as float)), array(cast(0.0 as float))) + +query +SELECT a, b, array_intersect(a, b) FROM test_intersect_float + +-- double arrays with NaN, Infinity, -Infinity +statement +CREATE TABLE test_intersect_dbl(a array, b array) USING parquet + +statement +INSERT INTO test_intersect_dbl VALUES (array(1.0, 2.0, double('NaN')), array(2.0, double('NaN'))), (array(double('Infinity'), 1.0, double('-Infinity')), array(double('Infinity'), double('-Infinity'))), (array(0.0, -0.0), array(0.0)), (array(1.0, 2.0, NULL), array(1.0, NULL)) + +query +SELECT a, b, array_intersect(a, b) FROM test_intersect_dbl + +-- decimal arrays +statement +CREATE TABLE test_intersect_dec(a array, b array) USING parquet + +statement +INSERT INTO test_intersect_dec VALUES (array(1.00, 2.50, 3.00), array(2.50, 3.00, 4.00)), (array(1.00, 2.00), array(3.00, 4.00)), (array(1.10, NULL), array(1.10, NULL)) + +query +SELECT a, b, array_intersect(a, b) FROM test_intersect_dec + +-- date arrays +statement +CREATE TABLE test_intersect_date(a array, b array) USING parquet + +statement +INSERT INTO test_intersect_date VALUES (array(date '2024-01-01', date '2024-06-15', date '2024-12-31'), array(date '2024-06-15', date '2024-12-31')), (array(date '2024-01-01'), array(date '2024-12-31')), (array(date '2024-01-01', NULL), array(date '2024-01-01', NULL)) + +query +SELECT a, b, array_intersect(a, b) FROM test_intersect_date + +-- timestamp arrays +statement +CREATE TABLE test_intersect_ts(a array, b array) USING parquet + +statement +INSERT INTO test_intersect_ts VALUES (array(timestamp '2024-01-01 00:00:00', timestamp '2024-06-15 12:00:00', timestamp '2024-12-31 23:59:59'), array(timestamp '2024-06-15 12:00:00', timestamp '2024-12-31 23:59:59')), (array(timestamp '2024-01-01 00:00:00'), array(timestamp '2024-12-31 23:59:59')) + +query +SELECT a, b, array_intersect(a, b) FROM test_intersect_ts + +-- string arrays including empty strings and multibyte UTF-8 +statement +CREATE TABLE test_intersect_str(a array, b array) USING parquet + +statement +INSERT INTO test_intersect_str VALUES (array('a', 'b', 'c'), array('b', 'c', 'd')), (array('a', NULL, 'b'), array('a', NULL)), (array('', 'a'), array('', 'b')), (array('café', '中文', 'test'), array('café', 'other')), (NULL, array('a')) + +query +SELECT a, b, array_intersect(a, b) FROM test_intersect_str + +-- binary arrays +statement +CREATE TABLE test_intersect_bin(a array, b array) USING parquet + +statement +INSERT INTO test_intersect_bin VALUES (array(binary('abc'), binary('def'), binary('ghi')), array(binary('def'), binary('ghi'))), (array(binary('abc')), array(binary('xyz'))), (array(binary('abc'), NULL), array(binary('abc'), NULL)) + +query +SELECT a, b, array_intersect(a, b) FROM test_intersect_bin + +-- nested array> +statement +CREATE TABLE test_intersect_nested(a array>, b array>) USING parquet + +statement +INSERT INTO test_intersect_nested VALUES (array(array(1, 2), array(3, 4), array(5, 6)), array(array(3, 4), array(5, 6))), (array(array(1, 2)), array(array(3, 4))), (array(array(1, 2), cast(NULL as array)), array(array(1, 2), cast(NULL as array))) + +query +SELECT a, b, array_intersect(a, b) FROM test_intersect_nested + +-- mixed column and literal with NULL elements +query +SELECT array_intersect(a, array(1, 2, NULL)) FROM test_array_intersect + +query +SELECT array_intersect(array(1, NULL, 3), b) FROM test_array_intersect + +-- conditional (CASE WHEN) arrays +query +SELECT array_intersect(CASE WHEN a IS NOT NULL THEN a ELSE array(0) END, b) FROM test_array_intersect diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index c91099c9e0..767968b7c1 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -378,9 +378,9 @@ abstract class CometTestBase |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} """.stripMargin) } - if (!QueryTest.compare( - QueryTest.prepareAnswer(sparkAnswer, isSorted), - QueryTest.prepareAnswer(cometAnswer, isSorted))) { + val preparedSpark = prepareCometAnswer(sparkAnswer, isSorted) + val preparedComet = prepareCometAnswer(cometAnswer, isSorted) + if (!QueryTest.compare(preparedSpark, preparedComet)) { val getRowType: Option[Row] => String = row => row .map(r => if (r.schema == null) "struct<>" else r.schema.catalogString) @@ -394,14 +394,47 @@ abstract class CometTestBase |${sideBySide( s"== Spark Answer - ${sparkAnswer.size} ==" +: getRowType(sparkAnswer.headOption) +: - QueryTest.prepareAnswer(sparkAnswer, isSorted).map(_.toString()), + preparedSpark.map(_.toString()), s"== Comet Answer - ${cometAnswer.size} ==" +: getRowType(cometAnswer.headOption) +: - QueryTest.prepareAnswer(cometAnswer, isSorted).map(_.toString())).mkString("\n")} + preparedComet.map(_.toString())).mkString("\n")} """.stripMargin) } } + /** + * Like `QueryTest.prepareAnswer` but recursively converts nested arrays to seqs. Spark's + * version only normalizes top-level `Array[_]`, leaving inner arrays (e.g. `Array[Byte]` from + * `array`) intact. Their default `toString` is the JVM identity (`[B@`), which + * makes the toString-based sort in `prepareAnswer` non-deterministic and causes spurious + * mismatches between the two sides. + */ + private def prepareCometAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { + val converted = answer.map(prepareCometRow) + if (isSorted) converted else converted.sortBy(_.toString()) + } + + private def prepareCometRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map(normalizeForComparison)) + } + + private def normalizeForComparison(value: Any): Any = value match { + case null => null + case bd: java.math.BigDecimal => BigDecimal(bd) + case row: Row => prepareCometRow(row) + case arr: Array[_] => arr.toSeq.map(normalizeForComparison) + case map: scala.collection.Map[_, _] => + map.map { case (k, v) => normalizeForComparison(k) -> normalizeForComparison(v) } + case seq: scala.collection.Iterable[_] => seq.map(normalizeForComparison).toSeq + case b: java.lang.Byte => b.byteValue + case s: java.lang.Short => s.shortValue + case i: java.lang.Integer => i.intValue + case l: java.lang.Long => l.longValue + case f: java.lang.Float => f.floatValue + case d: java.lang.Double => d.doubleValue + case x => x + } + /** * A helper function for comparing Comet DataFrame with Spark result using absolute tolerance. */