Skip to content

[perf] training: add attn_only selective activation checkpointing#1410

Open
rich7420 wants to merge 3 commits into
hao-ai-lab:mainfrom
rich7420:attn-only-selective-ac
Open

[perf] training: add attn_only selective activation checkpointing#1410
rich7420 wants to merge 3 commits into
hao-ai-lab:mainfrom
rich7420:attn-only-selective-ac

Conversation

@rich7420

@rich7420 rich7420 commented May 27, 2026

Copy link
Copy Markdown
Contributor

Motivation

Video DiT training has to run with activation checkpointing — at video sequence
lengths (e.g. 72k tokens) the activations don't fit otherwise. But the existing
modes leave a gap:

  • full recomputes the entire transformer block in backward, so it
    re-runs the attention forward — the single most expensive op — a second
    time. On Wan2.1-T2V-1.3B we measured the attention forward kernel running
    1200× under full where only 600× are needed; the other 600 are pure
    redundant recompute, ~12% of GPU compute burned every step.
  • ops (the stock selective mode) is meant to avoid that, but it OOMs
    at video sequence lengths because it also saves the FFN/GEMM intermediates
    (~51 GB across the blocks) — which are large but cheap to recompute.

So in practice there is no usable middle option: recompute everything (full,
wastes the attention recompute) or OOM (ops).

What this adds

attn_only: a per-block selective policy that saves only the attention
forward output
(small — ~221 MB/block at seq 72k) and any functional
collective output (so communication is never re-issued in backward), and
recomputes everything else. It removes the attention recompute full pays,
while staying within memory.

Why this lever

  • Removes real work (FLOPs), not host-wait → the speedup is
    fabric-independent (helps on both PCIe and NVSwitch) and stacks with
    torch.compile
    (which won't merge separate nn.Linear weights or change
    the recompute structure).
  • Near-zero cost: +2 GB peak vs full; numerically equivalent (saves the
    forward output rather than recomputing it).
  • Last recoverable AC win: after attn_only, the dominant kernel is the
    flash backward (gradient computation — irreducible). This captures the
    remaining activation-checkpoint headroom.

activation_checkpoint.py is shared by both the legacy (fastvideo/training/)
and modular (fastvideo/train/) stacks, so both benefit from one change.

Opt-in, default-preserving

The default is unchanged (no AC unless requested). attn_only joins
full/ops/block_skip as a new point on the memory↔speed curve — it is not
a replacement for full, which remains the lowest-memory option for
memory-constrained training.

mode peak (1×H100) note
full 26.8 GB recompute all — lowest memory
attn_only 28.9 GB +2 GB, no attention recompute
ops OOM also saves the FFN mm

Use attn_only when you have headroom over full.

Changes

  • CheckpointType.ATTN_ONLY + dispatch branch
  • _apply_activation_checkpointing_attn_only: MUST_SAVE the attention forward
    output and any _c10d_functional collective (so a collective is never
    re-issued during the backward recompute — that is expensive and a cross-rank
    ordering hazard), PREFER_RECOMPUTE the rest; the per-op decision is cached
  • _is_attention_forward: backend-agnostic op-name match (FA2
    _flash_attn_forward, FA3 fwd, aten SDPA; backward excluded) — robust to
    the runtime-selected attention backend
  • CLI: add attn_only to --enable-gradient-checkpointing-type choices

Results

Kernel-level (nsys, 1×H100, Wan2.1-T2V-1.3B 720×1280×77f) — the direct,
hardware-independent measurement: the only kernel whose count changes is
flash_fwd_kernel (1200 → 600, −4.67 s); flash-bwd / GEMM / elementwise
counts are conserved. Total GPU compute 40.2 s → 35.3 s (−12.3%), ~95% of
which is the flash-fwd halving — i.e. the change removes exactly the
attention-forward recompute and nothing else.

Wall-clock step time (steady-state, no profiler):

config full attn_only speedup
H100×1 4.19 s 3.75 s −10.5%
H100×2 2.25 s 2.08 s −7.6%

H100×2 reproduced across two independent runs (2.28→2.09 in the other). It also speeds up multi-GPU L40S,
but Modal's per-job L40S allocation has large PCIe-topology variance (identical
work ranged 5.5–9.1 s/step), so we don't quote a clean L40S wall-time delta — the
−12.3% compute figure above is the allocation-independent one. The win narrows as
GPU count grows (fixed comm cost dilutes the compute saving) and scales with the
attention fraction of the step — largest for long-sequence video training, which
is why it is opt-in rather than a forced default.

Numerically equivalent: per-step loss differs from full within the
GPU non-determinism noise floor (full-vs-full on the same seed already differs by
~0.002 from flash-attn backward atomics). By construction it saves the forward
output rather than recomputing it, so it introduces no systematic drift.

Usage

--enable_gradient_checkpointing_type attn_only

Test plan

  • [added] CPU unit test (fastvideo/tests/training/checkpoint/test_activation_checkpoint_attn_only.py):
    CheckpointType.ATTN_ONLY, the attention-forward classifier
    (FA2/FA3/SDPA forward matched; backward and FFN/QKV mm excluded), the
    save decision (attention forward + functional collectives saved, GEMM
    recomputed) and its caching, per-block checkpoint_wrapper application, and
    the error paths. No GPU required.
  • [manual] e2e training runs clean with attn_only on 1–2×H100 and 2×L40S,
    20 steps, no OOM.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new selective activation-checkpointing policy called attn_only. This policy optimizes memory usage and training speed by only saving the expensive attention forward outputs and collective communication outputs, while recomputing cheaper GEMM/FFN intermediates. Unit tests have been added to verify the behavior of this new policy. The reviewer suggests two key improvements: first, saving all collective communication operators in the _c10d_functional namespace (rather than just reduce_scatter_tensor) to prevent potential distributed deadlocks; second, caching the results of the attention forward operator check to eliminate string-matching overhead during training.

Comment on lines +114 to +118
_reduce_scatter = torch.ops._c10d_functional.reduce_scatter_tensor.default

def _attn_only_policy(ctx, func, *args, **kwargs):
save = _is_attention_forward(func) or func == _reduce_scatter
return CheckpointPolicy.MUST_SAVE if save else CheckpointPolicy.PREFER_RECOMPUTE

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Instead of only saving reduce_scatter_tensor, it is much safer and more robust to save all collective communication operators (e.g., all_gather_into_tensor, all_reduce, etc.) in the _c10d_functional namespace. Recomputing any communication operator during the backward pass is extremely expensive and can easily lead to distributed deadlocks or race conditions if different ranks execute them out of order. Checking for _c10d_functional in the operator's string representation will automatically match and save all functional collectives.

Suggested change
_reduce_scatter = torch.ops._c10d_functional.reduce_scatter_tensor.default
def _attn_only_policy(ctx, func, *args, **kwargs):
save = _is_attention_forward(func) or func == _reduce_scatter
return CheckpointPolicy.MUST_SAVE if save else CheckpointPolicy.PREFER_RECOMPUTE
def _attn_only_policy(ctx, func, *args, **kwargs):
# Save attention forward and any collective communication ops to avoid re-communicating during backward
save = _is_attention_forward(func) or "_c10d_functional" in str(func)
return CheckpointPolicy.MUST_SAVE if save else CheckpointPolicy.PREFER_RECOMPUTE

Comment on lines +90 to +101
def _is_attention_forward(func) -> bool:
"""True for the (expensive-to-recompute) attention forward op, across
backends: flash_attn lib FA2 (`flash_attn::_flash_attn_forward`), FA3
(`flash_attn_3::fwd`), FastVideo custom ops, CuTe variants, and aten SDPA.
Matched by op name so we don't depend on which op object is registered at
import time (the active backend varies at runtime). Backward ops are
excluded so only the forward output is saved."""
s = str(func)
if "backward" in s:
return False
# FA2 names the op "...forward", FA3 (flash_attn_interface) names it "...fwd".
return ("flash_attn" in s and ("forward" in s or "fwd" in s)) or "_scaled_dot_product" in s

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Converting func to a string and performing multiple substring checks on every single operator execution inside the checkpointed block can introduce non-trivial Python overhead. Since the set of unique operators executed during training is very small, caching the result of _is_attention_forward using a dictionary (with a fallback for unhashable objects) will completely eliminate this overhead after the first step.

_ATTN_FWD_CACHE = {}


def _is_attention_forward(func) -> bool:
    """True for the (expensive-to-recompute) attention forward op, across
    backends: flash_attn lib FA2 (`flash_attn::_flash_attn_forward`), FA3
    (`flash_attn_3::fwd`), FastVideo custom ops, CuTe variants, and aten SDPA.
    Matched by op name so we don't depend on which op object is registered at
    import time (the active backend varies at runtime). Backward ops are
    excluded so only the forward output is saved."""
    try:
        if func in _ATTN_FWD_CACHE:
            return _ATTN_FWD_CACHE[func]
    except TypeError:
        pass

    s = str(func)
    if "backward" in s:
        res = False
    else:
        # FA2 names the op "...forward", FA3 (flash_attn_interface) names it "...fwd".
        res = ("flash_attn" in s and ("forward" in s or "fwd" in s)) or "_scaled_dot_product" in s

    try:
        _ATTN_FWD_CACHE[func] = res
    except TypeError:
        pass
    return res

@mergify mergify Bot added type: perf Performance improvement scope: training Training pipeline, methods, configs scope: infra CI, tests, Docker, build labels May 27, 2026
@mergify

mergify Bot commented May 27, 2026

Copy link
Copy Markdown
Contributor

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 PR merge requirements

Waiting for

  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
This rule is failing.
  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
  • check-success=fastcheck-passed
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model|skill|skills|infra)\]

@SolitaryThinker

Copy link
Copy Markdown
Collaborator

Hi @rich7420 — automated review from Gob, one of @SolitaryThinker's AI reviewers. Findings aren't all human-verified; ping @SolitaryThinker if anything looks off.

TL;DR

The CLI default remains backwards-compatible, and the new attn_only path is plausibly using selective activation checkpointing to save attention and collective outputs while recomputing other block ops. The blockers are coverage and backend support: the tests never run a forward/backward parity check, and the op classifier recognizes Flash/SDPA names but not several shipped FastVideo attention backends.

Verdict: Changes requested

  • S0 (blockers): 0
  • S1 (must-fix): 2
  • S2 (should-fix; surfaced if persistent or important): 1
  • S3 (discussion): not shown here; see review.md

Findings (formatted for upload)

[S1] Add a real forward/backward parity test for attn_only

What: The new test file is explicitly CPU-only and says it exercises the op classifier and wrapper structure without running a forward/backward. The dummy block forward is marked never executed, so no test compares outputs, losses, input gradients, or parameter gradients for none, full, and attn_only.

Why it matters: Selective activation checkpointing can look structurally correct while still breaking autograd, RNG behavior, custom-op caching, or saved-tensor policy semantics. This PR changes training behavior, so the minimum safety bar is numerical and gradient parity against the uncheckpointed path.

Suggested fix: Add a small deterministic module with one real attention path plus an MLP, run one forward/backward under no checkpointing, full, and attn_only, then assert allclose for outputs, loss, input grads, and parameter grads. Keep a CPU/SDPA parity test always runnable, and add a CUDA backend case for at least one FastVideo-shipped attention implementation when available.

Evidence: fastvideo/tests/training/checkpoint/test_activation_checkpoint_attn_only.py:1, fastvideo/tests/training/checkpoint/test_activation_checkpoint_attn_only.py:72, fastvideo/tests/training/checkpoint/test_activation_checkpoint_attn_only.py:82


[S1] Gate or cover non-Flash FastVideo attention backends

What: _is_attention_forward only returns true for op names containing flash_attn plus forward/fwd, or _scaled_dot_product. FastVideo also ships attention backends whose calls do not match that predicate, including sageattn, video_sparse_attn, moba_attn_varlen, and SLA's custom autograd attention.

Why it matters: For those backends, attn_only silently treats the attention kernel as PREFER_RECOMPUTE, so the documented guarantee that attention forward outputs are saved is false and custom-kernel re-entry during backward is unverified. That undermines both the performance claim and the training-safety story for any training recipe using those backends.

Suggested fix: Either expand the policy to recognize every training-supported attention backend and add backend-specific coverage, or fail fast when attn_only is requested with an unsupported backend. If the intended support matrix is only Flash/SDPA, document that in the CLI help and tests instead of claiming broad FastVideo attention coverage.

Evidence: fastvideo/training/activation_checkpoint.py:101, fastvideo/attention/backends/sage_attn.py:57, fastvideo/attention/backends/video_sparse_attn.py:292, fastvideo/attention/backends/vmoba.py:187, fastvideo/attention/backends/sla.py:308


[S2] Add distributed coverage for the collective-saving policy

What: The implementation explicitly saves _c10d_functional outputs to avoid recomputing FSDP2 collectives, but the new tests are CPU-only and avoid distributed initialization entirely.

Why it matters: FSDP/SP plus activation checkpointing is where collective ordering bugs and backward deadlocks show up. A string classifier test cannot validate that the checkpoint wrapper, selective policy, and distributed collectives interact safely in a real train step.

Suggested fix: Add a minimal two-rank distributed regression, even if it is nightly/GPU-gated, that runs one forward/backward with FSDP or the sequence-parallel path under attn_only and compares with the existing checkpoint mode. At minimum, assert no extra collective is issued during recompute for the supported FSDP path.

Evidence: fastvideo/training/activation_checkpoint.py:109, fastvideo/tests/training/checkpoint/test_activation_checkpoint_attn_only.py:1


— Gob (@SolitaryThinker's AI reviewer). Full review (including S3 items) is archived locally.

rich7420 added 3 commits June 6, 2026 00:11
FULL activation checkpointing recomputes the entire transformer block in
backward, paying the expensive attention forward (flash_fwd) a second time.
The stock 'ops' selective mode avoids that but OOMs at seq 72k because it also
saves the FFN/GEMM mm intermediates (~51GB across 30 blocks) which are large
but cheap to recompute.

attn_only is a tailored per-block policy: MUST_SAVE only the attention forward
output (~221MB/block, small) and collective (reduce_scatter) outputs, and
PREFER_RECOMPUTE everything else. This eliminates the attention recompute that
FULL pays while keeping peak memory low.

Measured (Wan2.1-T2V-1.3B, 720x1280x77f, batch=1, H100, no nsys):
  full:      26.82 GB peak, 4.19 s/step
  attn_only: 28.88 GB peak (+2 GB, fits H100 and L40S 48GB), 3.75 s/step (-10.5%)

Numerically equivalent: attn_only-vs-full loss differs 0.0003-0.0009, within the
full-vs-full GPU non-determinism noise floor (0.0023, flash-attn backward atomics).
Orthogonal to torch.compile (matches the attention op by name across FA2/FA3/SDPA).
CPU-only coverage for the new attn_only policy: CheckpointType.ATTN_ONLY value,
the backend-agnostic attention-forward classifier (_is_attention_forward matches
FA2 forward / FA3 fwd / aten SDPA, excludes backward and the FFN/QKV mm), per-block
checkpoint_wrapper application, and the no-transformer-blocks / unknown-type error
paths. No GPU or distributed init required.
…p decision

Address review feedback on the attn_only policy:
- Save every _c10d_functional collective (not just reduce_scatter_tensor) so a
  collective is never re-issued during the backward recompute — recomputing one
  is expensive and a cross-rank ordering/deadlock hazard. This also covers
  FSDP2's all_gather_into_tensor in the checkpointed forward, which the previous
  reduce_scatter-only clause missed.
- Cache the per-op save decision (_attn_only_must_save); the set of distinct ops
  per step is tiny, so the string matching runs once per unique op instead of
  once per call.
Unit tests extended to cover collective-save and the cached decision.
@rich7420 rich7420 force-pushed the attn-only-selective-ac branch 3 times, most recently from 1986b1f to 8e18bc3 Compare June 5, 2026 17:53
@alexzms alexzms self-requested a review June 10, 2026 22:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

scope: infra CI, tests, Docker, build scope: training Training pipeline, methods, configs type: perf Performance improvement

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants