[perf] matrixgame2 causal: cache per-forward RoPE/timestep/blockmask recompute on-device#1415
[perf] matrixgame2 causal: cache per-forward RoPE/timestep/blockmask recompute on-device#1415rich7420 wants to merge 2 commits into
Conversation
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 PR merge requirementsWaiting for
This rule is failing.
|
There was a problem hiding this comment.
Code Review
This pull request optimizes performance and enables CUDA-graph capture by caching frequency tables, RoPE tables, and block masks directly on the target device, avoiding expensive CPU-to-GPU copies and redundant computations. The review feedback highlights three key improvement opportunities: first, the freqs_cis computation in causal_model.py is completely unused and can be bypassed entirely; second, the block mask cache keys should include the device to prevent runtime mismatches if the model is moved; and third, the lazy initialization in action_module.py should check the device of the cached tensors to ensure they remain on the correct device.
| # RoPE tables depend only on (grid, start_frame); both are constant | ||
| # within an AR block's DMD steps. Recomputing them on CPU in float64 | ||
| # and copying H2D every forward is wasteful AND breaks CUDA-graph | ||
| # capture (H2D copy is illegal mid-capture). Cache the device tensors | ||
| # keyed by (start_frame, grid) so the H2D happens once per block. | ||
| _rope_key = (start_frame, post_patch_num_frames, post_patch_height, | ||
| post_patch_width) | ||
| _rope_cache = getattr(self, "_rope_device_cache", None) | ||
| if _rope_cache is None: | ||
| _rope_cache = {} | ||
| self._rope_device_cache = _rope_cache | ||
| if _rope_key in _rope_cache: | ||
| freqs_cos, freqs_sin = _rope_cache[_rope_key] | ||
| else: | ||
| freqs_cos, freqs_sin = get_rotary_pos_embed( | ||
| ( | ||
| post_patch_num_frames * get_sp_world_size(), | ||
| post_patch_height, | ||
| post_patch_width, | ||
| ), | ||
| self.hidden_size, | ||
| self.num_attention_heads, | ||
| rope_dim_list, | ||
| dtype=torch.float32 if current_platform.is_mps() else torch.float64, | ||
| rope_theta=10000, | ||
| start_frame=start_frame, | ||
| ) | ||
| freqs_cos = freqs_cos.to(hidden_states.device) | ||
| freqs_sin = freqs_sin.to(hidden_states.device) | ||
| _rope_cache[_rope_key] = (freqs_cos, freqs_sin) | ||
| freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None |
There was a problem hiding this comment.
The freqs_cis variable (computed via get_rotary_pos_embed) is completely unused in CausalMatrixGame2SelfAttention.forward (which instead uses causal_rope_apply with its own self._freqs_cache).
Instead of caching this unused computation, we can completely avoid calling get_rotary_pos_embed and simply set freqs_cis = None. This eliminates the CPU overhead, GPU memory allocation, and H2D copy entirely without needing any caching logic.
| # RoPE tables depend only on (grid, start_frame); both are constant | |
| # within an AR block's DMD steps. Recomputing them on CPU in float64 | |
| # and copying H2D every forward is wasteful AND breaks CUDA-graph | |
| # capture (H2D copy is illegal mid-capture). Cache the device tensors | |
| # keyed by (start_frame, grid) so the H2D happens once per block. | |
| _rope_key = (start_frame, post_patch_num_frames, post_patch_height, | |
| post_patch_width) | |
| _rope_cache = getattr(self, "_rope_device_cache", None) | |
| if _rope_cache is None: | |
| _rope_cache = {} | |
| self._rope_device_cache = _rope_cache | |
| if _rope_key in _rope_cache: | |
| freqs_cos, freqs_sin = _rope_cache[_rope_key] | |
| else: | |
| freqs_cos, freqs_sin = get_rotary_pos_embed( | |
| ( | |
| post_patch_num_frames * get_sp_world_size(), | |
| post_patch_height, | |
| post_patch_width, | |
| ), | |
| self.hidden_size, | |
| self.num_attention_heads, | |
| rope_dim_list, | |
| dtype=torch.float32 if current_platform.is_mps() else torch.float64, | |
| rope_theta=10000, | |
| start_frame=start_frame, | |
| ) | |
| freqs_cos = freqs_cos.to(hidden_states.device) | |
| freqs_sin = freqs_sin.to(hidden_states.device) | |
| _rope_cache[_rope_key] = (freqs_cos, freqs_sin) | |
| freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None | |
| # freqs_cis is completely unused in CausalMatrixGame2SelfAttention.forward, | |
| # so we can avoid the expensive get_rotary_pos_embed computation entirely. | |
| freqs_cis = None |
| _k = ("main", num_frames, _fsl, _nfb, _las) | ||
| if _k not in bm_cache: | ||
| bm_cache[_k] = self._prepare_blockwise_causal_attn_mask( | ||
| device=_dev, num_frames=num_frames, frame_seqlen=_fsl, | ||
| num_frame_per_block=_nfb, local_attn_size=_las) | ||
| block_mask = bm_cache[_k] | ||
|
|
||
| if self.use_rope_keyboard: | ||
| block_mask_keyboard = self._prepare_blockwise_causal_attn_mask_action( | ||
| device=hidden_states.device, | ||
| num_frames=num_frames, | ||
| frame_seqlen=1, | ||
| num_frame_per_block=self.num_frame_per_block, | ||
| local_attn_size=self.local_attn_size, | ||
| ) | ||
| _k = ("act_kb", num_frames, 1, _nfb, _las) | ||
| if _k not in bm_cache: | ||
| bm_cache[_k] = self._prepare_blockwise_causal_attn_mask_action( | ||
| device=_dev, num_frames=num_frames, frame_seqlen=1, | ||
| num_frame_per_block=_nfb, local_attn_size=_las) | ||
| block_mask_keyboard = bm_cache[_k] | ||
| else: | ||
| block_mask_keyboard = self._prepare_blockwise_causal_attn_mask_keyboard( | ||
| device=hidden_states.device, | ||
| num_frames=num_frames, | ||
| frame_seqlen=post_patch_height * post_patch_width, | ||
| num_frame_per_block=self.num_frame_per_block, | ||
| local_attn_size=self.local_attn_size, | ||
| ) | ||
| block_mask_mouse = self._prepare_blockwise_causal_attn_mask_action( | ||
| device=hidden_states.device, | ||
| num_frames=num_frames, | ||
| frame_seqlen=1, | ||
| num_frame_per_block=self.num_frame_per_block, | ||
| local_attn_size=self.local_attn_size, | ||
| ) | ||
| _k = ("kb", num_frames, _fsl, _nfb, _las) | ||
| if _k not in bm_cache: | ||
| bm_cache[_k] = self._prepare_blockwise_causal_attn_mask_keyboard( | ||
| device=_dev, num_frames=num_frames, frame_seqlen=_fsl, | ||
| num_frame_per_block=_nfb, local_attn_size=_las) | ||
| block_mask_keyboard = bm_cache[_k] | ||
|
|
||
| _k = ("act_ms", num_frames, 1, _nfb, _las) | ||
| if _k not in bm_cache: | ||
| bm_cache[_k] = self._prepare_blockwise_causal_attn_mask_action( | ||
| device=_dev, num_frames=num_frames, frame_seqlen=1, | ||
| num_frame_per_block=_nfb, local_attn_size=_las) | ||
| block_mask_mouse = bm_cache[_k] |
There was a problem hiding this comment.
The cache key _k for bm_cache does not include the device _dev. If the model is moved to a different device (or if inference is run on multiple devices/GPUs), using a cached BlockMask created on a different device will result in a runtime device mismatch error.
To make the cache robust, include _dev in the cache keys.
_k = ("main", num_frames, _fsl, _nfb, _las, _dev)
if _k not in bm_cache:
bm_cache[_k] = self._prepare_blockwise_causal_attn_mask(
device=_dev, num_frames=num_frames, frame_seqlen=_fsl,
num_frame_per_block=_nfb, local_attn_size=_las)
block_mask = bm_cache[_k]
if self.use_rope_keyboard:
_k = ("act_kb", num_frames, 1, _nfb, _las, _dev)
if _k not in bm_cache:
bm_cache[_k] = self._prepare_blockwise_causal_attn_mask_action(
device=_dev, num_frames=num_frames, frame_seqlen=1,
num_frame_per_block=_nfb, local_attn_size=_las)
block_mask_keyboard = bm_cache[_k]
else:
_k = ("kb", num_frames, _fsl, _nfb, _las, _dev)
if _k not in bm_cache:
bm_cache[_k] = self._prepare_blockwise_causal_attn_mask_keyboard(
device=_dev, num_frames=num_frames, frame_seqlen=_fsl,
num_frame_per_block=_nfb, local_attn_size=_las)
block_mask_keyboard = bm_cache[_k]
_k = ("act_ms", num_frames, 1, _nfb, _las, _dev)
if _k not in bm_cache:
bm_cache[_k] = self._prepare_blockwise_causal_attn_mask_action(
device=_dev, num_frames=num_frames, frame_seqlen=1,
num_frame_per_block=_nfb, local_attn_size=_las)
block_mask_mouse = bm_cache[_k]| if self._freqs_cos is None or self._freqs_sin is None: | ||
| self._freqs_cos, self._freqs_sin = self.get_rotary_pos_embed( | ||
| _fc, _fs = self.get_rotary_pos_embed( | ||
| 7500, | ||
| self.patch_size[1], | ||
| self.patch_size[2], | ||
| 64, | ||
| self.mouse_qk_dim_list, | ||
| start_offset=0, | ||
| ) | ||
| self._freqs_cos = _fc.to(x.device) | ||
| self._freqs_sin = _fs.to(x.device) |
There was a problem hiding this comment.
If the model is moved to a different device after the first forward pass, self._freqs_cos and self._freqs_sin will remain on the old device. This will cause silent H2D copies on every subsequent forward pass inside _apply_rotary_emb_qk, defeating the caching optimization and breaking CUDA-graph capture.
To prevent this, check if the cached tensors are on the correct device (x.device) before reusing them.
| if self._freqs_cos is None or self._freqs_sin is None: | |
| self._freqs_cos, self._freqs_sin = self.get_rotary_pos_embed( | |
| _fc, _fs = self.get_rotary_pos_embed( | |
| 7500, | |
| self.patch_size[1], | |
| self.patch_size[2], | |
| 64, | |
| self.mouse_qk_dim_list, | |
| start_offset=0, | |
| ) | |
| self._freqs_cos = _fc.to(x.device) | |
| self._freqs_sin = _fs.to(x.device) | |
| if self._freqs_cos is None or self._freqs_sin is None or self._freqs_cos.device != x.device: | |
| _fc, _fs = self.get_rotary_pos_embed( | |
| 7500, | |
| self.patch_size[1], | |
| self.patch_size[2], | |
| 64, | |
| self.mouse_qk_dim_list, | |
| start_offset=0, | |
| ) | |
| self._freqs_cos = _fc.to(x.device) | |
| self._freqs_sin = _fs.to(x.device) |
4decf50 to
ac84d6f
Compare
|
Hi @rich7420 — automated review from Gob, one of @SolitaryThinker's AI reviewers. Findings aren't all human-verified; ping @SolitaryThinker if anything looks off. TL;DRThis PR removes repeated matrixgame2 causal inference work by moving timestep frequency creation onto-device, skipping unused RoPE construction, and caching device-specific block masks. I found one must-fix cache invalidation issue: the block-mask cache is built from the model default block size even when the caller overrides the current block size, which can silently apply the wrong attention mask for boundary or variable-sized blocks. Verdict: request-changes
Findings (formatted for upload)[S1] Block-mask cache ignores the per-call block size overrideWhat: In Why it matters: The flex-attention masks depend on the actual block size used for the current forward. If the first boundary block or any future variable-sized block runs with Suggested fix: Use Evidence: — Gob (@SolitaryThinker's AI reviewer). Full review (including S3 items) is archived locally. |
Wan's DiT forward rebuilt rotary embeddings on every call via a float64 CPU compute (`get_rotary_pos_embed`) + H2D copy of the cos/sin pair. The inputs — post-patch shape + compute device — are constant across every step of a single generation, so the precompute happens once per (shape, device) and is reused for the rest of the denoise loop (180+ forwards for a 50-step CFG run). Mirrors the matrixgame2 fix in hao-ai-lab#1415: same hotspot, different model, same cache-on-self pattern. Composes with hao-ai-lab#1245's fused RoPE Triton kernel (that PR optimizes each `_apply_rotary_emb` call site; this one removes the recurring per-forward precompute upstream of it). Math is identical — first call computes, subsequent calls memoize. Bit-exact equivalence with the prior path expected.
TimestepEmbedder rebuilt the sinusoidal frequency table on every forward via a CPU `arange` + `exp` followed by an H2D copy. The table depends only on (frequency_embedding_size, max_period, freq_dtype) — all instance- constant — and the compute device, so it is identical across every denoising step. Cache it on the module per device; only the timestep- dependent product (`args = t * freqs`) runs each call. `timestep_embedding` gains an optional precomputed `freqs` argument so the free function's behavior is unchanged for any direct caller. Same cache-on- self pattern as the Wan rotary precompute in this branch and matrixgame2 PR hao-ai-lab#1415. TimestepEmbedder is shared across the DiT stack (Wan, HunyuanVideo/HV15, hunyuangamecraft, longcat, matrixgame2/3, hyworld), so every model that uses it drops the per-forward recompute + H2D. Bit-exact — the cached tensor is the same value the prior path recomputed each call.
TimestepEmbedder rebuilt the sinusoidal frequency table on every forward via a CPU `arange` + `exp` followed by an H2D copy. The table depends only on (frequency_embedding_size, max_period, freq_dtype) — all instance- constant — and the compute device, so it is identical across every denoising step. Cache it on the module per device; only the timestep- dependent product (`args = t * freqs`) runs each call. `timestep_embedding` gains an optional precomputed `freqs` argument so the free function's behavior is unchanged for any direct caller. Same cache-on- self pattern as the Wan rotary precompute in this branch and matrixgame2 PR hao-ai-lab#1415. TimestepEmbedder is shared across the DiT stack (Wan, HunyuanVideo/HV15, hunyuangamecraft, longcat, matrixgame2/3, hyworld), so every model that uses it drops the per-forward recompute + H2D. Bit-exact — the cached tensor is the same value the prior path recomputed each call.
…, skip unused RoPE, on-device timestep) matrixgame2 causal inference repeated constant work on every DiT forward: - 4 flex-attention BlockMasks were rebuilt via create_block_mask each forward (compile + a GPU-tensor python loop). They depend only on (num_frames, frame_seqlen, block size, local_attn_size, device), so cache them per param-tuple (device included in the key). - get_rotary_pos_embed ran a float64 CPU compute + H2D copy each forward to build freqs_cis — which CausalMatrixGame2SelfAttention.forward never reads (it computes its own RoPE via self._freqs_cache). Set freqs_cis = None and skip the computation entirely. - the sinusoidal timestep table was built on CPU then copied H2D; build it on the target device instead. - the action-module RoPE freqs were cached as CPU tensors, forcing an H2D copy every forward inside _apply_rotary_emb_qk; cache them device-resident (with a device check) so the per-call .to is a no-op. All numerically identical. Removes per-forward CPU compute + H2D copies that also blocked CUDA-graph capture. Measured (H100, 117 frames = 30 DiT forwards): isolated DiT forward 194 ms -> 95 ms; end-to-end inference 12.6 s -> 10.6 s (-16%); denoise-stage GPU-active 23% -> 33.5%.
550bb34 to
471991a
Compare
…ock size The cached BlockMask was keyed/built from self.num_frame_per_block, but the transformer blocks run with effective_num_frame_per_block (the per-call num_frame_per_block override). Align the mask with the block size the blocks actually use; it is part of the cache key so different sizes get distinct entries. Verified on H100: the size-1 and size-3 masks genuinely differ, so any forward whose effective size diverges from the default (e.g. a boundary block of size 1) would otherwise apply the wrong attention pattern. For the current shipped matrixgame2 configs the two sizes coincide, so output is unchanged — this removes the latent mismatch.
471991a to
e51c063
Compare
Summary
matrixgame2 causal inference repeated constant work on every DiT forward.
This removes it:
create_block_maskeachforward (compile + a GPU-tensor python loop). They depend only on
(num_frames, frame_seqlen, block size, local_attn_size, device), so theyare cached per param-tuple (device included in the key).
get_rotary_pos_embedran a float64 CPU compute + H2D copy each forward tobuild
freqs_cis, whichCausalMatrixGame2SelfAttention.forwardnever reads(it computes its own RoPE via
self._freqs_cache). Now skipped entirely(
freqs_cis = None).CPU + H2D.
check) so the per-call
.tois a no-op.All numerically identical.
How it was found (nsys-ai)
On an H100 matrixgame2 trace,
gpu_idle_gapsplus the memcpy breakdown showedheavy per-forward host→device traffic and repeated
create_block_maskwork; aCUDA-graph capture probe pinpointed the exact per-forward recompute sites.
Result (H100, 117 frames = 30 DiT forwards)
The biggest single contributor is the BlockMask cache —
create_block_maskisexpensive and was running once per forward.