Skip to content

Add LoRA-aware JAX-to-PyTorch checkpoint conversion#960

Open
Ret1ehS wants to merge 2 commits into
Physical-Intelligence:mainfrom
Ret1ehS:feat/lora-aware-conversion
Open

Add LoRA-aware JAX-to-PyTorch checkpoint conversion#960
Ret1ehS wants to merge 2 commits into
Physical-Intelligence:mainfrom
Ret1ehS:feat/lora-aware-conversion

Conversation

@Ret1ehS
Copy link
Copy Markdown

@Ret1ehS Ret1ehS commented Jun 2, 2026

Summary

This PR adds LoRA-aware handling to examples/convert_jax_model_to_pytorch.py.

Previously, converted checkpoints were loaded with load_state_dict(strict=False), so LoRA adapter tensors could appear as unexpected_keys and be silently ignored. This meant a LoRA-finetuned JAX checkpoint could be converted into a PyTorch checkpoint without the adapter updates being represented in the final weights.

This change folds LoRA adapter weights into the corresponding base weights before the existing slicing/load flow. For configs without LoRA adapters, the merge is a no-op and the previous conversion path is preserved.

Implementation notes

  • Adds merge_lora_into_base() to merge LoRA weights before slicing/loading.
  • Detects LoRA from the model config rather than from variant names.
  • Uses the configured PaliGemma and action-expert variants instead of hardcoded model variant strings.
  • Preserves the existing non-LoRA default precision behavior.
  • Defaults to float32 when LoRA adapters are present, because merged LoRA weights showed noticeably lower drift in float32 than bfloat16.
  • Raises an error if LoRA-related keys remain in unexpected_keys after loading, so failed merges do not silently produce divergent checkpoints.

OpenPI-specific LoRA details

openpi.models.lora.Einsum and openpi.models.lora.FeedForward apply LoRA slightly differently, so the merge follows the runtime behavior:

  • attn_vec_einsum sums over the head dimension in the second LoRA einsum, so the merged delta must match that cross-head behavior rather than using an independent per-head outer product.
  • FeedForward._dot() adds the LoRA path without applying scaling_value, while Einsum does apply scaling. The MLP merge therefore uses scale=1.0.

Validation

  • pre-commit passed.
  • ruff check . passed.
  • ruff format . passed.
  • python3 -m py_compile examples/convert_jax_model_to_pytorch.py passed.
  • git diff --check passed.

Additional local validation on a LoRA-finetuned checkpoint:

  • Before this change, LoRA adapter tensors appeared in unexpected_keys and were ignored.
  • After merging, no LoRA keys remained in unexpected_keys.
  • Corrected attn_vec_einsum merge matched the online LoRA computation closely:
    • max_abs = 0.000001669
    • mean_abs = 0.000000284
  • Float32 converted checkpoint matched the JAX reference more closely than bfloat16:
    • float32: mean_abs = 0.0000929647, max_abs = 0.0016714334
    • bfloat16: mean_abs = 0.0011840940, max_abs = 0.0170528293

Fixes #958.
Related: #840, #729, #810.

Ret1ehS added 2 commits June 2, 2026 16:59
The previous script silently dropped LoRA adapter weights via
load_state_dict(strict=False), producing PyTorch checkpoints that
diverged from the JAX original. This change:

- Detects LoRA configs and merges adapter weights into base weights
  before applying the existing slice flow
- Handles two non-standard LoRA quirks in openpi's runtime:
  * attn_vec_einsum requires sum_N(lora_b), not per-head outer product
  * MLP FeedForward._dot adds LoRA delta without alpha/rank scaling
- Defaults to float32 for the saved merged checkpoint to avoid bf16
  precision loss
- Asserts no unresolved LoRA keys after load

Fixes Physical-Intelligence#958
Related: Physical-Intelligence#840, Physical-Intelligence#729, Physical-Intelligence#810
@jimmyt857 jimmyt857 removed their request for review June 2, 2026 14:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

examples/convert_jax_model_to_pytorch.py silently drops LoRA adapter weights for LoRA-finetuned checkpoints

1 participant