fix(engine): use meta device for non-rank-0 in FSDP memory_efficient_load#1182
fix(engine): use meta device for non-rank-0 in FSDP memory_efficient_load#1182garrett4wade merged 1 commit intomainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements memory-efficient loading for FSDP by initializing models on the 'meta' device for non-rank-0 processes and materializing them using to_empty() before state dict loading. It also refactors imports in recover.py and saver.py to use TYPE_CHECKING blocks. Feedback highlights that the current implementation is incomplete for vision models because the necessary broadcasting logic is not triggered for them. Additionally, it is recommended to extend the meta-tensor detection to include model buffers to ensure all components are correctly materialized.
| if dist.get_rank() == 0: | ||
| loading_device = "cpu" | ||
| else: | ||
| loading_device = "meta" |
There was a problem hiding this comment.
The implementation of memory_efficient_load using the meta device for non-rank-0 processes is currently incomplete for vision models. While _create_device_model correctly sets the loading_device to "meta" for non-rank-0 ranks, the broadcasting logic in initialize() (lines 309-316) only triggers if is_llm_cpu_load is true or LoRA is used. Since is_llm_cpu_load explicitly excludes vision models (line 312), need_broadcast will remain False for VLMs, leaving non-rank-0 processes with uninitialized meta tensors. This will cause a crash during training when the model is accessed.
| # Handle meta device models (from memory_efficient_load where non-rank-0 | ||
| # processes use meta tensors). to_empty() materializes meta tensors without | ||
| # copying data; the actual weights come from set_model_state_dict broadcast. | ||
| has_meta = any(p.device.type == "meta" for p in model.parameters()) |
There was a problem hiding this comment.
The check for meta tensors should also include model buffers. Some architectures use persistent buffers (e.g., for rotary embeddings or KV cache constants) that might be initialized on the meta device when using memory_efficient_load. These also need to be materialized before loading the state dict to avoid errors during the broadcast/load phase.
| has_meta = any(p.device.type == "meta" for p in model.parameters()) | |
| has_meta = any(t.is_meta for t in model.parameters()) or any(t.is_meta for t in model.buffers()) |
There was a problem hiding this comment.
Pull request overview
Fixes excessive per-node CPU memory usage during FSDP initialization when fsdp.memory_efficient_load=True by avoiding full CPU model materialization on non-rank-0 workers, and reduces import-time side effects from transformers.
Changes:
- Update FSDP engine initialization to create the model on
metafor non-rank-0 ranks whenmemory_efficient_loadis enabled. - Update FSDP full-state broadcast loading to materialize meta tensors via
to_empty()beforeset_model_state_dict. - Defer
AutoProcessorimports behindTYPE_CHECKINGin saver/recover utilities to avoid eagertorchvisionimport.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| areal/utils/saver.py | Defers AutoProcessor import to type-checking time; adds future annotations. |
| areal/utils/recover.py | Same deferred import pattern for AutoProcessor; adds future annotations. |
| areal/engine/fsdp_utils/init.py | Uses to_empty() to materialize meta models before broadcasting state dict. |
| areal/engine/fsdp_engine.py | Switches non-rank-0 model creation device to meta under memory_efficient_load. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Only rank 0 loads on CPU; other ranks use meta device (zero memory) | ||
| # to avoid CPU OOM when multiple workers share a node. | ||
| # Weights are broadcast from rank 0 after FSDP sharding in initialize(). | ||
| if dist.get_rank() == 0: | ||
| loading_device = "cpu" |
There was a problem hiding this comment.
memory_efficient_load is documented/configured to not use rank-0 broadcast for VLMs (each rank loads weights independently on CPU), but this change sets loading_device="meta" for non-rank-0 whenever memory_efficient_load is enabled. For vision models this will instantiate the model on meta and then proceed without any later broadcast (since is_llm_cpu_load is false when self.is_vision_model), leaving meta tensors and likely breaking initialization. Consider gating the meta-device path to LLM-only (e.g., memory_efficient_load and not self.is_vision_model and not init_from_scratch), or explicitly force loading_device="cpu" for vision models when memory_efficient_load is enabled.
| # Only rank 0 loads on CPU; other ranks use meta device (zero memory) | |
| # to avoid CPU OOM when multiple workers share a node. | |
| # Weights are broadcast from rank 0 after FSDP sharding in initialize(). | |
| if dist.get_rank() == 0: | |
| loading_device = "cpu" | |
| # For LLMs, only rank 0 loads on CPU; other ranks use meta device | |
| # (zero memory) to avoid CPU OOM when multiple workers share a node. | |
| # Weights are broadcast from rank 0 after FSDP sharding in initialize(). | |
| # Vision models do not use that broadcast path, so every rank must load | |
| # them on CPU even when memory_efficient_load is enabled. | |
| if self.is_vision_model: | |
| loading_device = "cpu" | |
| elif dist.get_rank() == 0: | |
| loading_device = "cpu" |
…load When memory_efficient_load=True, all ranks previously created the full model on CPU via from_config(), causing CPU OOM on nodes with limited RAM (e.g. 8 workers × 64GB = 512GB on a 256GB node). Now only rank-0 loads on CPU; other ranks use meta device (zero memory cost). Weights are broadcast from rank-0 after FSDP sharding via fsdp2_load_full_state_dict with broadcast_from_rank0=True. Also defer AutoProcessor import in saver.py and recover.py to avoid importing torchvision eagerly (which can crash when torchvision version mismatches torch). Key changes: - fsdp_engine.py: non-rank-0 uses "meta" device in memory_efficient_load - fsdp_utils/__init__.py: use to_empty() for meta→device conversion - saver.py, recover.py: move AutoProcessor to TYPE_CHECKING block
ca0c70d to
bfff796
Compare
|
CI blocks due to other errors. Merging this PR. |
Summary
Fix CPU OOM when using
fsdp.memory_efficient_load=Truewith multiple workers per node. Non-rank-0 processes now usemetadevice instead ofcpuduring model creation, reducing per-node CPU memory from ~512 GiB to ~64 GiB for a 32B model.Problem
When
memory_efficient_load=True, all ranks created the full model on CPU viafrom_config()withtorch.device("cpu")context. For a 32B model (bf16, ~64 GiB) with 8 workers per node:The original design intent was to avoid GPU OOM by loading on CPU first, then broadcasting sharded weights. However, it shifted the OOM problem from GPU to CPU.
Root Cause
In
_create_device_model(), the loading device was unconditionally set to"cpu"for all ranks:Combined with
from_config()(which creates real tensors, not meta tensors), every rank materialized the entire model in CPU RAM.Fix
fsdp_engine.py: Only rank-0 loads on CPU; other ranks use"meta"device (zero memory cost). Weights are broadcast from rank-0 after FSDP sharding viafsdp2_load_full_state_dictwithbroadcast_from_rank0=True.fsdp_utils/__init__.py: Handle meta→device conversion usingto_empty()instead of.to()(which fails on meta tensors).saver.py,recover.py: MoveAutoProcessorimport toTYPE_CHECKINGblock to avoid eagerly importingtorchvision.Test Results
Verified with Qwen3-32B, FSDP
d4c8, 4 nodes × 8 L20Y GPUs (80 GiB),memory_efficient_load=True:The
memory_efficient_load=Falsepath is unchanged.