Skip to content

Add QwenImageEditLoRA contrib model: Multi-LoRA diffusion on Trainium2#166

Open
jimburtoft wants to merge 1 commit into
aws-neuron:mainfrom
jimburtoft:contrib/qwen-image-edit-lora
Open

Add QwenImageEditLoRA contrib model: Multi-LoRA diffusion on Trainium2#166
jimburtoft wants to merge 1 commit into
aws-neuron:mainfrom
jimburtoft:contrib/qwen-image-edit-lora

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

Summary

  • Adds contrib model for Qwen-Image-Edit-2511 with runtime multi-LoRA adapter switching on trn2.3xlarge
  • Uses ModelBuilder SPMD trace with input_output_aliases for zero-copy LoRA buffer updates (<1ms, no recompilation)
  • 60-layer MMDiT transformer compiled with NKI Flash Attention, TP=4 sharding

Model Details

Property Value
HuggingFace ID Qwen/Qwen-Image-Edit-2511
LoRA Adapter fal/Qwen-Image-Edit-2511-Multiple-Angles-LoRA
Parameters ~28B total (~20.4B transformer)
Instance trn2.3xlarge (LNC=2, TP=4)
SDK Neuron SDK 2.29, PyTorch 2.9

Key Features

  • Runtime LoRA switching: write_to_neuron_buffer() updates 1,680 LoRA buffers (840 A + 840 B) at runtime without recompilation
  • 14 LoRA targets per block: All Q/K/V/O projections + MLP up/gate/down for both image and text streams
  • NKI Flash Attention: Custom kernel for fused scaled dot-product attention
  • Multi-resolution support: Compile once per resolution; spatial dims parametrize patch count

Benchmark (1024x1024, 28 steps, CFG)

  • Per-step latency: 1,541 ms
  • Full generation: 87.3s (with LoRA)
  • LoRA overhead: <1%

Integration Tests

All 5 tests pass on trn2.3xlarge:

  1. test_compilation_produces_artifacts — verifies NEFF + weight files
  2. test_inference_produces_valid_output — non-zero, finite outputs
  3. test_lora_aliasing_changes_output — LoRA injection changes output
  4. test_lora_zero_restores_baseline — zeroing LoRA restores deterministic baseline
  5. test_lora_buffer_count — exactly 840 A + 840 B keys per rank

Files

contrib/models/QwenImageEditLoRA/
├── README.md
├── src/
│   ├── __init__.py
│   ├── modeling_qwen_image_edit_lora.py  (1261 lines)
│   └── neuron_parallel_utils.py          (593 lines)
└── test/
    ├── __init__.py
    ├── integration/
    │   ├── __init__.py
    │   └── test_model.py                 (305 lines)
    └── unit/
        └── __init__.py

Multi-LoRA runtime adapter switching for Qwen-Image-Edit-2511 on trn2.
Uses ModelBuilder SPMD trace with input_output_aliases for zero-copy
LoRA buffer updates at runtime (<1ms, no recompilation).

- 60-layer MMDiT transformer with NKI Flash Attention
- TP=4 sharding (trn2.3xlarge, LNC=2)
- 14 LoRA targets per block (840 A + 840 B buffers)
- Integration tests: compilation, inference, aliasing, buffer count
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant