[perf] Cache RoPE position-embedding tables across denoising steps#1442
[perf] Cache RoPE position-embedding tables across denoising steps#1442Godmook wants to merge 2 commits into
Conversation
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 PR merge requirementsWaiting for
This rule is failing.
|
There was a problem hiding this comment.
Code Review
This pull request introduces a memoization cache for the get_rotary_pos_embed function to avoid redundant computations of rotary position embeddings, along with a comprehensive suite of unit tests. The review feedback highlights a critical issue regarding unbounded memory growth in the global cache, suggesting a FIFO eviction policy to cap the cache size. Additionally, improvements are suggested for the unit tests, including strengthening a weak assertion and adding a test to verify the cache eviction behavior.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| use_real=use_real, | ||
| ) | ||
| # Cached tensors are shared (read-only); callers copy via .to()/.float(). | ||
| _ROTARY_POS_EMBED_CACHE[cache_key] = (freqs_cos, freqs_sin) |
There was a problem hiding this comment.
The global cache _ROTARY_POS_EMBED_CACHE grows indefinitely without any size limit or eviction policy. In long-running environments (such as a web service or when generating videos with varying resolutions, lengths, or changing start_frame values), this will lead to unbounded memory growth and potential Out-Of-Memory (OOM) errors.
Since each cache entry contains large float64 tensors (e.g., ~150MB for 720p resolution), we should limit the cache size (e.g., to 16 entries) and evict the oldest entries using a simple FIFO policy.
| _ROTARY_POS_EMBED_CACHE[cache_key] = (freqs_cos, freqs_sin) | |
| _ROTARY_POS_EMBED_CACHE[cache_key] = (freqs_cos, freqs_sin) | |
| if len(_ROTARY_POS_EMBED_CACHE) > 16: | |
| _ROTARY_POS_EMBED_CACHE.pop(next(iter(_ROTARY_POS_EMBED_CACHE))) |
| """List-valued rescale factors are hashable and keyed apart from scalars.""" | ||
| _call(theta_rescale_factor=1.0) | ||
| _call(theta_rescale_factor=[1.0, 1.0, 1.0]) | ||
| assert len(_ROTARY_POS_EMBED_CACHE) >= 1 |
There was a problem hiding this comment.
The assertion assert len(_ROTARY_POS_EMBED_CACHE) >= 1 is too weak and does not actually verify that the scalar and list-valued rescale factors are keyed apart (distinct), as the test name and docstring suggest. It should assert == 2 to ensure both entries are correctly cached as separate keys.
| assert len(_ROTARY_POS_EMBED_CACHE) >= 1 | |
| assert len(_ROTARY_POS_EMBED_CACHE) == 2 |
| cos0, _ = _call(start_frame=0) | ||
| cos3, _ = _call(start_frame=3) | ||
| assert not torch.equal(cos0, cos3) | ||
| assert len(_ROTARY_POS_EMBED_CACHE) == 2 |
There was a problem hiding this comment.
To ensure that the cache eviction policy works correctly and prevents memory leaks, we should add a unit test that verifies the cache size is capped (e.g., at 16 entries) and older entries are evicted.
| assert len(_ROTARY_POS_EMBED_CACHE) == 2 | |
| assert len(_ROTARY_POS_EMBED_CACHE) == 2 | |
| def test_cache_eviction_limits_memory(): | |
| """The cache should evict the oldest entries when exceeding the maximum size.""" | |
| for i in range(20): | |
| _call(start_frame=i) | |
| assert len(_ROTARY_POS_EMBED_CACHE) == 16 |
|
Thanks for the contribution! The implementation looks great — clean memoization matching the existing One question before I approve: have you measured the end-to-end wall-clock impact (a full generation before/after)? The micro-benchmark is convincing, but since the saved RoPE cost is host-side it may partly overlap with async GPU work, so I am curious how much it actually moves the needle end to end. Everything else looks great. |
I'll do E2E test! Thanks! |
Quick Following. I'm currently doing other works nowadays. I'll do this E2E test until tomorrow night. I requested GPU for E2E test and I think It will allocated to me at tomorrow 3pm. Thanks! |
Summary
get_rotary_pos_embedis called inside every DiTforward(), i.e. once per denoising step. Its inputs (latent grid size, head dim, theta, dtype, SP shard,start_frame) are constant across the entire denoising loop, yet it rebuilds the fullfloat64cos/sin tables (meshgrid+ 3×outer+cos/sin) from scratch every step and then copies them host→device. For a 50-step run this is the same table computed 50× (100× with CFG).This PR memoizes the result, matching the existing
_ROPE_DICT/get_ropecaching pattern already in the same file. All 11 DiT models that callget_rotary_pos_embedbenefit with no model-side changes.It is a pure-function memoization, so outputs are bitwise-identical and SSIM is unchanged.
What changed
fastvideo/layers/rotary_embedding.py: add a module-level_ROTARY_POS_EMBED_CACHEand return the cached(cos, sin)on hit.rope_sizes,rope_dim_list(post-None-normalization),rope_theta,theta_rescale_factor,interpolation_factor,shard_dim,sp_rank,sp_world_size,dtype,start_frame,use_real.sp_rank/sp_world_sizeare in the key so sequence-parallel ranks nevershare a table (e.g. MatrixGame
do_sp_sharding=True).start_frameis in the key so causal/autoregressive models(causal-Wan, MatrixGame) stay correct frame-to-frame.
Before / After
Micro-benchmark of
get_rotary_pos_embed(CPU,float64— same dtype the DiTcall sites use). "Before" = recompute every step (current behavior); "After" = cache hit.
~99.8% of the per-step RoPE cost removed (double the wall-clock saving when CFG runs cond+uncond).
This is a host-side cost that sits on the critical path at the start of each
forward; the relative win is largest on small / few-step distilled models (FastWan 1.3B, Turbo*), smaller on large many-step models where the DiT blocks dominate.Correctness & safety
(covered by tests), so generated output / SSIM is unchanged.
copy via
.to(device).float(), and no call site mutates the returned tensorsin place (verified by grep). Cache holds only CPU memory, bounded by the number
of distinct resolutions used in a run.
Testing
New CPU-only test suite:
fastvideo/tests/ops/test_rotary_pos_embed_cache.py(20 cases), covering:
use_real)use_real,start_frame,theta,shard_dimNonerope_dim_listshares a key with its normalized explicit listuse_realcontrols last-dim size; degenerate grids(1,1,1),(1,H,W),(T,1,1).to()/.float()copy does not corrupt the cachestart_frameshifts temporal positions (causal path)Run:
Result: 20 passed.
pre-commit(yapf / ruff / mypy / codespell) passes onthe modified source file.