Skip to content

Add Jepa 2.1 contrib model#160

Open
dstair wants to merge 7 commits into
aws-neuron:mainfrom
dstair:jepa2
Open

Add Jepa 2.1 contrib model#160
dstair wants to merge 7 commits into
aws-neuron:mainfrom
dstair:jepa2

Conversation

@dstair
Copy link
Copy Markdown
Contributor

@dstair dstair commented May 7, 2026

Description

NxDI contrib implementation of V-JEPA 2.1, Meta's self-supervised video foundation model. V-JEPA 2.1 is a Vision Transformer encoder that learns visual representations by predicting masked video segments in representation space. This is a vision encoder — not a causal language model — compiled for inference on AWS Trainium via torch_neuronx.trace().

Key architecture features ported:

  • 3D RoPE: Separate depth/height/width rotations on head_dim slices, using repeat_interleave layout
  • Conv3d tubelet embedding: 3D convolution for video patch embedding (patch_size=16, tubelet_size=2)
  • Hierarchical output: Normed features from 4 intermediate layers
  • Modality embeddings: Separate learned embeddings for image vs video inputs
  • Self-contained port: ~700 lines, no upstream vjepa2 imports at runtime

Model Information

Model Name: V-JEPA 2.1 (vit_base, vit_large, vit_giant, vit_gigantic)

Model Architecture: Vision Transformer encoder with 3D RoPE (86M–1.8B params)

Purpose: Self-supervised video representation learning (feature extraction, not text generation)

Checklist

Required Components

  • Accuracy Test (test/integration/test_model.py)

    • Integration test validates Neuron vs CPU accuracy via neuron_allclose (rtol=0.01)
    • Test can compile and run the model on Neuron (validated on trn2.3xlarge and trn2.48xlarge)
    • Pretrained weight validation: cosine similarity 0.9998–1.0002 across ViT-B/L/g/G configurations (test/integration/test_pretrained_smoke.py)
  • README.md with the following sections:

    • Usage Example: CPU inference, Neuron compilation, DataParallel
    • Compatibility Matrix: trn2.3xlarge and trn2.48xlarge with SDK 2.28
    • Example Checkpoints: Meta's pretrained weights (auto-download from dl.fbaipublicfiles.com/vjepa2/)
    • Testing Instructions: Commands to run unit and integration test suites
  • Source Code (src/)

    • modeling_jepa21.py (~700 lines): Self-contained encoder implementation, no upstream imports
    • Properly structured in the contrib folder hierarchy

Optional Components

  • Unit Tests (CPU-based, no Neuron device required)
    • test_encoder.py — Construction: 4/4 PASS (ViT-B/L/g construction, invalid arch)
    • test_encoder.py — Forward: 6/6 PASS (video/image/batch shapes, hierarchical output, determinism, resolution)
    • test_encoder.py — Components: 4/4 PASS (PatchEmbed3D, RoPEAttention, Block)

Not Applicable (vision encoder, not causal LM)

  • vLLM Integration — not applicable (not a text generation model)
  • TPOT/TTFT benchmarks — not applicable (no token generation)
  • Logit divergence test — not applicable (no autoregressive decoding)
  • On-device sampling — not applicable

Folder Structure

contrib/models/jepa-2-1/
├── README.md
├── AGENT.md                          # Technical reference for coding agents
├── pyproject.toml
├── examples/
│   └── demo_classify.py              # CPU video classification demo (HF V-JEPA 2 + SSv2)
├── src/
│   ├── __init__.py
│   └── modeling_jepa21.py            # Self-contained encoder (3D RoPE, Conv3d)
└── test/
    ├── __init__.py
    ├── unit/
    │   ├── __init__.py
    │   └── test_encoder.py           # CPU-only: construction, forward, components (14 tests)
    └── integration/
        ├── __init__.py
        ├── test_model.py             # Neuron: trace, accuracy, ViT-B/L (4 tests)
        └── test_pretrained_smoke.py  # Pretrained: CPU + Neuron validation (5 tests)

Testing

How did you test this change?

All tests run on a trn2.3xlarge instance (2 NeuronCores, sa-east-1) using the Neuron SDK 2.28 venv (/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/). Unit tests run on CPU only. Integration tests compile and run the model on Neuron hardware. Pretrained smoke tests download official Meta weights and validate BF16 Neuron output against FP32 CPU reference.

Test Results:

# Unit tests (CPU only)
============================= test session starts ==============================
test/unit/test_encoder.py::TestEncoderConstruction::test_vit_base_construction PASSED
test/unit/test_encoder.py::TestEncoderConstruction::test_vit_large_construction PASSED
test/unit/test_encoder.py::TestEncoderConstruction::test_vit_giant_construction PASSED
test/unit/test_encoder.py::TestEncoderConstruction::test_invalid_arch_raises PASSED
test/unit/test_encoder.py::TestEncoderForward::test_video_forward_shape PASSED
test/unit/test_encoder.py::TestEncoderForward::test_image_forward_shape PASSED
test/unit/test_encoder.py::TestEncoderForward::test_batch_forward PASSED
test/unit/test_encoder.py::TestEncoderForward::test_hierarchical_output PASSED
test/unit/test_encoder.py::TestEncoderForward::test_output_deterministic PASSED
test/unit/test_encoder.py::TestEncoderForward::test_256_resolution PASSED
test/unit/test_encoder.py::TestEncoderComponents::test_patch_embed_3d PASSED
test/unit/test_encoder.py::TestEncoderComponents::test_patch_embed_3d_image PASSED
test/unit/test_encoder.py::TestEncoderComponents::test_rope_attention PASSED
test/unit/test_encoder.py::TestEncoderComponents::test_block PASSED
======================= 14 passed in 31.43s ========================

# Pretrained smoke tests — CPU
test/integration/test_pretrained_smoke.py::TestPretrainedCPU::test_pretrained_loads PASSED
test/integration/test_pretrained_smoke.py::TestPretrainedCPU::test_pretrained_forward_shape PASSED
test/integration/test_pretrained_smoke.py::TestPretrainedCPU::test_pretrained_no_nan PASSED
=================== 3 passed in 96.53s (0:01:36) ====================

# Pretrained smoke tests — Neuron
test/integration/test_pretrained_smoke.py::TestPretrainedNeuron::test_pretrained_neuron_vs_cpu PASSED
test/integration/test_pretrained_smoke.py::TestPretrainedNeuron::test_pretrained_neuron_no_nan PASSED
================== 2 passed in 849.27s (0:14:09) ===================

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.28
  • Instance Type(s): trn2.3xlarge (ViT-B, ViT-L), trn2.48xlarge (ViT-g, ViT-G)
  • PyTorch Version: 2.9.0
  • Python Version: 3.12
Instance NeuronCores Status Notes
trn2.3xlarge 2 PASS ViT-B and ViT-L compiled and benchmarked
trn2.48xlarge 64 PASS ViT-g and ViT-G compiled and benchmarked (ViT-G video exceeds graph limit)

Additional Information

Key porting decisions:

  1. F.scaled_dot_product_attention is not supported by torch_neuronx.trace() — replaced with manual Q @ K^T * scale → softmax → @ V path (use_sdpa=False).
  2. Self-contained port: all upstream imports from Meta's vjepa2 repo replaced with inline implementations. No runtime dependency on the upstream repo.
  3. Conv3d tubelet embedding and 3D RoPE with repeat_interleave both compile natively on Neuron — no workarounds needed.

Known limitations:

  • ViT-G video (16 frames) exceeds neuronx-cc's 10M instruction limit; requires parallel_model_trace to split across NeuronCores
  • ViT-g/G require trn2.48xlarge for compilation (>130GB host RAM); compiled models run on any trn2 instance
  • Not a causal LM — no vLLM integration, no KV cache, no token generation

Performance (single NeuronCore, batch=1, BF16, 16 frames, 384×384):

Model Median Latency
ViT-B (86M) 247.4 ms
ViT-L (300M) 741.8 ms
ViT-g (1.01B) 1029.5 ms

Related Issues

N/A — initial contribution.

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions

Not applicable. V-JEPA 2.1 is a vision encoder, not a text generation model. It does not use KV cache, token generation, or autoregressive decoding.


By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant