Skip to content

[feat] World model training using third person games#1443

Open
mignonjia wants to merge 26 commits into
hao-ai-lab:mainfrom
mignonjia:mhuo/longlive-nvl
Open

[feat] World model training using third person games#1443
mignonjia wants to merge 26 commits into
hao-ai-lab:mainfrom
mignonjia:mhuo/longlive-nvl

Conversation

@mignonjia

@mignonjia mignonjia commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Purpose

Add world-model training helpers for MatrixGame2/Zelda, including LongLive-style streaming long tuning, Zelda validation utilities, action overlays, and synthetic optical-flow validation wiring.

Fixes #

Changes

  • Add StreamingLongTuningMethod for LongLive-style streaming long tuning from a self-forcing checkpoint, with long overlapping chunks and streaming context.
  • Add MatrixGame2 causal attention sink support for long-context streaming.
  • Add Zelda world-model configs, validation metric wiring, and synthetic optical-flow calibration.
  • Add validation action overlays and artifact-save handling.
  • Document Zelda training/validation data downloads and the suggested data/zeldam2-clean local path.

Test Plan

PYTHONPATH=. pytest fastvideo/tests/train/methods/test_streaming_long_tuning.py -q
PYTHONPATH=. pytest fastvideo/tests/train/callbacks/test_validation.py -q
pre-commit run --files docs/training/train_infra.md examples/train/scenario/worldmodel/README.md examples/train/scenario/worldmodel/self_forcing_causal_i2v_zelda.yaml examples/train/scenario/worldmodel/streaming_long_tuning_causal_i2v.yaml fastvideo/train/methods/distribution_matching/streaming_long_tuning.py fastvideo/tests/train/callbacks/test_validation.py fastvideo/tests/train/methods/test_streaming_long_tuning.py PR_doc.md

Test Results

Test output
PYTHONPATH=. pytest fastvideo/tests/train/methods/test_streaming_long_tuning.py -q
..                                                                       [100%]
2 passed, 14 warnings in 0.55s

PYTHONPATH=. pytest fastvideo/tests/train/callbacks/test_validation.py -q
...........................                                              [100%]
27 passed, 14 warnings in 0.56s

pre-commit run --files ...
yapf.....................................................................Passed
ruff (legacy alias)......................................................Passed
codespell................................................................Passed
PyMarkdown...............................................................Passed
Lint GitHub Actions workflow files...................(no files to check)Skipped
mypy.....................................................................Passed
Check for spaces in all filenames........................................Passed
Suggestion...............................................................Passed

Review Notes

  • Kept the independent no-grad critic rollout in _losses_for_batch, matching existing DMD2Method and SelfForcingMethod behavior.
  • Left _dmd_loss_masked denominator unchanged because it already clamps with clamp_min(1e-6).

Checklist

  • I ran pre-commit on the changed PR files and fixed all issues
  • I added or updated tests for my changes
  • I updated documentation if needed
  • I considered GPU memory impact of my changes

For model/pipeline changes, also check:

  • I verified SSIM regression tests pass
  • I updated the support matrix if adding a new model

@mignonjia mignonjia requested a review from alexzms June 8, 2026 23:02
@mergify mergify Bot added scope: training Training pipeline, methods, configs scope: infra CI, tests, Docker, build scope: model Model architecture (DiTs, encoders, VAEs) labels Jun 8, 2026
@mergify

mergify Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 PR merge requirements

Waiting for

  • #approved-reviews-by>=1
  • check-success=fastcheck-passed
  • check-success=full-suite-passed
  • check-success~=pre-commit
This rule is failing.
  • #approved-reviews-by>=1
  • check-success=fastcheck-passed
  • check-success=full-suite-passed
  • check-success~=pre-commit
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model|skill|skills|infra)\]

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a LongLive-style multi-stage self-forcing method (StreamingLongTuningMethod) for streaming rollouts, along with training-time validation metrics and action overlays (keyboard/mouse) on validation frames. It also adds several configurations, documentation, and tests supporting these features. The review feedback highlights three key areas for improvement: optimizing the student rollout in _losses_for_batch to avoid redundant computations when update_student is enabled, clamping the denominator in _dmd_loss_masked to prevent division-by-zero and gradient explosion, and replacing copy.deepcopy with copy.copy on attn_metadata to avoid performance and memory overhead.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +469 to +494
if update_student:
generator_pred_x0 = self._student_rollout(
training_batch,
with_grad=True,
)
student_ctx = (
training_batch.timesteps,
training_batch.attn_metadata_vsa,
)
generator_loss = self._dmd_loss_masked(
generator_pred_x0,
training_batch,
chunk_mask=chunk_mask,
)

with torch.no_grad():
generator_pred_x0 = self._student_rollout(
training_batch,
with_grad=False,
)

fake_score_loss, critic_ctx, critic_outputs = (self._critic_flow_matching_loss_for_x0(
generator_pred_x0,
training_batch,
chunk_mask=chunk_mask,
))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When update_student is True, the student rollout is executed twice: once with gradients enabled (with_grad=True) and once with gradients disabled (with_grad=False). Since DiT rollouts are computationally expensive, we can optimize this by only running the second rollout when update_student is False, and simply detaching generator_pred_x0 when update_student is True.

        if update_student:
            generator_pred_x0 = self._student_rollout(
                training_batch,
                with_grad=True,
            )
            student_ctx = (
                training_batch.timesteps,
                training_batch.attn_metadata_vsa,
            )
            generator_loss = self._dmd_loss_masked(
                generator_pred_x0,
                training_batch,
                chunk_mask=chunk_mask,
            )
            critic_pred_x0 = generator_pred_x0.detach()
        else:
            with torch.no_grad():
                critic_pred_x0 = self._student_rollout(
                    training_batch,
                    with_grad=False,
                )

        fake_score_loss, critic_ctx, critic_outputs = (self._critic_flow_matching_loss_for_x0(
            critic_pred_x0,
            training_batch,
            chunk_mask=chunk_mask,
        ))

Comment on lines +1227 to +1228
denom = torch.abs(generator_pred_x0 - real_cfg_x0).mean()
grad = (faker_x0 - real_cfg_x0) / denom

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If generator_pred_x0 and real_cfg_x0 are identical or extremely close, denom will be zero or near-zero. This causes division by zero, resulting in inf or NaN values in grad. While torch.nan_to_num is called afterwards, it replaces inf with the maximum representable float value (e.g., 3.4e38), which will cause overflow and gradient explosion during the backward pass. Clamping denom with a small epsilon (e.g., 1e-8) prevents this instability.

Suggested change
denom = torch.abs(generator_pred_x0 - real_cfg_x0).mean()
grad = (faker_x0 - real_cfg_x0) / denom
denom = torch.abs(generator_pred_x0 - real_cfg_x0).mean().clamp_min(1e-8)
grad = (faker_x0 - real_cfg_x0) / denom

batch.attn_metadata = None
batch.attn_metadata_vsa = None
build_metadata(batch)
batch.attn_metadata_vsa = copy.deepcopy(batch.attn_metadata)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using copy.deepcopy on attn_metadata can be extremely slow and memory-intensive because it deep-copies the underlying PyTorch tensors. As documented in MatrixGame2Model.prepare_batch, a shallow copy (copy.copy) should be used instead to keep the lru_cache'd index fields shared while allowing modifications to metadata fields like VSA_sparsity.

Suggested change
batch.attn_metadata_vsa = copy.deepcopy(batch.attn_metadata)
batch.attn_metadata_vsa = copy.copy(batch.attn_metadata)

@mergify mergify Bot added the scope: inference Inference pipeline, serving, CLI label Jun 9, 2026
@mignonjia mignonjia self-assigned this Jun 9, 2026
@mergify mergify Bot added the scope: docs Documentation label Jun 9, 2026
@mignonjia mignonjia marked this pull request as ready for review June 9, 2026 22:29
@mignonjia mignonjia changed the title World model training helper functions World model training Jun 9, 2026
@mignonjia mignonjia changed the title World model training [feat] World model training using third person games Jun 9, 2026
@mergify mergify Bot added the type: feat New feature or capability label Jun 9, 2026
@hao-ai-lab hao-ai-lab deleted a comment from mergify Bot Jun 10, 2026
@mergify

mergify Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Pre-commit checks failed

Hi @mignonjia, the pre-commit checks have failed. To fix them locally:

# Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install

# Run all checks and auto-fix what's possible
pre-commit run --all-files

Common fixes:

  • yapf: yapf -i <file> (formatting)
  • ruff: ruff check --fix <file> (linting)
  • codespell: codespell --write-changes <file> (spelling)

After fixing, commit and push the changes. The checks will re-run automatically.

For future commits, pre-commit will run automatically on changed files before each commit.

@mergify

mergify Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Pre-commit checks failed

Hi @mignonjia, the pre-commit checks have failed. To fix them locally:

# Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install

# Run all checks and auto-fix what's possible
pre-commit run --all-files

Common fixes:

  • yapf: yapf -i <file> (formatting)
  • ruff: ruff check --fix <file> (linting)
  • codespell: codespell --write-changes <file> (spelling)

After fixing, commit and push the changes. The checks will re-run automatically.

For future commits, pre-commit will run automatically on changed files before each commit.

@mergify mergify Bot added the scope: data Data preprocessing, datasets label Jun 10, 2026
@mergify

mergify Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Pre-commit checks failed

Hi @mignonjia, the pre-commit checks have failed. To fix them locally:

# Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install

# Run all checks and auto-fix what's possible
pre-commit run --all-files

Common fixes:

  • yapf: yapf -i <file> (formatting)
  • ruff: ruff check --fix <file> (linting)
  • codespell: codespell --write-changes <file> (spelling)

After fixing, commit and push the changes. The checks will re-run automatically.

For future commits, pre-commit will run automatically on changed files before each commit.

@mignonjia mignonjia force-pushed the mhuo/longlive-nvl branch from a212635 to 4003e70 Compare June 10, 2026 21:37
@mergify

mergify Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Pre-commit checks failed

Hi @mignonjia, the pre-commit checks have failed. To fix them locally:

# Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install

# Run all checks and auto-fix what's possible
pre-commit run --all-files

Common fixes:

  • yapf: yapf -i <file> (formatting)
  • ruff: ruff check --fix <file> (linting)
  • codespell: codespell --write-changes <file> (spelling)

After fixing, commit and push the changes. The checks will re-run automatically.

For future commits, pre-commit will run automatically on changed files before each commit.

@mignonjia mignonjia marked this pull request as draft June 10, 2026 21:40
@mignonjia mignonjia marked this pull request as ready for review June 12, 2026 01:50
@mergify

mergify Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

This PR has merge conflicts with the base branch. Please rebase:

git fetch origin main
git rebase origin/main
# Resolve any conflicts, then:
git push --force-with-lease

@mergify mergify Bot added the needs-rebase PR has merge conflicts label Jun 12, 2026
@mignonjia mignonjia removed their assignment Jun 16, 2026
@mergify

mergify Bot commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Pre-commit checks failed

Hi @mignonjia, the pre-commit checks have failed. To fix them locally:

# Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install

# Run all checks and auto-fix what's possible
pre-commit run --all-files

Common fixes:

  • yapf: yapf -i <file> (formatting)
  • ruff: ruff check --fix <file> (linting)
  • codespell: codespell --write-changes <file> (spelling)

After fixing, commit and push the changes. The checks will re-run automatically.

For future commits, pre-commit will run automatically on changed files before each commit.

@mergify

mergify Bot commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Pre-commit checks failed

Hi @mignonjia, the pre-commit checks have failed. To fix them locally:

# Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install

# Run all checks and auto-fix what's possible
pre-commit run --all-files

Common fixes:

  • yapf: yapf -i <file> (formatting)
  • ruff: ruff check --fix <file> (linting)
  • codespell: codespell --write-changes <file> (spelling)

After fixing, commit and push the changes. The checks will re-run automatically.

For future commits, pre-commit will run automatically on changed files before each commit.

@mergify mergify Bot removed the needs-rebase PR has merge conflicts label Jun 16, 2026
state.previous_latents = full_chunk.detach()[:, -chunk_size:]

if not dist.is_initialized() or dist.get_rank() == 0:
print(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should use logger

)


def _retain_kv_with_sink(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are duplicate functions in fastvideo/models/dits/matrixgame2/causal_model.py; they should be merged.

where=("multi_phased_distill_schedule"
f"[{idx}].streaming_fixed_overlap_latents"),
)),
train_first_chunk=_as_bool(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

train_first_chunk is read but never used, so it should be deleted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

scope: data Data preprocessing, datasets scope: docs Documentation scope: inference Inference pipeline, serving, CLI scope: infra CI, tests, Docker, build scope: model Model architecture (DiTs, encoders, VAEs) scope: training Training pipeline, methods, configs type: feat New feature or capability

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants