Skip to content

Add FlashVSR contrib model with video super-resolution on Neuron#165

Open
jimburtoft wants to merge 4 commits into
aws-neuron:mainfrom
jimburtoft:contrib/flashvsr
Open

Add FlashVSR contrib model with video super-resolution on Neuron#165
jimburtoft wants to merge 4 commits into
aws-neuron:mainfrom
jimburtoft:contrib/flashvsr

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

@jimburtoft jimburtoft commented May 18, 2026

Note: The below template includes items meant for model contributions only. For other contributions such as bug fixes, features, etc., only fill out the relevant portions of the form.

Description

FlashVSR is a video super-resolution model (4x upscaling) using a streaming DiT architecture based on Wan 2.1 1.3B. This contrib packages the DiT backbone for Neuron via NxDI ModelBuilder with TP=4, using NKI tiled flash attention (attention_cte) for 23040-token sequences that would otherwise OOM.

The pipeline processes video in overlapping chunks: a first chunk (6 latent frames → 24 output frames) followed by stream chunks (2 latent frames → 8 output frames each). Single-step DMD denoising enables efficient 4x upscaling at 768×1280 output resolution.

Model Information

Model Name: FlashVSR v1.1 (JunhaoZhuang/FlashVSR-v1.1)

Model Architecture: 30-layer DiT (dim=1536, 12 heads, head_dim=128) with factored 3D RoPE, LCSA self-attention, text cross-attention, AdaLN modulation, QK-norm with DistributedRMSNorm

Purpose: Video super-resolution (4x spatial upscaling, 480p → 1920p)

Checklist

Required Components

  • Accuracy Test (test/integration/test_dit_accuracy.py)

    • Validates DiT output against CPU reference using neuron_allclose(rtol=0.05, atol=0.1)
    • Measured: max_rel_error=0.025, max_abs_error=0.066, cosine_similarity=0.9997
    • Complementary cosine similarity assertion (>0.999 threshold)
    • 5% rtol justified for 30-layer BF16 DiT with TP=4 (cf. NxDI MLP tests use rtol=6e-2)
  • README.md with the following sections:

    • Usage Example: Clear code example showing compile → load → inference
    • Compatibility Matrix: trn2.3xlarge TP=4 LNC=2, SDK 2.29
    • Example Checkpoints: Link to JunhaoZhuang/FlashVSR-v1.1 on HuggingFace
    • Testing Instructions: pytest commands for accuracy and E2E tests
  • Source Code (src/)

    • modeling_flashvsr.py — NxDI-compatible DiT with Application/ModelWrapper/InferenceConfig (1242 lines)
    • pipeline.py — Full inference pipeline orchestration
    • tcdecoder.py — TCDecoder (latent → RGB) wrapper
    • lq_projection.py — LQ conditioning projection wrapper
    • weights.py — Weight format detection and conversion (DiffSynth/diffusers → Neuron)
    • download_weights.py — HuggingFace weight download utility

Optional Components

  • E2E Pipeline Test (test/integration/test_pipeline_e2e.py) — PSNR validation
  • Unit Tests — Not included

Folder Structure

/contrib/models/FlashVSR/
  README.md
  /src
    __init__.py
    modeling_flashvsr.py
    pipeline.py
    tcdecoder.py
    lq_projection.py
    weights.py
    download_weights.py
  /test
    __init__.py
    /integration
      __init__.py
      test_dit_accuracy.py
      test_pipeline_e2e.py

Testing

How did you test this change?

Compiled and tested on trn2.3xlarge (LNC=2, 4 NeuronCores) with Neuron SDK 2.29.1 (DLAMI 20260502, neuronx-cc 2.24.8799.0, NxDI 0.9.17334). DiT first-chunk compiled via NxDI ModelBuilder (TP=4, BF16) and validated against CPU reference model with identical weights.

Test Results:

neuron_allclose(rtol=0.05, atol=0.1):
  allclose: True
  max_rel_error: 0.025
  max_abs_error: 0.066
  cosine_similarity: 0.9997

DiT first-chunk latency (5 iters, post-warmup): 1540 ± 68 ms
Output shape: [1, 16, 6, 96, 160] (correct)
Weight loading: 0 missing, 0 unexpected keys

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.29.1 (neuronx-cc 2.24.8799.0)
  • Instance Type(s): trn2.3xlarge (LNC=2, TP=4)
  • PyTorch Version: 2.9.0
  • Python Version: 3.12

Additional Information

  • Uses attention_cte NKI kernel from nkilib for tiled flash attention (avoids materializing full S×S attention matrix in HBM)
  • DistributedRMSNorm for QK-norm with all-reduce across TP ranks
  • Single-step DMD (Distribution Matching Distillation) — one DiT forward pass per chunk
  • The attn_mask input is unused in Phase 1 (dense attention); kept for future Phase 2 LCSA block-sparse support on larger instances (trn2.48xlarge TP=16)
  • Weight conversion supports both DiffSynth and diffusers checkpoint formats

Related Issues

None.

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions

Not applicable — FlashVSR is a video generation model, not an LLM.


By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

DiT accuracy test now uses rtol=0.05, atol=0.1 (validated on trn2.3xlarge):
- max_rel_error: 0.025, cosine_similarity: 0.9997
- 5% rtol is standard for 30-layer BF16 DiT with TP=4 (cf NxDI MLP test rtol=6e-2)
- Added complementary cosine similarity check (>0.999 threshold)
- Updated README accuracy section with measured values
Migrate TCDecoder from torch_neuronx.trace() to NxDI ModelBuilder with
input_output_aliases for MemBlock state persistence in device HBM.

Key changes:
- tcdecoder.py: Add NeuronTCDecoderStateful (stateful nn.Module with
  9 state Parameters), TCDecoderApplication (ModelBuilder wrapper),
  and decode_video_nxdi() inference helper
- pipeline.py: compile_pipeline() now compiles TCDecoder via NxDI,
  load_pipeline() loads NxDI TCDecoder with legacy fallback,
  run_inference() dispatches to NxDI or trace path automatically
- __init__.py: Export new NxDI classes

Performance (validated on trn2.3xlarge, SDK 2.29.1):
- Per-frame latency: 78ms (vs 237ms trace baseline) = 3.0x faster
- Compilation: 2.1s with NEFF cache, ~226s fresh
- Accuracy: cosine >0.9995, neuron_allclose PASS (rtol=0.05)
- Output shape: (4, 3, 768, 1280) per frame
@jimburtoft
Copy link
Copy Markdown
Contributor Author

Update: NxDI TCDecoder with HBM State Persistence (3.0x decode speedup)

Migrated the TCDecoder from torch_neuronx.trace() to NxDI ModelBuilder with input_output_aliases for MemBlock state persistence in device HBM.

What changed

  • TCDecoder now uses NeuronTCDecoderStateful — an nn.Module with 9 state Parameters that persist in HBM between NEFF calls via input_output_aliases
  • compile_pipeline() compiles TCDecoder via NxDI ModelBuilder (no separate trace step needed)
  • load_pipeline() auto-detects and loads NxDI TCDecoder (with legacy trace fallback)
  • New decode_video_nxdi() function — states persist in HBM, no PCIe per frame

Performance (trn2.3xlarge, SDK 2.29.1)

Metric trace baseline NxDI (this PR) Improvement
Per-frame latency 237 ms 78 ms 3.04x faster
Total decode (22 frames) 5,210 ms 1,717 ms 3.03x faster
Compilation time ~5 min 2.1s (cached) / 226s (fresh) ~1.3x faster
Output shape (4, 3, 768, 1280) (4, 3, 768, 1280) Identical

Why it's faster

The trace-based approach transfers 9 MemBlock state tensors (total ~100MB) over PCIe on every frame call. With input_output_aliases, states remain in device HBM — the compiler writes updated states back to the same memory locations as zero-copy aliases. Only the 784-channel input frame crosses PCIe per call.

Validated

  • Compilation: PASS
  • Load + weight initialization: PASS
  • Output shape: PASS (4, 3, 768, 1280)
  • Latency: 78 ms/frame (target was <80ms)
  • Speedup: 3.04x (target was 3.0x)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant