diff --git a/deepspeed/compile/init_z3.py b/deepspeed/compile/init_z3.py index 11f7eec8e2bd..e9a510cd6055 100644 --- a/deepspeed/compile/init_z3.py +++ b/deepspeed/compile/init_z3.py @@ -44,9 +44,8 @@ def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None): dc = get_deepcompile_handle() dc.init(engine.data_parallel_group, compile_config, engine.zero_reduce_bucket_size()) - # Unset hooks - for m in engine.module.modules(): - m._parameters = m._original_parameters + # Keep ZeROOrderedDict as a fallback for dynamo-skipped frames that + # run eagerly without compiled allgather ops. if use_opt: optimizer.parameter_offload._remove_module_hooks() diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 7708999fcdf7..568f677ffc4b 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2449,9 +2449,29 @@ def forward(self, *inputs, **kwargs): # We can't have this in forward prologue as the compiler compiles hooks including the forward prologue. self.launch_compile_passes(self.global_steps) + # When DeepCompile is active the per-module gather/release hooks are + # removed and all parameter gathering is handled by compiled graph ops. + # However, torch._dynamo may skip frames that contain graph breaks in + # loops. Skipped frames execute eagerly without the compiled ops, so + # the ZeROOrderedDict safety-net must be enabled to auto-gather any + # parameter accessed in those frames. + _dc_z3_eager_fallback = (self.is_deepcompile_active() and self.zero_optimization_partition_weights()) + if _dc_z3_eager_fallback: + for module in self.module.modules(): + if isinstance(module._parameters, ZeROOrderedDict): + module._parameters._in_forward = True + with autocast_if_enabled(self): loss = self.module(*inputs, **kwargs) + if _dc_z3_eager_fallback: + for p in self.module.parameters(): + if hasattr(p, "ds_status") and p.ds_status == ZeroParamStatus.AVAILABLE and not p.ds_persist: + p.partition() + for module in self.module.modules(): + if isinstance(module._parameters, ZeROOrderedDict): + module._parameters._in_forward = False + # Register output backward hooks # preprocess_once_fn is called for preprocessing # preprocess_per_tensor_fn scales a tensor for gradient accumulation diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 1edd666e532d..e859891b49c4 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -61,9 +61,8 @@ def __getitem__(self, key): if param is None: return param - # TODO: only weaken this check during compilation if hasattr(param, "ds_status") and param.ds_status == ZeroParamStatus.NOT_AVAILABLE: - if self._parent_module._parameters._in_forward: + if self._parent_module._parameters._in_forward and not torch.compiler.is_compiling(): register_external_parameter(FWD_MODULE_STACK[-1], param) param.all_gather() print_rank_0(f'Registering external parameter from getter {key} ds_id = {param.ds_id}', force=False) diff --git a/tests/torch_compile/test_deepcompile_skipped_frame.py b/tests/torch_compile/test_deepcompile_skipped_frame.py new file mode 100644 index 000000000000..67183a271b95 --- /dev/null +++ b/tests/torch_compile/test_deepcompile_skipped_frame.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Regression test for https://github.com/deepspeedai/DeepSpeed/issues/7942 + +When torch._dynamo skips a frame (e.g. because of a graph break inside a +for/while loop), the frame runs in eager mode. DeepCompile removes the +ZeRO-3 parameter-gathering hooks, so parameters accessed in the skipped +frame remain partitioned (shape ``[0]``). For an embedding layer this +causes ``RuntimeError: 'weight' must be 2-D``. + +This test creates a model whose forward contains an embedding lookup +followed by a loop with a deliberate graph break, reproducing the pattern +that triggers the bug. +""" + +import argparse +import torch +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed import comm + +torch._dynamo.config.cache_size_limit = 100 + + +class SkippedFrameModel(torch.nn.Module): + """Model that triggers a dynamo frame skip. + + ``forward`` contains an embedding lookup followed by a loop whose body + calls ``print`` (an opaque side-effect), which causes a graph break + inside the loop. Dynamo skips the entire frame, so the embedding lookup + runs in eager mode with ZeRO-3 partitioned weights. + """ + + def __init__(self, vocab_size=128, hidden=64, n_layers=2): + super().__init__() + self.embed_tokens = torch.nn.Embedding(vocab_size, hidden) + self.layers = torch.nn.ModuleList([torch.nn.Linear(hidden, hidden, bias=False) for _ in range(n_layers)]) + self.head = torch.nn.Linear(hidden, vocab_size, bias=False) + + def forward(self, input_ids): + h = self.embed_tokens(input_ids) + for layer in self.layers: + h = layer(h) + # graph break inside a loop body — dynamo skips the entire frame + if torch.compiler.is_compiling(): + torch._dynamo.graph_break() + return self.head(h) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--local_rank", type=int, default=-1) + parser.add_argument("--deepspeed_config", type=str, default="ds_config_z3.json") + args = parser.parse_args() + + model = SkippedFrameModel() + engine, _, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) + engine.compile() + + device = get_accelerator().current_device_name() + input_ids = torch.randint(0, 128, (1, 16), device=device) + + for step in range(3): + loss = engine(input_ids).sum() + engine.backward(loss) + engine.step() + if comm.get_rank() == 0: + print(f"step={step} loss={loss.item():.4f}") + + if comm.get_rank() == 0: + print("PASS") + + +if __name__ == "__main__": + main()