Add FlashVSR contrib model with video super-resolution on Neuron#165
Open
jimburtoft wants to merge 4 commits into
Open
Add FlashVSR contrib model with video super-resolution on Neuron#165jimburtoft wants to merge 4 commits into
jimburtoft wants to merge 4 commits into
Conversation
DiT accuracy test now uses rtol=0.05, atol=0.1 (validated on trn2.3xlarge): - max_rel_error: 0.025, cosine_similarity: 0.9997 - 5% rtol is standard for 30-layer BF16 DiT with TP=4 (cf NxDI MLP test rtol=6e-2) - Added complementary cosine similarity check (>0.999 threshold) - Updated README accuracy section with measured values
Migrate TCDecoder from torch_neuronx.trace() to NxDI ModelBuilder with input_output_aliases for MemBlock state persistence in device HBM. Key changes: - tcdecoder.py: Add NeuronTCDecoderStateful (stateful nn.Module with 9 state Parameters), TCDecoderApplication (ModelBuilder wrapper), and decode_video_nxdi() inference helper - pipeline.py: compile_pipeline() now compiles TCDecoder via NxDI, load_pipeline() loads NxDI TCDecoder with legacy fallback, run_inference() dispatches to NxDI or trace path automatically - __init__.py: Export new NxDI classes Performance (validated on trn2.3xlarge, SDK 2.29.1): - Per-frame latency: 78ms (vs 237ms trace baseline) = 3.0x faster - Compilation: 2.1s with NEFF cache, ~226s fresh - Accuracy: cosine >0.9995, neuron_allclose PASS (rtol=0.05) - Output shape: (4, 3, 768, 1280) per frame
f6b9828 to
6db2715
Compare
Contributor
Author
Update: NxDI TCDecoder with HBM State Persistence (3.0x decode speedup)Migrated the TCDecoder from What changed
Performance (trn2.3xlarge, SDK 2.29.1)
Why it's fasterThe trace-based approach transfers 9 MemBlock state tensors (total ~100MB) over PCIe on every frame call. With Validated
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Note: The below template includes items meant for model contributions only. For other contributions such as bug fixes, features, etc., only fill out the relevant portions of the form.
Description
FlashVSR is a video super-resolution model (4x upscaling) using a streaming DiT architecture based on Wan 2.1 1.3B. This contrib packages the DiT backbone for Neuron via NxDI ModelBuilder with TP=4, using NKI tiled flash attention (
attention_cte) for 23040-token sequences that would otherwise OOM.The pipeline processes video in overlapping chunks: a first chunk (6 latent frames → 24 output frames) followed by stream chunks (2 latent frames → 8 output frames each). Single-step DMD denoising enables efficient 4x upscaling at 768×1280 output resolution.
Model Information
Model Name: FlashVSR v1.1 (JunhaoZhuang/FlashVSR-v1.1)
Model Architecture: 30-layer DiT (dim=1536, 12 heads, head_dim=128) with factored 3D RoPE, LCSA self-attention, text cross-attention, AdaLN modulation, QK-norm with DistributedRMSNorm
Purpose: Video super-resolution (4x spatial upscaling, 480p → 1920p)
Checklist
Required Components
Accuracy Test (
test/integration/test_dit_accuracy.py)neuron_allclose(rtol=0.05, atol=0.1)README.md with the following sections:
Source Code (
src/)modeling_flashvsr.py— NxDI-compatible DiT with Application/ModelWrapper/InferenceConfig (1242 lines)pipeline.py— Full inference pipeline orchestrationtcdecoder.py— TCDecoder (latent → RGB) wrapperlq_projection.py— LQ conditioning projection wrapperweights.py— Weight format detection and conversion (DiffSynth/diffusers → Neuron)download_weights.py— HuggingFace weight download utilityOptional Components
test/integration/test_pipeline_e2e.py) — PSNR validationFolder Structure
Testing
How did you test this change?
Compiled and tested on trn2.3xlarge (LNC=2, 4 NeuronCores) with Neuron SDK 2.29.1 (DLAMI 20260502, neuronx-cc 2.24.8799.0, NxDI 0.9.17334). DiT first-chunk compiled via NxDI ModelBuilder (TP=4, BF16) and validated against CPU reference model with identical weights.
Test Results:
Compatibility
Tested with:
Additional Information
attention_cteNKI kernel from nkilib for tiled flash attention (avoids materializing full S×S attention matrix in HBM)attn_maskinput is unused in Phase 1 (dense attention); kept for future Phase 2 LCSA block-sparse support on larger instances (trn2.48xlarge TP=16)Related Issues
None.
vLLM Integration
Not applicable — FlashVSR is a video generation model, not an LLM.
By submitting this PR, I confirm that: