Skip to content

fix(engine): use meta device for non-rank-0 in FSDP memory_efficient_load#1182

Merged
garrett4wade merged 1 commit intomainfrom
fix/fsdp-memory-efficient-load
Apr 16, 2026
Merged

fix(engine): use meta device for non-rank-0 in FSDP memory_efficient_load#1182
garrett4wade merged 1 commit intomainfrom
fix/fsdp-memory-efficient-load

Conversation

@yulangz
Copy link
Copy Markdown
Collaborator

@yulangz yulangz commented Apr 14, 2026

Summary

Fix CPU OOM when using fsdp.memory_efficient_load=True with multiple workers per node. Non-rank-0 processes now use meta device instead of cpu during model creation, reducing per-node CPU memory from ~512 GiB to ~64 GiB for a 32B model.

Problem

When memory_efficient_load=True, all ranks created the full model on CPU via from_config() with torch.device("cpu") context. For a 32B model (bf16, ~64 GiB) with 8 workers per node:

  • Each worker allocates ~64 GiB CPU RAM for the full model parameters
  • 8 workers × 64 GiB = ~512 GiB, far exceeding typical node memory (256 GiB)
  • Result: kernel OOM killer or Ray memory monitor kills workers before FSDP sharding begins

The original design intent was to avoid GPU OOM by loading on CPU first, then broadcasting sharded weights. However, it shifted the OOM problem from GPU to CPU.

Root Cause

In _create_device_model(), the loading device was unconditionally set to "cpu" for all ranks:

if self.config.fsdp.memory_efficient_load:
    loading_device = "cpu"  # All ranks allocate full model on CPU

Combined with from_config() (which creates real tensors, not meta tensors), every rank materialized the entire model in CPU RAM.

Fix

  • fsdp_engine.py: Only rank-0 loads on CPU; other ranks use "meta" device (zero memory cost). Weights are broadcast from rank-0 after FSDP sharding via fsdp2_load_full_state_dict with broadcast_from_rank0=True.
  • fsdp_utils/__init__.py: Handle meta→device conversion using to_empty() instead of .to() (which fails on meta tensors).
  • saver.py, recover.py: Move AutoProcessor import to TYPE_CHECKING block to avoid eagerly importing torchvision.

Test Results

Verified with Qwen3-32B, FSDP d4c8, 4 nodes × 8 L20Y GPUs (80 GiB), memory_efficient_load=True:

Metric Before (all ranks CPU) After (rank-0 CPU + others meta)
Non-rank-0 model creation ~60s 0.19s
Peak CPU RAM per node ~512 GiB → OOM ~64 GiB (rank-0 only)
GPU memory after train step N/A (crashed) 42.32 / 79.33 GiB
Training ❌ CPU OOM ✅ Completes successfully

The memory_efficient_load=False path is unchanged.

Copilot AI review requested due to automatic review settings April 14, 2026 08:33
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements memory-efficient loading for FSDP by initializing models on the 'meta' device for non-rank-0 processes and materializing them using to_empty() before state dict loading. It also refactors imports in recover.py and saver.py to use TYPE_CHECKING blocks. Feedback highlights that the current implementation is incomplete for vision models because the necessary broadcasting logic is not triggered for them. Additionally, it is recommended to extend the meta-tensor detection to include model buffers to ensure all components are correctly materialized.

Comment thread areal/engine/fsdp_engine.py Outdated
Comment on lines +872 to +875
if dist.get_rank() == 0:
loading_device = "cpu"
else:
loading_device = "meta"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of memory_efficient_load using the meta device for non-rank-0 processes is currently incomplete for vision models. While _create_device_model correctly sets the loading_device to "meta" for non-rank-0 ranks, the broadcasting logic in initialize() (lines 309-316) only triggers if is_llm_cpu_load is true or LoRA is used. Since is_llm_cpu_load explicitly excludes vision models (line 312), need_broadcast will remain False for VLMs, leaving non-rank-0 processes with uninitialized meta tensors. This will cause a crash during training when the model is accessed.

Comment thread areal/engine/fsdp_utils/__init__.py Outdated
# Handle meta device models (from memory_efficient_load where non-rank-0
# processes use meta tensors). to_empty() materializes meta tensors without
# copying data; the actual weights come from set_model_state_dict broadcast.
has_meta = any(p.device.type == "meta" for p in model.parameters())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check for meta tensors should also include model buffers. Some architectures use persistent buffers (e.g., for rotary embeddings or KV cache constants) that might be initialized on the meta device when using memory_efficient_load. These also need to be materialized before loading the state dict to avoid errors during the broadcast/load phase.

Suggested change
has_meta = any(p.device.type == "meta" for p in model.parameters())
has_meta = any(t.is_meta for t in model.parameters()) or any(t.is_meta for t in model.buffers())

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes excessive per-node CPU memory usage during FSDP initialization when fsdp.memory_efficient_load=True by avoiding full CPU model materialization on non-rank-0 workers, and reduces import-time side effects from transformers.

Changes:

  • Update FSDP engine initialization to create the model on meta for non-rank-0 ranks when memory_efficient_load is enabled.
  • Update FSDP full-state broadcast loading to materialize meta tensors via to_empty() before set_model_state_dict.
  • Defer AutoProcessor imports behind TYPE_CHECKING in saver/recover utilities to avoid eager torchvision import.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

File Description
areal/utils/saver.py Defers AutoProcessor import to type-checking time; adds future annotations.
areal/utils/recover.py Same deferred import pattern for AutoProcessor; adds future annotations.
areal/engine/fsdp_utils/init.py Uses to_empty() to materialize meta models before broadcasting state dict.
areal/engine/fsdp_engine.py Switches non-rank-0 model creation device to meta under memory_efficient_load.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread areal/engine/fsdp_engine.py Outdated
Comment on lines +869 to +873
# Only rank 0 loads on CPU; other ranks use meta device (zero memory)
# to avoid CPU OOM when multiple workers share a node.
# Weights are broadcast from rank 0 after FSDP sharding in initialize().
if dist.get_rank() == 0:
loading_device = "cpu"
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memory_efficient_load is documented/configured to not use rank-0 broadcast for VLMs (each rank loads weights independently on CPU), but this change sets loading_device="meta" for non-rank-0 whenever memory_efficient_load is enabled. For vision models this will instantiate the model on meta and then proceed without any later broadcast (since is_llm_cpu_load is false when self.is_vision_model), leaving meta tensors and likely breaking initialization. Consider gating the meta-device path to LLM-only (e.g., memory_efficient_load and not self.is_vision_model and not init_from_scratch), or explicitly force loading_device="cpu" for vision models when memory_efficient_load is enabled.

Suggested change
# Only rank 0 loads on CPU; other ranks use meta device (zero memory)
# to avoid CPU OOM when multiple workers share a node.
# Weights are broadcast from rank 0 after FSDP sharding in initialize().
if dist.get_rank() == 0:
loading_device = "cpu"
# For LLMs, only rank 0 loads on CPU; other ranks use meta device
# (zero memory) to avoid CPU OOM when multiple workers share a node.
# Weights are broadcast from rank 0 after FSDP sharding in initialize().
# Vision models do not use that broadcast path, so every rank must load
# them on CPU even when memory_efficient_load is enabled.
if self.is_vision_model:
loading_device = "cpu"
elif dist.get_rank() == 0:
loading_device = "cpu"

Copilot uses AI. Check for mistakes.
…load

When memory_efficient_load=True, all ranks previously created the full
model on CPU via from_config(), causing CPU OOM on nodes with limited
RAM (e.g. 8 workers × 64GB = 512GB on a 256GB node).

Now only rank-0 loads on CPU; other ranks use meta device (zero memory
cost). Weights are broadcast from rank-0 after FSDP sharding via
fsdp2_load_full_state_dict with broadcast_from_rank0=True.

Also defer AutoProcessor import in saver.py and recover.py to avoid
importing torchvision eagerly (which can crash when torchvision version
mismatches torch).

Key changes:
- fsdp_engine.py: non-rank-0 uses "meta" device in memory_efficient_load
- fsdp_utils/__init__.py: use to_empty() for meta→device conversion
- saver.py, recover.py: move AutoProcessor to TYPE_CHECKING block
@yulangz yulangz force-pushed the fix/fsdp-memory-efficient-load branch from ca0c70d to bfff796 Compare April 14, 2026 08:45
@yulangz yulangz added the safe-to-test Ready to run unit-tests in a PR. label Apr 14, 2026
Copy link
Copy Markdown
Collaborator

@garrett4wade garrett4wade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@garrett4wade
Copy link
Copy Markdown
Collaborator

CI blocks due to other errors. Merging this PR.

@garrett4wade garrett4wade merged commit f34bea8 into main Apr 16, 2026
11 of 13 checks passed
@garrett4wade garrett4wade deleted the fix/fsdp-memory-efficient-load branch April 16, 2026 02:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

safe-to-test Ready to run unit-tests in a PR.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants