diff --git a/core/src/main/scala/org/apache/spark/internal/config/Tests.scala b/core/src/main/scala/org/apache/spark/internal/config/Tests.scala index 8ecb14be1dfb8..987563eb3da83 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Tests.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Tests.scala @@ -41,8 +41,39 @@ private[spark] object Tests { val INJECT_SHUFFLE_FETCH_FAILURES = ConfigBuilder("spark.testing.injectShuffleFetchFailures") - .doc("Injecting fetch failures for shuffle stages by providing an invalid BlockManager " + - "location for the first stage attempt. Testing only flag!") + .doc("Corrupt the registered MapStatus of partition 0 on the first successful attempt " + + "of every shuffle map stage, to induce downstream FetchFailed and stage retry. " + + "Testing only.") + .booleanConf + .createWithDefault(false) + + val INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY = + ConfigBuilder("spark.testing.injectShuffleFetchFailuresDownstreamDelay") + .doc("Used with INJECT_SHUFFLE_FETCH_FAILURES. Defer the producer's partition-0 " + + "corruption until N ShuffleMapStage consumer task successes have been observed. " + + "Default 1; set to 0 to corrupt at registration. Testing only.") + .intConf + .checkValue(_ >= 0, "Downstream-success delay must be non-negative") + .createWithDefault(1) + + val INJECT_SHUFFLE_FETCH_FAILURES_RESULT_STAGE_DELAY = + ConfigBuilder("spark.testing.injectShuffleFetchFailuresResultStageDelay") + .doc("Used with INJECT_SHUFFLE_FETCH_FAILURES. Counterpart to " + + "INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY for ResultStage consumers. With the " + + "default 0, when a ResultStage is the consumer of a pending corruption it is corrupted " + + "before the result tasks dispatch, so the result stage has no completed tasks when " + + "INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE fires (the rollback would " + + "otherwise abort the result stage, since OSS Spark does not support rolling result " + + "stages back). Set to N > 0 to defer until N result-stage tasks have succeeded - this " + + "is the only way to actually exercise the result-stage abort path. Testing only.") + .intConf + .checkValue(_ >= 0, "Result-stage-success delay must be non-negative") + .createWithDefault(0) + + val INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE = + ConfigBuilder("spark.testing.injectShuffleForceChecksumMismatchOnRecompute") + .doc("Used with INJECT_SHUFFLE_FETCH_FAILURES. Flag the recompute as checksum " + + "mismatched, forcing downstream `rollbackSucceedingStages`. Testing only.") .booleanConf .createWithDefault(false) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 22720b98aafde..68d507bff9822 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -179,6 +179,30 @@ private[spark] class DAGScheduler( private[spark] val jobIdToQueryExecutionId = new ConcurrentHashMap[Int, java.lang.Long]() + // For INJECT_SHUFFLE_FETCH_FAILURES: per-shuffleId, the stage attempt whose partition-0 task + // we corrupted. Read to (a) avoid re-corrupting that partition on recompute, and (b) decide + // when to fire INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE - the recompute is the + // task whose stageAttemptId is not the recorded one. + private val injectShuffleFetchFailuresCorruptedAttempt = + new ConcurrentHashMap[Int, Int]() + + // For INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY > 0: shuffles whose mapper-0 corruption + // has been deferred until enough downstream consumer tasks succeed. The value is the mapId + // we will eventually swap to an invalid BlockManagerId. + private val injectShuffleFetchFailuresPendingDelayedCorruption = + new ConcurrentHashMap[Int, Long]() + + // For INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY: per-shuffle counter of consumer + // task-success events observed so far. + private val injectShuffleFetchFailuresDownstreamSuccessCount = + new ConcurrentHashMap[Int, Int]() + + // Bogus BlockManagerId used by INJECT_SHUFFLE_FETCH_FAILURES to mark a corrupted MapStatus. + // Any consumer task that tries to fetch a block from this location will FetchFailed because + // the executorId is INVALID_EXECUTOR_ID. + private val injectShuffleFetchFailuresInvalidBlockManagerId = + BlockManagerId(BlockManagerId.INVALID_EXECUTOR_ID, "invalid", 1, None) + // Job groups that are cancelled with `cancelFutureJobs` as true, with at most // `NUM_CANCELLED_JOB_GROUPS_TO_TRACK` stored. On a new job submission, if its job group is in // this set, the job will be immediately cancelled. @@ -933,6 +957,11 @@ private[spark] class DAGScheduler( } for ((k, v) <- shuffleIdToMapStage.find(_._2 == stage)) { shuffleIdToMapStage.remove(k) + if (Utils.isTesting) { + injectShuffleFetchFailuresCorruptedAttempt.remove(k) + injectShuffleFetchFailuresPendingDelayedCorruption.remove(k) + injectShuffleFetchFailuresDownstreamSuccessCount.remove(k) + } } if (waitingStages.contains(stage)) { logDebug("Removing stage %d from waiting set.".format(stageId)) @@ -1618,6 +1647,116 @@ private[spark] class DAGScheduler( } } + /** + * Returns true when this just-completed shuffle map task should have its output corrupted by + * the test-only fetch-failure injection. We corrupt only the partition-0 task, and only on + * the stage attempt that first successfully completes partition 0 - latched into + * injectShuffleFetchFailuresCorruptedAttempt. Recomputes (later attempts) of that partition + * are left clean so the consumer can make progress on its retry. The latch is per-shuffle, + * so non-leaf stages whose earlier attempts failed on fetch from upstream are still + * corrupted on their first successful attempt. + */ + private def shouldCorruptShuffleOutputForTest(shuffleId: Int, task: Task[_]): Boolean = { + if (task.partitionId != 0) return false + val recorded = injectShuffleFetchFailuresCorruptedAttempt.computeIfAbsent( + shuffleId, _ => task.stageAttemptId) + recorded == task.stageAttemptId + } + + /** + * Apply the test-only fetch-failure injection to this just-completed map task: with + * DOWNSTREAM_DELAY > 0 record the mapId so maybeApplyDelayedCorruptionForTest can corrupt + * it later, otherwise update the MapStatus location to + * injectShuffleFetchFailuresInvalidBlockManagerId inline. + */ + private def corruptShuffleOutputForTest(shuffleId: Int, status: MapStatus): Unit = { + val downstreamDelay = + sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY) + if (downstreamDelay > 0) { + injectShuffleFetchFailuresPendingDelayedCorruption.put(shuffleId, status.mapId) + } else { + status.updateLocation(injectShuffleFetchFailuresInvalidBlockManagerId) + } + } + + /** + * For INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE: returns true when this shuffle map + * task is the recompute of a partition whose previous successful attempt was the one corrupted + * by INJECT_SHUFFLE_FETCH_FAILURES. Forcing the mismatch on the recompute drives the rollback + * path - downstream ShuffleMapStages get cleaned up and re-run fully, downstream ResultStages + * are aborted. + */ + private def isForcedChecksumMismatchForTest(shuffleId: Int, task: Task[_]): Boolean = { + if (!sc.conf.get(config.Tests.INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE)) return false + if (task.partitionId != 0) return false + val recorded = + injectShuffleFetchFailuresCorruptedAttempt.getOrDefault(shuffleId, -1) + recorded >= 0 && recorded != task.stageAttemptId + } + + /** + * Apply the deferred mapper-0 corruption (configured via + * INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY for ShuffleMapStage consumers and + * INJECT_SHUFFLE_FETCH_FAILURES_RESULT_STAGE_DELAY for ResultStage consumers) when enough + * consumer tasks have succeeded. Walks the just-completed stage's direct shuffle parents, + * increments the per-shuffle consumer-success counter, and corrupts the registered MapStatus + * when the counter reaches the configured delay. + */ + private def maybeApplyDelayedCorruptionForTest(stage: Stage): Unit = { + if (!sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES)) return + if (injectShuffleFetchFailuresPendingDelayedCorruption.isEmpty) return + val isResultStage = stage.isInstanceOf[ResultStage] + val delay = if (isResultStage) { + sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES_RESULT_STAGE_DELAY) + } else { + sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY) + } + if (delay <= 0) return // delay == 0 was already handled at submission time + + val parentShuffleIds = stage.parents.collect { + case sms: ShuffleMapStage => sms.shuffleDep.shuffleId + } + parentShuffleIds.foreach { shuffleId => + if (injectShuffleFetchFailuresPendingDelayedCorruption.containsKey(shuffleId)) { + val newCount = injectShuffleFetchFailuresDownstreamSuccessCount.merge(shuffleId, 1, _ + _) + if (newCount >= delay) { + val mapId = injectShuffleFetchFailuresPendingDelayedCorruption.remove(shuffleId) + mapOutputTracker.updateMapOutput( + shuffleId, mapId, injectShuffleFetchFailuresInvalidBlockManagerId) + logInfo(s"Test injection: corrupted mapper-0 of shuffle $shuffleId after " + + s"$newCount downstream consumer successes") + } + } + } + } + + /** + * For INJECT_SHUFFLE_FETCH_FAILURES_RESULT_STAGE_DELAY = 0: when a ResultStage is about to + * dispatch tasks, fire any pending mapper-0 corruption for its direct shuffle parents + * BEFORE result tasks start. This keeps the result stage at zero finished tasks when + * INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE later triggers `rollbackSucceedingStages`, + * so the rollback path does not abort a partially-finished result stage. + */ + private def maybePreemptiveCorruptionForResultStage(stage: Stage): Unit = { + if (!stage.isInstanceOf[ResultStage]) return + if (!sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES)) return + if (sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES_RESULT_STAGE_DELAY) > 0) return + if (injectShuffleFetchFailuresPendingDelayedCorruption.isEmpty) return + + val parentShuffleIds = stage.parents.collect { + case sms: ShuffleMapStage => sms.shuffleDep.shuffleId + } + parentShuffleIds.foreach { shuffleId => + if (injectShuffleFetchFailuresPendingDelayedCorruption.containsKey(shuffleId)) { + val mapId = injectShuffleFetchFailuresPendingDelayedCorruption.remove(shuffleId) + mapOutputTracker.updateMapOutput( + shuffleId, mapId, injectShuffleFetchFailuresInvalidBlockManagerId) + logInfo(s"Test injection: corrupted mapper-0 of shuffle $shuffleId before result-stage " + + s"submission") + } + } + } + private def configureShufflePushMergerLocations(stage: ShuffleMapStage): Unit = { if (stage.shuffleDep.getMergerLocs.nonEmpty) return val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations( @@ -1635,6 +1774,10 @@ private[spark] class DAGScheduler( private def submitMissingTasks(stage: Stage, jobId: Int): Unit = { logDebug("submitMissingTasks(" + stage + ")") + if (Utils.isTesting) { + maybePreemptiveCorruptionForResultStage(stage) + } + // For statically indeterminate stages being retried, we trigger rollback BEFORE task // submission. This is more efficient than deferring to task completion because: // 1. It avoids submitting a partial stage that would need to be cancelled @@ -2271,6 +2414,10 @@ private[spark] class DAGScheduler( taskScheduler.notifyPartitionCompletion(stageId, task.partitionId) } + if (Utils.isTesting) { + maybeApplyDelayedCorruptionForTest(stage) + } + task match { case rt: ResultTask[_, _] => // Cast to ResultStage here because it's part of the ResultTask @@ -2342,21 +2489,16 @@ private[spark] class DAGScheduler( // The epoch of the task is acceptable (i.e., the task was launched after the most // recent failure we're aware of for the executor), so mark the task's output as // available. - // For testing purposes, inject fetch failures controlled from the driver-side by - // supplying an invalid location. if (Utils.isTesting && sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES) && - task.stageAttemptId == 0) { - val currentLocation = status.location - val invalidLocation = BlockManagerId( - execId = BlockManagerId.INVALID_EXECUTOR_ID, - host = currentLocation.host, - port = currentLocation.port, - topologyInfo = currentLocation.topologyInfo) - status.updateLocation(invalidLocation) + shouldCorruptShuffleOutputForTest(shuffleStage.shuffleDep.shuffleId, task)) { + corruptShuffleOutputForTest(shuffleStage.shuffleDep.shuffleId, status) } - val isChecksumMismatched = mapOutputTracker.registerMapOutput( + val realChecksumMismatched = mapOutputTracker.registerMapOutput( shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) + val isChecksumMismatched = realChecksumMismatched || + (Utils.isTesting && + isForcedChecksumMismatchForTest(shuffleStage.shuffleDep.shuffleId, task)) if (isChecksumMismatched) { shuffleStage.isChecksumMismatched = isChecksumMismatched // Runtime detection of nondeterministic output via checksum mismatch. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/MetricsFailureInjectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/MetricsFailureInjectionSuite.scala index 847a12f4f305c..885091ecd3fb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/MetricsFailureInjectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/MetricsFailureInjectionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.metric import scala.util.Random +import org.apache.spark.SparkException import org.apache.spark.internal.config import org.apache.spark.sql.{Column, Dataset} import org.apache.spark.sql.execution.adaptive.{AQETestHelper, DisableAdaptiveExecutionSuite} @@ -348,11 +349,15 @@ class MetricsFailureInjectionSuite runQueryWithMetrics() { finalDf => if (injectFailure) { assert(stage1Metric.value > 300) + // The non-deterministic UDF in stage 1 makes mapper 0's recompute produce a different + // checksum from its corrupted first attempt, which fires rollbackSucceedingStages and + // re-runs stage 2 in full. The raw stage 2 accumulator therefore overcounts; SLAM + // remains stable. + assert(stage2Metric.value > 5, s"stage2Metric=${stage2Metric.value}") } else { assert(stage1Metric.value === 300) + assert(stage2Metric.value === 5) } - // Stage2 doesn't have a downstream shuffle stage we can fail. - assert(stage2Metric.value === 5) assert(stage1SLAMetric.lastAttemptValueForHighestRDDId() === Some(300)) assert(stage2SLAMetric.lastAttemptValueForHighestRDDId() === Some(5)) @@ -361,4 +366,226 @@ class MetricsFailureInjectionSuite assert(stage2SLAMetric.lastAttemptValueForDataset(finalDf) === Some(5)) } } + + test("Three stage metrics block failure injection") { + val stage1Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 1 counter") + val stage2Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 2 counter") + val stage3Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 3 counter") + val stage1SLAMetric = + SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 1 SLAM") + val stage2SLAMetric = + SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 2 SLAM") + val stage3SLAMetric = + SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 3 SLAM") + + withTable("primary_table", "secondary_table") { + setUpTestTable("primary_table") + setUpTestTable("secondary_table") + withSparkContextConf( + config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true") { + val stage1MetricsExpr = incrementMetrics(Seq(stage1Metric, stage1SLAMetric)) + val stage1 = spark.read.table("primary_table") + .filter(Column(stage1MetricsExpr)) + val stage2MetricsExpr = incrementMetrics(Seq(stage2Metric, stage2SLAMetric)) + val stage2 = stage1.join( + spark.read.table("secondary_table"), + usingColumn = "id", + joinType = "fullOuter") + .filter(Column(stage2MetricsExpr)) + val stage3MetricsExpr = incrementMetrics(Seq(stage3Metric, stage3SLAMetric)) + val stage3 = stage2 + .groupBy("primary_table.low_cardinality_col") + .count() + .filter(Column(stage3MetricsExpr)) + val finalDf = stage3.as[(Int, Long)] + val result = finalDf.collect() + assert(result.toMap === (0 until 5).map(v => (v, 300 / 5)).toMap) + + // Both stage1 (leaf) and stage2 (non-leaf) get corrupted on their first successful + // attempt and re-run. stage3 is a result stage with no shuffle output, so it isn't + // corrupted and runs only once successfully. + assert(stage1Metric.value > 300, s"stage1Metric=${stage1Metric.value}") + assert(stage2Metric.value > 300, s"stage2Metric=${stage2Metric.value}") + assert(stage3Metric.value === 5) + + // SLAM correctly reports each stage's last successful attempt's contribution only. + assert(stage1SLAMetric.lastAttemptValueForHighestRDDId() === Some(300)) + assert(stage2SLAMetric.lastAttemptValueForHighestRDDId() === Some(300)) + assert(stage3SLAMetric.lastAttemptValueForHighestRDDId() === Some(5)) + + assert(stage1SLAMetric.lastAttemptValueForDataset(finalDf) === Some(300)) + assert(stage2SLAMetric.lastAttemptValueForDataset(finalDf) === Some(300)) + assert(stage3SLAMetric.lastAttemptValueForDataset(finalDf) === Some(5)) + } + } + } + + test("Three stage metrics force-checksum-mismatch on recompute") { + // INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE additionally flags the recompute of the + // partition-0 task as a checksum mismatch. The DAGScheduler then runs + // `rollbackSucceedingStages`, which (a) for downstream ShuffleMapStages clears their map + // outputs and forces a full retry of every previously-finished partition, and (b) for the + // ResultStage downstream is a no-op because the result stage hasn't started yet - it just + // runs once after the rollback completes. + // + // Without a timing guarantee the FetchFailed in stage 2 may fire before any stage 2 task + // finishes, in which case the rollback has nothing to clear and stage 2 metrics look the + // same as in the recompute-only mode. So we only assert `stage2Metric > 300`, which is the + // sum of partial-attempt-0 contributions (>=1 partition since rollback had something to + // roll back) plus a full attempt-1; the deterministic version of this scenario lives in + // the delayed-corruption test below. + val stage1Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 1 counter") + val stage2Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 2 counter") + val stage3Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 3 counter") + val stage1SLAMetric = + SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 1 SLAM") + val stage2SLAMetric = + SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 2 SLAM") + val stage3SLAMetric = + SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 3 SLAM") + + withTable("primary_table", "secondary_table") { + setUpTestTable("primary_table") + setUpTestTable("secondary_table") + withSparkContextConf( + config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true", + config.Tests.INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE.key -> "true") { + val stage1MetricsExpr = incrementMetrics(Seq(stage1Metric, stage1SLAMetric)) + val stage1 = spark.read.table("primary_table") + .filter(Column(stage1MetricsExpr)) + val stage2MetricsExpr = incrementMetrics(Seq(stage2Metric, stage2SLAMetric)) + val stage2 = stage1.join( + spark.read.table("secondary_table"), + usingColumn = "id", + joinType = "fullOuter") + .filter(Column(stage2MetricsExpr)) + val stage3MetricsExpr = incrementMetrics(Seq(stage3Metric, stage3SLAMetric)) + val stage3 = stage2 + .groupBy("primary_table.low_cardinality_col") + .count() + .filter(Column(stage3MetricsExpr)) + val finalDf = stage3.as[(Int, Long)] + val result = finalDf.collect() + assert(result.toMap === (0 until 5).map(v => (v, 300 / 5)).toMap) + + // The recompute-with-mismatch injection drives `rollbackSucceedingStages` against the + // checksum-mismatched producer. Stage 2 is a downstream ShuffleMapStage and gets its map + // outputs cleared and rerun. The total raw accumulator on stage 2 is + // (partial-attempt-0 contributions) + (full-attempt-1 = 300). In the recompute-only + // mode it would be exactly 300 because attempt 1 only re-runs the missing partitions; + // here it is strictly larger when the rollback had any partitions to clear. + assert(stage1Metric.value > 300, s"stage1Metric=${stage1Metric.value}") + assert(stage2Metric.value > 300, s"stage2Metric=${stage2Metric.value}") + assert(stage3Metric.value === 5) + + // SLAM still reports the last successful attempt's contribution per RDD. + assert(stage1SLAMetric.lastAttemptValueForHighestRDDId() === Some(300)) + assert(stage2SLAMetric.lastAttemptValueForHighestRDDId() === Some(300)) + assert(stage3SLAMetric.lastAttemptValueForHighestRDDId() === Some(5)) + + assert(stage1SLAMetric.lastAttemptValueForDataset(finalDf) === Some(300)) + assert(stage2SLAMetric.lastAttemptValueForDataset(finalDf) === Some(300)) + assert(stage3SLAMetric.lastAttemptValueForDataset(finalDf) === Some(5)) + } + } + } + + test("Three stage metrics force-checksum-mismatch with delayed corruption") { + // Same setup as the previous test but with INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY=1 + // and shuffle.partitions=20 (much greater than the test's local[2] cores). The DAGScheduler + // event loop is single-threaded for completion events, so deferring the producer's + // mapper-0 corruption until after one consumer success guarantees AT LEAST ONE consumer + // task fully completed before the FetchFailed cascade kicks in. With mode 3's rollback, + // those completed-then-cleared partitions all re-run during the rollback retry, giving a + // strict lower bound on the raw stage-2 accumulator that's not reachable in + // recompute-only mode. + val stage1Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 1 counter") + val stage2Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 2 counter") + val stage3Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 3 counter") + val stage1SLAMetric = + SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 1 SLAM") + val stage2SLAMetric = + SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 2 SLAM") + val stage3SLAMetric = + SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 3 SLAM") + + withTable("primary_table", "secondary_table") { + setUpTestTable("primary_table") + setUpTestTable("secondary_table") + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "20") { + withSparkContextConf( + config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true", + config.Tests.INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY.key -> "1", + config.Tests.INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE.key -> "true") { + val stage1MetricsExpr = incrementMetrics(Seq(stage1Metric, stage1SLAMetric)) + val stage1 = spark.read.table("primary_table") + .filter(Column(stage1MetricsExpr)) + val stage2MetricsExpr = incrementMetrics(Seq(stage2Metric, stage2SLAMetric)) + val stage2 = stage1.join( + spark.read.table("secondary_table"), + usingColumn = "id", + joinType = "fullOuter") + .filter(Column(stage2MetricsExpr)) + val stage3MetricsExpr = incrementMetrics(Seq(stage3Metric, stage3SLAMetric)) + val stage3 = stage2 + .groupBy("primary_table.low_cardinality_col") + .count() + .filter(Column(stage3MetricsExpr)) + val finalDf = stage3.as[(Int, Long)] + val result = finalDf.collect() + assert(result.toMap === (0 until 5).map(v => (v, 300 / 5)).toMap) + + // With delay=1 and 20 shuffle partitions on local[2], at least one stage-2 reducer + // task is guaranteed to fully process its rows before the corruption fires. Mode 3's + // rollback then re-runs all 20 stage-2 partitions, replaying those previously- + // completed ones. The recompute-only baseline is 300 (full coverage across attempts) + // + size(mapper 0) for the FetchFailed-driven retry; mode 3 adds at least one + // already-completed partition's worth on top of that. Partition sizes vary with the + // hash of `id`, so we just assert "strictly above the recompute-only baseline" rather + // than a tight numeric bound. + assert(stage1Metric.value > 300, s"stage1Metric=${stage1Metric.value}") + assert(stage2Metric.value > 315, + s"stage2Metric should be above the mode-2 baseline (~315) because the rollback " + + s"re-played a partition that completed in attempt 0, got ${stage2Metric.value}") + assert(stage3Metric.value === 5) + + assert(stage1SLAMetric.lastAttemptValueForHighestRDDId() === Some(300)) + assert(stage2SLAMetric.lastAttemptValueForHighestRDDId() === Some(300)) + assert(stage3SLAMetric.lastAttemptValueForHighestRDDId() === Some(5)) + + assert(stage1SLAMetric.lastAttemptValueForDataset(finalDf) === Some(300)) + assert(stage2SLAMetric.lastAttemptValueForDataset(finalDf) === Some(300)) + assert(stage3SLAMetric.lastAttemptValueForDataset(finalDf) === Some(5)) + } + } + } + } + + test("Force checksum mismatch aborts a downstream ResultStage") { + // 2-stage query whose downstream is a ResultStage. With RESULT_STAGE_DELAY=1 the result + // stage gets at least one finished task before the FetchFailed cascade, so by the time + // mode 3's forced checksum mismatch on stage 1 mapper 0 fires `rollbackSucceedingStages`, + // the result stage's findMissingPartitions is strictly less than numTasks - which OSS + // Spark cannot roll back, so the job aborts. With the default RESULT_STAGE_DELAY=0 the + // result stage is corrupted before any task dispatches and the rollback path does not + // abort. + withTable("test_table") { + setUpTestTable("test_table") + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "20") { + withSparkContextConf( + config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true", + config.Tests.INJECT_SHUFFLE_FETCH_FAILURES_RESULT_STAGE_DELAY.key -> "1", + config.Tests.INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE.key -> "true") { + val df = spark.read.table("test_table") + .groupBy("low_cardinality_col") + .count() + val ex = intercept[SparkException] { + df.collect() + } + assert(ex.getMessage.contains("indeterminate"), + s"expected an 'indeterminate'-stage abort, got: ${ex.getMessage}") + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLLastAttemptMetricIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLLastAttemptMetricIntegrationSuite.scala index 2e7af075a3e74..6b6848e908cd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLLastAttemptMetricIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLLastAttemptMetricIntegrationSuite.scala @@ -703,3 +703,23 @@ class SQLLastAttemptMetricIntegrationSuiteWithStageRetries }(pos) } } + +class SQLLastAttemptMetricIntegrationSuiteWithChecksumMismatch + extends SQLLastAttemptMetricIntegrationSuite { + override protected def withRetries = true + + override protected def test( + testName: String, + testTags: org.scalatest.Tag*) + (testFun: => Any) + (implicit pos: org.scalactic.source.Position): Unit = { + super.test(testName, testTags : _*) { + withSparkContextConf( + config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true", + config.Tests.INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE.key -> "true") { + // Forced checksum-mismatch rollback should also not affect SLAM metrics. + testFun + } + }(pos) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLLastAttemptMetricPlanShapesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLLastAttemptMetricPlanShapesSuite.scala index ea8d9568f7e4b..a86a12a7fa813 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLLastAttemptMetricPlanShapesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLLastAttemptMetricPlanShapesSuite.scala @@ -119,6 +119,10 @@ class SQLLastAttemptMetricPlanShapesSuite def hasStageRetries: Boolean = spark.sparkContext.conf .getOption(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key).contains("true") + def hasChecksumMismatch: Boolean = spark.sparkContext.conf + .getOption(config.Tests.INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE.key) + .contains("true") + def hasAQEReplans: Boolean = AQETestHelper.isForcedCancellationEnabled } @@ -132,10 +136,10 @@ class SQLLastAttemptMetricPlanShapesSuite )(testTags: Tag*): Unit = { for { useAQE <- BOOLEAN_DOMAIN - stageRetries <- BOOLEAN_DOMAIN + failureMode <- FailureMode.all aqeReplans <- if (useAQE) BOOLEAN_DOMAIN else Seq(false) } test(s"$label - " + - s"useAQE=$useAQE, stageRetries=$stageRetries, aqeReplans=$aqeReplans", + s"useAQE=$useAQE, failureMode=$failureMode, aqeReplans=$aqeReplans", testTags: _*) { // There is some special handling for df.cache() / df.persist() / df.localCheckpoint() tests. @@ -147,13 +151,12 @@ class SQLLastAttemptMetricPlanShapesSuite withSQLConf(extraSQLConfs.toSeq: _*) { val aqeRetryMetrics = if (aqeReplans) Seq(testSLAMetric) else Seq.empty AQETestHelper.withForcedCancellation(aqeRetryMetrics: _*) { - withSparkContextConf( - config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> stageRetries.toString) { + withSparkContextConf(failureMode.sparkContextConfs: _*) { val resultDf = spark.sql(sqlQuery) val _ = resultDf.collect() // normal value of the metrics shall not work with retries or replans - if (!stageRetries && !aqeReplans) { + if (!failureMode.causesStageRetries && !aqeReplans) { metricValueCheck(Some(testSLAMetric.value)) } // test LastRDDValue @@ -487,4 +490,27 @@ object SQLLastAttemptMetricPlanShapesSuite { val LARGE_CARDINALITY: Int = 111 val TABLE_NAME: String = "test_table" + + sealed trait FailureMode { + def sparkContextConfs: Seq[(String, String)] + def causesStageRetries: Boolean + } + object FailureMode { + case object NoFailure extends FailureMode { + val sparkContextConfs: Seq[(String, String)] = Seq.empty + val causesStageRetries: Boolean = false + } + case object FetchFailure extends FailureMode { + val sparkContextConfs: Seq[(String, String)] = + Seq(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true") + val causesStageRetries: Boolean = true + } + case object ChecksumMismatch extends FailureMode { + val sparkContextConfs: Seq[(String, String)] = Seq( + config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true", + config.Tests.INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE.key -> "true") + val causesStageRetries: Boolean = true + } + val all: Seq[FailureMode] = Seq(NoFailure, FetchFailure, ChecksumMismatch) + } } diff --git a/sql/hive/src/test/resources/conf/binding-policy-exceptions/configs-without-binding-policy-exceptions b/sql/hive/src/test/resources/conf/binding-policy-exceptions/configs-without-binding-policy-exceptions index 2aa6cb885ca31..32e062ba90c35 100644 --- a/sql/hive/src/test/resources/conf/binding-policy-exceptions/configs-without-binding-policy-exceptions +++ b/sql/hive/src/test/resources/conf/binding-policy-exceptions/configs-without-binding-policy-exceptions @@ -1174,6 +1174,9 @@ spark.test.noStageRetry spark.testing spark.testing.dynamicAllocation.schedule.enabled spark.testing.injectShuffleFetchFailures +spark.testing.injectShuffleFetchFailuresDownstreamDelay +spark.testing.injectShuffleFetchFailuresResultStageDelay +spark.testing.injectShuffleForceChecksumMismatchOnRecompute spark.testing.memory spark.testing.nCoresPerExecutor spark.testing.nExecutorsPerHost