[perf]: batched CFG for Wan/Cosmos + diffusers T5 recipe alignment#1416
[perf]: batched CFG for Wan/Cosmos + diffusers T5 recipe alignment#1416Mister-Raggs wants to merge 4 commits into
Conversation
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 PR merge requirementsWaiting for
This rule is failing.
|
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds an opt-in (default-on) "batched CFG" path to DenoisingStage that runs the conditional and unconditional model passes in a single batch=2 forward, plus the upstream text-encoding padding work needed to make pos/neg shapes line up. Also adds unit tests and a Modal-based A/B benchmarking harness.
Changes:
- New
use_batched_cfgflag (defaultTrue) plumbed throughFastVideoArgs,VideoGenerator, andDenoisingStage, with auto-fallback to sequential CFG when V2V/I2V/action/camera conditioning is present or pos/neg shapes mismatch. TextEncodingStagenow forcespadding="max_length"and applies a diffusers-style trim+zero-pad to pos/neg embeddings for Wan/Cosmos DiTs when CFG is on, so embeddings can be concatenated along batch.- New tests (
test_denoising_batched_cfg.py,test_wan_cross_attn_mask_math.py) and Modal A/B harness (batched_cfg_ab.py,_batched_cfg_ab_inner.py).
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| fastvideo/fastvideo_args.py | Adds use_batched_cfg config field and CLI flag. |
| fastvideo/entrypoints/video_generator.py | Allows use_batched_cfg to be passed via VideoGenerator. |
| fastvideo/pipelines/stages/denoising.py | Implements batched-CFG forward path with auto-disable + shape-mismatch fallback; honors dit_precision. |
| fastvideo/pipelines/stages/text_encoding.py | Forces max-length tokenization and trim+zero-pad of pos/neg embeddings for Wan/Cosmos when CFG is on. |
| fastvideo/tests/stages/test_denoising_batched_cfg.py | Equivalence + autodetect tests for batched vs sequential CFG. |
| fastvideo/tests/stages/test_wan_cross_attn_mask_math.py | Math-only test confirming SDPA + length mask matches unpadded attention. |
| fastvideo/tests/modal/batched_cfg_ab.py | Modal A/B harness driving baseline vs patched runs and pairwise SSIM. |
| fastvideo/tests/modal/_batched_cfg_ab_inner.py | Inner script: runs a single pass or recomputes SSIM from mp4s. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| _zero_pad_recipe_dits = {"WanVideoConfig", "CosmosVideoConfig", "Cosmos25VideoConfig"} | ||
| _dit_cfg_mro_names = {cls.__name__ for cls in type(fastvideo_args.pipeline_config.dit_config).__mro__} | ||
| _use_canonical_recipe = bool(_zero_pad_recipe_dits & _dit_cfg_mro_names) |
| if not _use_canonical_recipe: | ||
| for ne in neg_embeds_list: | ||
| batch.negative_prompt_embeds.append(ne) | ||
| if batch.negative_attention_mask is not None: | ||
| for nm in neg_masks_list: | ||
| batch.negative_attention_mask.append(nm) | ||
| return batch |
| noise_pred_text = noise_pred | ||
| noise_pred = noise_pred_uncond + current_guidance_scale * (noise_pred_text - | ||
| noise_pred_uncond) | ||
|
|
||
| # Apply guidance rescale if needed (CFG-only path) | ||
| if batch.do_classifier_free_guidance and batch.guidance_rescale > 0.0: | ||
| # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | ||
| noise_pred = self.rescale_noise_cfg( | ||
| noise_pred, | ||
| noise_pred_text, | ||
| guidance_rescale=batch.guidance_rescale, | ||
| ) |
| label = "batched" if use_batched_cfg else "sequential" | ||
| print(f"[inner] [{label}] prompt {i}: {wall:.3f}s -> {mp4s[0]}", flush=True) |
| from fastvideo.utils import PRECISION_TO_TYPE as _PRECISION_TO_TYPE | ||
| _dit_precision = getattr(fastvideo_args.pipeline_config, "dit_precision", "bf16") | ||
| target_dtype = _PRECISION_TO_TYPE.get(_dit_precision, torch.bfloat16) |
| _cfg_conditioning_present = (batch.video_latent is not None or batch.image_latent is not None | ||
| or batch.pil_image is not None or len(batch.image_embeds) > 0 | ||
| or batch.mouse_cond is not None or batch.keyboard_cond is not None | ||
| or batch.c2ws_plucker_emb is not None or batch.camera_states is not None) |
| @pytest.mark.parametrize("disabling_field, value", [ | ||
| ("video_latent", torch.zeros(1, 4, 2, 4, 4)), | ||
| ("image_latent", torch.zeros(1, 4, 2, 4, 4)), | ||
| ("image_embeds", [torch.zeros(1, 4)]), |
| # prebuild-wheels. ~30s install vs ~90min cold source build; sidesteps | ||
| # Kuan's #1389 Volume-cache pattern entirely on the inference path. |
|
|
||
| combined_cond_kwargs = self.prepare_extra_func_kwargs( | ||
| self.transformer.forward, | ||
| { | ||
| "encoder_hidden_states_2": _cat_list_or_none(batch.clip_embedding_neg, batch.clip_embedding_pos), | ||
| "encoder_attention_mask": _cat_list_or_none(batch.negative_attention_mask, | ||
| batch.prompt_attention_mask), | ||
| }, | ||
| ) |
There was a problem hiding this comment.
Code Review
This pull request introduces batched classifier-free guidance (CFG) to run cond and uncond passes as a single batch=2 DiT forward per denoise step, reducing overhead. It updates the denoising and text encoding stages to handle batched CFG precomputation, padding, and shape alignment, while falling back to sequential CFG when necessary. It also adds an A/B testing harness and unit/equivalence tests. The review feedback suggests catching a broader Exception during video decoding in the test harness to ensure fallback reliability, and copying timesteps_r_kwarg instead of re-creating it to avoid discarding other potential keys.
| try: | ||
| from torchvision.io import read_video | ||
| frames, _, _ = read_video(path, pts_unit="sec", output_format="TCHW") | ||
| return frames | ||
| except (ImportError, AttributeError): | ||
| pass |
There was a problem hiding this comment.
Catching only ImportError and AttributeError might not be sufficient if torchvision.io.read_video fails due to runtime or backend issues (e.g., missing FFmpeg/PyAV shared libraries or decoding errors). Catching a broader Exception ensures that the fallback to the manual av decoding path is always executed successfully.
| try: | |
| from torchvision.io import read_video | |
| frames, _, _ = read_video(path, pts_unit="sec", output_format="TCHW") | |
| return frames | |
| except (ImportError, AttributeError): | |
| pass | |
| try: | |
| from torchvision.io import read_video | |
| frames, _, _ = read_video(path, pts_unit="sec", output_format="TCHW") | |
| return frames | |
| except Exception: | |
| pass |
| timesteps_r_kwarg_cfg = timesteps_r_kwarg | ||
| if "timestep_r" in timesteps_r_kwarg and timesteps_r_kwarg["timestep_r"] is not None: | ||
| timesteps_r_kwarg_cfg = {"timestep_r": timesteps_r_kwarg["timestep_r"].repeat(2)} |
There was a problem hiding this comment.
Reassigning timesteps_r_kwarg_cfg as a new dictionary with only the "timestep_r" key will discard any other keys that might be present in timesteps_r_kwarg (e.g., if a subclass or future update extends prepare_extra_func_kwargs to return additional arguments). Copying the dictionary and updating the key in-place is more robust and future-proof.
| timesteps_r_kwarg_cfg = timesteps_r_kwarg | |
| if "timestep_r" in timesteps_r_kwarg and timesteps_r_kwarg["timestep_r"] is not None: | |
| timesteps_r_kwarg_cfg = {"timestep_r": timesteps_r_kwarg["timestep_r"].repeat(2)} | |
| timesteps_r_kwarg_cfg = timesteps_r_kwarg.copy() | |
| if "timestep_r" in timesteps_r_kwarg_cfg and timesteps_r_kwarg_cfg["timestep_r"] is not None: | |
| timesteps_r_kwarg_cfg["timestep_r"] = timesteps_r_kwarg_cfg["timestep_r"].repeat(2) |
|
This PR has merge conflicts with the base branch. Please rebase: git fetch origin main
git rebase origin/main
# Resolve any conflicts, then:
git push --force-with-lease |
31cf91c to
56f284c
Compare
|
Rebased onto current main (post-#1372 AG merge) — composition note Two changes in the force-push:
History stays 4 logical commits — fixup squashed into commit 1 via |
cc63852 to
6cad9f8
Compare
|
@Mister-Raggs thanks for the patch!
So bit-equiv is real but FA3+eager only — FA2 already drifts at 0.96, and compile (what we actually deploy) gave me ~0 speedup + 0.52 worst. The win's basically just that one H100-eager cell. Dug into why: Wan cross-attn has no text masking — Couple nits:
If you wanna go default-on later, the real fix is wiring that dead |
Run cond + uncond as a single batch=2 DiT forward per denoise step instead of two sequential batch=1 forwards. Default OFF so this PR has zero behaviour change for any user who doesn't opt in via `use_batched_cfg=True`. Composes with hao-ai-lab#1372 Adaptive Guidance — mutually exclusive at the gate level (AG selectively skips uncond, batched-CFG forces both; running them together defeats AG's win). Bit-equivalent to sequential CFG on H100 FA3 eager (SSIM=1.000000, Wan 14B 720x1280x49f/30steps, 5 prompts). On other configs (compile, FA2, smaller models) bf16 numerics drift slightly (~0.04 SSIM mean, visually imperceptible per frame-by-frame inspection) due to Inductor kernel selection and batched flash-attn numerics. Perf delta is run-to-run variable. Changes: - fastvideo_args.py: new use_batched_cfg: bool = False field + --use-batched-cfg CLI. Auto-fallback to sequential when V2V/I2V/ TI2V/action/camera conditioning is present (those carry batch=1 conditioning tensors out of scope for this PR) OR when AG is active. - entrypoints/video_generator.py: add use_batched_cfg to _FROM_PRETRAINED_CONVENIENCE_KWARGS. - pipelines/stages/denoising.py: gated batched branch in main DenoisingStage.forward. Cats [neg, pos] along batch dim with shape-match defensive fallback, single forward, chunk(2), existing CFG-combine + guidance_rescale math reused. Other DenoisingStage subclasses (Cosmos25, Dmd, ...) unchanged. - api/schema.py + api/compat.py: EngineConfig.use_batched_cfg field + legacy<->typed mappings (mirrors disable_autocast). - docs/design/inference_schema_parity_inventory.yaml: inventory entry under fastvideo_args.moved. - tests/api/test_parser.py: YAML-roundtrip expected dict updated. Sequential CFG path preserved bit-for-bit for non-batched callers. Other DiTs (HunyuanVideo, LongCat, ...) unaffected.
…ated on use_batched_cfg)
Diffusers' WanPipeline._get_t5_prompt_embeds (pipeline_wan.py:173-190)
and CosmosTextToWorldPipeline._get_t5_prompt_embeds (pipeline_cosmos_
text2world.py:197-237) both:
(a) tokenize(padding="max_length", max_length=N)
(b) T5 encoder forward (sees padded input)
(c) trim T5 output to per-sample real length
(d) re-pad with EXPLICIT ZEROS to max_length
Wan and Cosmos were trained against this recipe. FV's current
variable-length T5 output deviates. This change replicates the
canonical recipe — but only when use_batched_cfg=True, so users who
don't opt in see zero behaviour change.
HunyuanVideo's diffusers pipeline (pipeline_hunyuan_video.py
_get_llama_prompt_embeds) does NOT do (c)+(d) — uses encoder output
as-is, padded positions retain T5's natural bias-driven values.
Gated on dit_config MRO matching {WanVideoConfig, CosmosVideoConfig,
Cosmos25VideoConfig} so HunyuanVideo and every other DiT bypass the
recipe entirely.
Long-term: the right fix is to wire context_lens through to Wan's
attention layer so it can mask padding directly (rather than relying
on training-distribution-zero positions). That would let batched-CFG
pad to max(real_lens) instead of max_length and become bit-equivalent
across more configs. Out of scope for this PR; tracked as a follow-up
(see PR body and ground_truth W3e).
Stub-transformer equivalence between batched and sequential paths, CFG-off equivalence, autodetect-off on 7 conditioning fields (V2V, I2V, TI2V, action, camera, image embeds, etc.).
Single-container two-pass A/B harness modeled on Kuan's nccl_stream_ab.py pattern (used to validate hao-ai-lab#1395). Runs the same seed-pinned prompts twice through the same VideoGenerator (once sequential, once batched), computes pairwise SSIM via pytorch_msssim (matches fastvideo/tests/utils.py:compute_video_ssim_torchvision), prints a per-prompt + total wall/SSIM table. Architecture: the harness ferries pass config + records over JSON files and subprocesses the inner script via /opt/venv/bin/python. Modal's main function process runs in its own add_python="3.12" layer where FastVideo isn't importable; the venv subprocess works around that. Same reason ssim_test.py runs pytest as a subprocess. Features: - Mode-tagged output dirs ({eager,compile}) so multiple legs don't clobber each other in the hf-model-weights Volume. - Recovery function reads existing mp4s by mtime and recomputes SSIM for post-run gap-fill if the in-run SSIM step missed. - Prebuilt FA3 wheel install on Hopper (autodetected from gpu kwarg) via mjun0812/flash-attention-prebuild-wheels v0.9.4 — the same release that supplies the FA2 wheel already baked into the fastvideo-dev image. ~30s install vs Kuan's hao-ai-lab#1389 ~90min cold source build. - Modal Secret integration for HF token (no token-on-the-wire). - --enable-compile + --num-prompts flags for flexible legs. Reproducible by reviewers — see PR body for exact commands.
6cad9f8 to
8512734
Compare
|
This PR has merge conflicts with the base branch. Please rebase: git fetch origin main
git rebase origin/main
# Resolve any conflicts, then:
git push --force-with-lease |
Summary
Adds an opt-in batched classifier-free-guidance path in the main
DenoisingStage: when CFG is on, the cond and uncond passes run as a singlebatch=2DiT forward per denoise step instead of two sequentialbatch=1forwards. Bit-equivalent to sequential on H100 FA3 eager (SSIM = 1.000000).Adds the prerequisite: align Wan / Cosmos text encoding to diffusers' canonical recipe (
tokenize(padding="max_length")→ trim T5 output to real length → re-pad with explicit zeros to max_length). HunyuanVideo and all other DiTs are unaffected — diffusers uses a different recipe for those, so they're gated off.Default flag value is
True; sequential-CFG fallback preserved for V2V/I2V/TI2V/action/camera paths (auto-detected).Problem
Two issues, addressed together because they're coupled:
Sequential CFG wastes launch overhead and prevents kernel-level batching. Every denoise step runs two
current_model(...)calls — once for the conditional prompt, once for the unconditional. On modern GPUs, batch=2 forward is typically 1.4-1.6× the wall of batch=1 (vs 2× sequential), so the mechanism leaves perf on the table.FV's Wan/Cosmos T5 output deviates from their canonical diffusers pipeline. Diffusers'
WanPipeline._get_t5_prompt_embedsandCosmosTextToWorldPipeline._get_t5_prompt_embedsboth tokenize with `padding="max_length"`, then trim the T5 output to each sample's real length, then re-pad with explicit zeros to `max_length`. That's the recipe these models were trained against; FV's variable-length path is a deviation. Aligning makes the model's input match training distribution, AND gives the matched-shape contract that batched-CFG needs to cat pos + neg along the batch dim.Solution
`FastVideoArgs`: new `use_batched_cfg: bool = True` field with `--use-batched-cfg` CLI flag. Added to `_FROM_PRETRAINED_CONVENIENCE_KWARGS`.
`TextEncodingStage`: when CFG is on AND the active `dit_config` MRO contains `WanVideoConfig`, `CosmosVideoConfig`, or `Cosmos25VideoConfig`, applies the diffusers-canonical trim+pad-with-zeros recipe to both pos and neg encodings. Other DiTs (HunyuanVideo, etc.) bypass the gate entirely — their behaviour is unchanged.
`DenoisingStage.forward` (main one): new gated branch when `use_batched_cfg=True` and CFG is on. Cats `[neg, pos]` along the batch dim, single forward at batch=2, `chunk(2)` → existing CFG combine + `guidance_rescale` math reused unchanged. Defensive shape-mismatch fallback to sequential if the batch was constructed without going through `TextEncodingStage`.
Autodetect-off when V2V/I2V/TI2V/action/camera conditioning is present in the batch — those carry batch=1 conditioning tensors that aren't covered by this PR.
Validation
5 prompts, seed-pinned, single-container two-pass A/B (sequential pass + batched pass through the same `VideoGenerator`). Pairwise SSIM via `pytorch_msssim.ssim` (matches `fastvideo/tests/utils.py:compute_video_ssim_torchvision`).
Visual evidence — H100 14B compile, prompt 4 (worst-SSIM prompt)
Frame-by-frame inspection confirms the SSIM 0.68-worst delta in compile mode is bf16 / Inductor kernel-drift noise, not semantic divergence.
Baseline (sequential CFG):
h100_baseline_p4.mp4
Patched (batched CFG):
h100_patched_p4.mp4
Visual evidence — H100 14B eager, prompt 4 (bit-equivalent reference)
For comparison: SSIM = 1.000000 exactly, no kernel-drift contribution. Demonstrates that the batched-CFG mechanism itself is bit-equivalent to sequential.
Baseline (sequential CFG):
h100_eager_baseline_p4.mp4
Patched (batched CFG):
h100_eager_patched_p4.mp4
Trade-offs
Recipe alignment changes Wan/Cosmos output character vs FV's previous variable-length default. The new output matches diffusers' canonical pipeline (= Wan's training distribution). Existing Wan/Cosmos SSIM regression test reference videos in CI may need regeneration if the alignment shifts pixel values beyond the SSIM gate. HunyuanVideo and other DiTs are unaffected.
Compile path shows SSIM 0.85 mean / 0.68 worst on Wan 14B H100. The cause is Inductor picking different fused kernels for batch=1 vs batch=2 paths (not a bug in the batched-CFG mechanism itself, which is bit-equivalent on the eager path). Frame-by-frame visual inspection of the worst-SSIM prompt confirmed the videos are perceptually identical. This is consistent with bf16 numerics drift compounding across 30 layers × 30 steps, well within the perceptual-equivalence range.
L40S 1.3B shows +2% wall regression. The diffusers-canonical recipe pads every cross-attention K/V to `max_length=512` (vs FV's previous variable-length 55-98 tokens). On a small model, this per-step attention overhead exceeds the batched-CFG launch savings. The win shows up at larger models / faster GPUs / more amortizable launch overhead — which is exactly what we see on H100 14B (-1.96% eager, -1.78% compile).
Test plan
Reproduce these numbers yourself
All A/B runs above are reproducible end-to-end via the included Modal harness. ~$10 total compute for the full matrix.
One-time setup:
```bash
pip install modal && modal token new
modal secret create huggingface-token HF_TOKEN=hf_xxxxxxxxxxxx
```
Run the same A/B legs as the validation table:
```bash
H100 14B eager (the SSIM=1.000000 / -1.96% wall result)
modal run --detach fastvideo/tests/modal/batched_cfg_ab.py \
--gpu H100 --model wan2_2-t2v-14b
H100 14B compile (deployment perf number)
modal run --detach fastvideo/tests/modal/batched_cfg_ab.py \
--gpu H100 --model wan2_2-t2v-14b --enable-compile
L40S 1.3B eager (correctness validation)
modal run --detach fastvideo/tests/modal/batched_cfg_ab.py \
--gpu L40S --model wan2_1-t2v-1.3b
```
Each leg writes mp4s to a per-config directory in the `hf-model-weights` Volume (`/root/data/ab_out//<mode_tag>/{baseline,patched}/prompt_NN/`). The harness prints the wall + SSIM table at the end of each run. If your terminal disconnects, recover via:
```bash
modal run --detach fastvideo/tests/modal/batched_cfg_ab.py::recover_ssim \
--git-repo https://github.com/Mister-Raggs/FastVideo.git \
--git-ref perf/wan-batched-cfg \
--model-preset wan2_2-t2v-14b --mode-tag eager
```
Use the harness for your own PRs. The single-container two-pass A/B pattern is generic (modeled on Kuan's `nccl_stream_ab.py` for #1395) — fork the harness, swap the `MODEL_PRESETS` entry and the flag toggle, and you have a reproducible perf+SSIM gate for any single-flag PR.
Follow-ups (not blocking this PR)
Related work
Composes with #1372 Adaptive Guidance: when AG runs both passes, batched=2 fires; when AG skips uncond, batched-CFG gracefully falls through to sequential.
Mirrors the upstream pattern in `fastvideo/pipelines/stages/longcat_denoising.py:97-149` which already does batched-CFG natively (LongCat's tokenizer is configured with `padding="max_length"` upstream so the same shape contract holds).