Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions deepspeed/compile/init_z3.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None):
dc = get_deepcompile_handle()
dc.init(engine.data_parallel_group, compile_config, engine.zero_reduce_bucket_size())

# Unset hooks
for m in engine.module.modules():
m._parameters = m._original_parameters
# Keep ZeROOrderedDict as a fallback for dynamo-skipped frames that
# run eagerly without compiled allgather ops.

if use_opt:
optimizer.parameter_offload._remove_module_hooks()
Expand Down
20 changes: 20 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2449,9 +2449,29 @@ def forward(self, *inputs, **kwargs):
# We can't have this in forward prologue as the compiler compiles hooks including the forward prologue.
self.launch_compile_passes(self.global_steps)

# When DeepCompile is active the per-module gather/release hooks are
# removed and all parameter gathering is handled by compiled graph ops.
# However, torch._dynamo may skip frames that contain graph breaks in
# loops. Skipped frames execute eagerly without the compiled ops, so
# the ZeROOrderedDict safety-net must be enabled to auto-gather any
# parameter accessed in those frames.
_dc_z3_eager_fallback = (self.is_deepcompile_active() and self.zero_optimization_partition_weights())
if _dc_z3_eager_fallback:
for module in self.module.modules():
if isinstance(module._parameters, ZeROOrderedDict):
module._parameters._in_forward = True

with autocast_if_enabled(self):
loss = self.module(*inputs, **kwargs)

if _dc_z3_eager_fallback:
for p in self.module.parameters():
if hasattr(p, "ds_status") and p.ds_status == ZeroParamStatus.AVAILABLE and not p.ds_persist:
p.partition()
Comment on lines +2468 to +2470

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Don't free eager-gathered weights before backward

When a Dynamo-skipped frame executes eagerly and touches a ZeRO-3 parameter whose backward needs the weight (for example a Linear after an embedding, where grad-input must be computed), this loop immediately calls p.partition() after forward. In DeepCompile the ZeRO module backward hooks have been removed, and deepcompile_backward_prologue() only starts the compiled runtime, so there is no eager fallback to re-gather that saved weight before the eager autograd node runs; backward will see the released [0] parameter/storage or compute from invalid state. The fallback-gathered params need to stay available until their eager backward use has completed, or get a matching backward-time gather/release path.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

One question: I was wondering why we need to walk through all parameters and free those that are still gathered at this point? Does it mean the parameters gathered outside the compiled graphs are all alive till this point? If so, it can increase the peak GPU memory usage, which can hurt training efficiency in some cases.

for module in self.module.modules():
if isinstance(module._parameters, ZeROOrderedDict):
module._parameters._in_forward = False

# Register output backward hooks
# preprocess_once_fn is called for preprocessing
# preprocess_per_tensor_fn scales a tensor for gradient accumulation
Expand Down
3 changes: 1 addition & 2 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@ def __getitem__(self, key):
if param is None:
return param

# TODO: only weaken this check during compilation
if hasattr(param, "ds_status") and param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if self._parent_module._parameters._in_forward:
if self._parent_module._parameters._in_forward and not torch.compiler.is_compiling():
register_external_parameter(FWD_MODULE_STACK[-1], param)
param.all_gather()
print_rank_0(f'Registering external parameter from getter {key} ds_id = {param.ds_id}', force=False)
Expand Down
77 changes: 77 additions & 0 deletions tests/torch_compile/test_deepcompile_skipped_frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
"""Regression test for https://github.com/deepspeedai/DeepSpeed/issues/7942

When torch._dynamo skips a frame (e.g. because of a graph break inside a
for/while loop), the frame runs in eager mode. DeepCompile removes the
ZeRO-3 parameter-gathering hooks, so parameters accessed in the skipped
frame remain partitioned (shape ``[0]``). For an embedding layer this
causes ``RuntimeError: 'weight' must be 2-D``.

This test creates a model whose forward contains an embedding lookup
followed by a loop with a deliberate graph break, reproducing the pattern
that triggers the bug.
"""

import argparse
import torch
import deepspeed
from deepspeed.accelerator import get_accelerator
from deepspeed import comm

torch._dynamo.config.cache_size_limit = 100


class SkippedFrameModel(torch.nn.Module):
"""Model that triggers a dynamo frame skip.

``forward`` contains an embedding lookup followed by a loop whose body
calls ``print`` (an opaque side-effect), which causes a graph break
inside the loop. Dynamo skips the entire frame, so the embedding lookup
runs in eager mode with ZeRO-3 partitioned weights.
"""

def __init__(self, vocab_size=128, hidden=64, n_layers=2):
super().__init__()
self.embed_tokens = torch.nn.Embedding(vocab_size, hidden)
self.layers = torch.nn.ModuleList([torch.nn.Linear(hidden, hidden, bias=False) for _ in range(n_layers)])
self.head = torch.nn.Linear(hidden, vocab_size, bias=False)

def forward(self, input_ids):
h = self.embed_tokens(input_ids)
for layer in self.layers:
h = layer(h)
# graph break inside a loop body — dynamo skips the entire frame
if torch.compiler.is_compiling():
torch._dynamo.graph_break()
return self.head(h)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--deepspeed_config", type=str, default="ds_config_z3.json")
args = parser.parse_args()

model = SkippedFrameModel()
engine, _, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters())
engine.compile()

device = get_accelerator().current_device_name()
input_ids = torch.randint(0, 128, (1, 16), device=device)

for step in range(3):
loss = engine(input_ids).sum()
engine.backward(loss)
engine.step()
if comm.get_rank() == 0:
print(f"step={step} loss={loss.item():.4f}")

if comm.get_rank() == 0:
print("PASS")


if __name__ == "__main__":
main()
Loading