-
Notifications
You must be signed in to change notification settings - Fork 4.9k
[DeepCompile] fix gather params in dynamo skipped frames for ZeRO3 #8059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2449,9 +2449,29 @@ def forward(self, *inputs, **kwargs): | |
| # We can't have this in forward prologue as the compiler compiles hooks including the forward prologue. | ||
| self.launch_compile_passes(self.global_steps) | ||
|
|
||
| # When DeepCompile is active the per-module gather/release hooks are | ||
| # removed and all parameter gathering is handled by compiled graph ops. | ||
| # However, torch._dynamo may skip frames that contain graph breaks in | ||
| # loops. Skipped frames execute eagerly without the compiled ops, so | ||
| # the ZeROOrderedDict safety-net must be enabled to auto-gather any | ||
| # parameter accessed in those frames. | ||
| _dc_z3_eager_fallback = (self.is_deepcompile_active() and self.zero_optimization_partition_weights()) | ||
| if _dc_z3_eager_fallback: | ||
| for module in self.module.modules(): | ||
| if isinstance(module._parameters, ZeROOrderedDict): | ||
| module._parameters._in_forward = True | ||
|
|
||
| with autocast_if_enabled(self): | ||
| loss = self.module(*inputs, **kwargs) | ||
|
|
||
| if _dc_z3_eager_fallback: | ||
| for p in self.module.parameters(): | ||
| if hasattr(p, "ds_status") and p.ds_status == ZeroParamStatus.AVAILABLE and not p.ds_persist: | ||
| p.partition() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for fixing this! One question: I was wondering why we need to walk through all parameters and free those that are still gathered at this point? Does it mean the parameters gathered outside the compiled graphs are all alive till this point? If so, it can increase the peak GPU memory usage, which can hurt training efficiency in some cases. |
||
| for module in self.module.modules(): | ||
| if isinstance(module._parameters, ZeROOrderedDict): | ||
| module._parameters._in_forward = False | ||
|
|
||
| # Register output backward hooks | ||
| # preprocess_once_fn is called for preprocessing | ||
| # preprocess_per_tensor_fn scales a tensor for gradient accumulation | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
| """Regression test for https://github.com/deepspeedai/DeepSpeed/issues/7942 | ||
|
|
||
| When torch._dynamo skips a frame (e.g. because of a graph break inside a | ||
| for/while loop), the frame runs in eager mode. DeepCompile removes the | ||
| ZeRO-3 parameter-gathering hooks, so parameters accessed in the skipped | ||
| frame remain partitioned (shape ``[0]``). For an embedding layer this | ||
| causes ``RuntimeError: 'weight' must be 2-D``. | ||
|
|
||
| This test creates a model whose forward contains an embedding lookup | ||
| followed by a loop with a deliberate graph break, reproducing the pattern | ||
| that triggers the bug. | ||
| """ | ||
|
|
||
| import argparse | ||
| import torch | ||
| import deepspeed | ||
| from deepspeed.accelerator import get_accelerator | ||
| from deepspeed import comm | ||
|
|
||
| torch._dynamo.config.cache_size_limit = 100 | ||
|
|
||
|
|
||
| class SkippedFrameModel(torch.nn.Module): | ||
| """Model that triggers a dynamo frame skip. | ||
|
|
||
| ``forward`` contains an embedding lookup followed by a loop whose body | ||
| calls ``print`` (an opaque side-effect), which causes a graph break | ||
| inside the loop. Dynamo skips the entire frame, so the embedding lookup | ||
| runs in eager mode with ZeRO-3 partitioned weights. | ||
| """ | ||
|
|
||
| def __init__(self, vocab_size=128, hidden=64, n_layers=2): | ||
| super().__init__() | ||
| self.embed_tokens = torch.nn.Embedding(vocab_size, hidden) | ||
| self.layers = torch.nn.ModuleList([torch.nn.Linear(hidden, hidden, bias=False) for _ in range(n_layers)]) | ||
| self.head = torch.nn.Linear(hidden, vocab_size, bias=False) | ||
|
|
||
| def forward(self, input_ids): | ||
| h = self.embed_tokens(input_ids) | ||
| for layer in self.layers: | ||
| h = layer(h) | ||
| # graph break inside a loop body — dynamo skips the entire frame | ||
| if torch.compiler.is_compiling(): | ||
| torch._dynamo.graph_break() | ||
| return self.head(h) | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--local_rank", type=int, default=-1) | ||
| parser.add_argument("--deepspeed_config", type=str, default="ds_config_z3.json") | ||
| args = parser.parse_args() | ||
|
|
||
| model = SkippedFrameModel() | ||
| engine, _, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) | ||
| engine.compile() | ||
|
|
||
| device = get_accelerator().current_device_name() | ||
| input_ids = torch.randint(0, 128, (1, 16), device=device) | ||
|
|
||
| for step in range(3): | ||
| loss = engine(input_ids).sum() | ||
| engine.backward(loss) | ||
| engine.step() | ||
| if comm.get_rank() == 0: | ||
| print(f"step={step} loss={loss.item():.4f}") | ||
|
|
||
| if comm.get_rank() == 0: | ||
| print("PASS") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When a Dynamo-skipped frame executes eagerly and touches a ZeRO-3 parameter whose backward needs the weight (for example a
Linearafter an embedding, where grad-input must be computed), this loop immediately callsp.partition()after forward. In DeepCompile the ZeRO module backward hooks have been removed, anddeepcompile_backward_prologue()only starts the compiled runtime, so there is no eager fallback to re-gather that saved weight before the eager autograd node runs; backward will see the released[0]parameter/storage or compute from invalid state. The fallback-gathered params need to stay available until their eager backward use has completed, or get a matching backward-time gather/release path.Useful? React with 👍 / 👎.