Skip to content

[perf]: batched CFG for Wan/Cosmos + diffusers T5 recipe alignment#1416

Open
Mister-Raggs wants to merge 4 commits into
hao-ai-lab:mainfrom
Mister-Raggs:perf/wan-batched-cfg
Open

[perf]: batched CFG for Wan/Cosmos + diffusers T5 recipe alignment#1416
Mister-Raggs wants to merge 4 commits into
hao-ai-lab:mainfrom
Mister-Raggs:perf/wan-batched-cfg

Conversation

@Mister-Raggs

@Mister-Raggs Mister-Raggs commented May 30, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds an opt-in batched classifier-free-guidance path in the main DenoisingStage: when CFG is on, the cond and uncond passes run as a single batch=2 DiT forward per denoise step instead of two sequential batch=1 forwards. Bit-equivalent to sequential on H100 FA3 eager (SSIM = 1.000000).

Adds the prerequisite: align Wan / Cosmos text encoding to diffusers' canonical recipe (tokenize(padding="max_length") → trim T5 output to real length → re-pad with explicit zeros to max_length). HunyuanVideo and all other DiTs are unaffected — diffusers uses a different recipe for those, so they're gated off.

Default flag value is True; sequential-CFG fallback preserved for V2V/I2V/TI2V/action/camera paths (auto-detected).

Problem

Two issues, addressed together because they're coupled:

  1. Sequential CFG wastes launch overhead and prevents kernel-level batching. Every denoise step runs two current_model(...) calls — once for the conditional prompt, once for the unconditional. On modern GPUs, batch=2 forward is typically 1.4-1.6× the wall of batch=1 (vs 2× sequential), so the mechanism leaves perf on the table.

  2. FV's Wan/Cosmos T5 output deviates from their canonical diffusers pipeline. Diffusers' WanPipeline._get_t5_prompt_embeds and CosmosTextToWorldPipeline._get_t5_prompt_embeds both tokenize with `padding="max_length"`, then trim the T5 output to each sample's real length, then re-pad with explicit zeros to `max_length`. That's the recipe these models were trained against; FV's variable-length path is a deviation. Aligning makes the model's input match training distribution, AND gives the matched-shape contract that batched-CFG needs to cat pos + neg along the batch dim.

Solution

`FastVideoArgs`: new `use_batched_cfg: bool = True` field with `--use-batched-cfg` CLI flag. Added to `_FROM_PRETRAINED_CONVENIENCE_KWARGS`.

`TextEncodingStage`: when CFG is on AND the active `dit_config` MRO contains `WanVideoConfig`, `CosmosVideoConfig`, or `Cosmos25VideoConfig`, applies the diffusers-canonical trim+pad-with-zeros recipe to both pos and neg encodings. Other DiTs (HunyuanVideo, etc.) bypass the gate entirely — their behaviour is unchanged.

`DenoisingStage.forward` (main one): new gated branch when `use_batched_cfg=True` and CFG is on. Cats `[neg, pos]` along the batch dim, single forward at batch=2, `chunk(2)` → existing CFG combine + `guidance_rescale` math reused unchanged. Defensive shape-mismatch fallback to sequential if the batch was constructed without going through `TextEncodingStage`.

Autodetect-off when V2V/I2V/TI2V/action/camera conditioning is present in the batch — those carry batch=1 conditioning tensors that aren't covered by this PR.

Validation

5 prompts, seed-pinned, single-container two-pass A/B (sequential pass + batched pass through the same `VideoGenerator`). Pairwise SSIM via `pytorch_msssim.ssim` (matches `fastvideo/tests/utils.py:compute_video_ssim_torchvision`).

Config Wall delta SSIM mean SSIM worst Note
H100 14B eager (FA3, SP=2) -1.96% 1.000000 1.000000 bit-equivalent — clean mechanism proof
H100 14B compile (FA3, SP=2) -1.78% 0.85 0.68 compile-induced kernel drift; videos visually identical (verified by frame inspection)
L40S 1.3B eager (FA2) +2.0% 0.96 0.95 small model + FA2 batched-numerics noise + recipe-alignment cost > batched savings on this scale

Visual evidence — H100 14B compile, prompt 4 (worst-SSIM prompt)

Frame-by-frame inspection confirms the SSIM 0.68-worst delta in compile mode is bf16 / Inductor kernel-drift noise, not semantic divergence.

Baseline (sequential CFG):

h100_baseline_p4.mp4

Patched (batched CFG):

h100_patched_p4.mp4

Visual evidence — H100 14B eager, prompt 4 (bit-equivalent reference)

For comparison: SSIM = 1.000000 exactly, no kernel-drift contribution. Demonstrates that the batched-CFG mechanism itself is bit-equivalent to sequential.

Baseline (sequential CFG):

h100_eager_baseline_p4.mp4

Patched (batched CFG):

h100_eager_patched_p4.mp4

Trade-offs

Recipe alignment changes Wan/Cosmos output character vs FV's previous variable-length default. The new output matches diffusers' canonical pipeline (= Wan's training distribution). Existing Wan/Cosmos SSIM regression test reference videos in CI may need regeneration if the alignment shifts pixel values beyond the SSIM gate. HunyuanVideo and other DiTs are unaffected.

Compile path shows SSIM 0.85 mean / 0.68 worst on Wan 14B H100. The cause is Inductor picking different fused kernels for batch=1 vs batch=2 paths (not a bug in the batched-CFG mechanism itself, which is bit-equivalent on the eager path). Frame-by-frame visual inspection of the worst-SSIM prompt confirmed the videos are perceptually identical. This is consistent with bf16 numerics drift compounding across 30 layers × 30 steps, well within the perceptual-equivalence range.

L40S 1.3B shows +2% wall regression. The diffusers-canonical recipe pads every cross-attention K/V to `max_length=512` (vs FV's previous variable-length 55-98 tokens). On a small model, this per-step attention overhead exceeds the batched-CFG launch savings. The win shows up at larger models / faster GPUs / more amortizable launch overhead — which is exactly what we see on H100 14B (-1.96% eager, -1.78% compile).

Test plan

  • Unit test: `fastvideo/tests/stages/test_denoising_batched_cfg.py` — stub-transformer equivalence, CFG-off equivalence, autodetect-off on 7 conditioning fields.
  • Math unit test: `fastvideo/tests/stages/test_wan_cross_attn_mask_math.py` — kept in tree as forward-coverage for a model-side-mask approach a future contributor might attempt for HunyuanVideo. Passes 10/11 cases (fp32 exact + bf16 within 1 ULP).
  • Modal A/B harness: `fastvideo/tests/modal/batched_cfg_ab.py` + `_batched_cfg_ab_inner.py`. Single-container two-pass A/B; subprocesses the inner script via `/opt/venv/bin/python` so FastVideo runs in the image's venv. Reusable for future single-flag perf PRs.
  • CI SSIM regression suite: Wan/Cosmos may need reference-video regeneration; HunyuanVideo expected unchanged. Will iterate on review.

Reproduce these numbers yourself

All A/B runs above are reproducible end-to-end via the included Modal harness. ~$10 total compute for the full matrix.

One-time setup:
```bash
pip install modal && modal token new
modal secret create huggingface-token HF_TOKEN=hf_xxxxxxxxxxxx
```

Run the same A/B legs as the validation table:
```bash

H100 14B eager (the SSIM=1.000000 / -1.96% wall result)

modal run --detach fastvideo/tests/modal/batched_cfg_ab.py \
--gpu H100 --model wan2_2-t2v-14b

H100 14B compile (deployment perf number)

modal run --detach fastvideo/tests/modal/batched_cfg_ab.py \
--gpu H100 --model wan2_2-t2v-14b --enable-compile

L40S 1.3B eager (correctness validation)

modal run --detach fastvideo/tests/modal/batched_cfg_ab.py \
--gpu L40S --model wan2_1-t2v-1.3b
```

Each leg writes mp4s to a per-config directory in the `hf-model-weights` Volume (`/root/data/ab_out//<mode_tag>/{baseline,patched}/prompt_NN/`). The harness prints the wall + SSIM table at the end of each run. If your terminal disconnects, recover via:

```bash
modal run --detach fastvideo/tests/modal/batched_cfg_ab.py::recover_ssim \
--git-repo https://github.com/Mister-Raggs/FastVideo.git \
--git-ref perf/wan-batched-cfg \
--model-preset wan2_2-t2v-14b --mode-tag eager
```

Use the harness for your own PRs. The single-container two-pass A/B pattern is generic (modeled on Kuan's `nccl_stream_ab.py` for #1395) — fork the harness, swap the `MODEL_PRESETS` entry and the flag toggle, and you have a reproducible perf+SSIM gate for any single-flag PR.

Follow-ups (not blocking this PR)

  • HunyuanVideo batched-CFG: different recipe required (encoder output used as-is per diffusers). The model-side SDPA + bool attn_mask path explored and reverted in this PR is actually the right tool for HV; the local unit test left in tree provides a head start.
  • Padding-length optimization: pad to `max(pos_real, neg_real)` instead of fixed `max_length=512`. Could recover most of the per-cross-attn cost the recipe alignment introduces; trade-off is deviation from training distribution. Worth measuring SSIM at variable padding.
  • Multi-prompt batching: the batched=2 framework here extends naturally to batched=N for serving workloads.
  • Recipe-alignment quality measurement: worth comparing (a) FV current variable-length Wan vs (b) this PR's recipe-aligned Wan vs (c) diffusers reference. If alignment improves output character independently of batched-CFG, that's a separately-citable quality win.

Related work

Composes with #1372 Adaptive Guidance: when AG runs both passes, batched=2 fires; when AG skips uncond, batched-CFG gracefully falls through to sequential.

Mirrors the upstream pattern in `fastvideo/pipelines/stages/longcat_denoising.py:97-149` which already does batched-CFG natively (LongCat's tokenizer is configured with `padding="max_length"` upstream so the same shape contract holds).

Copilot AI review requested due to automatic review settings May 30, 2026 15:20
@mergify mergify Bot added type: perf Performance improvement scope: inference Inference pipeline, serving, CLI scope: infra CI, tests, Docker, build labels May 30, 2026
@mergify

mergify Bot commented May 30, 2026

Copy link
Copy Markdown
Contributor

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 PR merge requirements

Waiting for

  • #approved-reviews-by>=1
  • check-success=fastcheck-passed
  • check-success=full-suite-passed
  • check-success~=pre-commit
This rule is failing.
  • #approved-reviews-by>=1
  • check-success=fastcheck-passed
  • check-success=full-suite-passed
  • check-success~=pre-commit
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model|skill|skills|infra)\]

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds an opt-in (default-on) "batched CFG" path to DenoisingStage that runs the conditional and unconditional model passes in a single batch=2 forward, plus the upstream text-encoding padding work needed to make pos/neg shapes line up. Also adds unit tests and a Modal-based A/B benchmarking harness.

Changes:

  • New use_batched_cfg flag (default True) plumbed through FastVideoArgs, VideoGenerator, and DenoisingStage, with auto-fallback to sequential CFG when V2V/I2V/action/camera conditioning is present or pos/neg shapes mismatch.
  • TextEncodingStage now forces padding="max_length" and applies a diffusers-style trim+zero-pad to pos/neg embeddings for Wan/Cosmos DiTs when CFG is on, so embeddings can be concatenated along batch.
  • New tests (test_denoising_batched_cfg.py, test_wan_cross_attn_mask_math.py) and Modal A/B harness (batched_cfg_ab.py, _batched_cfg_ab_inner.py).

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
fastvideo/fastvideo_args.py Adds use_batched_cfg config field and CLI flag.
fastvideo/entrypoints/video_generator.py Allows use_batched_cfg to be passed via VideoGenerator.
fastvideo/pipelines/stages/denoising.py Implements batched-CFG forward path with auto-disable + shape-mismatch fallback; honors dit_precision.
fastvideo/pipelines/stages/text_encoding.py Forces max-length tokenization and trim+zero-pad of pos/neg embeddings for Wan/Cosmos when CFG is on.
fastvideo/tests/stages/test_denoising_batched_cfg.py Equivalence + autodetect tests for batched vs sequential CFG.
fastvideo/tests/stages/test_wan_cross_attn_mask_math.py Math-only test confirming SDPA + length mask matches unpadded attention.
fastvideo/tests/modal/batched_cfg_ab.py Modal A/B harness driving baseline vs patched runs and pairwise SSIM.
fastvideo/tests/modal/_batched_cfg_ab_inner.py Inner script: runs a single pass or recomputes SSIM from mp4s.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +73 to +75
_zero_pad_recipe_dits = {"WanVideoConfig", "CosmosVideoConfig", "Cosmos25VideoConfig"}
_dit_cfg_mro_names = {cls.__name__ for cls in type(fastvideo_args.pipeline_config.dit_config).__mro__}
_use_canonical_recipe = bool(_zero_pad_recipe_dits & _dit_cfg_mro_names)
Comment on lines +123 to +129
if not _use_canonical_recipe:
for ne in neg_embeds_list:
batch.negative_prompt_embeds.append(ne)
if batch.negative_attention_mask is not None:
for nm in neg_masks_list:
batch.negative_attention_mask.append(nm)
return batch
Comment on lines +492 to +503
noise_pred_text = noise_pred
noise_pred = noise_pred_uncond + current_guidance_scale * (noise_pred_text -
noise_pred_uncond)

# Apply guidance rescale if needed (CFG-only path)
if batch.do_classifier_free_guidance and batch.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = self.rescale_noise_cfg(
noise_pred,
noise_pred_text,
guidance_rescale=batch.guidance_rescale,
)
Comment on lines +104 to +105
label = "batched" if use_batched_cfg else "sequential"
print(f"[inner] [{label}] prompt {i}: {wall:.3f}s -> {mp4s[0]}", flush=True)
Comment thread fastvideo/pipelines/stages/denoising.py Outdated
Comment on lines +105 to +107
from fastvideo.utils import PRECISION_TO_TYPE as _PRECISION_TO_TYPE
_dit_precision = getattr(fastvideo_args.pipeline_config, "dit_precision", "bf16")
target_dtype = _PRECISION_TO_TYPE.get(_dit_precision, torch.bfloat16)
Comment on lines +241 to +244
_cfg_conditioning_present = (batch.video_latent is not None or batch.image_latent is not None
or batch.pil_image is not None or len(batch.image_embeds) > 0
or batch.mouse_cond is not None or batch.keyboard_cond is not None
or batch.c2ws_plucker_emb is not None or batch.camera_states is not None)
Comment on lines +167 to +170
@pytest.mark.parametrize("disabling_field, value", [
("video_latent", torch.zeros(1, 4, 2, 4, 4)),
("image_latent", torch.zeros(1, 4, 2, 4, 4)),
("image_embeds", [torch.zeros(1, 4)]),
Comment on lines +45 to +46
# prebuild-wheels. ~30s install vs ~90min cold source build; sidesteps
# Kuan's #1389 Volume-cache pattern entirely on the inference path.
Comment on lines +296 to +304

combined_cond_kwargs = self.prepare_extra_func_kwargs(
self.transformer.forward,
{
"encoder_hidden_states_2": _cat_list_or_none(batch.clip_embedding_neg, batch.clip_embedding_pos),
"encoder_attention_mask": _cat_list_or_none(batch.negative_attention_mask,
batch.prompt_attention_mask),
},
)

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces batched classifier-free guidance (CFG) to run cond and uncond passes as a single batch=2 DiT forward per denoise step, reducing overhead. It updates the denoising and text encoding stages to handle batched CFG precomputation, padding, and shape alignment, while falling back to sequential CFG when necessary. It also adds an A/B testing harness and unit/equivalence tests. The review feedback suggests catching a broader Exception during video decoding in the test harness to ensure fallback reliability, and copying timesteps_r_kwarg instead of re-creating it to avoid discarding other potential keys.

Comment on lines +148 to +153
try:
from torchvision.io import read_video
frames, _, _ = read_video(path, pts_unit="sec", output_format="TCHW")
return frames
except (ImportError, AttributeError):
pass

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching only ImportError and AttributeError might not be sufficient if torchvision.io.read_video fails due to runtime or backend issues (e.g., missing FFmpeg/PyAV shared libraries or decoding errors). Catching a broader Exception ensures that the fallback to the manual av decoding path is always executed successfully.

Suggested change
try:
from torchvision.io import read_video
frames, _, _ = read_video(path, pts_unit="sec", output_format="TCHW")
return frames
except (ImportError, AttributeError):
pass
try:
from torchvision.io import read_video
frames, _, _ = read_video(path, pts_unit="sec", output_format="TCHW")
return frames
except Exception:
pass

Comment on lines +431 to +433
timesteps_r_kwarg_cfg = timesteps_r_kwarg
if "timestep_r" in timesteps_r_kwarg and timesteps_r_kwarg["timestep_r"] is not None:
timesteps_r_kwarg_cfg = {"timestep_r": timesteps_r_kwarg["timestep_r"].repeat(2)}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Reassigning timesteps_r_kwarg_cfg as a new dictionary with only the "timestep_r" key will discard any other keys that might be present in timesteps_r_kwarg (e.g., if a subclass or future update extends prepare_extra_func_kwargs to return additional arguments). Copying the dictionary and updating the key in-place is more robust and future-proof.

Suggested change
timesteps_r_kwarg_cfg = timesteps_r_kwarg
if "timestep_r" in timesteps_r_kwarg and timesteps_r_kwarg["timestep_r"] is not None:
timesteps_r_kwarg_cfg = {"timestep_r": timesteps_r_kwarg["timestep_r"].repeat(2)}
timesteps_r_kwarg_cfg = timesteps_r_kwarg.copy()
if "timestep_r" in timesteps_r_kwarg_cfg and timesteps_r_kwarg_cfg["timestep_r"] is not None:
timesteps_r_kwarg_cfg["timestep_r"] = timesteps_r_kwarg_cfg["timestep_r"].repeat(2)

@mergify

mergify Bot commented May 30, 2026

Copy link
Copy Markdown
Contributor

This PR has merge conflicts with the base branch. Please rebase:

git fetch origin main
git rebase origin/main
# Resolve any conflicts, then:
git push --force-with-lease

@mergify mergify Bot added the needs-rebase PR has merge conflicts label May 30, 2026
@Mister-Raggs Mister-Raggs force-pushed the perf/wan-batched-cfg branch from 31cf91c to 56f284c Compare May 31, 2026 05:54
@mergify mergify Bot added scope: docs Documentation and removed needs-rebase PR has merge conflicts labels May 31, 2026
@Mister-Raggs

Copy link
Copy Markdown
Contributor Author

Rebased onto current main (post-#1372 AG merge) — composition note

Two changes in the force-push:

  1. docs/design/inference_schema_parity_inventory.yaml: added use_batched_cfg: generator.engine.use_batched_cfg under surfaces.fastvideo_args.moved, mirroring disable_autocast. Fixes the test_fastvideo_args_fields_are_classified parity test that flagged the new field.

  2. denoising.py merge with [perf] Add Adaptive Guidance (CFG gating) for stale-uncond reuse #1372 Adaptive Guidance: AG and batched-CFG are now mutually exclusive at the gate:

    use_batched_cfg = (fastvideo_args.use_batched_cfg
                       and batch.do_classifier_free_guidance
                       and not _cfg_conditioning_present
                       and not _cfg_gate_active)

    Rationale: AG selectively skips the uncond forward to save wall (its design); batched-CFG forces both passes in one forward (its design). Running them together would defeat AG's perf win — every step would compute uncond regardless of the gate fraction. When AG is inactive (FASTVIDEO_CFG_GATE_STEP=1.0, the default), batched-CFG fires normally. When AG is enabled, batched-CFG falls back to sequential and AG's delta-cache reuse runs unchanged.

    The PR body's "composes with [perf] Add Adaptive Guidance (CFG gating) for stale-uncond reuse #1372" line is now precise: they compose by deferring to whichever is active, not by stacking. If a future caller wants both wins, that's a separate scope (the batched path could maintain its own delta cache across batched-2 forwards — out of scope here).

History stays 4 logical commits — fixup squashed into commit 1 via --autosquash, conflict resolved cleanly in commit 1.

@Mister-Raggs Mister-Raggs force-pushed the perf/wan-batched-cfg branch 2 times, most recently from cc63852 to 6cad9f8 Compare May 31, 2026 06:24
@rich7420

rich7420 commented May 31, 2026

Copy link
Copy Markdown
Contributor

@Mister-Raggs thanks for the patch!
I've tried some config in modal

leg wall SSIM worst
H100 14B eager (FA3) -1.46% 1.0 ✅
L40S 1.3B eager (FA2) +1.95% 0.95
H100 14B compile (FA3) +0.13% 0.52

So bit-equiv is real but FA3+eager only — FA2 already drifts at 0.96, and compile (what we actually deploy) gave me ~0 speedup + 0.52 worst. The win's basically just that one H100-eager cell.

Dug into why: Wan cross-attn has no text masking — context_lens is dead, hardcoded None (wanvideo.py:383/533), and the encoder_attention_mask you cat gets dropped by the sig filter. So the recipe's load-bearing (no mask → need equal-length pad), not gratuitous. Fair enough.

Couple nits:

  • default False — it changes output for everyone who isn't on H100-eager, and "output-identical" only holds there
  • recipe gate doesn't check use_batched_cfg, so --use-batched-cfg=False doesn't actually restore old Wan output. one-liner: and fastvideo_args.use_batched_cfg on the _cfg_padding gate
  • drop the target_dtypedit_precision bit (unrelated, hits stable_audio fp16), the ltx2_3_base test (rebase drift), and the reverted mask-math test
  • heads up: with_options blows up on modal 1.3.5, had to patch the harness to run it

If you wanna go default-on later, the real fix is wiring that dead context_lens into the mask — then you pad to max(real), get true bit-equiv with zero output change, and the whole text_encoding recipe disappears. Your mask-math test's a head start there.

Run cond + uncond as a single batch=2 DiT forward per denoise step
instead of two sequential batch=1 forwards. Default OFF so this PR
has zero behaviour change for any user who doesn't opt in via
`use_batched_cfg=True`. Composes with hao-ai-lab#1372 Adaptive Guidance —
mutually exclusive at the gate level (AG selectively skips uncond,
batched-CFG forces both; running them together defeats AG's win).

Bit-equivalent to sequential CFG on H100 FA3 eager (SSIM=1.000000,
Wan 14B 720x1280x49f/30steps, 5 prompts). On other configs (compile,
FA2, smaller models) bf16 numerics drift slightly (~0.04 SSIM mean,
visually imperceptible per frame-by-frame inspection) due to
Inductor kernel selection and batched flash-attn numerics. Perf
delta is run-to-run variable.

Changes:
- fastvideo_args.py: new use_batched_cfg: bool = False field +
  --use-batched-cfg CLI. Auto-fallback to sequential when V2V/I2V/
  TI2V/action/camera conditioning is present (those carry batch=1
  conditioning tensors out of scope for this PR) OR when AG is
  active.
- entrypoints/video_generator.py: add use_batched_cfg to
  _FROM_PRETRAINED_CONVENIENCE_KWARGS.
- pipelines/stages/denoising.py: gated batched branch in main
  DenoisingStage.forward. Cats [neg, pos] along batch dim with
  shape-match defensive fallback, single forward, chunk(2),
  existing CFG-combine + guidance_rescale math reused. Other
  DenoisingStage subclasses (Cosmos25, Dmd, ...) unchanged.
- api/schema.py + api/compat.py: EngineConfig.use_batched_cfg
  field + legacy<->typed mappings (mirrors disable_autocast).
- docs/design/inference_schema_parity_inventory.yaml: inventory
  entry under fastvideo_args.moved.
- tests/api/test_parser.py: YAML-roundtrip expected dict updated.

Sequential CFG path preserved bit-for-bit for non-batched callers.
Other DiTs (HunyuanVideo, LongCat, ...) unaffected.
…ated on use_batched_cfg)

Diffusers' WanPipeline._get_t5_prompt_embeds (pipeline_wan.py:173-190)
and CosmosTextToWorldPipeline._get_t5_prompt_embeds (pipeline_cosmos_
text2world.py:197-237) both:
  (a) tokenize(padding="max_length", max_length=N)
  (b) T5 encoder forward (sees padded input)
  (c) trim T5 output to per-sample real length
  (d) re-pad with EXPLICIT ZEROS to max_length

Wan and Cosmos were trained against this recipe. FV's current
variable-length T5 output deviates. This change replicates the
canonical recipe — but only when use_batched_cfg=True, so users who
don't opt in see zero behaviour change.

HunyuanVideo's diffusers pipeline (pipeline_hunyuan_video.py
_get_llama_prompt_embeds) does NOT do (c)+(d) — uses encoder output
as-is, padded positions retain T5's natural bias-driven values.
Gated on dit_config MRO matching {WanVideoConfig, CosmosVideoConfig,
Cosmos25VideoConfig} so HunyuanVideo and every other DiT bypass the
recipe entirely.

Long-term: the right fix is to wire context_lens through to Wan's
attention layer so it can mask padding directly (rather than relying
on training-distribution-zero positions). That would let batched-CFG
pad to max(real_lens) instead of max_length and become bit-equivalent
across more configs. Out of scope for this PR; tracked as a follow-up
(see PR body and ground_truth W3e).
Stub-transformer equivalence between batched and sequential paths,
CFG-off equivalence, autodetect-off on 7 conditioning fields
(V2V, I2V, TI2V, action, camera, image embeds, etc.).
Single-container two-pass A/B harness modeled on Kuan's
nccl_stream_ab.py pattern (used to validate hao-ai-lab#1395). Runs the same
seed-pinned prompts twice through the same VideoGenerator (once
sequential, once batched), computes pairwise SSIM via pytorch_msssim
(matches fastvideo/tests/utils.py:compute_video_ssim_torchvision),
prints a per-prompt + total wall/SSIM table.

Architecture: the harness ferries pass config + records over JSON
files and subprocesses the inner script via /opt/venv/bin/python.
Modal's main function process runs in its own add_python="3.12"
layer where FastVideo isn't importable; the venv subprocess works
around that. Same reason ssim_test.py runs pytest as a subprocess.

Features:
- Mode-tagged output dirs ({eager,compile}) so multiple legs don't
  clobber each other in the hf-model-weights Volume.
- Recovery function reads existing mp4s by mtime and recomputes SSIM
  for post-run gap-fill if the in-run SSIM step missed.
- Prebuilt FA3 wheel install on Hopper (autodetected from gpu kwarg)
  via mjun0812/flash-attention-prebuild-wheels v0.9.4 — the same
  release that supplies the FA2 wheel already baked into the
  fastvideo-dev image. ~30s install vs Kuan's hao-ai-lab#1389 ~90min cold
  source build.
- Modal Secret integration for HF token (no token-on-the-wire).
- --enable-compile + --num-prompts flags for flexible legs.

Reproducible by reviewers — see PR body for exact commands.
@Mister-Raggs Mister-Raggs force-pushed the perf/wan-batched-cfg branch from 6cad9f8 to 8512734 Compare June 1, 2026 04:57
@mergify

mergify Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

This PR has merge conflicts with the base branch. Please rebase:

git fetch origin main
git rebase origin/main
# Resolve any conflicts, then:
git push --force-with-lease

@mergify mergify Bot added the needs-rebase PR has merge conflicts label Jun 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase PR has merge conflicts scope: docs Documentation scope: inference Inference pipeline, serving, CLI scope: infra CI, tests, Docker, build type: perf Performance improvement

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants