Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions core/src/main/scala/org/apache/spark/internal/config/Tests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,39 @@ private[spark] object Tests {

val INJECT_SHUFFLE_FETCH_FAILURES =
ConfigBuilder("spark.testing.injectShuffleFetchFailures")
.doc("Injecting fetch failures for shuffle stages by providing an invalid BlockManager " +
"location for the first stage attempt. Testing only flag!")
.doc("Corrupt the registered MapStatus of partition 0 on the first successful attempt " +
"of every shuffle map stage, to induce downstream FetchFailed and stage retry. " +
"Testing only.")
.booleanConf
.createWithDefault(false)

val INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY =
ConfigBuilder("spark.testing.injectShuffleFetchFailuresDownstreamDelay")
.doc("Used with INJECT_SHUFFLE_FETCH_FAILURES. Defer the producer's partition-0 " +
"corruption until N ShuffleMapStage consumer task successes have been observed. " +
"Default 1; set to 0 to corrupt at registration. Testing only.")
.intConf
.checkValue(_ >= 0, "Downstream-success delay must be non-negative")
.createWithDefault(1)

val INJECT_SHUFFLE_FETCH_FAILURES_RESULT_STAGE_DELAY =
ConfigBuilder("spark.testing.injectShuffleFetchFailuresResultStageDelay")
.doc("Used with INJECT_SHUFFLE_FETCH_FAILURES. Counterpart to " +
"INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY for ResultStage consumers. With the " +
"default 0, when a ResultStage is the consumer of a pending corruption it is corrupted " +
"before the result tasks dispatch, so the result stage has no completed tasks when " +
"INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE fires (the rollback would " +
"otherwise abort the result stage, since OSS Spark does not support rolling result " +
"stages back). Set to N > 0 to defer until N result-stage tasks have succeeded - this " +
"is the only way to actually exercise the result-stage abort path. Testing only.")
.intConf
.checkValue(_ >= 0, "Result-stage-success delay must be non-negative")
.createWithDefault(0)

val INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE =
ConfigBuilder("spark.testing.injectShuffleForceChecksumMismatchOnRecompute")
.doc("Used with INJECT_SHUFFLE_FETCH_FAILURES. Flag the recompute as checksum " +
"mismatched, forcing downstream `rollbackSucceedingStages`. Testing only.")
.booleanConf
.createWithDefault(false)

Expand Down
164 changes: 153 additions & 11 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,30 @@ private[spark] class DAGScheduler(

private[spark] val jobIdToQueryExecutionId = new ConcurrentHashMap[Int, java.lang.Long]()

// For INJECT_SHUFFLE_FETCH_FAILURES: per-shuffleId, the stage attempt whose partition-0 task
// we corrupted. Read to (a) avoid re-corrupting that partition on recompute, and (b) decide
// when to fire INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE - the recompute is the
// task whose stageAttemptId is not the recorded one.
private val injectShuffleFetchFailuresCorruptedAttempt =
new ConcurrentHashMap[Int, Int]()

// For INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY > 0: shuffles whose mapper-0 corruption
// has been deferred until enough downstream consumer tasks succeed. The value is the mapId
// we will eventually swap to an invalid BlockManagerId.
private val injectShuffleFetchFailuresPendingDelayedCorruption =
new ConcurrentHashMap[Int, Long]()

// For INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY: per-shuffle counter of consumer
// task-success events observed so far.
private val injectShuffleFetchFailuresDownstreamSuccessCount =
new ConcurrentHashMap[Int, Int]()

// Bogus BlockManagerId used by INJECT_SHUFFLE_FETCH_FAILURES to mark a corrupted MapStatus.
// Any consumer task that tries to fetch a block from this location will FetchFailed because
// the executorId is INVALID_EXECUTOR_ID.
private val injectShuffleFetchFailuresInvalidBlockManagerId =
BlockManagerId(BlockManagerId.INVALID_EXECUTOR_ID, "invalid", 1, None)

// Job groups that are cancelled with `cancelFutureJobs` as true, with at most
// `NUM_CANCELLED_JOB_GROUPS_TO_TRACK` stored. On a new job submission, if its job group is in
// this set, the job will be immediately cancelled.
Expand Down Expand Up @@ -933,6 +957,11 @@ private[spark] class DAGScheduler(
}
for ((k, v) <- shuffleIdToMapStage.find(_._2 == stage)) {
shuffleIdToMapStage.remove(k)
if (Utils.isTesting) {
injectShuffleFetchFailuresCorruptedAttempt.remove(k)
injectShuffleFetchFailuresPendingDelayedCorruption.remove(k)
injectShuffleFetchFailuresDownstreamSuccessCount.remove(k)
}
}
if (waitingStages.contains(stage)) {
logDebug("Removing stage %d from waiting set.".format(stageId))
Expand Down Expand Up @@ -1618,6 +1647,116 @@ private[spark] class DAGScheduler(
}
}

/**
* Returns true when this just-completed shuffle map task should have its output corrupted by
* the test-only fetch-failure injection. We corrupt only the partition-0 task, and only on
* the stage attempt that first successfully completes partition 0 - latched into
* injectShuffleFetchFailuresCorruptedAttempt. Recomputes (later attempts) of that partition
* are left clean so the consumer can make progress on its retry. The latch is per-shuffle,
* so non-leaf stages whose earlier attempts failed on fetch from upstream are still
* corrupted on their first successful attempt.
*/
private def shouldCorruptShuffleOutputForTest(shuffleId: Int, task: Task[_]): Boolean = {
if (task.partitionId != 0) return false
val recorded = injectShuffleFetchFailuresCorruptedAttempt.computeIfAbsent(
shuffleId, _ => task.stageAttemptId)
recorded == task.stageAttemptId
}

/**
* Apply the test-only fetch-failure injection to this just-completed map task: with
* DOWNSTREAM_DELAY > 0 record the mapId so maybeApplyDelayedCorruptionForTest can corrupt
* it later, otherwise update the MapStatus location to
* injectShuffleFetchFailuresInvalidBlockManagerId inline.
*/
private def corruptShuffleOutputForTest(shuffleId: Int, status: MapStatus): Unit = {
val downstreamDelay =
sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY)
if (downstreamDelay > 0) {
injectShuffleFetchFailuresPendingDelayedCorruption.put(shuffleId, status.mapId)
} else {
status.updateLocation(injectShuffleFetchFailuresInvalidBlockManagerId)
}
}

/**
* For INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE: returns true when this shuffle map
* task is the recompute of a partition whose previous successful attempt was the one corrupted
* by INJECT_SHUFFLE_FETCH_FAILURES. Forcing the mismatch on the recompute drives the rollback
* path - downstream ShuffleMapStages get cleaned up and re-run fully, downstream ResultStages
* are aborted.
*/
private def isForcedChecksumMismatchForTest(shuffleId: Int, task: Task[_]): Boolean = {
if (!sc.conf.get(config.Tests.INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE)) return false
if (task.partitionId != 0) return false
val recorded =
injectShuffleFetchFailuresCorruptedAttempt.getOrDefault(shuffleId, -1)
recorded >= 0 && recorded != task.stageAttemptId
}

/**
* Apply the deferred mapper-0 corruption (configured via
* INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY for ShuffleMapStage consumers and
* INJECT_SHUFFLE_FETCH_FAILURES_RESULT_STAGE_DELAY for ResultStage consumers) when enough
* consumer tasks have succeeded. Walks the just-completed stage's direct shuffle parents,
* increments the per-shuffle consumer-success counter, and corrupts the registered MapStatus
* when the counter reaches the configured delay.
*/
private def maybeApplyDelayedCorruptionForTest(stage: Stage): Unit = {
if (!sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES)) return
if (injectShuffleFetchFailuresPendingDelayedCorruption.isEmpty) return
val isResultStage = stage.isInstanceOf[ResultStage]
val delay = if (isResultStage) {
sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES_RESULT_STAGE_DELAY)
} else {
sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY)
}
if (delay <= 0) return // delay == 0 was already handled at submission time

val parentShuffleIds = stage.parents.collect {
case sms: ShuffleMapStage => sms.shuffleDep.shuffleId
}
parentShuffleIds.foreach { shuffleId =>
if (injectShuffleFetchFailuresPendingDelayedCorruption.containsKey(shuffleId)) {
val newCount = injectShuffleFetchFailuresDownstreamSuccessCount.merge(shuffleId, 1, _ + _)
if (newCount >= delay) {
val mapId = injectShuffleFetchFailuresPendingDelayedCorruption.remove(shuffleId)
mapOutputTracker.updateMapOutput(
shuffleId, mapId, injectShuffleFetchFailuresInvalidBlockManagerId)
logInfo(s"Test injection: corrupted mapper-0 of shuffle $shuffleId after " +
s"$newCount downstream consumer successes")
}
}
}
}

/**
* For INJECT_SHUFFLE_FETCH_FAILURES_RESULT_STAGE_DELAY = 0: when a ResultStage is about to
* dispatch tasks, fire any pending mapper-0 corruption for its direct shuffle parents
* BEFORE result tasks start. This keeps the result stage at zero finished tasks when
* INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE later triggers `rollbackSucceedingStages`,
* so the rollback path does not abort a partially-finished result stage.
*/
private def maybePreemptiveCorruptionForResultStage(stage: Stage): Unit = {
if (!stage.isInstanceOf[ResultStage]) return
if (!sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES)) return
if (sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES_RESULT_STAGE_DELAY) > 0) return
if (injectShuffleFetchFailuresPendingDelayedCorruption.isEmpty) return

val parentShuffleIds = stage.parents.collect {
case sms: ShuffleMapStage => sms.shuffleDep.shuffleId
}
parentShuffleIds.foreach { shuffleId =>
if (injectShuffleFetchFailuresPendingDelayedCorruption.containsKey(shuffleId)) {
val mapId = injectShuffleFetchFailuresPendingDelayedCorruption.remove(shuffleId)
mapOutputTracker.updateMapOutput(
shuffleId, mapId, injectShuffleFetchFailuresInvalidBlockManagerId)
logInfo(s"Test injection: corrupted mapper-0 of shuffle $shuffleId before result-stage " +
s"submission")
}
}
}

private def configureShufflePushMergerLocations(stage: ShuffleMapStage): Unit = {
if (stage.shuffleDep.getMergerLocs.nonEmpty) return
val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
Expand All @@ -1635,6 +1774,10 @@ private[spark] class DAGScheduler(
private def submitMissingTasks(stage: Stage, jobId: Int): Unit = {
logDebug("submitMissingTasks(" + stage + ")")

if (Utils.isTesting) {
maybePreemptiveCorruptionForResultStage(stage)
}

// For statically indeterminate stages being retried, we trigger rollback BEFORE task
// submission. This is more efficient than deferring to task completion because:
// 1. It avoids submitting a partial stage that would need to be cancelled
Expand Down Expand Up @@ -2271,6 +2414,10 @@ private[spark] class DAGScheduler(
taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
}

if (Utils.isTesting) {
maybeApplyDelayedCorruptionForTest(stage)
}

task match {
case rt: ResultTask[_, _] =>
// Cast to ResultStage here because it's part of the ResultTask
Expand Down Expand Up @@ -2342,21 +2489,16 @@ private[spark] class DAGScheduler(
// The epoch of the task is acceptable (i.e., the task was launched after the most
// recent failure we're aware of for the executor), so mark the task's output as
// available.
// For testing purposes, inject fetch failures controlled from the driver-side by
// supplying an invalid location.
if (Utils.isTesting &&
sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES) &&
task.stageAttemptId == 0) {
val currentLocation = status.location
val invalidLocation = BlockManagerId(
execId = BlockManagerId.INVALID_EXECUTOR_ID,
host = currentLocation.host,
port = currentLocation.port,
topologyInfo = currentLocation.topologyInfo)
status.updateLocation(invalidLocation)
shouldCorruptShuffleOutputForTest(shuffleStage.shuffleDep.shuffleId, task)) {
corruptShuffleOutputForTest(shuffleStage.shuffleDep.shuffleId, status)
}
val isChecksumMismatched = mapOutputTracker.registerMapOutput(
val realChecksumMismatched = mapOutputTracker.registerMapOutput(
shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
val isChecksumMismatched = realChecksumMismatched ||
(Utils.isTesting &&
isForcedChecksumMismatchForTest(shuffleStage.shuffleDep.shuffleId, task))
if (isChecksumMismatched) {
shuffleStage.isChecksumMismatched = isChecksumMismatched
// Runtime detection of nondeterministic output via checksum mismatch.
Expand Down
Loading