Skip to content

[perf] Cache RoPE position-embedding tables across denoising steps#1442

Open
Godmook wants to merge 2 commits into
hao-ai-lab:mainfrom
Godmook:cache_rotary
Open

[perf] Cache RoPE position-embedding tables across denoising steps#1442
Godmook wants to merge 2 commits into
hao-ai-lab:mainfrom
Godmook:cache_rotary

Conversation

@Godmook

@Godmook Godmook commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Summary

get_rotary_pos_embed is called inside every DiT forward(), i.e. once per denoising step. Its inputs (latent grid size, head dim, theta, dtype, SP shard,start_frame) are constant across the entire denoising loop, yet it rebuilds the full float64 cos/sin tables (meshgrid + 3× outer + cos/sin) from scratch every step and then copies them host→device. For a 50-step run this is the same table computed 50× (100× with CFG).

This PR memoizes the result, matching the existing _ROPE_DICT / get_rope caching pattern already in the same file. All 11 DiT models that call get_rotary_pos_embed benefit with no model-side changes.

It is a pure-function memoization, so outputs are bitwise-identical and SSIM is unchanged.

What changed

  • fastvideo/layers/rotary_embedding.py: add a module-level _ROTARY_POS_EMBED_CACHE and return the cached (cos, sin) on hit.
  • The cache key includes every output-affecting argument:
    rope_sizes, rope_dim_list (post-None-normalization), rope_theta,
    theta_rescale_factor, interpolation_factor, shard_dim, sp_rank,
    sp_world_size, dtype, start_frame, use_real.
    • sp_rank / sp_world_size are in the key so sequence-parallel ranks never
      share a table (e.g. MatrixGame do_sp_sharding=True).
    • start_frame is in the key so causal/autoregressive models
      (causal-Wan, MatrixGame) stay correct frame-to-frame.

Before / After

Micro-benchmark of get_rotary_pos_embed (CPU, float64 — same dtype the DiT
call sites use). "Before" = recompute every step (current behavior); "After" = cache hit.

Config seq len Before (per step) After (per step) 50-step saved
Wan1.3B 480p 32,760 13.7 ms 25.9 µs ~0.69 s
Wan14B 480p 32,760 13.3 ms 21.8 µs ~0.67 s
Wan14B 720p 75,600 28.7 ms 49.8 µs ~1.43 s

~99.8% of the per-step RoPE cost removed (double the wall-clock saving when CFG runs cond+uncond).
This is a host-side cost that sits on the critical path at the start of each forward; the relative win is largest on small / few-step distilled models (FastWan 1.3B, Turbo*), smaller on large many-step models where the DiT blocks dominate.

Correctness & safety

  • Pure memoization → cached tensors are bitwise-equal to a fresh recompute
    (covered by tests), so generated output / SSIM is unchanged.
  • Cached tensors are CPU (pre-transfer) and treated as read-only; all call sites
    copy via .to(device).float(), and no call site mutates the returned tensors
    in place (verified by grep). Cache holds only CPU memory, bounded by the number
    of distinct resolutions used in a run.

Testing

New CPU-only test suite: fastvideo/tests/ops/test_rotary_pos_embed_cache.py
(20 cases), covering:

  • cache hit returns the same objects; repeated calls keep a single entry
  • cached values == fresh recompute (parametrized over resolutions / dtype / use_real)
  • distinct keys for each of: resolution, dtype, use_real, start_frame, theta, shard_dim
  • None rope_dim_list shares a key with its normalized explicit list
  • use_real controls last-dim size; degenerate grids (1,1,1), (1,H,W), (T,1,1)
  • list-vs-scalar rescale factors are hashable and keyed apart
  • caller .to()/.float() copy does not corrupt the cache
  • start_frame shifts temporal positions (causal path)

Run:

pytest fastvideo/tests/ops/test_rotary_pos_embed_cache.py -v

Result: 20 passed. pre-commit (yapf / ruff / mypy / codespell) passes on
the modified source file.

@mergify mergify Bot added type: perf Performance improvement scope: infra CI, tests, Docker, build scope: model Model architecture (DiTs, encoders, VAEs) labels Jun 8, 2026
@mergify

mergify Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 PR merge requirements

Waiting for

  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
This rule is failing.
  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
  • check-success=fastcheck-passed
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model|skill|skills|infra)\]

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a memoization cache for the get_rotary_pos_embed function to avoid redundant computations of rotary position embeddings, along with a comprehensive suite of unit tests. The review feedback highlights a critical issue regarding unbounded memory growth in the global cache, suggesting a FIFO eviction policy to cap the cache size. Additionally, improvements are suggested for the unit tests, including strengthening a weak assertion and adding a test to verify the cache eviction behavior.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

use_real=use_real,
)
# Cached tensors are shared (read-only); callers copy via .to()/.float().
_ROTARY_POS_EMBED_CACHE[cache_key] = (freqs_cos, freqs_sin)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The global cache _ROTARY_POS_EMBED_CACHE grows indefinitely without any size limit or eviction policy. In long-running environments (such as a web service or when generating videos with varying resolutions, lengths, or changing start_frame values), this will lead to unbounded memory growth and potential Out-Of-Memory (OOM) errors.

Since each cache entry contains large float64 tensors (e.g., ~150MB for 720p resolution), we should limit the cache size (e.g., to 16 entries) and evict the oldest entries using a simple FIFO policy.

Suggested change
_ROTARY_POS_EMBED_CACHE[cache_key] = (freqs_cos, freqs_sin)
_ROTARY_POS_EMBED_CACHE[cache_key] = (freqs_cos, freqs_sin)
if len(_ROTARY_POS_EMBED_CACHE) > 16:
_ROTARY_POS_EMBED_CACHE.pop(next(iter(_ROTARY_POS_EMBED_CACHE)))

"""List-valued rescale factors are hashable and keyed apart from scalars."""
_call(theta_rescale_factor=1.0)
_call(theta_rescale_factor=[1.0, 1.0, 1.0])
assert len(_ROTARY_POS_EMBED_CACHE) >= 1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion assert len(_ROTARY_POS_EMBED_CACHE) >= 1 is too weak and does not actually verify that the scalar and list-valued rescale factors are keyed apart (distinct), as the test name and docstring suggest. It should assert == 2 to ensure both entries are correctly cached as separate keys.

Suggested change
assert len(_ROTARY_POS_EMBED_CACHE) >= 1
assert len(_ROTARY_POS_EMBED_CACHE) == 2

cos0, _ = _call(start_frame=0)
cos3, _ = _call(start_frame=3)
assert not torch.equal(cos0, cos3)
assert len(_ROTARY_POS_EMBED_CACHE) == 2

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To ensure that the cache eviction policy works correctly and prevents memory leaks, we should add a unit test that verifies the cache size is capped (e.g., at 16 entries) and older entries are evicted.

Suggested change
assert len(_ROTARY_POS_EMBED_CACHE) == 2
assert len(_ROTARY_POS_EMBED_CACHE) == 2
def test_cache_eviction_limits_memory():
"""The cache should evict the oldest entries when exceeding the maximum size."""
for i in range(20):
_call(start_frame=i)
assert len(_ROTARY_POS_EMBED_CACHE) == 16

@alexzms alexzms self-requested a review June 8, 2026 20:49
@alexzms

alexzms commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Thanks for the contribution! The implementation looks great — clean memoization matching the existing _ROPE_DICT pattern in the same file, bitwise-identical output, and solid test coverage.

One question before I approve: have you measured the end-to-end wall-clock impact (a full generation before/after)? The micro-benchmark is convincing, but since the saved RoPE cost is host-side it may partly overlap with async GPU work, so I am curious how much it actually moves the needle end to end. Everything else looks great.

@alexzms alexzms left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E2E latency test

@Godmook

Godmook commented Jun 10, 2026

Copy link
Copy Markdown
Contributor Author

E2E latency test

I'll do E2E test! Thanks!

@Godmook

Godmook commented Jun 12, 2026

Copy link
Copy Markdown
Contributor Author

E2E latency test

I'll do E2E test! Thanks!

Quick Following. I'm currently doing other works nowadays. I'll do this E2E test until tomorrow night. I requested GPU for E2E test and I think It will allocated to me at tomorrow 3pm. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

scope: infra CI, tests, Docker, build scope: model Model architecture (DiTs, encoders, VAEs) type: perf Performance improvement

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants