Skip to content

Fix ZeRO-3 + PEFT mixed-dtype error for core trainers#6091

Merged
albertvillanova merged 19 commits into
mainfrom
fix-6089
Jun 19, 2026
Merged

Fix ZeRO-3 + PEFT mixed-dtype error for core trainers#6091
albertvillanova merged 19 commits into
mainfrom
fix-6089

Conversation

@albertvillanova

@albertvillanova albertvillanova commented Jun 17, 2026

Copy link
Copy Markdown
Member

Fixes a TypeError: output tensor must have the same type as input tensor that broke training with DeepSpeed ZeRO Stage 3 + non-quantized PEFT (LoRA) after deepspeed 0.19.2 was released.

Fix #6089.

Motivation

DeepSpeed 0.19.2 changed _configure_distributed_model to skip the blanket module.bfloat16() cast for ZeRO-Init models:

Before 0.19.2, that cast was accidentally unifying all parameter dtypes, including PEFT LoRA adapter parameters. After 0.19.2 the cast is skipped, exposing a latent bug in DeepSpeed's _allgather_params_coalesced: output buffers are allocated using the dtype of the first persistent parameter, so when persistent parameters have mixed dtypes the subsequent all_gather_into_tensor call raises a TypeError.

The mixed-dtype situation arises because:

  1. The base model is loaded in bf16 via ZeRO-Init → base model parameters have ds_tensor.dtype = bfloat16
  2. PEFT's default autocast_adapter_dtype=True upcasts LoRA adapter parameters to fp32 (intended for QLoRA stability, not needed for non-quantized bf16 training)
  3. Both base model params and LoRA params end up in persistent_parameters → dtype mismatch on all-gather

A fix has been reported upstream:

However, on TRL's side we can add a short-term workaround. Note that TRL already does something analog for QLoRA at SFTTrainer. We can extend this to also handle non-quantized PEFT + ZeRO3 by passing autocast_adapter_dtype=False to get_peft_model(), which prevents PEFT from upcasting adapters to fp32 (keeping them in bf16 to match the base model).

Solution

Pass autocast_adapter_dtype=False to get_peft_model() when ZeRO Stage 3 is active and the model is not quantized. This prevents PEFT from upcasting LoRA adapter parameters to fp32, keeping them in the base model's dtype (bf16) and eliminating the mismatch.

The fp32 upcast (autocast_adapter_dtype=True) is a QLoRA-specific concern: with a 4-bit quantized base model, higher-precision adapters compensate for the coarse weight representation. For non-quantized bf16 training, keeping LoRA adapters in bf16 is correct and causes no stability regression: this matches how FSDP2 handles non-quantized LoRA.

The existing QLoRA workaround (manual cast to bf16 for is_loaded_in_4bit/is_loaded_in_8bit) is left in place.


Note

Medium Risk
Changes PEFT adapter dtype and DeepSpeed version constraints across multiple core trainers; affects distributed ZeRO-3 + LoRA training paths but is narrowly gated and leaves QLoRA behavior unchanged.

Overview
Fixes DeepSpeed ZeRO Stage 3 training with non-quantized LoRA, which started failing on the first optimizer step with a mixed-dtype TypeError after DeepSpeed 0.19.2 (issue #6089).

DPO, GRPO, RLOO, Reward, and SFT trainers now call get_peft_model(..., autocast_adapter_dtype=False) when ZeRO-3 is active, the model is not 4/8-bit quantized, and PEFT is ≥ 0.12.0, so LoRA weights stay in bf16 instead of being upcast to fp32. Quantized (QLoRA) paths still use the existing manual bf16 cast on trainable params. A shared _is_quantized_model flag replaces repeated getattr checks.

Dependencies: the temporary deepspeed<0.19.2 cap is removed from pyproject.toml (deepspeed and dev extras), allowing newer DeepSpeed now that TRL handles the dtype mismatch.

Reviewed by Cursor Bugbot for commit ed91b94. Bugbot is set up for automated code reviews on this repo. Configure here.

@bot-ci-comment

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec

Copy link
Copy Markdown
Member

Thanks, the other trainers don't require the same fix?

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Want higher recall? High effort reviews run extra passes and find more bugs. A team admin can switch effort levels in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit f515838. Configure here.

Comment thread trl/trainer/dpo_trainer.py Outdated
@albertvillanova albertvillanova changed the title Fix ZeRO-3 + PEFT (LoRA) TypeError caused by mixed-dtype persistent parameters Fix ZeRO-3 + PEFT (LoRA) TypeError caused by mixed-dtype persistent parameters for core trainers Jun 18, 2026
@albertvillanova albertvillanova changed the title Fix ZeRO-3 + PEFT (LoRA) TypeError caused by mixed-dtype persistent parameters for core trainers Fix ZeRO-3 + PEFT mixed-dtype TypeError for core trainers Jun 18, 2026
@albertvillanova albertvillanova changed the title Fix ZeRO-3 + PEFT mixed-dtype TypeError for core trainers Fix ZeRO-3 + PEFT mixed-dtype error for core trainers Jun 18, 2026

@qgallouedec qgallouedec left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

@albertvillanova albertvillanova merged commit bf6a7b5 into main Jun 19, 2026
13 checks passed
@albertvillanova albertvillanova deleted the fix-6089 branch June 19, 2026 08:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CI fails for distributed smoke tests: TypeError: output tensor must have the same type as input tensor

2 participants