[feat] World model training using third person games#1443
Conversation
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 PR merge requirementsWaiting for
This rule is failing.
|
There was a problem hiding this comment.
Code Review
This pull request introduces a LongLive-style multi-stage self-forcing method (StreamingLongTuningMethod) for streaming rollouts, along with training-time validation metrics and action overlays (keyboard/mouse) on validation frames. It also adds several configurations, documentation, and tests supporting these features. The review feedback highlights three key areas for improvement: optimizing the student rollout in _losses_for_batch to avoid redundant computations when update_student is enabled, clamping the denominator in _dmd_loss_masked to prevent division-by-zero and gradient explosion, and replacing copy.deepcopy with copy.copy on attn_metadata to avoid performance and memory overhead.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if update_student: | ||
| generator_pred_x0 = self._student_rollout( | ||
| training_batch, | ||
| with_grad=True, | ||
| ) | ||
| student_ctx = ( | ||
| training_batch.timesteps, | ||
| training_batch.attn_metadata_vsa, | ||
| ) | ||
| generator_loss = self._dmd_loss_masked( | ||
| generator_pred_x0, | ||
| training_batch, | ||
| chunk_mask=chunk_mask, | ||
| ) | ||
|
|
||
| with torch.no_grad(): | ||
| generator_pred_x0 = self._student_rollout( | ||
| training_batch, | ||
| with_grad=False, | ||
| ) | ||
|
|
||
| fake_score_loss, critic_ctx, critic_outputs = (self._critic_flow_matching_loss_for_x0( | ||
| generator_pred_x0, | ||
| training_batch, | ||
| chunk_mask=chunk_mask, | ||
| )) |
There was a problem hiding this comment.
When update_student is True, the student rollout is executed twice: once with gradients enabled (with_grad=True) and once with gradients disabled (with_grad=False). Since DiT rollouts are computationally expensive, we can optimize this by only running the second rollout when update_student is False, and simply detaching generator_pred_x0 when update_student is True.
if update_student:
generator_pred_x0 = self._student_rollout(
training_batch,
with_grad=True,
)
student_ctx = (
training_batch.timesteps,
training_batch.attn_metadata_vsa,
)
generator_loss = self._dmd_loss_masked(
generator_pred_x0,
training_batch,
chunk_mask=chunk_mask,
)
critic_pred_x0 = generator_pred_x0.detach()
else:
with torch.no_grad():
critic_pred_x0 = self._student_rollout(
training_batch,
with_grad=False,
)
fake_score_loss, critic_ctx, critic_outputs = (self._critic_flow_matching_loss_for_x0(
critic_pred_x0,
training_batch,
chunk_mask=chunk_mask,
))| denom = torch.abs(generator_pred_x0 - real_cfg_x0).mean() | ||
| grad = (faker_x0 - real_cfg_x0) / denom |
There was a problem hiding this comment.
If generator_pred_x0 and real_cfg_x0 are identical or extremely close, denom will be zero or near-zero. This causes division by zero, resulting in inf or NaN values in grad. While torch.nan_to_num is called afterwards, it replaces inf with the maximum representable float value (e.g., 3.4e38), which will cause overflow and gradient explosion during the backward pass. Clamping denom with a small epsilon (e.g., 1e-8) prevents this instability.
| denom = torch.abs(generator_pred_x0 - real_cfg_x0).mean() | |
| grad = (faker_x0 - real_cfg_x0) / denom | |
| denom = torch.abs(generator_pred_x0 - real_cfg_x0).mean().clamp_min(1e-8) | |
| grad = (faker_x0 - real_cfg_x0) / denom |
| batch.attn_metadata = None | ||
| batch.attn_metadata_vsa = None | ||
| build_metadata(batch) | ||
| batch.attn_metadata_vsa = copy.deepcopy(batch.attn_metadata) |
There was a problem hiding this comment.
Using copy.deepcopy on attn_metadata can be extremely slow and memory-intensive because it deep-copies the underlying PyTorch tensors. As documented in MatrixGame2Model.prepare_batch, a shallow copy (copy.copy) should be used instead to keep the lru_cache'd index fields shared while allowing modifications to metadata fields like VSA_sparsity.
| batch.attn_metadata_vsa = copy.deepcopy(batch.attn_metadata) | |
| batch.attn_metadata_vsa = copy.copy(batch.attn_metadata) |
Pre-commit checks failedHi @mignonjia, the pre-commit checks have failed. To fix them locally: # Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install
# Run all checks and auto-fix what's possible
pre-commit run --all-filesCommon fixes:
After fixing, commit and push the changes. The checks will re-run automatically. For future commits, |
Pre-commit checks failedHi @mignonjia, the pre-commit checks have failed. To fix them locally: # Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install
# Run all checks and auto-fix what's possible
pre-commit run --all-filesCommon fixes:
After fixing, commit and push the changes. The checks will re-run automatically. For future commits, |
Pre-commit checks failedHi @mignonjia, the pre-commit checks have failed. To fix them locally: # Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install
# Run all checks and auto-fix what's possible
pre-commit run --all-filesCommon fixes:
After fixing, commit and push the changes. The checks will re-run automatically. For future commits, |
a212635 to
4003e70
Compare
Pre-commit checks failedHi @mignonjia, the pre-commit checks have failed. To fix them locally: # Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install
# Run all checks and auto-fix what's possible
pre-commit run --all-filesCommon fixes:
After fixing, commit and push the changes. The checks will re-run automatically. For future commits, |
|
This PR has merge conflicts with the base branch. Please rebase: git fetch origin main
git rebase origin/main
# Resolve any conflicts, then:
git push --force-with-lease |
Pre-commit checks failedHi @mignonjia, the pre-commit checks have failed. To fix them locally: # Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install
# Run all checks and auto-fix what's possible
pre-commit run --all-filesCommon fixes:
After fixing, commit and push the changes. The checks will re-run automatically. For future commits, |
Pre-commit checks failedHi @mignonjia, the pre-commit checks have failed. To fix them locally: # Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install
# Run all checks and auto-fix what's possible
pre-commit run --all-filesCommon fixes:
After fixing, commit and push the changes. The checks will re-run automatically. For future commits, |
| state.previous_latents = full_chunk.detach()[:, -chunk_size:] | ||
|
|
||
| if not dist.is_initialized() or dist.get_rank() == 0: | ||
| print( |
| ) | ||
|
|
||
|
|
||
| def _retain_kv_with_sink( |
There was a problem hiding this comment.
There are duplicate functions in fastvideo/models/dits/matrixgame2/causal_model.py; they should be merged.
| where=("multi_phased_distill_schedule" | ||
| f"[{idx}].streaming_fixed_overlap_latents"), | ||
| )), | ||
| train_first_chunk=_as_bool( |
There was a problem hiding this comment.
train_first_chunk is read but never used, so it should be deleted.
Purpose
Add world-model training helpers for MatrixGame2/Zelda, including LongLive-style streaming long tuning, Zelda validation utilities, action overlays, and synthetic optical-flow validation wiring.
Fixes #
Changes
StreamingLongTuningMethodfor LongLive-style streaming long tuning from a self-forcing checkpoint, with long overlapping chunks and streaming context.data/zeldam2-cleanlocal path.Test Plan
Test Results
Test output
Review Notes
_losses_for_batch, matching existingDMD2MethodandSelfForcingMethodbehavior._dmd_loss_maskeddenominator unchanged because it already clamps withclamp_min(1e-6).Checklist
For model/pipeline changes, also check: