Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1006,9 +1006,11 @@ object SQLConf {
buildConf("spark.sql.shuffle.spreadNullJoinKeys.enabled")
.doc("When true, Spark may spread rows with NULL equi-join keys across shuffle partitions " +
"for shuffled LEFT, RIGHT, and FULL OUTER equi-joins and LEFT ANTI equi-joins on " +
"nullable keys to reduce shuffle skew. Null-aware join output partitioning does not " +
"satisfy a strict ClusteredDistribution, so downstream grouping, windowing, or " +
"equi-joins may require an extra shuffle. If one input is already hash partitioned, " +
"nullable keys to reduce shuffle skew. For LEFT OUTER joins, Spark may also derive a " +
"nullable physical join key from a safe left-side residual predicate so rows that " +
"cannot match do not collapse onto one reducer. Null-aware join output partitioning " +
"does not satisfy a strict ClusteredDistribution, so downstream grouping, windowing, " +
"or equi-joins may require an extra shuffle. If one input is already hash partitioned, " +
"only the other input may be reshuffled into the null-aware layout, so the pre-shuffled " +
"input can keep its NULL skew.")
.version("4.3.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import org.apache.spark.sql.execution.streaming.runtime.{StreamingExecutionRelat
import org.apache.spark.sql.execution.streaming.sources.MemoryPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.BooleanType

/**
* Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting
Expand Down Expand Up @@ -181,6 +182,40 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object JoinSelection extends Strategy with JoinSelectionHelper {
private val hintErrorHandler = conf.hintErrorHandler

private def canEvaluateAsJoinGuard(condition: Expression): Boolean = condition match {
case _: AttributeReference | _: Literal => true
case in: In =>
in.list.forall(_.isInstanceOf[Literal]) && in.children.forall(canEvaluateAsJoinGuard)
case _: And | _: Or | _: Not | _: IsNull | _: IsNotNull | _: EqualTo |
_: EqualNullSafe | _: InSet | _: Coalesce | _: GetStructField =>
condition.children.forall(canEvaluateAsJoinGuard)
case _ => false
}

private def addUnmatchableRowGuard(
joinType: JoinType,
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
nonEquiCond: Option[Expression],
left: LogicalPlan): (Seq[Expression], Seq[Expression]) = {
nonEquiCond match {
case Some(condition)
if conf.getConf(SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED) &&
joinType == LeftOuter &&
condition.deterministic &&
condition.references.subsetOf(left.outputSet) &&
canEvaluateAsJoinGuard(condition) =>
// A left-side residual that is not TRUE proves that this left row cannot match. Turn
// that fact into a nullable physical join key so the existing null-aware shuffle can
// spread the row. Keep the residual condition as a defensive correctness check.
val leftGuard = If(condition, Literal.TrueLiteral, Literal(null, BooleanType))
leftGuard.setTagValue(joins.ShuffledJoinTags.UNMATCHABLE_ROW_GUARD, ())
(leftKeys :+ leftGuard, rightKeys :+ Literal.TrueLiteral)
case _ =>
(leftKeys, rightKeys)
}
}

private def checkHintBuildSide(
onlyLookingAtHint: Boolean,
buildSide: Option[BuildSide],
Expand Down Expand Up @@ -238,6 +273,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case j @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, nonEquiCond,
_, left, right, hint) =>
val hashJoinSupport = hashJoinSupported(leftKeys, rightKeys)
lazy val (shuffledLeftKeys, shuffledRightKeys) = addUnmatchableRowGuard(
joinType, leftKeys, rightKeys, nonEquiCond, left)

def createBroadcastHashJoin(onlyLookingAtHint: Boolean) = {
if (hashJoinSupport) {
val buildSide = getBroadcastBuildSide(j, onlyLookingAtHint, conf)
Expand Down Expand Up @@ -265,8 +303,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
buildSide.map {
buildSide =>
Seq(joins.ShuffledHashJoinExec(
leftKeys,
rightKeys,
shuffledLeftKeys,
shuffledRightKeys,
joinType,
buildSide,
nonEquiCond,
Expand All @@ -286,7 +324,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def createSortMergeJoin() = {
if (canMerge(joinType) && RowOrdering.isOrderable(leftKeys)) {
Some(Seq(joins.SortMergeJoinExec(
leftKeys, rightKeys, joinType, nonEquiCond, planLater(left), planLater(right))))
shuffledLeftKeys,
shuffledRightKeys,
joinType,
nonEquiCond,
planLater(left),
planLater(right))))
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, LeftAnti, LeftExistence, LeftOuter, LeftSingle, RightOuter}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning, PartitioningCollection, UnknownPartitioning, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.internal.SQLConf

private[sql] object ShuffledJoinTags {
val UNMATCHABLE_ROW_GUARD: TreeNodeTag[Unit] =
TreeNodeTag[Unit]("unmatchable_row_join_guard")
}

/**
* Holds common logic for join operators by shuffling two child relations
* using the join keys.
Expand All @@ -48,6 +54,9 @@ trait ShuffledJoin extends JoinCodegenSupport {
preservedSideHasNullableKeys
}

private lazy val hasUnmatchableRowGuard: Boolean =
leftKeys.exists(_.getTagValue(ShuffledJoinTags.UNMATCHABLE_ROW_GUARD).isDefined)

override def nodeName: String = {
if (isSkewJoin) super.nodeName + "(skew=true)" else super.nodeName
}
Expand All @@ -59,6 +68,13 @@ trait ShuffledJoin extends JoinCodegenSupport {
// We re-arrange the shuffle partitions to deal with skew join, and the new children
// partitioning doesn't satisfy `ClusteredDistribution`.
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
} else if (canSpreadNullJoinKeys && hasUnmatchableRowGuard) {
// Use every physical join key. In particular, a planner-added nullable guard key must not
// be dropped in favor of an existing partitioning on only the original equi-join keys.
ClusteredDistribution(
leftKeys, requireAllClusterKeys = true, allowNullKeySpreading = true) ::
ClusteredDistribution(
rightKeys, requireAllClusterKeys = true, allowNullKeySpreading = true) :: Nil
} else if (canSpreadNullJoinKeys) {
ClusteredDistribution(leftKeys, allowNullKeySpreading = true) ::
ClusteredDistribution(rightKeys, allowNullKeySpreading = true) :: Nil
Expand Down
114 changes: 113 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
import org.apache.spark.internal.config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, If, Literal, SortOrder}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, JoinSelectionHelper}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, Join, JoinHint, NO_BROADCAST_AND_REPLICATION}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, NullAwareHashPartitioning}
import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
Expand Down Expand Up @@ -58,6 +59,117 @@ class JoinSuite extends SharedSparkSession with AdaptiveSparkPlanHelper
assert(planned.size === 1)
}

test("SPARK-57648: spread rows rejected by a safe left outer join residual") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.SHUFFLE_PARTITIONS.key -> "4",
SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
val left = Seq(
(1, java.lang.Boolean.TRUE, "match"),
(1, java.lang.Boolean.FALSE, "false"),
(2, null.asInstanceOf[java.lang.Boolean], "null")).toDF("k", "eligible", "lv").as("l")
val right = Seq((1, "right-1"), (2, "right-2")).toDF("k", "rv").as("r")
val joined = left.join(
right, $"l.k" === $"r.k" && $"l.eligible", "left_outer")

val sortMergeJoin = joined.queryExecution.sparkPlan.collectFirst {
case join: SortMergeJoinExec => join
}.getOrElse(fail("Expected a sort-merge join"))
assert(sortMergeJoin.leftKeys.size == 2)
assert(sortMergeJoin.rightKeys.size == 2)
assert(sortMergeJoin.leftKeys.last.isInstanceOf[If])
assert(sortMergeJoin.rightKeys.last == Literal.TrueLiteral)
assert(sortMergeJoin.condition.nonEmpty)
assert(sortMergeJoin.requiredChildDistribution.forall {
case ClusteredDistribution(_, true, _, _) => true
case _ => false
})

val compoundPredicate = !($"l.lv".isin("false")) ||
(functions.coalesce($"l.eligible", functions.lit(false)) && $"l.lv".isNotNull)
val compoundJoin = left.join(
right, $"l.k" === $"r.k" && compoundPredicate, "left_outer")
val compoundSortMergeJoin = compoundJoin.queryExecution.sparkPlan.collectFirst {
case join: SortMergeJoinExec => join
}.getOrElse(fail("Expected a sort-merge join"))
assert(compoundSortMergeJoin.leftKeys.size == 2)
assert(compoundSortMergeJoin.leftKeys.last.isInstanceOf[If])

checkAnswer(joined, Seq(
Row(1, true, "match", 1, "right-1"),
Row(1, false, "false", null, null),
Row(2, null, "null", null, null)))

val shufflePartitionings = joined.queryExecution.executedPlan.collect {
case exchange: ShuffleExchangeExec => exchange.outputPartitioning
}
assert(shufflePartitionings.size == 2)
assert(shufflePartitionings.forall {
case NullAwareHashPartitioning(expressions, _) => expressions.size == 2
case _ => false
})
}
}

test("SPARK-57648: spread unmatchable rows for shuffled hash join") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.SHUFFLE_PARTITIONS.key -> "4",
SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
val left = Seq((1, true), (1, false)).toDF("k", "eligible").as("l")
val right = Seq((1, "right-1")).toDF("k", "rv").as("r").hint("SHUFFLE_HASH")
val joined = left.join(
right, $"l.k" === $"r.k" && $"l.eligible", "left_outer")

val shuffledHashJoin = joined.queryExecution.sparkPlan.collectFirst {
case join: ShuffledHashJoinExec => join
}.getOrElse(fail("Expected a shuffled hash join"))
assert(shuffledHashJoin.leftKeys.size == 2)
assert(shuffledHashJoin.rightKeys.size == 2)
assert(shuffledHashJoin.condition.nonEmpty)

checkAnswer(joined, Seq(
Row(1, true, 1, "right-1"),
Row(1, false, null, null)))
}
}

test("SPARK-57648: do not add an unmatchable-row guard outside its safe scope") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val left = Seq((1, true, "match")).toDF("k", "eligible", "lv").as("l")
val right = Seq((1, "right-1")).toDF("k", "rv").as("r")

withSQLConf(SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "false") {
val joined = left.join(
right, $"l.k" === $"r.k" && $"l.eligible", "left_outer")
val sortMergeJoin = joined.queryExecution.sparkPlan.collectFirst {
case join: SortMergeJoinExec => join
}.getOrElse(fail("Expected a sort-merge join"))
assert(sortMergeJoin.leftKeys.size == 1)
}

withSQLConf(SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
val unsupportedPredicate = left.join(
right, $"l.k" === $"r.k" && $"l.lv".contains("match"), "left_outer")
val sortMergeJoin = unsupportedPredicate.queryExecution.sparkPlan.collectFirst {
case join: SortMergeJoinExec => join
}.getOrElse(fail("Expected a sort-merge join"))
assert(sortMergeJoin.leftKeys.size == 1)

val broadcastJoin = left.join(
right.hint("BROADCAST"), $"l.k" === $"r.k" && $"l.eligible", "left_outer")
val broadcastHashJoin = broadcastJoin.queryExecution.sparkPlan.collectFirst {
case join: BroadcastHashJoinExec => join
}.getOrElse(fail("Expected a broadcast hash join"))
assert(broadcastHashJoin.leftKeys.size == 1)
}
}
}

def assertJoin(pair: (String, Class[_ <: BinaryExecNode])): Any = {
val sqlString = pair._1
val c = pair._2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
Seq((true, true, 5), (false, true, 3), (true, false, 7), (false, false, 5)).foreach {
case (partial, filter, expected) =>
withSQLConf(
SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> true.toString,
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString,
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,
SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> filter.toString,
Expand Down Expand Up @@ -1100,6 +1101,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
Seq(("true", 5), ("false", 5)).foreach {
case (enable, expected) =>
withSQLConf(
SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> true.toString,
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString,
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) {
Expand Down Expand Up @@ -1146,6 +1148,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
Seq(("true", 5), ("false", 5)).foreach {
case (enable, expected) =>
withSQLConf(
SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> true.toString,
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString,
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) {
Expand Down