[feat]: make VSA tile cache configurable for training#1444
[feat]: make VSA tile cache configurable for training#1444SolitaryThinker wants to merge 1 commit into
Conversation
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 PR merge requirementsWaiting for
This rule is failing.
|
There was a problem hiding this comment.
Code Review
This pull request introduces a configurable option, cache_tile_buf, to control the reuse of the per-step padded Video Sparse Attention (VSA) tile buffer across attention layers. It defaults to False during training to prevent out-of-memory (OOM) issues under activation checkpointing, while remaining True by default for inference. The option is integrated into configuration files, CLI arguments, and metadata builders, and is supported by new unit tests. The review feedback suggests a minor improvement to simplify a redundant boolean expression in the configuration parsing logic.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| run_name=str(tr.get("run_name", "") or ""), | ||
| ), | ||
| vsa_sparsity=float(vs.get("sparsity", 0.0) or 0.0), | ||
| vsa_cache_tile_buf=bool(vs.get("cache_tile_buf", False) or False), |
There was a problem hiding this comment.
The expression bool(vs.get("cache_tile_buf", False) or False) contains a redundant or False check. Since bool(None) and bool(False) both evaluate to False, simply calling bool() on the retrieved value is sufficient and more readable.
| vsa_cache_tile_buf=bool(vs.get("cache_tile_buf", False) or False), | |
| vsa_cache_tile_buf=bool(vs.get("cache_tile_buf", False)), |
References
- PEP 8 encourages writing simple and readable expressions. Redundant boolean operations like
or Falseinside abool()cast should be avoided to keep the code clean. (link)
|
This PR has merge conflicts with the base branch. Please rebase: git fetch origin main
git rebase origin/main
# Resolve any conflicts, then:
git push --force-with-lease |
PR hao-ai-lab#1434 hard-coded cache_tile_buf=False at both VSA training build sites to fix the hao-ai-lab#1423 activation-checkpointing OOM. That leaves memory-rich clusters (FSDP/SP-sharded, or runs without full activation checkpointing) unable to keep the per-step tile-buffer-reuse speedup. Expose it as a config knob, defaulting to False to preserve hao-ai-lab#1434's OOM-safe training behavior; opt in with True to keep the cache: - new framework: TrainingConfig.vsa_cache_tile_buf (training.vsa.cache_tile_buf), wired into WanModel._build_attention_metadata. - legacy: TrainingArgs.VSA_cache_tile_buf (--VSA-cache-tile-buf), wired into TrainingPipeline._build_attention_metadata (inherited by distillation / self-forcing). - inference is untouched (default True). Add a value-equivalence test (cached vs uncached tilings are bit-identical, on a padded shape that exercises the zero-pad scatter) and document the option in example.yaml.
327c5f7 to
c9aeca7
Compare
What
Make the VSA per-step tile-buffer cache (
cache_tile_buf) configurable for training, instead of forcing it off.Why
#1434 fixed the #1423 activation-checkpointing OOM by hard-coding
cache_tile_buf=Falseat both VSA training build sites. The cached padded buffer was pinned on the per-stepattn_metadata, so under full activation checkpointing it survived into the backward recompute and inflated peak memory.But the cache is purely a per-step speed optimization (reuse the padded QKVG tile scratch across VSA layers). On memory-rich setups — FSDP/SP-sharded large clusters, or runs that don't use full activation checkpointing — the OOM doesn't bite, and keeping the cache is a free speedup. Hard-coding
Falsetakes that off the table.This exposes it as a config knob, defaulting to
Falseto preserve #1434's OOM-safe training behavior. Opt in withTrueto keep the buffer-reuse speedup.Changes
TrainingConfig.vsa_cache_tile_buf(YAMLtraining.vsa.cache_tile_buf), parsed inconfig.py, wired intoWanModel._build_attention_metadata(wan.py). Inherited by cosmos/hunyuan/matrixgame2/wan_causal.TrainingArgs.VSA_cache_tile_buf(--VSA-cache-tile-buf), wired intoTrainingPipeline._build_attention_metadata(training_pipeline.py). Inherited by distillation + self-forcing.True.tile()paths produce bit-identical tilings, on a padded shape (raw_latent_shape=(5,4,4)→ padded T=8) that actually exercises the zero-pad-position scatter. This pins the property that the flag is a pure performance knob, never a correctness change — and addresses the two MINOR test-depth suggestions from the [bugfix]: release VSA tile cache during training #1434 review.cache_tile_bufinexamples/train/configs/example.yaml.Stacked on #1434
This is stacked on #1434 (which introduces the
cache_tile_bufplumbing —mainhas no such field yet). The net-new change here is one commit; the diff againstmainshows #1434's commit too until it merges. Once #1434 lands, I'll rebase and this collapses to just the configurability commit. Please review/merge #1434 first.Test plan
ruff/yapf/codespellclean on all changed files.(5,4,4)).cache_tile_bufplumbing tests from [bugfix]: release VSA tile cache during training #1434 still pass (structure unchanged).