[perf] training: add attn_only selective activation checkpointing#1410
[perf] training: add attn_only selective activation checkpointing#1410rich7420 wants to merge 3 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new selective activation-checkpointing policy called attn_only. This policy optimizes memory usage and training speed by only saving the expensive attention forward outputs and collective communication outputs, while recomputing cheaper GEMM/FFN intermediates. Unit tests have been added to verify the behavior of this new policy. The reviewer suggests two key improvements: first, saving all collective communication operators in the _c10d_functional namespace (rather than just reduce_scatter_tensor) to prevent potential distributed deadlocks; second, caching the results of the attention forward operator check to eliminate string-matching overhead during training.
| _reduce_scatter = torch.ops._c10d_functional.reduce_scatter_tensor.default | ||
|
|
||
| def _attn_only_policy(ctx, func, *args, **kwargs): | ||
| save = _is_attention_forward(func) or func == _reduce_scatter | ||
| return CheckpointPolicy.MUST_SAVE if save else CheckpointPolicy.PREFER_RECOMPUTE |
There was a problem hiding this comment.
Instead of only saving reduce_scatter_tensor, it is much safer and more robust to save all collective communication operators (e.g., all_gather_into_tensor, all_reduce, etc.) in the _c10d_functional namespace. Recomputing any communication operator during the backward pass is extremely expensive and can easily lead to distributed deadlocks or race conditions if different ranks execute them out of order. Checking for _c10d_functional in the operator's string representation will automatically match and save all functional collectives.
| _reduce_scatter = torch.ops._c10d_functional.reduce_scatter_tensor.default | |
| def _attn_only_policy(ctx, func, *args, **kwargs): | |
| save = _is_attention_forward(func) or func == _reduce_scatter | |
| return CheckpointPolicy.MUST_SAVE if save else CheckpointPolicy.PREFER_RECOMPUTE | |
| def _attn_only_policy(ctx, func, *args, **kwargs): | |
| # Save attention forward and any collective communication ops to avoid re-communicating during backward | |
| save = _is_attention_forward(func) or "_c10d_functional" in str(func) | |
| return CheckpointPolicy.MUST_SAVE if save else CheckpointPolicy.PREFER_RECOMPUTE |
| def _is_attention_forward(func) -> bool: | ||
| """True for the (expensive-to-recompute) attention forward op, across | ||
| backends: flash_attn lib FA2 (`flash_attn::_flash_attn_forward`), FA3 | ||
| (`flash_attn_3::fwd`), FastVideo custom ops, CuTe variants, and aten SDPA. | ||
| Matched by op name so we don't depend on which op object is registered at | ||
| import time (the active backend varies at runtime). Backward ops are | ||
| excluded so only the forward output is saved.""" | ||
| s = str(func) | ||
| if "backward" in s: | ||
| return False | ||
| # FA2 names the op "...forward", FA3 (flash_attn_interface) names it "...fwd". | ||
| return ("flash_attn" in s and ("forward" in s or "fwd" in s)) or "_scaled_dot_product" in s |
There was a problem hiding this comment.
Converting func to a string and performing multiple substring checks on every single operator execution inside the checkpointed block can introduce non-trivial Python overhead. Since the set of unique operators executed during training is very small, caching the result of _is_attention_forward using a dictionary (with a fallback for unhashable objects) will completely eliminate this overhead after the first step.
_ATTN_FWD_CACHE = {}
def _is_attention_forward(func) -> bool:
"""True for the (expensive-to-recompute) attention forward op, across
backends: flash_attn lib FA2 (`flash_attn::_flash_attn_forward`), FA3
(`flash_attn_3::fwd`), FastVideo custom ops, CuTe variants, and aten SDPA.
Matched by op name so we don't depend on which op object is registered at
import time (the active backend varies at runtime). Backward ops are
excluded so only the forward output is saved."""
try:
if func in _ATTN_FWD_CACHE:
return _ATTN_FWD_CACHE[func]
except TypeError:
pass
s = str(func)
if "backward" in s:
res = False
else:
# FA2 names the op "...forward", FA3 (flash_attn_interface) names it "...fwd".
res = ("flash_attn" in s and ("forward" in s or "fwd" in s)) or "_scaled_dot_product" in s
try:
_ATTN_FWD_CACHE[func] = res
except TypeError:
pass
return res
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 PR merge requirementsWaiting for
This rule is failing.
|
|
Hi @rich7420 — automated review from Gob, one of @SolitaryThinker's AI reviewers. Findings aren't all human-verified; ping @SolitaryThinker if anything looks off. TL;DRThe CLI default remains backwards-compatible, and the new Verdict: Changes requested
Findings (formatted for upload)[S1] Add a real forward/backward parity test for
|
FULL activation checkpointing recomputes the entire transformer block in backward, paying the expensive attention forward (flash_fwd) a second time. The stock 'ops' selective mode avoids that but OOMs at seq 72k because it also saves the FFN/GEMM mm intermediates (~51GB across 30 blocks) which are large but cheap to recompute. attn_only is a tailored per-block policy: MUST_SAVE only the attention forward output (~221MB/block, small) and collective (reduce_scatter) outputs, and PREFER_RECOMPUTE everything else. This eliminates the attention recompute that FULL pays while keeping peak memory low. Measured (Wan2.1-T2V-1.3B, 720x1280x77f, batch=1, H100, no nsys): full: 26.82 GB peak, 4.19 s/step attn_only: 28.88 GB peak (+2 GB, fits H100 and L40S 48GB), 3.75 s/step (-10.5%) Numerically equivalent: attn_only-vs-full loss differs 0.0003-0.0009, within the full-vs-full GPU non-determinism noise floor (0.0023, flash-attn backward atomics). Orthogonal to torch.compile (matches the attention op by name across FA2/FA3/SDPA).
CPU-only coverage for the new attn_only policy: CheckpointType.ATTN_ONLY value, the backend-agnostic attention-forward classifier (_is_attention_forward matches FA2 forward / FA3 fwd / aten SDPA, excludes backward and the FFN/QKV mm), per-block checkpoint_wrapper application, and the no-transformer-blocks / unknown-type error paths. No GPU or distributed init required.
…p decision Address review feedback on the attn_only policy: - Save every _c10d_functional collective (not just reduce_scatter_tensor) so a collective is never re-issued during the backward recompute — recomputing one is expensive and a cross-rank ordering/deadlock hazard. This also covers FSDP2's all_gather_into_tensor in the checkpointed forward, which the previous reduce_scatter-only clause missed. - Cache the per-op save decision (_attn_only_must_save); the set of distinct ops per step is tiny, so the string matching runs once per unique op instead of once per call. Unit tests extended to cover collective-save and the cached decision.
1986b1f to
8e18bc3
Compare
Motivation
Video DiT training has to run with activation checkpointing — at video sequence
lengths (e.g. 72k tokens) the activations don't fit otherwise. But the existing
modes leave a gap:
fullrecomputes the entire transformer block in backward, so itre-runs the attention forward — the single most expensive op — a second
time. On Wan2.1-T2V-1.3B we measured the attention forward kernel running
1200× under
fullwhere only 600× are needed; the other 600 are pureredundant recompute, ~12% of GPU compute burned every step.
ops(the stock selective mode) is meant to avoid that, but it OOMsat video sequence lengths because it also saves the FFN/GEMM intermediates
(~51 GB across the blocks) — which are large but cheap to recompute.
So in practice there is no usable middle option: recompute everything (
full,wastes the attention recompute) or OOM (
ops).What this adds
attn_only: a per-block selective policy that saves only the attentionforward output (small — ~221 MB/block at seq 72k) and any functional
collective output (so communication is never re-issued in backward), and
recomputes everything else. It removes the attention recompute
fullpays,while staying within memory.
Why this lever
fabric-independent (helps on both PCIe and NVSwitch) and stacks with
torch.compile(which won't merge separatenn.Linearweights or changethe recompute structure).
full; numerically equivalent (saves theforward output rather than recomputing it).
attn_only, the dominant kernel is theflash backward (gradient computation — irreducible). This captures the
remaining activation-checkpoint headroom.
activation_checkpoint.pyis shared by both the legacy (fastvideo/training/)and modular (
fastvideo/train/) stacks, so both benefit from one change.Opt-in, default-preserving
The default is unchanged (no AC unless requested).
attn_onlyjoinsfull/ops/block_skipas a new point on the memory↔speed curve — it is nota replacement for
full, which remains the lowest-memory option formemory-constrained training.
Use
attn_onlywhen you have headroom overfull.Changes
CheckpointType.ATTN_ONLY+ dispatch branch_apply_activation_checkpointing_attn_only: MUST_SAVE the attention forwardoutput and any
_c10d_functionalcollective (so a collective is neverre-issued during the backward recompute — that is expensive and a cross-rank
ordering hazard), PREFER_RECOMPUTE the rest; the per-op decision is cached
_is_attention_forward: backend-agnostic op-name match (FA2_flash_attn_forward, FA3fwd, aten SDPA; backward excluded) — robust tothe runtime-selected attention backend
attn_onlyto--enable-gradient-checkpointing-typechoicesResults
Kernel-level (nsys, 1×H100, Wan2.1-T2V-1.3B 720×1280×77f) — the direct,
hardware-independent measurement: the only kernel whose count changes is
flash_fwd_kernel(1200 → 600, −4.67 s); flash-bwd / GEMM / elementwisecounts are conserved. Total GPU compute 40.2 s → 35.3 s (−12.3%), ~95% of
which is the flash-fwd halving — i.e. the change removes exactly the
attention-forward recompute and nothing else.
Wall-clock step time (steady-state, no profiler):
H100×2 reproduced across two independent runs (2.28→2.09 in the other). It also speeds up multi-GPU L40S,
but Modal's per-job L40S allocation has large PCIe-topology variance (identical
work ranged 5.5–9.1 s/step), so we don't quote a clean L40S wall-time delta — the
−12.3% compute figure above is the allocation-independent one. The win narrows as
GPU count grows (fixed comm cost dilutes the compute saving) and scales with the
attention fraction of the step — largest for long-sequence video training, which
is why it is opt-in rather than a forced default.
Numerically equivalent: per-step loss differs from
fullwithin theGPU non-determinism noise floor (full-vs-full on the same seed already differs by
~0.002 from flash-attn backward atomics). By construction it saves the forward
output rather than recomputing it, so it introduces no systematic drift.
Usage
Test plan
fastvideo/tests/training/checkpoint/test_activation_checkpoint_attn_only.py):CheckpointType.ATTN_ONLY, the attention-forward classifier(FA2/FA3/SDPA forward matched; backward and FFN/QKV
mmexcluded), thesave decision (attention forward + functional collectives saved, GEMM
recomputed) and its caching, per-block
checkpoint_wrapperapplication, andthe error paths. No GPU required.
attn_onlyon 1–2×H100 and 2×L40S,20 steps, no OOM.