Skip to content

[perf] matrixgame2 causal: cache per-forward RoPE/timestep/blockmask recompute on-device#1415

Open
rich7420 wants to merge 2 commits into
hao-ai-lab:mainfrom
rich7420:matrixgame2-cache-per-forward-recompute
Open

[perf] matrixgame2 causal: cache per-forward RoPE/timestep/blockmask recompute on-device#1415
rich7420 wants to merge 2 commits into
hao-ai-lab:mainfrom
rich7420:matrixgame2-cache-per-forward-recompute

Conversation

@rich7420

@rich7420 rich7420 commented May 30, 2026

Copy link
Copy Markdown
Contributor

Summary

matrixgame2 causal inference repeated constant work on every DiT forward.
This removes it:

  • Four flex-attention BlockMasks were rebuilt via create_block_mask each
    forward (compile + a GPU-tensor python loop). They depend only on
    (num_frames, frame_seqlen, block size, local_attn_size, device), so they
    are cached per param-tuple (device included in the key).
  • get_rotary_pos_embed ran a float64 CPU compute + H2D copy each forward to
    build freqs_cis, which CausalMatrixGame2SelfAttention.forward never reads
    (it computes its own RoPE via self._freqs_cache). Now skipped entirely
    (freqs_cis = None).
  • the sinusoidal timestep table is built on the target device instead of
    CPU + H2D.
  • the action-module RoPE freqs are cached device-resident (with a device
    check) so the per-call .to is a no-op.

All numerically identical.

How it was found (nsys-ai)

On an H100 matrixgame2 trace, gpu_idle_gaps plus the memcpy breakdown showed
heavy per-forward host→device traffic and repeated create_block_mask work; a
CUDA-graph capture probe pinpointed the exact per-forward recompute sites.

Result (H100, 117 frames = 30 DiT forwards)

  • isolated DiT forward: 194 ms → 95 ms (~2×)
  • end-to-end inference wall: 12.6 s → 10.6 s (−16%)
  • denoise-stage GPU-active share: 23% → 33.5%

The biggest single contributor is the BlockMask cache — create_block_mask is
expensive and was running once per forward.

@mergify mergify Bot added type: perf Performance improvement scope: model Model architecture (DiTs, encoders, VAEs) labels May 30, 2026
@mergify

mergify Bot commented May 30, 2026

Copy link
Copy Markdown
Contributor

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 PR merge requirements

Waiting for

  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
This rule is failing.
  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
  • check-success=fastcheck-passed
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model|skill|skills|infra)\]

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes performance and enables CUDA-graph capture by caching frequency tables, RoPE tables, and block masks directly on the target device, avoiding expensive CPU-to-GPU copies and redundant computations. The review feedback highlights three key improvement opportunities: first, the freqs_cis computation in causal_model.py is completely unused and can be bypassed entirely; second, the block mask cache keys should include the device to prevent runtime mismatches if the model is moved; and third, the lazy initialization in action_module.py should check the device of the cached tensors to ensure they remain on the correct device.

Comment on lines 980 to 1010
# RoPE tables depend only on (grid, start_frame); both are constant
# within an AR block's DMD steps. Recomputing them on CPU in float64
# and copying H2D every forward is wasteful AND breaks CUDA-graph
# capture (H2D copy is illegal mid-capture). Cache the device tensors
# keyed by (start_frame, grid) so the H2D happens once per block.
_rope_key = (start_frame, post_patch_num_frames, post_patch_height,
post_patch_width)
_rope_cache = getattr(self, "_rope_device_cache", None)
if _rope_cache is None:
_rope_cache = {}
self._rope_device_cache = _rope_cache
if _rope_key in _rope_cache:
freqs_cos, freqs_sin = _rope_cache[_rope_key]
else:
freqs_cos, freqs_sin = get_rotary_pos_embed(
(
post_patch_num_frames * get_sp_world_size(),
post_patch_height,
post_patch_width,
),
self.hidden_size,
self.num_attention_heads,
rope_dim_list,
dtype=torch.float32 if current_platform.is_mps() else torch.float64,
rope_theta=10000,
start_frame=start_frame,
)
freqs_cos = freqs_cos.to(hidden_states.device)
freqs_sin = freqs_sin.to(hidden_states.device)
_rope_cache[_rope_key] = (freqs_cos, freqs_sin)
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The freqs_cis variable (computed via get_rotary_pos_embed) is completely unused in CausalMatrixGame2SelfAttention.forward (which instead uses causal_rope_apply with its own self._freqs_cache).

Instead of caching this unused computation, we can completely avoid calling get_rotary_pos_embed and simply set freqs_cis = None. This eliminates the CPU overhead, GPU memory allocation, and H2D copy entirely without needing any caching logic.

Suggested change
# RoPE tables depend only on (grid, start_frame); both are constant
# within an AR block's DMD steps. Recomputing them on CPU in float64
# and copying H2D every forward is wasteful AND breaks CUDA-graph
# capture (H2D copy is illegal mid-capture). Cache the device tensors
# keyed by (start_frame, grid) so the H2D happens once per block.
_rope_key = (start_frame, post_patch_num_frames, post_patch_height,
post_patch_width)
_rope_cache = getattr(self, "_rope_device_cache", None)
if _rope_cache is None:
_rope_cache = {}
self._rope_device_cache = _rope_cache
if _rope_key in _rope_cache:
freqs_cos, freqs_sin = _rope_cache[_rope_key]
else:
freqs_cos, freqs_sin = get_rotary_pos_embed(
(
post_patch_num_frames * get_sp_world_size(),
post_patch_height,
post_patch_width,
),
self.hidden_size,
self.num_attention_heads,
rope_dim_list,
dtype=torch.float32 if current_platform.is_mps() else torch.float64,
rope_theta=10000,
start_frame=start_frame,
)
freqs_cos = freqs_cos.to(hidden_states.device)
freqs_sin = freqs_sin.to(hidden_states.device)
_rope_cache[_rope_key] = (freqs_cos, freqs_sin)
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
# freqs_cis is completely unused in CausalMatrixGame2SelfAttention.forward,
# so we can avoid the expensive get_rotary_pos_embed computation entirely.
freqs_cis = None

Comment on lines +1061 to +1088
_k = ("main", num_frames, _fsl, _nfb, _las)
if _k not in bm_cache:
bm_cache[_k] = self._prepare_blockwise_causal_attn_mask(
device=_dev, num_frames=num_frames, frame_seqlen=_fsl,
num_frame_per_block=_nfb, local_attn_size=_las)
block_mask = bm_cache[_k]

if self.use_rope_keyboard:
block_mask_keyboard = self._prepare_blockwise_causal_attn_mask_action(
device=hidden_states.device,
num_frames=num_frames,
frame_seqlen=1,
num_frame_per_block=self.num_frame_per_block,
local_attn_size=self.local_attn_size,
)
_k = ("act_kb", num_frames, 1, _nfb, _las)
if _k not in bm_cache:
bm_cache[_k] = self._prepare_blockwise_causal_attn_mask_action(
device=_dev, num_frames=num_frames, frame_seqlen=1,
num_frame_per_block=_nfb, local_attn_size=_las)
block_mask_keyboard = bm_cache[_k]
else:
block_mask_keyboard = self._prepare_blockwise_causal_attn_mask_keyboard(
device=hidden_states.device,
num_frames=num_frames,
frame_seqlen=post_patch_height * post_patch_width,
num_frame_per_block=self.num_frame_per_block,
local_attn_size=self.local_attn_size,
)
block_mask_mouse = self._prepare_blockwise_causal_attn_mask_action(
device=hidden_states.device,
num_frames=num_frames,
frame_seqlen=1,
num_frame_per_block=self.num_frame_per_block,
local_attn_size=self.local_attn_size,
)
_k = ("kb", num_frames, _fsl, _nfb, _las)
if _k not in bm_cache:
bm_cache[_k] = self._prepare_blockwise_causal_attn_mask_keyboard(
device=_dev, num_frames=num_frames, frame_seqlen=_fsl,
num_frame_per_block=_nfb, local_attn_size=_las)
block_mask_keyboard = bm_cache[_k]

_k = ("act_ms", num_frames, 1, _nfb, _las)
if _k not in bm_cache:
bm_cache[_k] = self._prepare_blockwise_causal_attn_mask_action(
device=_dev, num_frames=num_frames, frame_seqlen=1,
num_frame_per_block=_nfb, local_attn_size=_las)
block_mask_mouse = bm_cache[_k]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The cache key _k for bm_cache does not include the device _dev. If the model is moved to a different device (or if inference is run on multiple devices/GPUs), using a cached BlockMask created on a different device will result in a runtime device mismatch error.

To make the cache robust, include _dev in the cache keys.

        _k = ("main", num_frames, _fsl, _nfb, _las, _dev)
        if _k not in bm_cache:
            bm_cache[_k] = self._prepare_blockwise_causal_attn_mask(
                device=_dev, num_frames=num_frames, frame_seqlen=_fsl,
                num_frame_per_block=_nfb, local_attn_size=_las)
        block_mask = bm_cache[_k]

        if self.use_rope_keyboard:
            _k = ("act_kb", num_frames, 1, _nfb, _las, _dev)
            if _k not in bm_cache:
                bm_cache[_k] = self._prepare_blockwise_causal_attn_mask_action(
                    device=_dev, num_frames=num_frames, frame_seqlen=1,
                    num_frame_per_block=_nfb, local_attn_size=_las)
            block_mask_keyboard = bm_cache[_k]
        else:
            _k = ("kb", num_frames, _fsl, _nfb, _las, _dev)
            if _k not in bm_cache:
                bm_cache[_k] = self._prepare_blockwise_causal_attn_mask_keyboard(
                    device=_dev, num_frames=num_frames, frame_seqlen=_fsl,
                    num_frame_per_block=_nfb, local_attn_size=_las)
            block_mask_keyboard = bm_cache[_k]

        _k = ("act_ms", num_frames, 1, _nfb, _las, _dev)
        if _k not in bm_cache:
            bm_cache[_k] = self._prepare_blockwise_causal_attn_mask_action(
                device=_dev, num_frames=num_frames, frame_seqlen=1,
                num_frame_per_block=_nfb, local_attn_size=_las)
        block_mask_mouse = bm_cache[_k]

Comment on lines +778 to +788
if self._freqs_cos is None or self._freqs_sin is None:
self._freqs_cos, self._freqs_sin = self.get_rotary_pos_embed(
_fc, _fs = self.get_rotary_pos_embed(
7500,
self.patch_size[1],
self.patch_size[2],
64,
self.mouse_qk_dim_list,
start_offset=0,
)
self._freqs_cos = _fc.to(x.device)
self._freqs_sin = _fs.to(x.device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If the model is moved to a different device after the first forward pass, self._freqs_cos and self._freqs_sin will remain on the old device. This will cause silent H2D copies on every subsequent forward pass inside _apply_rotary_emb_qk, defeating the caching optimization and breaking CUDA-graph capture.

To prevent this, check if the cached tensors are on the correct device (x.device) before reusing them.

Suggested change
if self._freqs_cos is None or self._freqs_sin is None:
self._freqs_cos, self._freqs_sin = self.get_rotary_pos_embed(
_fc, _fs = self.get_rotary_pos_embed(
7500,
self.patch_size[1],
self.patch_size[2],
64,
self.mouse_qk_dim_list,
start_offset=0,
)
self._freqs_cos = _fc.to(x.device)
self._freqs_sin = _fs.to(x.device)
if self._freqs_cos is None or self._freqs_sin is None or self._freqs_cos.device != x.device:
_fc, _fs = self.get_rotary_pos_embed(
7500,
self.patch_size[1],
self.patch_size[2],
64,
self.mouse_qk_dim_list,
start_offset=0,
)
self._freqs_cos = _fc.to(x.device)
self._freqs_sin = _fs.to(x.device)

@rich7420 rich7420 force-pushed the matrixgame2-cache-per-forward-recompute branch from 4decf50 to ac84d6f Compare May 30, 2026 10:31
@SolitaryThinker

Copy link
Copy Markdown
Collaborator

Hi @rich7420 — automated review from Gob, one of @SolitaryThinker's AI reviewers. Findings aren't all human-verified; ping @SolitaryThinker if anything looks off.

TL;DR

This PR removes repeated matrixgame2 causal inference work by moving timestep frequency creation onto-device, skipping unused RoPE construction, and caching device-specific block masks. I found one must-fix cache invalidation issue: the block-mask cache is built from the model default block size even when the caller overrides the current block size, which can silently apply the wrong attention mask for boundary or variable-sized blocks.

Verdict: request-changes

  • S0 (blockers): 0
  • S1 (must-fix): 1
  • S2 (should-fix; surfaced if persistent or important): 0
  • S3 (discussion): not shown here; see review.md

Findings (formatted for upload)

[S1] Block-mask cache ignores the per-call block size override

What: In CausalMatrixGame2WanModel._forward_inference, the cached BlockMask key and builder use _nfb = self.num_frame_per_block at fastvideo/models/dits/matrixgame2/causal_model.py:1032, but the same forward computes effective_num_frame_per_block from kwargs at causal_model.py:902 and passes that effective value into each block at causal_model.py:1104. The MatrixGame2 denoising stage explicitly passes the current block size as num_frame_per_block = num_frames at fastvideo/pipelines/stages/matrixgame2_denoising.py:257-258, and boundary mode can set the first block to a different size.

Why it matters: The flex-attention masks depend on the actual block size used for the current forward. If the first boundary block or any future variable-sized block runs with effective_num_frame_per_block != self.num_frame_per_block, the cache will reuse/build a mask for the wrong attention pattern while the action and KV-cache logic use the effective block size, producing silent wrong attention rather than a crash.

Suggested fix: Use effective_num_frame_per_block for _nfb when constructing all three inference block-mask cache keys and when calling the _prepare_blockwise_* helpers. Add a small parity/unit probe that calls _forward_inference twice with the same num_frames but different num_frame_per_block overrides and asserts separate cache entries or masks.

Evidence: fastvideo/models/dits/matrixgame2/causal_model.py:902, fastvideo/models/dits/matrixgame2/causal_model.py:1032, fastvideo/models/dits/matrixgame2/causal_model.py:1035-1062, fastvideo/models/dits/matrixgame2/causal_model.py:1104, fastvideo/pipelines/stages/matrixgame2_denoising.py:257-258


— Gob (@SolitaryThinker's AI reviewer). Full review (including S3 items) is archived locally.

Mister-Raggs added a commit to Mister-Raggs/FastVideo that referenced this pull request Jun 4, 2026
Wan's DiT forward rebuilt rotary embeddings on every call via a
float64 CPU compute (`get_rotary_pos_embed`) + H2D copy of the cos/sin
pair. The inputs — post-patch shape + compute device — are constant
across every step of a single generation, so the precompute happens
once per (shape, device) and is reused for the rest of the denoise
loop (180+ forwards for a 50-step CFG run).

Mirrors the matrixgame2 fix in hao-ai-lab#1415: same hotspot, different model,
same cache-on-self pattern. Composes with hao-ai-lab#1245's fused RoPE Triton
kernel (that PR optimizes each `_apply_rotary_emb` call site; this
one removes the recurring per-forward precompute upstream of it).

Math is identical — first call computes, subsequent calls memoize.
Bit-exact equivalence with the prior path expected.
Mister-Raggs added a commit to Mister-Raggs/FastVideo that referenced this pull request Jun 5, 2026
TimestepEmbedder rebuilt the sinusoidal frequency table on every forward
via a CPU `arange` + `exp` followed by an H2D copy. The table depends only
on (frequency_embedding_size, max_period, freq_dtype) — all instance-
constant — and the compute device, so it is identical across every
denoising step. Cache it on the module per device; only the timestep-
dependent product (`args = t * freqs`) runs each call.

`timestep_embedding` gains an optional precomputed `freqs` argument so the
free function's behavior is unchanged for any direct caller. Same cache-on-
self pattern as the Wan rotary precompute in this branch and matrixgame2
PR hao-ai-lab#1415. TimestepEmbedder is shared across the DiT stack (Wan,
HunyuanVideo/HV15, hunyuangamecraft, longcat, matrixgame2/3, hyworld), so
every model that uses it drops the per-forward recompute + H2D.

Bit-exact — the cached tensor is the same value the prior path recomputed
each call.
Mister-Raggs added a commit to Mister-Raggs/FastVideo that referenced this pull request Jun 5, 2026
TimestepEmbedder rebuilt the sinusoidal frequency table on every forward
via a CPU `arange` + `exp` followed by an H2D copy. The table depends only
on (frequency_embedding_size, max_period, freq_dtype) — all instance-
constant — and the compute device, so it is identical across every
denoising step. Cache it on the module per device; only the timestep-
dependent product (`args = t * freqs`) runs each call.

`timestep_embedding` gains an optional precomputed `freqs` argument so the
free function's behavior is unchanged for any direct caller. Same cache-on-
self pattern as the Wan rotary precompute in this branch and matrixgame2
PR hao-ai-lab#1415. TimestepEmbedder is shared across the DiT stack (Wan,
HunyuanVideo/HV15, hunyuangamecraft, longcat, matrixgame2/3, hyworld), so
every model that uses it drops the per-forward recompute + H2D.

Bit-exact — the cached tensor is the same value the prior path recomputed
each call.
…, skip unused RoPE, on-device timestep)

matrixgame2 causal inference repeated constant work on every DiT forward:

- 4 flex-attention BlockMasks were rebuilt via create_block_mask each forward
  (compile + a GPU-tensor python loop). They depend only on
  (num_frames, frame_seqlen, block size, local_attn_size, device), so cache
  them per param-tuple (device included in the key).
- get_rotary_pos_embed ran a float64 CPU compute + H2D copy each forward to
  build freqs_cis — which CausalMatrixGame2SelfAttention.forward never reads
  (it computes its own RoPE via self._freqs_cache). Set freqs_cis = None and
  skip the computation entirely.
- the sinusoidal timestep table was built on CPU then copied H2D; build it on
  the target device instead.
- the action-module RoPE freqs were cached as CPU tensors, forcing an H2D copy
  every forward inside _apply_rotary_emb_qk; cache them device-resident (with a
  device check) so the per-call .to is a no-op.

All numerically identical. Removes per-forward CPU compute + H2D copies that
also blocked CUDA-graph capture.

Measured (H100, 117 frames = 30 DiT forwards): isolated DiT forward
194 ms -> 95 ms; end-to-end inference 12.6 s -> 10.6 s (-16%); denoise-stage
GPU-active 23% -> 33.5%.
@rich7420 rich7420 force-pushed the matrixgame2-cache-per-forward-recompute branch 2 times, most recently from 550bb34 to 471991a Compare June 5, 2026 16:49
…ock size

The cached BlockMask was keyed/built from self.num_frame_per_block, but the
transformer blocks run with effective_num_frame_per_block (the per-call
num_frame_per_block override). Align the mask with the block size the blocks
actually use; it is part of the cache key so different sizes get distinct
entries. Verified on H100: the size-1 and size-3 masks genuinely differ, so
any forward whose effective size diverges from the default (e.g. a boundary
block of size 1) would otherwise apply the wrong attention pattern. For the
current shipped matrixgame2 configs the two sizes coincide, so output is
unchanged — this removes the latent mismatch.
@rich7420 rich7420 force-pushed the matrixgame2-cache-per-forward-recompute branch from 471991a to e51c063 Compare June 5, 2026 16:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

scope: model Model architecture (DiTs, encoders, VAEs) type: perf Performance improvement

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants