diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ff2dd2dbd4833..9b609059f2306 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1006,9 +1006,11 @@ object SQLConf { buildConf("spark.sql.shuffle.spreadNullJoinKeys.enabled") .doc("When true, Spark may spread rows with NULL equi-join keys across shuffle partitions " + "for shuffled LEFT, RIGHT, and FULL OUTER equi-joins and LEFT ANTI equi-joins on " + - "nullable keys to reduce shuffle skew. Null-aware join output partitioning does not " + - "satisfy a strict ClusteredDistribution, so downstream grouping, windowing, or " + - "equi-joins may require an extra shuffle. If one input is already hash partitioned, " + + "nullable keys to reduce shuffle skew. For LEFT OUTER joins, Spark may also derive a " + + "nullable physical join key from a safe left-side residual predicate so rows that " + + "cannot match do not collapse onto one reducer. Null-aware join output partitioning " + + "does not satisfy a strict ClusteredDistribution, so downstream grouping, windowing, " + + "or equi-joins may require an extra shuffle. If one input is already hash partitioned, " + "only the other input may be reshuffled into the null-aware layout, so the pre-shuffled " + "input can keep its NULL skew.") .version("4.3.0") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6e761fbe07b27..0b011f2a278be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -51,6 +51,7 @@ import org.apache.spark.sql.execution.streaming.runtime.{StreamingExecutionRelat import org.apache.spark.sql.execution.streaming.sources.MemoryPlan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.BooleanType /** * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting @@ -181,6 +182,40 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object JoinSelection extends Strategy with JoinSelectionHelper { private val hintErrorHandler = conf.hintErrorHandler + private def canEvaluateAsJoinGuard(condition: Expression): Boolean = condition match { + case _: AttributeReference | _: Literal => true + case in: In => + in.list.forall(_.isInstanceOf[Literal]) && in.children.forall(canEvaluateAsJoinGuard) + case _: And | _: Or | _: Not | _: IsNull | _: IsNotNull | _: EqualTo | + _: EqualNullSafe | _: InSet | _: Coalesce | _: GetStructField => + condition.children.forall(canEvaluateAsJoinGuard) + case _ => false + } + + private def addUnmatchableRowGuard( + joinType: JoinType, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + nonEquiCond: Option[Expression], + left: LogicalPlan): (Seq[Expression], Seq[Expression]) = { + nonEquiCond match { + case Some(condition) + if conf.getConf(SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED) && + joinType == LeftOuter && + condition.deterministic && + condition.references.subsetOf(left.outputSet) && + canEvaluateAsJoinGuard(condition) => + // A left-side residual that is not TRUE proves that this left row cannot match. Turn + // that fact into a nullable physical join key so the existing null-aware shuffle can + // spread the row. Keep the residual condition as a defensive correctness check. + val leftGuard = If(condition, Literal.TrueLiteral, Literal(null, BooleanType)) + leftGuard.setTagValue(joins.ShuffledJoinTags.UNMATCHABLE_ROW_GUARD, ()) + (leftKeys :+ leftGuard, rightKeys :+ Literal.TrueLiteral) + case _ => + (leftKeys, rightKeys) + } + } + private def checkHintBuildSide( onlyLookingAtHint: Boolean, buildSide: Option[BuildSide], @@ -238,6 +273,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case j @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, nonEquiCond, _, left, right, hint) => val hashJoinSupport = hashJoinSupported(leftKeys, rightKeys) + lazy val (shuffledLeftKeys, shuffledRightKeys) = addUnmatchableRowGuard( + joinType, leftKeys, rightKeys, nonEquiCond, left) + def createBroadcastHashJoin(onlyLookingAtHint: Boolean) = { if (hashJoinSupport) { val buildSide = getBroadcastBuildSide(j, onlyLookingAtHint, conf) @@ -265,8 +303,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { buildSide.map { buildSide => Seq(joins.ShuffledHashJoinExec( - leftKeys, - rightKeys, + shuffledLeftKeys, + shuffledRightKeys, joinType, buildSide, nonEquiCond, @@ -286,7 +324,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def createSortMergeJoin() = { if (canMerge(joinType) && RowOrdering.isOrderable(leftKeys)) { Some(Seq(joins.SortMergeJoinExec( - leftKeys, rightKeys, joinType, nonEquiCond, planLater(left), planLater(right)))) + shuffledLeftKeys, + shuffledRightKeys, + joinType, + nonEquiCond, + planLater(left), + planLater(right)))) } else { None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala index 6b7ae1577054e..e165a84767e8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala @@ -20,8 +20,14 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, LeftAnti, LeftExistence, LeftOuter, LeftSingle, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning, PartitioningCollection, UnknownPartitioning, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.internal.SQLConf +private[sql] object ShuffledJoinTags { + val UNMATCHABLE_ROW_GUARD: TreeNodeTag[Unit] = + TreeNodeTag[Unit]("unmatchable_row_join_guard") +} + /** * Holds common logic for join operators by shuffling two child relations * using the join keys. @@ -48,6 +54,9 @@ trait ShuffledJoin extends JoinCodegenSupport { preservedSideHasNullableKeys } + private lazy val hasUnmatchableRowGuard: Boolean = + leftKeys.exists(_.getTagValue(ShuffledJoinTags.UNMATCHABLE_ROW_GUARD).isDefined) + override def nodeName: String = { if (isSkewJoin) super.nodeName + "(skew=true)" else super.nodeName } @@ -59,6 +68,13 @@ trait ShuffledJoin extends JoinCodegenSupport { // We re-arrange the shuffle partitions to deal with skew join, and the new children // partitioning doesn't satisfy `ClusteredDistribution`. UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + } else if (canSpreadNullJoinKeys && hasUnmatchableRowGuard) { + // Use every physical join key. In particular, a planner-added nullable guard key must not + // be dropped in favor of an existing partitioning on only the original equi-join keys. + ClusteredDistribution( + leftKeys, requireAllClusterKeys = true, allowNullKeySpreading = true) :: + ClusteredDistribution( + rightKeys, requireAllClusterKeys = true, allowNullKeySpreading = true) :: Nil } else if (canSpreadNullJoinKeys) { ClusteredDistribution(leftKeys, allowNullKeySpreading = true) :: ClusteredDistribution(rightKeys, allowNullKeySpreading = true) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 3ea77446f268d..fcc597fe95483 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -26,9 +26,10 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.internal.config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, If, Literal, SortOrder} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, JoinSelectionHelper} import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, Join, JoinHint, NO_BROADCAST_AND_REPLICATION} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, NullAwareHashPartitioning} import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} @@ -58,6 +59,117 @@ class JoinSuite extends SharedSparkSession with AdaptiveSparkPlanHelper assert(planned.size === 1) } + test("SPARK-57648: spread rows rejected by a safe left outer join residual") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "4", + SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") { + val left = Seq( + (1, java.lang.Boolean.TRUE, "match"), + (1, java.lang.Boolean.FALSE, "false"), + (2, null.asInstanceOf[java.lang.Boolean], "null")).toDF("k", "eligible", "lv").as("l") + val right = Seq((1, "right-1"), (2, "right-2")).toDF("k", "rv").as("r") + val joined = left.join( + right, $"l.k" === $"r.k" && $"l.eligible", "left_outer") + + val sortMergeJoin = joined.queryExecution.sparkPlan.collectFirst { + case join: SortMergeJoinExec => join + }.getOrElse(fail("Expected a sort-merge join")) + assert(sortMergeJoin.leftKeys.size == 2) + assert(sortMergeJoin.rightKeys.size == 2) + assert(sortMergeJoin.leftKeys.last.isInstanceOf[If]) + assert(sortMergeJoin.rightKeys.last == Literal.TrueLiteral) + assert(sortMergeJoin.condition.nonEmpty) + assert(sortMergeJoin.requiredChildDistribution.forall { + case ClusteredDistribution(_, true, _, _) => true + case _ => false + }) + + val compoundPredicate = !($"l.lv".isin("false")) || + (functions.coalesce($"l.eligible", functions.lit(false)) && $"l.lv".isNotNull) + val compoundJoin = left.join( + right, $"l.k" === $"r.k" && compoundPredicate, "left_outer") + val compoundSortMergeJoin = compoundJoin.queryExecution.sparkPlan.collectFirst { + case join: SortMergeJoinExec => join + }.getOrElse(fail("Expected a sort-merge join")) + assert(compoundSortMergeJoin.leftKeys.size == 2) + assert(compoundSortMergeJoin.leftKeys.last.isInstanceOf[If]) + + checkAnswer(joined, Seq( + Row(1, true, "match", 1, "right-1"), + Row(1, false, "false", null, null), + Row(2, null, "null", null, null))) + + val shufflePartitionings = joined.queryExecution.executedPlan.collect { + case exchange: ShuffleExchangeExec => exchange.outputPartitioning + } + assert(shufflePartitionings.size == 2) + assert(shufflePartitionings.forall { + case NullAwareHashPartitioning(expressions, _) => expressions.size == 2 + case _ => false + }) + } + } + + test("SPARK-57648: spread unmatchable rows for shuffled hash join") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "4", + SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") { + val left = Seq((1, true), (1, false)).toDF("k", "eligible").as("l") + val right = Seq((1, "right-1")).toDF("k", "rv").as("r").hint("SHUFFLE_HASH") + val joined = left.join( + right, $"l.k" === $"r.k" && $"l.eligible", "left_outer") + + val shuffledHashJoin = joined.queryExecution.sparkPlan.collectFirst { + case join: ShuffledHashJoinExec => join + }.getOrElse(fail("Expected a shuffled hash join")) + assert(shuffledHashJoin.leftKeys.size == 2) + assert(shuffledHashJoin.rightKeys.size == 2) + assert(shuffledHashJoin.condition.nonEmpty) + + checkAnswer(joined, Seq( + Row(1, true, 1, "right-1"), + Row(1, false, null, null))) + } + } + + test("SPARK-57648: do not add an unmatchable-row guard outside its safe scope") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val left = Seq((1, true, "match")).toDF("k", "eligible", "lv").as("l") + val right = Seq((1, "right-1")).toDF("k", "rv").as("r") + + withSQLConf(SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "false") { + val joined = left.join( + right, $"l.k" === $"r.k" && $"l.eligible", "left_outer") + val sortMergeJoin = joined.queryExecution.sparkPlan.collectFirst { + case join: SortMergeJoinExec => join + }.getOrElse(fail("Expected a sort-merge join")) + assert(sortMergeJoin.leftKeys.size == 1) + } + + withSQLConf(SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") { + val unsupportedPredicate = left.join( + right, $"l.k" === $"r.k" && $"l.lv".contains("match"), "left_outer") + val sortMergeJoin = unsupportedPredicate.queryExecution.sparkPlan.collectFirst { + case join: SortMergeJoinExec => join + }.getOrElse(fail("Expected a sort-merge join")) + assert(sortMergeJoin.leftKeys.size == 1) + + val broadcastJoin = left.join( + right.hint("BROADCAST"), $"l.k" === $"r.k" && $"l.eligible", "left_outer") + val broadcastHashJoin = broadcastJoin.queryExecution.sparkPlan.collectFirst { + case join: BroadcastHashJoinExec => join + }.getOrElse(fail("Expected a broadcast hash join")) + assert(broadcastHashJoin.leftKeys.size == 1) + } + } + } + def assertJoin(pair: (String, Class[_ <: BinaryExecNode])): Any = { val sqlString = pair._1 val c = pair._2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 711f6dbdcdb11..a6080eb99f311 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1050,6 +1050,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with Seq((true, true, 5), (false, true, 3), (true, false, 7), (false, false, 5)).foreach { case (partial, filter, expected) => withSQLConf( + SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> true.toString, SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString, SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> filter.toString, @@ -1100,6 +1101,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with Seq(("true", 5), ("false", 5)).foreach { case (enable, expected) => withSQLConf( + SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> true.toString, SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString, SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) { @@ -1146,6 +1148,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with Seq(("true", 5), ("false", 5)).foreach { case (enable, expected) => withSQLConf( + SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> true.toString, SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString, SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) {