Add LoRA-aware JAX-to-PyTorch checkpoint conversion#960
Open
Ret1ehS wants to merge 2 commits into
Open
Conversation
The previous script silently dropped LoRA adapter weights via load_state_dict(strict=False), producing PyTorch checkpoints that diverged from the JAX original. This change: - Detects LoRA configs and merges adapter weights into base weights before applying the existing slice flow - Handles two non-standard LoRA quirks in openpi's runtime: * attn_vec_einsum requires sum_N(lora_b), not per-head outer product * MLP FeedForward._dot adds LoRA delta without alpha/rank scaling - Defaults to float32 for the saved merged checkpoint to avoid bf16 precision loss - Asserts no unresolved LoRA keys after load Fixes Physical-Intelligence#958 Related: Physical-Intelligence#840, Physical-Intelligence#729, Physical-Intelligence#810
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds LoRA-aware handling to
examples/convert_jax_model_to_pytorch.py.Previously, converted checkpoints were loaded with
load_state_dict(strict=False), so LoRA adapter tensors could appear asunexpected_keysand be silently ignored. This meant a LoRA-finetuned JAX checkpoint could be converted into a PyTorch checkpoint without the adapter updates being represented in the final weights.This change folds LoRA adapter weights into the corresponding base weights before the existing slicing/load flow. For configs without LoRA adapters, the merge is a no-op and the previous conversion path is preserved.
Implementation notes
merge_lora_into_base()to merge LoRA weights before slicing/loading.float32when LoRA adapters are present, because merged LoRA weights showed noticeably lower drift in float32 than bfloat16.unexpected_keysafter loading, so failed merges do not silently produce divergent checkpoints.OpenPI-specific LoRA details
openpi.models.lora.Einsumandopenpi.models.lora.FeedForwardapply LoRA slightly differently, so the merge follows the runtime behavior:attn_vec_einsumsums over the head dimension in the second LoRA einsum, so the merged delta must match that cross-head behavior rather than using an independent per-head outer product.FeedForward._dot()adds the LoRA path without applyingscaling_value, whileEinsumdoes apply scaling. The MLP merge therefore usesscale=1.0.Validation
pre-commitpassed.ruff check .passed.ruff format .passed.python3 -m py_compile examples/convert_jax_model_to_pytorch.pypassed.git diff --checkpassed.Additional local validation on a LoRA-finetuned checkpoint:
unexpected_keysand were ignored.unexpected_keys.attn_vec_einsummerge matched the online LoRA computation closely:max_abs = 0.000001669mean_abs = 0.000000284mean_abs = 0.0000929647,max_abs = 0.0016714334mean_abs = 0.0011840940,max_abs = 0.0170528293Fixes #958.
Related: #840, #729, #810.