Fix ZeRO-3 + PEFT mixed-dtype error for core trainers#6091
Merged
Conversation
…lse for non-quantized models
This reverts commit dc17985.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Member
|
Thanks, the other trainers don't require the same fix? |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Want higher recall? High effort reviews run extra passes and find more bugs. A team admin can switch effort levels in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit f515838. Configure here.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

Fixes a
TypeError: output tensor must have the same type as input tensorthat broke training with DeepSpeed ZeRO Stage 3 + non-quantized PEFT (LoRA) after deepspeed 0.19.2 was released.Fix #6089.
Motivation
DeepSpeed 0.19.2 changed
_configure_distributed_modelto skip the blanketmodule.bfloat16()cast for ZeRO-Init models:Before 0.19.2, that cast was accidentally unifying all parameter dtypes, including PEFT LoRA adapter parameters. After 0.19.2 the cast is skipped, exposing a latent bug in DeepSpeed's
_allgather_params_coalesced: output buffers are allocated using the dtype of the first persistent parameter, so when persistent parameters have mixed dtypes the subsequentall_gather_into_tensorcall raises aTypeError.The mixed-dtype situation arises because:
ds_tensor.dtype = bfloat16autocast_adapter_dtype=Trueupcasts LoRA adapter parameters to fp32 (intended for QLoRA stability, not needed for non-quantized bf16 training)persistent_parameters→ dtype mismatch on all-gatherA fix has been reported upstream:
However, on TRL's side we can add a short-term workaround. Note that TRL already does something analog for QLoRA at SFTTrainer. We can extend this to also handle non-quantized PEFT + ZeRO3 by passing
autocast_adapter_dtype=Falsetoget_peft_model(), which prevents PEFT from upcasting adapters to fp32 (keeping them in bf16 to match the base model).Solution
Pass
autocast_adapter_dtype=Falsetoget_peft_model()when ZeRO Stage 3 is active and the model is not quantized. This prevents PEFT from upcasting LoRA adapter parameters to fp32, keeping them in the base model's dtype (bf16) and eliminating the mismatch.The fp32 upcast (
autocast_adapter_dtype=True) is a QLoRA-specific concern: with a 4-bit quantized base model, higher-precision adapters compensate for the coarse weight representation. For non-quantized bf16 training, keeping LoRA adapters in bf16 is correct and causes no stability regression: this matches how FSDP2 handles non-quantized LoRA.The existing QLoRA workaround (manual cast to bf16 for
is_loaded_in_4bit/is_loaded_in_8bit) is left in place.Note
Medium Risk
Changes PEFT adapter dtype and DeepSpeed version constraints across multiple core trainers; affects distributed ZeRO-3 + LoRA training paths but is narrowly gated and leaves QLoRA behavior unchanged.
Overview
Fixes DeepSpeed ZeRO Stage 3 training with non-quantized LoRA, which started failing on the first optimizer step with a mixed-dtype
TypeErrorafter DeepSpeed 0.19.2 (issue #6089).DPO, GRPO, RLOO, Reward, and SFT trainers now call
get_peft_model(..., autocast_adapter_dtype=False)when ZeRO-3 is active, the model is not 4/8-bit quantized, and PEFT is ≥ 0.12.0, so LoRA weights stay in bf16 instead of being upcast to fp32. Quantized (QLoRA) paths still use the existing manual bf16 cast on trainable params. A shared_is_quantized_modelflag replaces repeatedgetattrchecks.Dependencies: the temporary
deepspeed<0.19.2cap is removed frompyproject.toml(deepspeedanddevextras), allowing newer DeepSpeed now that TRL handles the dtype mismatch.Reviewed by Cursor Bugbot for commit ed91b94. Bugbot is set up for automated code reviews on this repo. Configure here.