From d52a4ecb342cf89d773d15e3cd579febac9d9be3 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Mon, 1 Jun 2026 00:36:51 +0000 Subject: [PATCH 01/13] Fix ZenFlow NaN under PyTorch-style backward via backward_prologue The PyTorch-style backward API drives backward through loss.backward() and the engine's autograd hooks, which call optimizer.backward_prologue() at the start of each backward pass instead of ZenFlow's own backward(). ZenFlow's per-microbatch setup therefore never ran, leaving micro_step unadvanced and the selective optimizer unsynced at a selection boundary, so the top-k update operated on stale state and the loss went NaN. - Override backward_prologue() with ZenFlow's per-microbatch setup: advance micro_step, refresh the auto-update bookkeeping, and on a selection boundary resync the fp32 master partition and clear the selective optimizer's moments. - Remove the standalone backward() override, which the PyTorch-style engine no longer calls. Validated on Qwen2.5-0.5B + Alpaca (ZeRO-2 offload, overlap step): loss now matches the old-version ZenFlow step-for-step instead of diverging to NaN. Signed-off-by: Tingfeng Lan --- .../runtime/zenflow/zenflow_stage_1_and_2.py | 26 ++++++------------- 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py index 2f5e423f1320..27882b1dfb20 100644 --- a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py @@ -538,15 +538,15 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt self._process_selected_fp32_groups_grad(tensor, curr_selected_reduce_size, communication_data_type) self.timers(SELECTIVE_OPTIMIZER_PROCESS_TIMER).stop() - def backward(self, loss, retain_graph=False): + def backward_prologue(self): + """Prepare ZenFlow's per-microbatch state before the backward pass. + + Called by the engine at the start of each backward. Advances the + micro-step counter and, on an auto-update step, refreshes the + update-interval bookkeeping. At a selection boundary, resyncs the fp32 + master partition from the bit16 weights and clears the selective + optimizer's moments so the next top-k update starts clean. """ - :attr:`backward` performs the following steps: - - 1. fp32_loss = loss.float() - 2. scaled_loss = fp32_loss*loss_scale - 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves - """ - self.backward_prologue() self.micro_step += 1 if self.auto_update: @@ -565,16 +565,6 @@ def backward(self, loss, retain_graph=False): self.selective_optimizer.clear_selected_mv() self.timers(SELECTIVE_OPTIMIZER_SYNC_TIMER).stop() - self.enter_backward() - if self.custom_loss_scaler: - scaled_loss = self.external_loss_scale * loss - scaled_loss.backward(retain_graph=retain_graph) - else: - self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) - - self.backward_epilogue() - self.exit_backward() - def log_selective_optimizer_timers(self): self.timers.log(SELECTIVE_OPTIMIZER_TIMERS) From 1d9f3ccee3b3d6408a2a8873cfcd30de430176fc Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 10 Jun 2026 04:39:31 +0000 Subject: [PATCH 02/13] Add fused multi-tensor CPU Adam for ZenFlow overlap step ZenFlow's overlapped CPU optimizer stepped each parameter through a separate `adam_update` call from Python and kept a stale snapshot for the GPU sync via `p.stale_param.data.copy_(p.data.clone())`. For a group with many parameters this pays one Python<->C++ crossing (and one OpenMP region spawn) per parameter, and the `clone()` adds a full allocation plus an extra memory pass every step. Add a fused multi-tensor entry that drives the whole group in C++ and writes the stale snapshot natively, so the overlapped step issues a single native call. - Add `ds_adam_step_multi` (bound as `adam_update_multi`): one call updates a list of params/grads/exp_avg/exp_avg_sq, advancing the bias-correction state once for the shared step; when a stale list is provided, each post-update parameter is snapshotted into it via a native copy. - Rewrite `ZenFlowCPUAdam._parallel_step` to collect the group's tensors and issue a single `adam_update_multi`, dropping the per-parameter calls and the Python-side `clone()`. - Leave the existing per-parameter `ds_adam_step` path unchanged. - Add a numerical-equivalence test: fused vs per-parameter is bit-for-bit equal across fp16/bf16/fp32 (params, moments, and the stale snapshot), plus the empty-stale path. Behavior is identical to the per-parameter path, verified bit-for-bit at the op level and as an unchanged end-to-end loss trajectory across ZeRO stages 1/2/3. Signed-off-by: Tingfeng Lan --- csrc/adam/cpu_adam.cpp | 3 + csrc/adam/cpu_adam_impl.cpp | 48 ++++++++++++++++ csrc/includes/cpu_adam.h | 14 +++++ deepspeed/ops/adam/zenflow_cpu_adam.py | 31 +++++++--- tests/unit/ops/adam/test_cpu_adam.py | 79 ++++++++++++++++++++++++++ 5 files changed, 166 insertions(+), 9 deletions(-) diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index f4c242ff9229..b3496a02ef1d 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -8,6 +8,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)"); + m.def("adam_update_multi", + &ds_adam_step_multi, + "DeepSpeed CPU Adam fused multi-tensor update (C++)"); m.def("adam_rollback", &ds_adam_rollback, "DeepSpeed CPU Adam rollback (C++)"); m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); diff --git a/csrc/adam/cpu_adam_impl.cpp b/csrc/adam/cpu_adam_impl.cpp index 1f2b8cf0df47..ea87d94d9715 100644 --- a/csrc/adam/cpu_adam_impl.cpp +++ b/csrc/adam/cpu_adam_impl.cpp @@ -236,6 +236,54 @@ int ds_adam_step(int optimizer_id, return 0; } +// Fused multi-tensor variant used by ZenFlow's overlapped optimizer. Driving the +// per-parameter loop in C++ avoids one Python<->C++ crossing (and one OpenMP region +// spawn) per parameter, which dominates on ZeRO Stage 1/2 where a group holds many +// small parameters. When stale_params is non-empty, the post-update parameter is +// snapshotted into it here, replacing the Python-side clone()+copy that ZenFlow used +// to keep a stale copy for the GPU sync. +int ds_adam_step_multi(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + std::vector& params, + std::vector& grads, + std::vector& exp_avgs, + std::vector& exp_avg_sqs, + std::vector& stale_params) +{ + const size_t num_tensors = params.size(); + TORCH_CHECK(grads.size() == num_tensors && exp_avgs.size() == num_tensors && + exp_avg_sqs.size() == num_tensors, + "ds_adam_step_multi: params/grads/exp_avgs/exp_avg_sqs length mismatch"); + const bool has_stale = !stale_params.empty(); + TORCH_CHECK(!has_stale || stale_params.size() == num_tensors, + "ds_adam_step_multi: stale_params length mismatch"); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + // All tensors share one optimizer step, so advance bias-correction state once. + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, epsilon, weight_decay, bias_correction); + + for (size_t i = 0; i < num_tensors; ++i) { + auto params_c = params[i].contiguous(); + auto grads_c = grads[i].contiguous(); + auto exp_avg_c = exp_avgs[i].contiguous(); + auto exp_avg_sq_c = exp_avg_sqs[i].contiguous(); + + invoke(opt, params_c, grads_c, exp_avg_c, exp_avg_sq_c, params_c.numel()); + + if (has_stale) { stale_params[i].copy_(params_c); } + } + + return 0; +} + void adamw_rollback_inplace(float* params, const float* grads, float* momentum, diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index f07a14e08438..d466bc410aab 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -220,6 +220,20 @@ int ds_adam_step(int optimizer_id, torch::Tensor& exp_avg, torch::Tensor& exp_avg_sq); +int ds_adam_step_multi(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + std::vector& params, + std::vector& grads, + std::vector& exp_avgs, + std::vector& exp_avg_sqs, + std::vector& stale_params); + int ds_adam_rollback(int optimizer_id, size_t step, float lr, diff --git a/deepspeed/ops/adam/zenflow_cpu_adam.py b/deepspeed/ops/adam/zenflow_cpu_adam.py index 0809d7a0f7e0..5b8ab17622f7 100644 --- a/deepspeed/ops/adam/zenflow_cpu_adam.py +++ b/deepspeed/ops/adam/zenflow_cpu_adam.py @@ -103,7 +103,15 @@ def _parallel_step(self, step_id, now_state, group_info, closure=None): # intended device for step device = torch.device('cpu') - stale_param = None + # Collect the per-group tensors and drive the whole group through a single fused + # native call. This keeps the per-parameter loop in C++, avoiding one + # Python<->C++ crossing per parameter, and lets the stale snapshot be written + # natively (no Python-side clone()). + params = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + stale_params = [] for group_id, group in enumerate(self.param_groups): for param_id, p in enumerate(group['params']): @@ -117,8 +125,6 @@ def _parallel_step(self, step_id, now_state, group_info, closure=None): state = self.state[p] # State initialization if len(state) == 0: - #print(f'group {group_id} param {param_id} = {p.numel()}') - # print("creating", flush=True) state['step'] = 0 #use full precision by default unless self.fp32_optimizer_states is off @@ -129,10 +135,17 @@ def _parallel_step(self, step_id, now_state, group_info, closure=None): state['exp_avg_sq'] = [exp_avg_sq, exp_avg_sq.clone()] state['step'] = step_id - beta1, beta2 = group_info['betas'] - self.ds_opt_adam.adam_update(self.opt_id, state['step'], group_info['lr'], beta1, beta2, - group_info['eps'], group_info['weight_decay'], - group_info['bias_correction'], p.data, p.overlap_grad[now_state].data, - state['exp_avg'][now_state], state['exp_avg_sq'][now_state]) - p.stale_param.data.copy_(p.data.clone()) + params.append(p.data) + grads.append(p.overlap_grad[now_state].data) + exp_avgs.append(state['exp_avg'][now_state]) + exp_avg_sqs.append(state['exp_avg_sq'][now_state]) + stale_params.append(p.stale_param.data) + + if not params: + return loss + + beta1, beta2 = group_info['betas'] + self.ds_opt_adam.adam_update_multi(self.opt_id, step_id, group_info['lr'], beta1, beta2, group_info['eps'], + group_info['weight_decay'], group_info['bias_correction'], params, grads, + exp_avgs, exp_avg_sqs, stale_params) return loss diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index 003a6f8f6a46..42f817bd1780 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -175,6 +175,85 @@ def test_bf16_optimizer_states_match_fp32(self, model_size): check_equal(param_fp32_states.float().norm(), param_bf16_states.float().norm(), atol=tolerance) +class TestCPUAdamFusedMultiTensor(DistributedTest): + """adam_update_multi (fused multi-tensor, used by ZenFlow overlap) must match a + per-parameter sequence of adam_update bit-for-bit, and write the post-update + parameter snapshot into the stale buffer.""" + world_size = 1 + reuse_dist_env = True + requires_cuda_env = False + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + @pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"]) + def test_multi_matches_single(self, dtype): + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-adam with half precision not supported on AMD CPUs") + + ds_opt_adam = CPUAdamBuilder().load() + + lr, beta1, beta2, eps, weight_decay = 1e-3, 0.9, 0.999, 1e-8, 0.0 + adamw_mode, bias_correction = True, True + # Mixed sizes (including ones that don't divide the SIMD width) exercise both the + # vectorized and scalar tails inside the fused C++ loop. + sizes = [64, 22, 1024, 1048576] + + opt_single, opt_multi = 0, 1 + ds_opt_adam.create_adam(opt_single, lr, beta1, beta2, eps, weight_decay, adamw_mode, False) + ds_opt_adam.create_adam(opt_multi, lr, beta1, beta2, eps, weight_decay, adamw_mode, False) + + torch.manual_seed(0) + params_single = [torch.randn(n, dtype=dtype) for n in sizes] + params_multi = [p.clone() for p in params_single] + exp_avg_single = [torch.zeros(n, dtype=torch.float) for n in sizes] + exp_avg_sq_single = [torch.zeros(n, dtype=torch.float) for n in sizes] + exp_avg_multi = [torch.zeros(n, dtype=torch.float) for n in sizes] + exp_avg_sq_multi = [torch.zeros(n, dtype=torch.float) for n in sizes] + stale_multi = [torch.zeros(n, dtype=dtype) for n in sizes] + + try: + for step in range(1, 6): + grads = [torch.randn(n, dtype=dtype) for n in sizes] + + for i in range(len(sizes)): + ds_opt_adam.adam_update(opt_single, step, lr, beta1, beta2, eps, weight_decay, bias_correction, + params_single[i], grads[i].clone(), exp_avg_single[i], + exp_avg_sq_single[i]) + + ds_opt_adam.adam_update_multi(opt_multi, step, lr, beta1, beta2, eps, weight_decay, bias_correction, + params_multi, [g.clone() for g in grads], exp_avg_multi, + exp_avg_sq_multi, stale_multi) + + for i in range(len(sizes)): + assert torch.equal(params_single[i], params_multi[i]), f"param mismatch at size {sizes[i]}" + assert torch.equal(exp_avg_single[i], exp_avg_multi[i]), f"exp_avg mismatch at size {sizes[i]}" + assert torch.equal(exp_avg_sq_single[i], + exp_avg_sq_multi[i]), f"exp_avg_sq mismatch at size {sizes[i]}" + # stale must hold the post-update parameter snapshot + assert torch.equal(stale_multi[i], params_multi[i]), f"stale mismatch at size {sizes[i]}" + finally: + ds_opt_adam.destroy_adam(opt_single) + ds_opt_adam.destroy_adam(opt_multi) + + def test_multi_without_stale(self): + """An empty stale list is allowed and simply skips the snapshot.""" + ds_opt_adam = CPUAdamBuilder().load() + opt_id = 2 + ds_opt_adam.create_adam(opt_id, 1e-3, 0.9, 0.999, 1e-8, 0.0, True, False) + try: + params = [torch.randn(64, dtype=torch.float)] + grads = [torch.randn(64, dtype=torch.float)] + exp_avg = [torch.zeros(64, dtype=torch.float)] + exp_avg_sq = [torch.zeros(64, dtype=torch.float)] + before = params[0].clone() + ds_opt_adam.adam_update_multi(opt_id, 1, 1e-3, 0.9, 0.999, 1e-8, 0.0, True, params, grads, exp_avg, + exp_avg_sq, []) + assert not torch.equal(params[0], before), "params should be updated even without stale buffers" + finally: + ds_opt_adam.destroy_adam(opt_id) + + class TestCPUAdamGPUError(DistributedTest): def test_cpu_adam_gpu_error(self): From 790a83a83b7a8dd8ef1e339f95624f3922d23547 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 10 Jun 2026 05:06:15 +0000 Subject: [PATCH 03/13] Let CPU Adam kernel run serially without OpenMP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Prepare the kernel for ZenFlow's in-process optimizer thread (L2). When the optimizer runs on a background thread pinned to a dedicated set of cores, it must not spawn OpenMP teams from the global libgomp pool — that pool is shared with the training thread's torch ops and would defeat the core partitioning. Thread a `parallel` flag through the step path (`Step_1/4/8`, `Step_AVX`, `step_invoker`, the dtype dispatch map, and `invoke`) and turn the two `#pragma omp parallel for` into `if (parallel)`. With `parallel=true` (the default everywhere) the region is identical to before; with `parallel=false` the loop runs serially in the calling thread, so a pinned pool can drive each element slice itself. - Expose the flag as an optional `parallel` argument on `adam_update_multi` (defaults to true, so existing callers are unchanged). - Add a test that the serial path matches the OpenMP path bit-for-bit across fp16/bf16/fp32. No behavior change for existing paths; Adam math is untouched. Signed-off-by: Tingfeng Lan --- csrc/adam/cpu_adam.cpp | 19 ++++++++- csrc/adam/cpu_adam_impl.cpp | 45 +++++++++++++-------- csrc/includes/cpu_adam.h | 14 ++++--- tests/unit/ops/adam/test_cpu_adam.py | 58 ++++++++++++++++++++++++++++ 4 files changed, 113 insertions(+), 23 deletions(-) diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index b3496a02ef1d..ed27ffa941b7 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -7,10 +7,27 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + using namespace pybind11::literals; m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)"); + // `parallel` defaults to true (OpenMP across elements, as before). ZenFlow's native + // pinned thread pool sets it false so each pool thread runs its slice serially. m.def("adam_update_multi", &ds_adam_step_multi, - "DeepSpeed CPU Adam fused multi-tensor update (C++)"); + "DeepSpeed CPU Adam fused multi-tensor update (C++)", + "optimizer_id"_a, + "step"_a, + "lr"_a, + "beta1"_a, + "beta2"_a, + "epsilon"_a, + "weight_decay"_a, + "bias_correction"_a, + "params"_a, + "grads"_a, + "exp_avgs"_a, + "exp_avg_sqs"_a, + "stale_params"_a, + "parallel"_a = true); m.def("adam_rollback", &ds_adam_rollback, "DeepSpeed CPU Adam rollback (C++)"); m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); diff --git a/csrc/adam/cpu_adam_impl.cpp b/csrc/adam/cpu_adam_impl.cpp index ea87d94d9715..231c26c6930a 100644 --- a/csrc/adam/cpu_adam_impl.cpp +++ b/csrc/adam/cpu_adam_impl.cpp @@ -23,11 +23,12 @@ void Adam_Optimizer::Step_1(ds_params_precision_t* _params, ds_params_precision_t* grads, ds_state_precision_t* _exp_avg, ds_state_precision_t* _exp_avg_sq, - size_t _param_size) + size_t _param_size, + bool parallel) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size); + Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, parallel); #endif if (_param_size > rounded_size) { float betta1_minus1 = 1 - _betta1; @@ -40,7 +41,7 @@ void Adam_Optimizer::Step_1(ds_params_precision_t* _params, size_t copy_size = TILE; if ((t + TILE) > _param_size) copy_size = _param_size - t; size_t offset = copy_size + t; -#pragma omp parallel for +#pragma omp parallel for if (parallel) for (size_t k = t; k < offset; k++) { float grad = (float)grads[k]; float param = (float)_params[k]; @@ -72,18 +73,20 @@ void Adam_Optimizer::Step_4(ds_params_precision_t* _params, ds_params_precision_t* grads, ds_state_precision_t* _exp_avg, ds_state_precision_t* _exp_avg_sq, - size_t _param_size) + size_t _param_size, + bool parallel) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size); + Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, parallel); #endif if (_param_size > rounded_size) Step_1((_params + rounded_size), (grads + rounded_size), (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), - (_param_size - rounded_size)); + (_param_size - rounded_size), + parallel); } int create_adam_optimizer(int optimizer_id, @@ -131,18 +134,20 @@ void Adam_Optimizer::Step_8(ds_params_precision_t* _params, ds_params_precision_t* grads, ds_state_precision_t* _exp_avg, ds_state_precision_t* _exp_avg_sq, - size_t _param_size) + size_t _param_size, + bool parallel) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size); + Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, parallel); #endif if (_param_size > rounded_size) Step_4((_params + rounded_size), (grads + rounded_size), (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), - (_param_size - rounded_size)); + (_param_size - rounded_size), + parallel); } template @@ -151,17 +156,20 @@ void step_invoker(std::shared_ptr opt, void* grads, void* _exp_avg, void* _exp_avg_sq, - size_t _param_size) + size_t _param_size, + bool parallel) { opt->Step_8((ds_params_precision_t*)(_params), (ds_params_precision_t*)(grads), (ds_state_precision_t*)(_exp_avg), (ds_state_precision_t*)(_exp_avg_sq), - _param_size); + _param_size, + parallel); } -std::map, - std::function, void*, void*, void*, void*, size_t)>> +std::map< + std::tuple, + std::function, void*, void*, void*, void*, size_t, bool)>> invokers; // Fill map with template functions for each type @@ -188,7 +196,8 @@ void invoke(std::shared_ptr opt, torch::Tensor& grads, torch::Tensor& exp_avg, torch::Tensor& exp_avg_sq, - size_t param_size) + size_t param_size, + bool parallel = true) { c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype()); c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg.options().dtype()); @@ -205,7 +214,8 @@ void invoke(std::shared_ptr opt, grads.data_ptr(), exp_avg.data_ptr(), exp_avg_sq.data_ptr(), - param_size); + param_size, + parallel); } int ds_adam_step(int optimizer_id, @@ -254,7 +264,8 @@ int ds_adam_step_multi(int optimizer_id, std::vector& grads, std::vector& exp_avgs, std::vector& exp_avg_sqs, - std::vector& stale_params) + std::vector& stale_params, + bool parallel) { const size_t num_tensors = params.size(); TORCH_CHECK(grads.size() == num_tensors && exp_avgs.size() == num_tensors && @@ -276,7 +287,7 @@ int ds_adam_step_multi(int optimizer_id, auto exp_avg_c = exp_avgs[i].contiguous(); auto exp_avg_sq_c = exp_avg_sqs[i].contiguous(); - invoke(opt, params_c, grads_c, exp_avg_c, exp_avg_sq_c, params_c.numel()); + invoke(opt, params_c, grads_c, exp_avg_c, exp_avg_sq_c, params_c.numel(), parallel); if (has_stale) { stale_params[i].copy_(params_c); } } diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index d466bc410aab..3a26f1c17d80 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -19,7 +19,8 @@ ds_params_precision_t* grads, \ ds_state_precision_t* _exp_avg, \ ds_state_precision_t* _exp_avg_sq, \ - size_t _param_size); + size_t _param_size, \ + bool parallel = true); class Adam_Optimizer { public: @@ -49,7 +50,8 @@ class Adam_Optimizer { ds_params_precision_t* grads, ds_state_precision_t* _exp_avg, ds_state_precision_t* _exp_avg_sq, - size_t param_size); + size_t param_size, + bool parallel = true); #endif STEP(1) STEP(4) @@ -115,7 +117,8 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, ds_params_precision_t* grads, ds_state_precision_t* _exp_avg, ds_state_precision_t* _exp_avg_sq, - size_t _param_size) + size_t _param_size, + bool parallel) { #if !defined(__AVX512__) if (std::is_same_v || @@ -156,7 +159,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, size_t copy_size = TILE; if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; size_t offset = copy_size + t; -#pragma omp parallel for +#pragma omp parallel for if (parallel) for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { AVX_Data grad_4[span]; simd_load(grad_4, grads + i); @@ -232,7 +235,8 @@ int ds_adam_step_multi(int optimizer_id, std::vector& grads, std::vector& exp_avgs, std::vector& exp_avg_sqs, - std::vector& stale_params); + std::vector& stale_params, + bool parallel = true); int ds_adam_rollback(int optimizer_id, size_t step, diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index 42f817bd1780..539bbc70548d 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -236,6 +236,64 @@ def test_multi_matches_single(self, dtype): ds_opt_adam.destroy_adam(opt_single) ds_opt_adam.destroy_adam(opt_multi) + @pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"]) + def test_serial_matches_parallel(self, dtype): + """The serial kernel path (parallel=False, used by ZenFlow's pinned thread pool) + must match the OpenMP path (parallel=True) bit-for-bit.""" + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-adam with half precision not supported on AMD CPUs") + + ds_opt_adam = CPUAdamBuilder().load() + lr, beta1, beta2, eps, weight_decay = 1e-3, 0.9, 0.999, 1e-8, 0.0 + sizes = [64, 22, 1024, 1048576] + + opt_par, opt_ser = 3, 4 + ds_opt_adam.create_adam(opt_par, lr, beta1, beta2, eps, weight_decay, True, False) + ds_opt_adam.create_adam(opt_ser, lr, beta1, beta2, eps, weight_decay, True, False) + + torch.manual_seed(0) + params_par = [torch.randn(n, dtype=dtype) for n in sizes] + params_ser = [p.clone() for p in params_par] + ea_par = [torch.zeros(n) for n in sizes] + eq_par = [torch.zeros(n) for n in sizes] + ea_ser = [torch.zeros(n) for n in sizes] + eq_ser = [torch.zeros(n) for n in sizes] + + try: + for step in range(1, 4): + grads = [torch.randn(n, dtype=dtype) for n in sizes] + ds_opt_adam.adam_update_multi(opt_par, + step, + lr, + beta1, + beta2, + eps, + weight_decay, + True, + params_par, [g.clone() for g in grads], + ea_par, + eq_par, [], + parallel=True) + ds_opt_adam.adam_update_multi(opt_ser, + step, + lr, + beta1, + beta2, + eps, + weight_decay, + True, + params_ser, [g.clone() for g in grads], + ea_ser, + eq_ser, [], + parallel=False) + for i in range(len(sizes)): + assert torch.equal(params_par[i], params_ser[i]), f"param mismatch at size {sizes[i]}" + assert torch.equal(ea_par[i], ea_ser[i]), f"exp_avg mismatch at size {sizes[i]}" + assert torch.equal(eq_par[i], eq_ser[i]), f"exp_avg_sq mismatch at size {sizes[i]}" + finally: + ds_opt_adam.destroy_adam(opt_par) + ds_opt_adam.destroy_adam(opt_ser) + def test_multi_without_stale(self): """An empty stale list is allowed and simply skips the snapshot.""" ds_opt_adam = CPUAdamBuilder().load() From 1640828dd7a82d86a97c5f71c40f1059386a9f45 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 10 Jun 2026 05:44:06 +0000 Subject: [PATCH 04/13] Add ZenFlowAdam: in-process overlapped CPU Adam Add the native side of ZenFlow's overlapped optimizer so the CPU Adam step can run concurrently with the Python training thread without a separate process. The existing design dodges the GIL by running the step in a multiprocessing subprocess, which costs process spawn, shared-memory tensors, a pipe, and per-step rebinding. With the step in native code that releases the GIL, a background thread in the same process achieves the same overlap and touches the same tensors directly. ZenFlowAdam owns a dispatcher thread and a pool of worker threads pinned to ZenFlow's dedicated cores. submit_step() hands a step to the dispatcher and returns immediately; wait_step() blocks (with the GIL released) until it finishes. The dispatcher advances the shared optimizer's bias-correction state per group, then fans each group's elements out to the pinned pool, where every thread runs its slice through the serial (parallel=false) kernel -- so the pool, not OpenMP, provides the parallelism and stays on the ZenFlow cores. - Pin pool threads with pthread_setaffinity_np (Linux); slice boundaries are rounded to the SIMD block so each slice's AVX/scalar split matches the whole-tensor kernel and the result is bit-identical. - Expose a small C handle API (zenflow_adam_create/register_group/submit/wait/ destroy); submit/wait/destroy release the GIL. - Tests: ZenFlowAdam matches the fused reference bit-for-bit with alternating double buffers and multiple groups, and the pipelined submit/wait (including the engine's skipped post-warmup wait) does not desync. Packaged inside the cpu_adam op to reuse Adam_Optimizer and the dtype dispatch; not yet wired into the ZenFlow engine. Signed-off-by: Tingfeng Lan --- csrc/adam/cpu_adam.cpp | 19 ++ csrc/adam/cpu_adam_impl.cpp | 355 +++++++++++++++++++++++++++ csrc/includes/cpu_adam.h | 28 +++ tests/unit/ops/adam/test_cpu_adam.py | 108 ++++++++ 4 files changed, 510 insertions(+) diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index ed27ffa941b7..386eb7b03fe2 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -31,4 +31,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("adam_rollback", &ds_adam_rollback, "DeepSpeed CPU Adam rollback (C++)"); m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); + + // ZenFlowAdam: in-process overlapped CPU Adam. wait/submit/destroy release the GIL + // so the optimizer thread overlaps the Python training thread. + m.def("zenflow_adam_create", &zenflow_adam_create, "ZenFlowAdam create (C++)"); + m.def("zenflow_adam_register_group", + &zenflow_adam_register_group, + "ZenFlowAdam register a parameter group (C++)"); + m.def("zenflow_adam_submit", + &zenflow_adam_submit, + "ZenFlowAdam submit an overlapped step (C++)", + pybind11::call_guard()); + m.def("zenflow_adam_wait", + &zenflow_adam_wait, + "ZenFlowAdam wait for a submitted step (C++)", + pybind11::call_guard()); + m.def("zenflow_adam_destroy", + &zenflow_adam_destroy, + "ZenFlowAdam destroy (C++)", + pybind11::call_guard()); } diff --git a/csrc/adam/cpu_adam_impl.cpp b/csrc/adam/cpu_adam_impl.cpp index 231c26c6930a..7a8398f4cb9e 100644 --- a/csrc/adam/cpu_adam_impl.cpp +++ b/csrc/adam/cpu_adam_impl.cpp @@ -4,14 +4,25 @@ // DeepSpeed Team #include +#include #include +#include +#include +#include #include #include #include #include +#include +#include #include #include +#include #include "cpu_adam.h" +#if defined(__linux__) +#include +#include +#endif using namespace std::string_literals; static std::unordered_map> s_optimizers; @@ -397,3 +408,347 @@ int destroy_adam_optimizer(int optimizer_id) return 0; } + +// --------------------------------------------------------------------------- +// ZenFlowAdam: in-process, GIL-released CPU Adam for ZenFlow's overlapped step. +// +// Replaces the multiprocessing optimizer subprocess. The optimizer step runs on +// a background dispatcher thread; the heavy per-element math is fanned out to a +// pool of worker threads pinned to ZenFlow's dedicated cores, each running its +// element slice through the serial (parallel=false) kernel. Because the workers +// hold no GIL while computing, the Python training thread keeps running. Since +// everything lives in one process the optimizer touches the same tensors the +// main thread holds -- no shared memory, pipe, or per-step rebinding. +// --------------------------------------------------------------------------- + +// A persistent pool of threads pinned to a fixed core set. parallel_for() splits +// [0, total) into one contiguous chunk per thread and blocks until all finish. +class PinnedThreadPool { +public: + explicit PinnedThreadPool(const std::vector& affinity) + { + n_ = std::max(1, affinity.size()); + for (size_t i = 0; i < n_; ++i) { + int core = affinity.empty() ? -1 : affinity[i % affinity.size()]; + threads_.emplace_back([this, i, core] { worker(i, core); }); + } + } + + ~PinnedThreadPool() + { + { + std::lock_guard lk(m_); + stop_ = true; + ++gen_; + } + cv_start_.notify_all(); + for (auto& t : threads_) t.join(); + } + + size_t size() const { return n_; } + + // Split [0, total) into one chunk per thread. Chunk boundaries are rounded up to a + // multiple of `align` so each slice's AVX/scalar split lines up with the whole-tensor + // kernel's split -- otherwise an element could be computed by AVX (FMA) in one layout + // and the scalar tail (mul+add) in another, which differ in the last bit. + void parallel_for(size_t total, size_t align, std::function fn) + { + { + std::unique_lock lk(m_); + fn_ = std::move(fn); + total_ = total; + align_ = std::max(1, align); + done_count_ = 0; + ++gen_; + } + cv_start_.notify_all(); + std::unique_lock lk(m_); + cv_done_.wait(lk, [this] { return done_count_ == n_; }); + } + +private: + void worker(size_t tid, int core) + { +#if defined(__linux__) + if (core >= 0) { + cpu_set_t set; + CPU_ZERO(&set); + CPU_SET(core, &set); + pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &set); + } +#endif + long seen = 0; + while (true) { + std::function fn; + size_t total = 0; + size_t align = 1; + { + std::unique_lock lk(m_); + cv_start_.wait(lk, [this, seen] { return gen_ != seen; }); + seen = gen_; + if (stop_) return; + fn = fn_; + total = total_; + align = align_; + } + size_t chunk = (total + n_ - 1) / n_; + chunk = ((chunk + align - 1) / align) * align; // round up to SIMD-block alignment + size_t begin = std::min(tid * chunk, total); + size_t end = std::min(begin + chunk, total); + if (end > begin) fn(begin, end); + { + std::lock_guard lk(m_); + ++done_count_; + if (done_count_ == n_) cv_done_.notify_one(); + } + } + } + + size_t n_; + std::vector threads_; + std::mutex m_; + std::condition_variable cv_start_, cv_done_; + std::function fn_; + size_t total_ = 0; + size_t align_ = 1; + size_t done_count_ = 0; + long gen_ = 0; + bool stop_ = false; +}; + +// SIMD block the Adam AVX kernel rounds to (Step_8 => span 8). Slicing on multiples of +// this keeps each slice's AVX/scalar boundary identical to the whole-tensor kernel. +#if defined(__AVX512__) or defined(__AVX256__) +static constexpr size_t kZenAdamAlign = SIMD_WIDTH * 8; +#else +static constexpr size_t kZenAdamAlign = 1; +#endif + +struct ZenHP { + float lr, beta1, beta2, eps, weight_decay; + bool bias_correction; +}; + +struct ZenGroup { + torch::Tensor param; + torch::Tensor grad[2]; + torch::Tensor exp_avg[2]; + torch::Tensor exp_avg_sq[2]; + torch::Tensor stale; // may be undefined -> stale snapshot skipped +}; + +class ZenFlowAdam { +public: + ZenFlowAdam(int optimizer_id, std::vector zf_affinity) : opt_id_(optimizer_id) + { + pool_ = std::make_unique(zf_affinity); + dispatcher_ = std::thread(&ZenFlowAdam::dispatcher_main, this); + } + + ~ZenFlowAdam() { shutdown(); } + + void register_group(torch::Tensor param, + torch::Tensor grad0, + torch::Tensor grad1, + torch::Tensor exp_avg0, + torch::Tensor exp_avg1, + torch::Tensor exp_avg_sq0, + torch::Tensor exp_avg_sq1, + torch::Tensor stale) + { + TORCH_CHECK(param.is_contiguous(), "ZenFlowAdam: param must be contiguous"); + ZenGroup g; + g.param = param; + g.grad[0] = grad0; + g.grad[1] = grad1; + g.exp_avg[0] = exp_avg0; + g.exp_avg[1] = exp_avg1; + g.exp_avg_sq[0] = exp_avg_sq0; + g.exp_avg_sq[1] = exp_avg_sq1; + g.stale = stale; + groups_.push_back(std::move(g)); + } + + // Hand a step to the dispatcher and return immediately (non-blocking). + void submit_step(int now_state, + int64_t step, + std::vector lr, + std::vector beta1, + std::vector beta2, + std::vector eps, + std::vector weight_decay, + std::vector bias_correction) + { + const size_t ng = groups_.size(); + TORCH_CHECK(lr.size() == ng && beta1.size() == ng && beta2.size() == ng && + eps.size() == ng && weight_decay.size() == ng && + bias_correction.size() == ng, + "ZenFlowAdam::submit_step: hyperparameter length must match group count"); + std::vector hps(ng); + for (size_t g = 0; g < ng; ++g) { + hps[g] = {lr[g], beta1[g], beta2[g], eps[g], weight_decay[g], (bool)bias_correction[g]}; + } + { + std::lock_guard lk(mtx_); + TORCH_CHECK(!has_work_, + "ZenFlowAdam::submit_step called before previous step was consumed"); + now_state_ = now_state; + step_ = step; + hps_ = std::move(hps); + has_work_ = true; + } + cv_.notify_all(); + } + + // Block until one submitted step has completed. Uses a completion counter so a + // skipped wait (the engine's first post-warmup round) does not desync: each + // wait consumes exactly one completion, like draining one message from the pipe. + void wait_step() + { + std::unique_lock lk(mtx_); + cv_.wait(lk, [this] { return completed_ > waited_; }); + ++waited_; + } + + void shutdown() + { + { + std::lock_guard lk(mtx_); + if (exit_) return; + exit_ = true; + } + cv_.notify_all(); + if (dispatcher_.joinable()) dispatcher_.join(); + pool_.reset(); + } + +private: + void dispatcher_main() + { + while (true) { + int now_state; + int64_t step; + std::vector hps; + { + std::unique_lock lk(mtx_); + cv_.wait(lk, [this] { return has_work_ || exit_; }); + if (exit_) return; + now_state = now_state_; + step = step_; + hps = hps_; + has_work_ = false; + } + run_step(now_state, step, hps); + { + std::lock_guard lk(mtx_); + ++completed_; + } + cv_.notify_all(); + } + } + + void run_step(int now_state, int64_t step, const std::vector& hps) + { + auto opt = std::static_pointer_cast(s_optimizers[opt_id_]); + for (size_t g = 0; g < groups_.size(); ++g) { + const ZenHP& hp = hps[g]; + // Groups share one Adam_Optimizer; advance its bias-correction state for + // this group before the pool reads it (pool is idle here -> no race). + opt->IncrementStep(step, hp.beta1, hp.beta2); + opt->update_state(hp.lr, hp.eps, hp.weight_decay, hp.bias_correction); + + ZenGroup& grp = groups_[g]; + torch::Tensor& P = grp.param; + torch::Tensor& G = grp.grad[now_state]; + torch::Tensor& M = grp.exp_avg[now_state]; + torch::Tensor& V = grp.exp_avg_sq[now_state]; + + auto it = invokers.find(std::tuple(P.scalar_type(), M.scalar_type())); + TORCH_CHECK(it != invokers.end(), + "ZenFlowAdam: unsupported param/state dtype combination"); + auto fn = it->second; + + char* pp = static_cast(P.data_ptr()); + char* gp = static_cast(G.data_ptr()); + char* mp = static_cast(M.data_ptr()); + char* vp = static_cast(V.data_ptr()); + char* sp = grp.stale.defined() ? static_cast(grp.stale.data_ptr()) : nullptr; + const size_t pe = P.element_size(); + const size_t se = M.element_size(); + const size_t numel = P.numel(); + + pool_->parallel_for(numel, kZenAdamAlign, [=](size_t b, size_t e) { + const size_t len = e - b; + // parallel=false: each pinned thread runs its slice serially. + fn(opt, pp + b * pe, gp + b * pe, mp + b * se, vp + b * se, len, false); + if (sp) std::memcpy(sp + b * pe, pp + b * pe, len * pe); + }); + } + } + + int opt_id_; + std::vector groups_; + std::unique_ptr pool_; + std::thread dispatcher_; + + std::mutex mtx_; + std::condition_variable cv_; + bool has_work_ = false; + bool exit_ = false; + int now_state_ = 0; + int64_t step_ = 0; + std::vector hps_; + uint64_t completed_ = 0; + uint64_t waited_ = 0; +}; + +// Handle-indexed registry, mirroring s_optimizers, so the Python side refers to a +// ZenFlowAdam by an int handle and the class itself stays encapsulated here. +static std::unordered_map> s_zenflow_adams; +static int s_next_zenflow_id = 0; + +int zenflow_adam_create(int optimizer_id, std::vector zf_affinity) +{ + int handle = s_next_zenflow_id++; + s_zenflow_adams[handle] = std::make_unique(optimizer_id, std::move(zf_affinity)); + return handle; +} + +void zenflow_adam_register_group(int handle, + torch::Tensor param, + torch::Tensor grad0, + torch::Tensor grad1, + torch::Tensor exp_avg0, + torch::Tensor exp_avg1, + torch::Tensor exp_avg_sq0, + torch::Tensor exp_avg_sq1, + torch::Tensor stale) +{ + s_zenflow_adams.at(handle)->register_group( + param, grad0, grad1, exp_avg0, exp_avg1, exp_avg_sq0, exp_avg_sq1, stale); +} + +void zenflow_adam_submit(int handle, + int now_state, + int64_t step, + std::vector lr, + std::vector beta1, + std::vector beta2, + std::vector eps, + std::vector weight_decay, + std::vector bias_correction) +{ + s_zenflow_adams.at(handle)->submit_step( + now_state, step, lr, beta1, beta2, eps, weight_decay, bias_correction); +} + +void zenflow_adam_wait(int handle) { s_zenflow_adams.at(handle)->wait_step(); } + +void zenflow_adam_destroy(int handle) +{ + auto it = s_zenflow_adams.find(handle); + if (it != s_zenflow_adams.end()) { + it->second->shutdown(); + s_zenflow_adams.erase(it); + } +} diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index 3a26f1c17d80..eb520ac34ae5 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -252,3 +252,31 @@ int ds_adam_rollback(int optimizer_id, torch::Tensor& exp_avg_sq); int destroy_adam_optimizer(int optimizer_id); + +// ZenFlowAdam: in-process, GIL-released overlapped CPU Adam for ZenFlow. The handle +// indexes a background dispatcher + pinned thread pool that drives the step. +int zenflow_adam_create(int optimizer_id, std::vector zf_affinity); + +void zenflow_adam_register_group(int handle, + torch::Tensor param, + torch::Tensor grad0, + torch::Tensor grad1, + torch::Tensor exp_avg0, + torch::Tensor exp_avg1, + torch::Tensor exp_avg_sq0, + torch::Tensor exp_avg_sq1, + torch::Tensor stale); + +void zenflow_adam_submit(int handle, + int now_state, + int64_t step, + std::vector lr, + std::vector beta1, + std::vector beta2, + std::vector eps, + std::vector weight_decay, + std::vector bias_correction); + +void zenflow_adam_wait(int handle); + +void zenflow_adam_destroy(int handle); diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index 539bbc70548d..304925f5542d 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -312,6 +312,114 @@ def test_multi_without_stale(self): ds_opt_adam.destroy_adam(opt_id) +class TestZenFlowAdamNative(DistributedTest): + """ZenFlowAdam (in-process background thread + pinned pool, sliced serial kernel) + must produce the same update as the reference fused path, with the alternating + double-buffered grads/moments that ZenFlow's overlap uses.""" + world_size = 1 + reuse_dist_env = True + requires_cuda_env = False + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + @pytest.mark.parametrize('dtype', [torch.float, torch.bfloat16], ids=["fp32", "bf16"]) + def test_matches_reference(self, dtype): + import os + ds = CPUAdamBuilder().load() + lr, beta1, beta2, eps, weight_decay = 1e-3, 0.9, 0.999, 1e-8, 0.0 + # Sizes that exercise the multi-thread slicing: smaller than the pool, not a + # multiple of it, and large. + sizes = [3, 1000, 100003] + + opt_zf, opt_ref = 5, 6 + ds.create_adam(opt_zf, lr, beta1, beta2, eps, weight_decay, True, False) + ds.create_adam(opt_ref, lr, beta1, beta2, eps, weight_decay, True, False) + + affinity = list(range(min(4, os.cpu_count() or 1))) + handle = ds.zenflow_adam_create(opt_zf, affinity) + + torch.manual_seed(0) + # ZenFlowAdam state (double-buffered) and the reference mirror of it. + params_zf = [torch.randn(n, dtype=dtype) for n in sizes] + params_ref = [p.clone() for p in params_zf] + grad = [[torch.zeros(n, dtype=dtype) for n in sizes] for _ in range(2)] + ea = [[torch.zeros(n) for n in sizes] for _ in range(2)] + eq = [[torch.zeros(n) for n in sizes] for _ in range(2)] + stale = [torch.zeros(n, dtype=dtype) for n in sizes] + ea_ref = [[t.clone() for t in ea[s]] for s in range(2)] + eq_ref = [[t.clone() for t in eq[s]] for s in range(2)] + stale_ref = [t.clone() for t in stale] + + for i in range(len(sizes)): + ds.zenflow_adam_register_group(handle, params_zf[i], grad[0][i], grad[1][i], ea[0][i], ea[1][i], eq[0][i], + eq[1][i], stale[i]) + + try: + for step in range(1, 6): + now = step & 1 + grads = [torch.randn(n, dtype=dtype) for n in sizes] + for i in range(len(sizes)): + grad[now][i].copy_(grads[i]) + + ds.zenflow_adam_submit(handle, now, step, [lr] * len(sizes), [beta1] * len(sizes), + [beta2] * len(sizes), [eps] * len(sizes), [weight_decay] * len(sizes), + [1] * len(sizes)) + ds.zenflow_adam_wait(handle) + + ds.adam_update_multi(opt_ref, step, lr, beta1, beta2, eps, weight_decay, True, params_ref, + [g.clone() for g in grads], ea_ref[now], eq_ref[now], stale_ref) + + for i in range(len(sizes)): + assert torch.equal(params_zf[i], params_ref[i]), f"param mismatch size {sizes[i]} step {step}" + assert torch.equal(ea[now][i], ea_ref[now][i]), f"exp_avg mismatch size {sizes[i]} step {step}" + assert torch.equal(eq[now][i], eq_ref[now][i]), f"exp_avg_sq mismatch size {sizes[i]}" + assert torch.equal(stale[i], stale_ref[i]), f"stale mismatch size {sizes[i]} step {step}" + finally: + ds.zenflow_adam_destroy(handle) + ds.destroy_adam(opt_zf) + ds.destroy_adam(opt_ref) + + def test_pipelined_submit_wait(self): + """Mirror the engine's pipeline: warmup does submit-then-wait, steady state does + wait-then-submit (each wait drains the *previous* submit), leaving one undrained + completion that destroy() cleans up. Must not hang or desync.""" + import os + ds = CPUAdamBuilder().load() + lr, beta1, beta2, eps, wd = 1e-3, 0.9, 0.999, 1e-8, 0.0 + n = 1024 + opt_id = 7 + ds.create_adam(opt_id, lr, beta1, beta2, eps, wd, True, False) + handle = ds.zenflow_adam_create(opt_id, list(range(min(4, os.cpu_count() or 1)))) + + param = torch.randn(n) + g = [torch.zeros(n), torch.zeros(n)] + ea = [torch.zeros(n), torch.zeros(n)] + eq = [torch.zeros(n), torch.zeros(n)] + stale = torch.zeros(n) + ds.zenflow_adam_register_group(handle, param, g[0], g[1], ea[0], ea[1], eq[0], eq[1], stale) + + def submit(now, step): + g[now].copy_(torch.randn(n)) + ds.zenflow_adam_submit(handle, now, step, [lr], [beta1], [beta2], [eps], [wd], [1]) + + try: + # warmup: submit then wait (no overlap) + submit(1, 1) + ds.zenflow_adam_wait(handle) + # steady: the first post-warmup wait is skipped, so this round is submit-only, + # and every later wait drains the submit from the previous round. + submit(0, 2) + for step in range(3, 8): + ds.zenflow_adam_wait(handle) + submit(step & 1, step) + ds.zenflow_adam_wait(handle) # drain the last submitted step + assert torch.all(torch.isfinite(param)) + finally: + ds.zenflow_adam_destroy(handle) + ds.destroy_adam(opt_id) + + class TestCPUAdamGPUError(DistributedTest): def test_cpu_adam_gpu_error(self): From d60c77793d16ea9225aee4c71be13a1a25883507 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 10 Jun 2026 06:02:20 +0000 Subject: [PATCH 05/13] Run ZenFlow stage 1/2 overlapped optimizer in-process Replace the multiprocessing optimizer subprocess with the in-process ZenFlowAdam handle for ZeRO stage 1/2. The subprocess existed only to dodge the GIL; now that the step runs in native code that releases the GIL, a background dispatcher plus a pinned thread pool in the same process give the same overlap and operate on the same tensors directly -- removing the pipe, shared-memory sharing, the manager dict, and the per-step rebinding. - `start_optimizer_process` branches: stage 1/2 builds an in-process ZenFlowCPUAdam, eagerly allocates the double-buffered moments, registers each group with the native handle, and confines the training thread to the PyTorch core set (affinity + OMP_NUM_THREADS + torch.set_num_threads) so it does not contend with the optimizer's pinned pool. Stage 3 keeps the subprocess for now. - `ZenFlowCPUAdam` gains init_native_overlap/submit_overlap_step/wait_overlap_step and destroys the handle on teardown. - stage 1/2 `zenflow_cpu_optimizer_step`/`wait_last_update_and_copy` call the handle's submit/wait instead of pipe send/recv. - Factor the zf/pt core split into `_compute_zf_pt_affinity`, shared by both paths. - Add an overlap_step=True unit test for stage 1/2 (the in-process path runs under the test harness; the stage 3 subprocess cannot spawn from the daemonic test process, which is itself a reason to migrate it). Verified: native and subprocess paths produce bit-identical loss trajectories for stage 1/2 over a seeded run. Signed-off-by: Tingfeng Lan --- deepspeed/ops/adam/zenflow_cpu_adam.py | 53 ++++++++++ .../runtime/zenflow/zenflow_stage_1_and_2.py | 12 +-- deepspeed/runtime/zenflow/zenflow_utils.py | 96 +++++++++++++------ tests/unit/runtime/zenflow/test_zf.py | 33 ++++++- 4 files changed, 153 insertions(+), 41 deletions(-) diff --git a/deepspeed/ops/adam/zenflow_cpu_adam.py b/deepspeed/ops/adam/zenflow_cpu_adam.py index 5b8ab17622f7..5fbee66d004d 100644 --- a/deepspeed/ops/adam/zenflow_cpu_adam.py +++ b/deepspeed/ops/adam/zenflow_cpu_adam.py @@ -149,3 +149,56 @@ def _parallel_step(self, step_id, now_state, group_info, closure=None): group_info['weight_decay'], group_info['bias_correction'], params, grads, exp_avgs, exp_avg_sqs, stale_params) return loss + + @torch.no_grad() + def init_native_overlap(self, zf_affinity): + """Create the native ZenFlowAdam handle and register every parameter group with + it. The optimizer state (double-buffered moments) is allocated eagerly here, + since the in-process worker needs the tensors registered before the first step. + Replaces the multiprocessing optimizer subprocess.""" + device = torch.device('cpu') + self.zf_handle = self.ds_opt_adam.zenflow_adam_create(self.opt_id, list(zf_affinity)) + + for group in self.param_groups: + for p in group['params']: + if not hasattr(p, 'overlap_grad'): + continue + assert p.data.device == device, "ZenFlowCPUAdam params must be on CPU" + + state = self.state[p] + if len(state) == 0: + state['step'] = 0 + state_dtype = torch.float if self.fp32_optimizer_states else p.dtype + exp_avg = torch.zeros_like(p.data, dtype=state_dtype, device=device) + exp_avg_sq = torch.zeros_like(p.data, dtype=state_dtype, device=device) + state['exp_avg'] = [exp_avg, exp_avg.clone()] + state['exp_avg_sq'] = [exp_avg_sq, exp_avg_sq.clone()] + + self.ds_opt_adam.zenflow_adam_register_group(self.zf_handle, p.data, p.overlap_grad[0].data, + p.overlap_grad[1].data, state['exp_avg'][0], + state['exp_avg'][1], state['exp_avg_sq'][0], + state['exp_avg_sq'][1], p.stale_param.data) + + def submit_overlap_step(self, now_state, step_id, group_infos): + """Hand one overlapped step to the native worker (non-blocking).""" + for group_id, group in enumerate(self.param_groups): + self.state[group['params'][0]]['step'] = step_id + lr, beta1, beta2, eps, weight_decay, bias_correction = [], [], [], [], [], [] + for info in group_infos: + lr.append(info['lr']) + beta1.append(info['betas'][0]) + beta2.append(info['betas'][1]) + eps.append(info['eps']) + weight_decay.append(info['weight_decay']) + bias_correction.append(1 if info['bias_correction'] else 0) + self.ds_opt_adam.zenflow_adam_submit(self.zf_handle, now_state, step_id, lr, beta1, beta2, eps, weight_decay, + bias_correction) + + def wait_overlap_step(self): + """Block (GIL released in C++) until the last submitted step finishes.""" + self.ds_opt_adam.zenflow_adam_wait(self.zf_handle) + + def __del__(self): + if hasattr(self, 'zf_handle'): + self.ds_opt_adam.zenflow_adam_destroy(self.zf_handle) + super().__del__() diff --git a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py index 27882b1dfb20..31bb843931f5 100644 --- a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py @@ -667,7 +667,7 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param): def wait_last_update_and_copy(self): - if not hasattr(self, 'parent_conn'): + if not getattr(self, 'process_optimizer_established', False): return if self.micro_step + 1 > self.full_warm_up_rounds and self.first_update_round_after_warmup: @@ -675,8 +675,7 @@ def wait_last_update_and_copy(self): return self.timers(OPTIMIZER_RECV_PARAMS_TIMER).start() - msg = self.parent_conn.recv() - assert msg["type"] == "done", "Optimizer process did not finish stepping correctly." + self.zf_cpu_adam.wait_overlap_step() self.timers(OPTIMIZER_RECV_PARAMS_TIMER).stop() for i, group in enumerate(self.bit16_groups): @@ -730,12 +729,7 @@ def zenflow_cpu_optimizer_step(self, now_state, scaled_global_grad_norm): group_infos.append(group_info) - self.parent_conn.send({ - "type": "step", - "now_state": now_state, - "micro_step": self.micro_step, - "group_infos": group_infos - }) + self.zf_cpu_adam.submit_overlap_step(now_state, self.micro_step + 1, group_infos) def step(self, closure=None): """ diff --git a/deepspeed/runtime/zenflow/zenflow_utils.py b/deepspeed/runtime/zenflow/zenflow_utils.py index f238b3626506..2d701c8ac8c7 100644 --- a/deepspeed/runtime/zenflow/zenflow_utils.py +++ b/deepspeed/runtime/zenflow/zenflow_utils.py @@ -107,7 +107,73 @@ def all_tensors_equal(tensor_list): return True +def _compute_zf_pt_affinity(zf_optimizer): + """Split this rank's cores into a ZenFlow-optimizer set and a training (PyTorch) set. + When every rank reports the same affinity the launcher did not bind workers, so do a + soft per-rank bind first, then carve off pt_reserved_cores_perc for training.""" + curr_rank = dist.get_rank() + total_rank = dist.get_world_size() + + current_affinity = psutil.Process().cpu_affinity() + all_affinities = [ + torch.zeros(len(current_affinity), + dtype=type(current_affinity[0]), + device=get_accelerator().current_device_name()) for _ in range(total_rank) + ] + dist.all_gather( + all_affinities, + torch.tensor(current_affinity, dtype=type(current_affinity[0]), + device=get_accelerator().current_device_name())) + if all_tensors_equal(all_affinities): + num_phy_cores = psutil.cpu_count(logical=False) + available_phy_cores = [i for i in current_affinity if i < num_phy_cores] + cores_per_rank = len(available_phy_cores) // total_rank + current_affinity = available_phy_cores[curr_rank * cores_per_rank:(curr_rank + 1) * cores_per_rank] + + pt_num_cores = math.ceil(zf_optimizer.pt_reserved_cores_perc * len(current_affinity)) + if pt_num_cores > 0 and pt_num_cores < len(current_affinity): + zf_affinity = current_affinity[pt_num_cores:] + pt_affinity = current_affinity[:pt_num_cores] + else: + zf_affinity = current_affinity + pt_affinity = current_affinity + return zf_affinity, pt_affinity + + +def _start_native_optimizer(zf_optimizer): + """In-process overlapped optimizer (ZeRO stage 1/2): a native ZenFlowAdam handle with + a background dispatcher and a pinned thread pool, replacing the optimizer subprocess. + Tensors are shared directly (same process), so there is no pipe/shared-memory plumbing. + The main thread is then confined to the training cores so it does not contend with the + optimizer's pinned pool.""" + from deepspeed.ops.adam import ZenFlowCPUAdam + + # Shallow-copy the param groups so building the in-process optimizer does not mutate the + # client optimizer's groups; the parameter tensors themselves stay shared. + param_groups_data = [dict(group) for group in zf_optimizer.optimizer.param_groups] + for group in param_groups_data: + for param in group["params"]: + if not hasattr(param, "stale_param"): + param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device) + + zf_affinity, pt_affinity = _compute_zf_pt_affinity(zf_optimizer) + + optimizer = ZenFlowCPUAdam(param_groups_data, overlap_step=True) + optimizer.init_native_overlap(zf_affinity) + zf_optimizer.zf_cpu_adam = optimizer + + psutil.Process().cpu_affinity(pt_affinity) + os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity)) + torch.set_num_threads(len(pt_affinity)) + + zf_optimizer.process_optimizer_established = True + + def start_optimizer_process(zf_optimizer): + if not zf_optimizer.zf_stage3: + _start_native_optimizer(zf_optimizer) + return + from multiprocessing import Pipe, get_context, Manager ctx = get_context("spawn") @@ -143,36 +209,8 @@ def start_optimizer_process(zf_optimizer): } for param in zf_optimizer.fp32_partitioned_groups_flat] if zf_optimizer.zf_stage3 else zf_optimizer.optimizer.param_groups) - curr_rank = dist.get_rank() - total_rank = dist.get_world_size() - current_process = psutil.Process() - current_affinity = current_process.cpu_affinity() - all_affinities = [ - torch.zeros(len(current_affinity), - dtype=type(current_affinity[0]), - device=get_accelerator().current_device_name()) for _ in range(total_rank) - ] - dist.all_gather( - all_affinities, - torch.tensor(current_affinity, dtype=type(current_affinity[0]), - device=get_accelerator().current_device_name())) - # When affinity across all ranks are the same, the workers are not binded. Do a soft bind here - if all_tensors_equal(all_affinities): - num_phy_cores = psutil.cpu_count(logical=False) - available_phy_cores = [i for i in current_affinity if i < num_phy_cores] - num_available_phy_cores = len(available_phy_cores) - my_rank = curr_rank - my_size = total_rank - cores_per_rank = num_available_phy_cores // my_size - current_affinity = available_phy_cores[my_rank * cores_per_rank:(my_rank + 1) * cores_per_rank] - pt_num_cores = math.ceil(zf_optimizer.pt_reserved_cores_perc * len(current_affinity)) - if pt_num_cores > 0 and pt_num_cores < len(current_affinity): - zf_affinity = current_affinity[pt_num_cores:] - pt_affinity = current_affinity[:pt_num_cores] - else: - zf_affinity = current_affinity - pt_affinity = current_affinity + zf_affinity, pt_affinity = _compute_zf_pt_affinity(zf_optimizer) zf_optimizer.process = ctx.Process( target=zenflow_optimizer_process, diff --git a/tests/unit/runtime/zenflow/test_zf.py b/tests/unit/runtime/zenflow/test_zf.py index 7adcdb784972..d176d82b817e 100644 --- a/tests/unit/runtime/zenflow/test_zf.py +++ b/tests/unit/runtime/zenflow/test_zf.py @@ -17,8 +17,14 @@ class BaseZenFlowTest: batch_size = 4 grad_acc_steps = 1 - def get_config_dict(self, stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, - full_warm_up_rounds): + def get_config_dict(self, + stage, + offload_selective_optimizer, + select_strategy, + select_interval, + update_interval, + full_warm_up_rounds, + overlap_step=False): config = { "train_batch_size": self.batch_size, "gradient_accumulation_steps": self.grad_acc_steps, @@ -40,7 +46,7 @@ def get_config_dict(self, stage, offload_selective_optimizer, select_strategy, s "select_strategy": select_strategy, "select_interval": select_interval, "update_interval": update_interval, - "overlap_step": False, + "overlap_step": overlap_step, "offload": offload_selective_optimizer, "auto_ratio": 0.99, "full_warm_up_rounds": full_warm_up_rounds, @@ -109,3 +115,24 @@ def test_zenflow_distributed(self, stage, offload_selective_optimizer, select_st config_dict = self.get_config_dict(stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, full_warm_up_rounds) self.run_training_distributed(config_dict) + + +# Stage 3 overlap still uses the optimizer subprocess, which cannot be spawned from the +# daemonic process the test harness runs in; it is covered once stage 3 moves to the +# in-process ZenFlowAdam path. Stage 1/2 use the in-process path and run fine here. +@pytest.mark.parametrize("stage", [1, 2]) +@pytest.mark.parametrize("full_warm_up_rounds", [0, 3]) +class TestZenFlowOverlapSingleGPU(DistributedTest, BaseZenFlowTest): + """overlap_step=True exercises the in-process ZenFlowAdam optimizer path for ZeRO + stage 1/2 (background dispatcher + pinned pool, no subprocess). Must stay finite.""" + world_size = 1 + + def test_zenflow_overlap(self, stage, full_warm_up_rounds): + config_dict = self.get_config_dict(stage, + offload_selective_optimizer=False, + select_strategy="auto", + select_interval="auto", + update_interval=4, + full_warm_up_rounds=full_warm_up_rounds, + overlap_step=True) + self.run_training_distributed(config_dict) From 40491c90660375e67e08cfcc2e3392c07cfeb524 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 10 Jun 2026 15:16:12 +0000 Subject: [PATCH 06/13] Run ZenFlow stage 1/2 overlapped optimizer in a separate native process Profiling the in-process design showed it regressed ~18% on large, memory- bandwidth-bound updates: the Adam moments (two thirds of the step's memory traffic) were allocated by the training thread and ended up NUMA-remote from the optimizer's pinned pool, and the pool contended with the training thread inside one process. A separate process avoids both -- it allocates its state locally on its own NUMA node and is isolated -- which is why the old subprocess was faster there. The old subprocess was only slow on small models because of its per-step Python/pickle/Manager overhead. So keep the separate process but make the coordination native: the optimizer runs the ZenFlowAdam pinned pool in its own process and talks to the training process through two process-shared semaphores in a shared-memory control block, instead of a pickling pipe. No Python in the optimizer loop, no per-step rebinding. Measured (ms/step, best of 3): 0.5M 7.6 vs 9.9, 134M 114 vs 119 -- faster than the old subprocess at both ends. - C++: ZenControl shared-memory block (sem_t cmd_ready/done, command, per-group hyperparameters); ZenFlowAdam::run_worker drives the pool from it; zenflow_adam_ctrl_{size,init,submit,wait,exit} for the training side. Reuses the pinned pool and run_step; in-process submit/wait kept only as a fast unit-test driver for the pool. Linux-only (POSIX semaphores). - Python: the optimizer process builds the pool, allocates state locally, and runs the worker loop; stage 1/2 submit/wait call the control functions. Drops the in-process ZenFlowCPUAdam overlap helpers. - Test: a cross-process op test (plain, not DistributedTest, so the non-daemonic pytest process can spawn the optimizer) checks bit-for-bit equality with the fused reference across alternating double buffers. The engine-level overlap test is removed again: like the subprocess, the optimizer process cannot be spawned from the daemonic test worker. Stage 3 still uses the pickling subprocess; migrating it is a follow-up. Verified: stage 1/2 training loss is bit-identical to the subprocess over a seeded run. Signed-off-by: Tingfeng Lan --- csrc/adam/cpu_adam.cpp | 23 ++++ csrc/adam/cpu_adam_impl.cpp | 106 ++++++++++++++++++ csrc/includes/cpu_adam.h | 20 ++++ deepspeed/ops/adam/zenflow_cpu_adam.py | 53 --------- .../runtime/zenflow/zenflow_stage_1_and_2.py | 23 ++-- deepspeed/runtime/zenflow/zenflow_utils.py | 76 ++++++++++--- tests/unit/ops/adam/test_cpu_adam.py | 73 ++++++++++++ tests/unit/runtime/zenflow/test_zf.py | 33 +----- 8 files changed, 295 insertions(+), 112 deletions(-) diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index 386eb7b03fe2..b7484bc7f5c9 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -50,4 +50,27 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) &zenflow_adam_destroy, "ZenFlowAdam destroy (C++)", pybind11::call_guard()); + +#if defined(__linux__) + // Cross-process driver (optimizer in a separate process, shared-memory semaphore control). + m.def( + "zenflow_adam_ctrl_size", &zenflow_adam_ctrl_size, "ZenFlowAdam control block size (C++)"); + m.def("zenflow_adam_ctrl_init", &zenflow_adam_ctrl_init, "ZenFlowAdam control init (C++)"); + m.def("zenflow_adam_run_worker", + &zenflow_adam_run_worker, + "ZenFlowAdam optimizer-process worker loop (C++)", + pybind11::call_guard()); + m.def("zenflow_adam_ctrl_submit", + &zenflow_adam_ctrl_submit, + "ZenFlowAdam cross-process submit (C++)", + pybind11::call_guard()); + m.def("zenflow_adam_ctrl_wait", + &zenflow_adam_ctrl_wait, + "ZenFlowAdam cross-process wait (C++)", + pybind11::call_guard()); + m.def("zenflow_adam_ctrl_exit", + &zenflow_adam_ctrl_exit, + "ZenFlowAdam cross-process exit (C++)", + pybind11::call_guard()); +#endif } diff --git a/csrc/adam/cpu_adam_impl.cpp b/csrc/adam/cpu_adam_impl.cpp index 7a8398f4cb9e..cef0a8d0ed52 100644 --- a/csrc/adam/cpu_adam_impl.cpp +++ b/csrc/adam/cpu_adam_impl.cpp @@ -22,6 +22,7 @@ #if defined(__linux__) #include #include +#include #endif using namespace std::string_literals; @@ -537,6 +538,27 @@ struct ZenGroup { torch::Tensor stale; // may be undefined -> stale snapshot skipped }; +#if defined(__linux__) +// Control block placed in a shared-memory buffer (a shared torch tensor's storage) so the +// main process and the optimizer process coordinate through two process-shared semaphores +// instead of a pickling pipe. The main process writes a command + per-group hyperparameters +// and posts cmd_ready; the worker runs the step and posts done. `done` is a counting +// semaphore, so a skipped wait (the engine's post-warmup transition) is drained later. +static constexpr int ZEN_MAX_GROUPS = 1024; +enum { ZEN_CMD_STEP = 0, ZEN_CMD_EXIT = 1 }; + +struct ZenControl { + sem_t cmd_ready; + sem_t done; + int cmd; + int now_state; + int64_t step; + int num_groups; + float hp[ZEN_MAX_GROUPS * 5]; // lr, beta1, beta2, eps, weight_decay per group + uint8_t bias_correction[ZEN_MAX_GROUPS]; +}; +#endif + class ZenFlowAdam { public: ZenFlowAdam(int optimizer_id, std::vector zf_affinity) : opt_id_(optimizer_id) @@ -622,6 +644,31 @@ class ZenFlowAdam { pool_.reset(); } +#if defined(__linux__) + // Process-mode driver: run in the optimizer process, block on the shared-memory control + // block, and run each requested step on the pinned pool. Returns on the exit command. + void run_worker(void* control_ptr) + { + ZenControl* ctrl = reinterpret_cast(control_ptr); + while (true) { + while (sem_wait(&ctrl->cmd_ready) != 0) {} // retry on EINTR + if (ctrl->cmd == ZEN_CMD_EXIT) break; + const int ng = ctrl->num_groups; + std::vector hps(ng); + for (int g = 0; g < ng; ++g) { + hps[g] = {ctrl->hp[g * 5 + 0], + ctrl->hp[g * 5 + 1], + ctrl->hp[g * 5 + 2], + ctrl->hp[g * 5 + 3], + ctrl->hp[g * 5 + 4], + (bool)ctrl->bias_correction[g]}; + } + run_step(ctrl->now_state, ctrl->step, hps); + sem_post(&ctrl->done); + } + } +#endif + private: void dispatcher_main() { @@ -752,3 +799,62 @@ void zenflow_adam_destroy(int handle) s_zenflow_adams.erase(it); } } + +#if defined(__linux__) +// Size (bytes) the shared control tensor must hold. +int64_t zenflow_adam_ctrl_size() { return (int64_t)sizeof(ZenControl); } + +// Called once by the main process before spawning the optimizer process. +void zenflow_adam_ctrl_init(uintptr_t control_ptr, int num_groups) +{ + TORCH_CHECK(num_groups <= ZEN_MAX_GROUPS, "ZenFlowAdam: too many param groups"); + auto* ctrl = reinterpret_cast(control_ptr); + ctrl->num_groups = num_groups; + ctrl->cmd = ZEN_CMD_STEP; + sem_init(&ctrl->cmd_ready, /*pshared=*/1, 0); + sem_init(&ctrl->done, /*pshared=*/1, 0); +} + +// Called in the optimizer process; blocks running steps until the exit command. +void zenflow_adam_run_worker(int handle, uintptr_t control_ptr) +{ s_zenflow_adams.at(handle)->run_worker(reinterpret_cast(control_ptr)); } + +void zenflow_adam_ctrl_submit(uintptr_t control_ptr, + int now_state, + int64_t step, + std::vector lr, + std::vector beta1, + std::vector beta2, + std::vector eps, + std::vector weight_decay, + std::vector bias_correction) +{ + auto* ctrl = reinterpret_cast(control_ptr); + const int ng = (int)lr.size(); + for (int g = 0; g < ng; ++g) { + ctrl->hp[g * 5 + 0] = lr[g]; + ctrl->hp[g * 5 + 1] = beta1[g]; + ctrl->hp[g * 5 + 2] = beta2[g]; + ctrl->hp[g * 5 + 3] = eps[g]; + ctrl->hp[g * 5 + 4] = weight_decay[g]; + ctrl->bias_correction[g] = bias_correction[g]; + } + ctrl->now_state = now_state; + ctrl->step = step; + ctrl->cmd = ZEN_CMD_STEP; + sem_post(&ctrl->cmd_ready); // release: hyperparameters above are visible to the worker +} + +void zenflow_adam_ctrl_wait(uintptr_t control_ptr) +{ + auto* ctrl = reinterpret_cast(control_ptr); + while (sem_wait(&ctrl->done) != 0) {} // retry on EINTR +} + +void zenflow_adam_ctrl_exit(uintptr_t control_ptr) +{ + auto* ctrl = reinterpret_cast(control_ptr); + ctrl->cmd = ZEN_CMD_EXIT; + sem_post(&ctrl->cmd_ready); +} +#endif diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index eb520ac34ae5..06cc8d02ed10 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "simd.h" #define STEP(SPAN) \ @@ -280,3 +281,22 @@ void zenflow_adam_submit(int handle, void zenflow_adam_wait(int handle); void zenflow_adam_destroy(int handle); + +#if defined(__linux__) +// Cross-process driver: the optimizer runs in a separate process and coordinates with the +// main process through two process-shared semaphores in a shared-memory control block. +int64_t zenflow_adam_ctrl_size(); +void zenflow_adam_ctrl_init(uintptr_t control_ptr, int num_groups); +void zenflow_adam_run_worker(int handle, uintptr_t control_ptr); +void zenflow_adam_ctrl_submit(uintptr_t control_ptr, + int now_state, + int64_t step, + std::vector lr, + std::vector beta1, + std::vector beta2, + std::vector eps, + std::vector weight_decay, + std::vector bias_correction); +void zenflow_adam_ctrl_wait(uintptr_t control_ptr); +void zenflow_adam_ctrl_exit(uintptr_t control_ptr); +#endif diff --git a/deepspeed/ops/adam/zenflow_cpu_adam.py b/deepspeed/ops/adam/zenflow_cpu_adam.py index 5fbee66d004d..5b8ab17622f7 100644 --- a/deepspeed/ops/adam/zenflow_cpu_adam.py +++ b/deepspeed/ops/adam/zenflow_cpu_adam.py @@ -149,56 +149,3 @@ def _parallel_step(self, step_id, now_state, group_info, closure=None): group_info['weight_decay'], group_info['bias_correction'], params, grads, exp_avgs, exp_avg_sqs, stale_params) return loss - - @torch.no_grad() - def init_native_overlap(self, zf_affinity): - """Create the native ZenFlowAdam handle and register every parameter group with - it. The optimizer state (double-buffered moments) is allocated eagerly here, - since the in-process worker needs the tensors registered before the first step. - Replaces the multiprocessing optimizer subprocess.""" - device = torch.device('cpu') - self.zf_handle = self.ds_opt_adam.zenflow_adam_create(self.opt_id, list(zf_affinity)) - - for group in self.param_groups: - for p in group['params']: - if not hasattr(p, 'overlap_grad'): - continue - assert p.data.device == device, "ZenFlowCPUAdam params must be on CPU" - - state = self.state[p] - if len(state) == 0: - state['step'] = 0 - state_dtype = torch.float if self.fp32_optimizer_states else p.dtype - exp_avg = torch.zeros_like(p.data, dtype=state_dtype, device=device) - exp_avg_sq = torch.zeros_like(p.data, dtype=state_dtype, device=device) - state['exp_avg'] = [exp_avg, exp_avg.clone()] - state['exp_avg_sq'] = [exp_avg_sq, exp_avg_sq.clone()] - - self.ds_opt_adam.zenflow_adam_register_group(self.zf_handle, p.data, p.overlap_grad[0].data, - p.overlap_grad[1].data, state['exp_avg'][0], - state['exp_avg'][1], state['exp_avg_sq'][0], - state['exp_avg_sq'][1], p.stale_param.data) - - def submit_overlap_step(self, now_state, step_id, group_infos): - """Hand one overlapped step to the native worker (non-blocking).""" - for group_id, group in enumerate(self.param_groups): - self.state[group['params'][0]]['step'] = step_id - lr, beta1, beta2, eps, weight_decay, bias_correction = [], [], [], [], [], [] - for info in group_infos: - lr.append(info['lr']) - beta1.append(info['betas'][0]) - beta2.append(info['betas'][1]) - eps.append(info['eps']) - weight_decay.append(info['weight_decay']) - bias_correction.append(1 if info['bias_correction'] else 0) - self.ds_opt_adam.zenflow_adam_submit(self.zf_handle, now_state, step_id, lr, beta1, beta2, eps, weight_decay, - bias_correction) - - def wait_overlap_step(self): - """Block (GIL released in C++) until the last submitted step finishes.""" - self.ds_opt_adam.zenflow_adam_wait(self.zf_handle) - - def __del__(self): - if hasattr(self, 'zf_handle'): - self.ds_opt_adam.zenflow_adam_destroy(self.zf_handle) - super().__del__() diff --git a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py index 31bb843931f5..41a61c9f4da5 100644 --- a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py @@ -675,7 +675,7 @@ def wait_last_update_and_copy(self): return self.timers(OPTIMIZER_RECV_PARAMS_TIMER).start() - self.zf_cpu_adam.wait_overlap_step() + self.zf_op.zenflow_adam_ctrl_wait(self.zf_ctrl.data_ptr()) self.timers(OPTIMIZER_RECV_PARAMS_TIMER).stop() for i, group in enumerate(self.bit16_groups): @@ -714,22 +714,21 @@ def zenflow_cpu_optimizer_step(self, now_state, scaled_global_grad_norm): if not self.process_optimizer_established: self.start_optimizer_process() - group_infos = [] + lr, beta1, beta2, eps, weight_decay, bias_correction = [], [], [], [], [], [] for group_no, group in enumerate(self.bit16_groups): single_grad_partition = self.single_partition_of_fp32_groups[group_no].overlap_grad[now_state] self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) - group_info = { - "lr": self.optimizer.param_groups[group_no]["lr"], - "betas": self.optimizer.param_groups[group_no]["betas"], - "eps": self.optimizer.param_groups[group_no]["eps"], - "weight_decay": self.optimizer.param_groups[group_no]["weight_decay"], - "bias_correction": self.optimizer.param_groups[group_no]["bias_correction"], - } + pg = self.optimizer.param_groups[group_no] + lr.append(pg["lr"]) + beta1.append(pg["betas"][0]) + beta2.append(pg["betas"][1]) + eps.append(pg["eps"]) + weight_decay.append(pg["weight_decay"]) + bias_correction.append(1 if pg["bias_correction"] else 0) - group_infos.append(group_info) - - self.zf_cpu_adam.submit_overlap_step(now_state, self.micro_step + 1, group_infos) + self.zf_op.zenflow_adam_ctrl_submit(self.zf_ctrl.data_ptr(), now_state, self.micro_step + 1, lr, beta1, beta2, + eps, weight_decay, bias_correction) def step(self, closure=None): """ diff --git a/deepspeed/runtime/zenflow/zenflow_utils.py b/deepspeed/runtime/zenflow/zenflow_utils.py index 2d701c8ac8c7..6cef4cc16d8e 100644 --- a/deepspeed/runtime/zenflow/zenflow_utils.py +++ b/deepspeed/runtime/zenflow/zenflow_utils.py @@ -140,31 +140,73 @@ def _compute_zf_pt_affinity(zf_optimizer): return zf_affinity, pt_affinity +def zenflow_optimizer_process_native(groups, ctrl, ready, zf_affinity, adamw_mode): + """ZeRO stage 1/2 optimizer process. Builds the native ZenFlowAdam pinned pool and runs + the worker loop driven by the shared-memory control block (no pickling pipe). The Adam + state is allocated here, in this process pinned to the optimizer cores, so it is + NUMA-local to the pool -- which is what makes a separate process worthwhile over an + in-process thread for large, memory-bandwidth-bound updates.""" + disable_accelerator() + current_process = psutil.Process() + current_process.cpu_affinity(zf_affinity) + os.environ['OMP_NUM_THREADS'] = str(len(zf_affinity)) + + from deepspeed.ops.op_builder import CPUAdamBuilder + op = CPUAdamBuilder().load() + op.create_adam(0, 1e-3, 0.9, 0.999, 1e-8, 0.0, adamw_mode, False) + handle = op.zenflow_adam_create(0, list(zf_affinity)) + for param, overlap_grad0, overlap_grad1, stale in groups: + exp_avg0 = torch.zeros_like(param) + exp_avg1 = torch.zeros_like(param) + exp_avg_sq0 = torch.zeros_like(param) + exp_avg_sq1 = torch.zeros_like(param) + op.zenflow_adam_register_group(handle, param, overlap_grad0, overlap_grad1, exp_avg0, exp_avg1, exp_avg_sq0, + exp_avg_sq1, stale) + ready.set() + op.zenflow_adam_run_worker(handle, ctrl.data_ptr()) + op.zenflow_adam_destroy(handle) + op.destroy_adam(0) + + def _start_native_optimizer(zf_optimizer): - """In-process overlapped optimizer (ZeRO stage 1/2): a native ZenFlowAdam handle with - a background dispatcher and a pinned thread pool, replacing the optimizer subprocess. - Tensors are shared directly (same process), so there is no pipe/shared-memory plumbing. - The main thread is then confined to the training cores so it does not contend with the - optimizer's pinned pool.""" - from deepspeed.ops.adam import ZenFlowCPUAdam + """ZeRO stage 1/2 overlapped optimizer: run the native ZenFlowAdam in a separate process, + coordinated through a shared-memory semaphore control block instead of a pickling pipe. + Keeps the isolation of a separate process (NUMA-local state, no contention with the + training thread) while removing the per-step Python/IPC overhead of the old subprocess.""" + from multiprocessing import get_context + from deepspeed.ops.op_builder import CPUAdamBuilder + + op = CPUAdamBuilder().load() + zf_optimizer.zf_op = op + + # Share the tensors the optimizer process reads/writes; the Adam state stays process-local. + groups = [] + for group in zf_optimizer.optimizer.param_groups: + param = group["params"][0] + param.data.share_memory_() + if not hasattr(param, "stale_param"): + param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device) + param.stale_param.data.share_memory_() + param.overlap_grad[0].data.share_memory_() + param.overlap_grad[1].data.share_memory_() + groups.append((param.data, param.overlap_grad[0].data, param.overlap_grad[1].data, param.stale_param.data)) - # Shallow-copy the param groups so building the in-process optimizer does not mutate the - # client optimizer's groups; the parameter tensors themselves stay shared. - param_groups_data = [dict(group) for group in zf_optimizer.optimizer.param_groups] - for group in param_groups_data: - for param in group["params"]: - if not hasattr(param, "stale_param"): - param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device) + ctrl = torch.zeros(op.zenflow_adam_ctrl_size(), dtype=torch.uint8).share_memory_() + op.zenflow_adam_ctrl_init(ctrl.data_ptr(), len(groups)) + zf_optimizer.zf_ctrl = ctrl zf_affinity, pt_affinity = _compute_zf_pt_affinity(zf_optimizer) - optimizer = ZenFlowCPUAdam(param_groups_data, overlap_step=True) - optimizer.init_native_overlap(zf_affinity) - zf_optimizer.zf_cpu_adam = optimizer + ctx = get_context("spawn") + ready = ctx.Event() + proc = ctx.Process(target=zenflow_optimizer_process_native, args=(groups, ctrl, ready, zf_affinity, True)) + proc.daemon = True + proc.start() + ready.wait() + zf_optimizer.process = proc psutil.Process().cpu_affinity(pt_affinity) os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity)) - torch.set_num_threads(len(pt_affinity)) zf_optimizer.process_optimizer_established = True diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index 304925f5542d..09c785046563 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -3,6 +3,8 @@ # DeepSpeed Team +import os +import sys import torch import numpy as np import pytest @@ -420,6 +422,77 @@ def submit(now, step): ds.destroy_adam(opt_id) +def _zenflow_adam_proc_worker(param, g0, g1, ea0, ea1, eq0, eq1, stale, ctrl, ready, affinity): + op = CPUAdamBuilder().load() + op.create_adam(0, 1e-3, 0.9, 0.999, 1e-8, 0.0, True, False) + handle = op.zenflow_adam_create(0, affinity) + op.zenflow_adam_register_group(handle, param, g0, g1, ea0, ea1, eq0, eq1, stale) + ready.set() + op.zenflow_adam_run_worker(handle, ctrl.data_ptr()) # blocks until the exit command + op.zenflow_adam_destroy(handle) + op.destroy_adam(0) + + +@pytest.mark.skipif(not sys.platform.startswith("linux"), reason="cross-process ZenFlowAdam is Linux-only") +def test_zenflow_adam_cross_process(): + """The optimizer-process driver (shared-memory semaphore control + native worker, the + production path for ZenFlow stage 1/2 overlap) must match the fused reference bit-for-bit + with alternating double buffers. Run as a plain test, not DistributedTest, so the pytest + process (non-daemonic) can spawn the optimizer process.""" + import torch.multiprocessing as mp + + op = CPUAdamBuilder().load() + if not hasattr(op, "zenflow_adam_ctrl_size"): + pytest.skip("cross-process ZenFlowAdam not available in this build") + + lr, beta1, beta2, eps, wd = 1e-3, 0.9, 0.999, 1e-8, 0.0 + n = 100003 # non-SIMD-aligned, exercises the scalar tail + affinity = list(range(min(4, os.cpu_count() or 1))) + + ctrl = torch.zeros(op.zenflow_adam_ctrl_size(), dtype=torch.uint8).share_memory_() + op.zenflow_adam_ctrl_init(ctrl.data_ptr(), 1) + + torch.manual_seed(0) + param = torch.randn(n).share_memory_() + g = [torch.zeros(n).share_memory_(), torch.zeros(n).share_memory_()] + ea = [torch.zeros(n).share_memory_(), torch.zeros(n).share_memory_()] + eq = [torch.zeros(n).share_memory_(), torch.zeros(n).share_memory_()] + stale = torch.zeros(n).share_memory_() + + op.create_adam(1, lr, beta1, beta2, eps, wd, True, False) + p_ref = param.clone() + ea_ref = [ea[0].clone(), ea[1].clone()] + eq_ref = [eq[0].clone(), eq[1].clone()] + st_ref = stale.clone() + + ctx = mp.get_context("spawn") + ready = ctx.Event() + proc = ctx.Process(target=_zenflow_adam_proc_worker, + args=(param, g[0], g[1], ea[0], ea[1], eq[0], eq[1], stale, ctrl, ready, affinity)) + proc.start() + try: + assert ready.wait(timeout=60), "optimizer process did not start" + for step in range(1, 6): + now = step & 1 + grad = torch.randn(n) + g[now].copy_(grad) + op.zenflow_adam_ctrl_submit(ctrl.data_ptr(), now, step, [lr], [beta1], [beta2], [eps], [wd], [1]) + op.zenflow_adam_ctrl_wait(ctrl.data_ptr()) + op.adam_update_multi(1, step, lr, beta1, beta2, eps, wd, True, [p_ref], [grad.clone()], [ea_ref[now]], + [eq_ref[now]], [st_ref]) + assert torch.equal(param, p_ref), f"param mismatch step {step}" + assert torch.equal(ea[now], ea_ref[now]), f"exp_avg mismatch step {step}" + assert torch.equal(eq[now], eq_ref[now]), f"exp_avg_sq mismatch step {step}" + assert torch.equal(stale, st_ref), f"stale mismatch step {step}" + op.zenflow_adam_ctrl_exit(ctrl.data_ptr()) + proc.join(timeout=10) + finally: + if proc.is_alive(): + proc.terminate() + proc.join(timeout=5) + op.destroy_adam(1) + + class TestCPUAdamGPUError(DistributedTest): def test_cpu_adam_gpu_error(self): diff --git a/tests/unit/runtime/zenflow/test_zf.py b/tests/unit/runtime/zenflow/test_zf.py index d176d82b817e..7adcdb784972 100644 --- a/tests/unit/runtime/zenflow/test_zf.py +++ b/tests/unit/runtime/zenflow/test_zf.py @@ -17,14 +17,8 @@ class BaseZenFlowTest: batch_size = 4 grad_acc_steps = 1 - def get_config_dict(self, - stage, - offload_selective_optimizer, - select_strategy, - select_interval, - update_interval, - full_warm_up_rounds, - overlap_step=False): + def get_config_dict(self, stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, + full_warm_up_rounds): config = { "train_batch_size": self.batch_size, "gradient_accumulation_steps": self.grad_acc_steps, @@ -46,7 +40,7 @@ def get_config_dict(self, "select_strategy": select_strategy, "select_interval": select_interval, "update_interval": update_interval, - "overlap_step": overlap_step, + "overlap_step": False, "offload": offload_selective_optimizer, "auto_ratio": 0.99, "full_warm_up_rounds": full_warm_up_rounds, @@ -115,24 +109,3 @@ def test_zenflow_distributed(self, stage, offload_selective_optimizer, select_st config_dict = self.get_config_dict(stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, full_warm_up_rounds) self.run_training_distributed(config_dict) - - -# Stage 3 overlap still uses the optimizer subprocess, which cannot be spawned from the -# daemonic process the test harness runs in; it is covered once stage 3 moves to the -# in-process ZenFlowAdam path. Stage 1/2 use the in-process path and run fine here. -@pytest.mark.parametrize("stage", [1, 2]) -@pytest.mark.parametrize("full_warm_up_rounds", [0, 3]) -class TestZenFlowOverlapSingleGPU(DistributedTest, BaseZenFlowTest): - """overlap_step=True exercises the in-process ZenFlowAdam optimizer path for ZeRO - stage 1/2 (background dispatcher + pinned pool, no subprocess). Must stay finite.""" - world_size = 1 - - def test_zenflow_overlap(self, stage, full_warm_up_rounds): - config_dict = self.get_config_dict(stage, - offload_selective_optimizer=False, - select_strategy="auto", - select_interval="auto", - update_interval=4, - full_warm_up_rounds=full_warm_up_rounds, - overlap_step=True) - self.run_training_distributed(config_dict) From 4164f14d8dfefe82ab822bf9a2966e076fa710ab Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 10 Jun 2026 15:38:55 +0000 Subject: [PATCH 07/13] Run ZenFlow stage 3 overlapped optimizer in the native process Migrate ZeRO stage 3 overlap to the same separate native-process optimizer used for stage 1/2: the optimizer process runs the ZenFlowAdam pinned pool driven by the shared-memory semaphore control block, instead of the pickling subprocess. - Generalize the optimizer-process startup to gather groups from fp32_partitioned_groups_flat for stage 3 (one flat partition per sub-group) and from the param groups for stage 1/2; both carry overlap_grad double buffers and a stale snapshot. start_optimizer_process now always takes the native path. - engine_stage3 submit/wait call zenflow_adam_ctrl_submit/ctrl_wait instead of the pipe; the warm-up transition guard is unchanged. - Remove the now-unreachable pickling optimizer loop (zenflow_optimizer_process) and its subprocess setup. Verified: stage 3 training loss is bit-identical to the old subprocess over a seeded run. Note: ZenFlowCPUAdam._parallel_step (and the adam_update_multi Python caller) are now only reachable from tests; pruning those superseded layers is left to a dedicated cleanup. Signed-off-by: Tingfeng Lan --- deepspeed/runtime/zenflow/engine_stage3.py | 33 +++--- deepspeed/runtime/zenflow/zenflow_utils.py | 127 +++------------------ 2 files changed, 31 insertions(+), 129 deletions(-) diff --git a/deepspeed/runtime/zenflow/engine_stage3.py b/deepspeed/runtime/zenflow/engine_stage3.py index 0d95be9f2f8a..2539c6e4dfb8 100644 --- a/deepspeed/runtime/zenflow/engine_stage3.py +++ b/deepspeed/runtime/zenflow/engine_stage3.py @@ -542,40 +542,33 @@ def zenflow_cpu_optimizer_overlap_step(optimizer_z3, now_state, scaled_global_gr if not optimizer_z3.process_optimizer_established: optimizer_z3.start_optimizer_process() - group_infos = [] + lr, beta1, beta2, eps, weight_decay, bias_correction = [], [], [], [], [], [] for group_no, group in enumerate(optimizer_z3.fp16_groups): optimizer_z3.unscale_and_clip_grads(group_no, scaled_global_grad_norm, now_state) param_group_id = optimizer_z3.sub_group_to_group_id[group_no] + pg = optimizer_z3.optimizer.param_groups[param_group_id] + lr.append(pg["lr"]) + beta1.append(pg["betas"][0]) + beta2.append(pg["betas"][1]) + eps.append(pg["eps"]) + weight_decay.append(pg["weight_decay"]) + bias_correction.append(1 if pg["bias_correction"] else 0) - group_info = { - "lr": optimizer_z3.optimizer.param_groups[param_group_id]["lr"], - "betas": optimizer_z3.optimizer.param_groups[param_group_id]["betas"], - "eps": optimizer_z3.optimizer.param_groups[param_group_id]["eps"], - "weight_decay": optimizer_z3.optimizer.param_groups[param_group_id]["weight_decay"], - "bias_correction": optimizer_z3.optimizer.param_groups[param_group_id]["bias_correction"], - } - - group_infos.append(group_info) - - optimizer_z3.parent_conn.send({ - "type": "step", - "now_state": now_state, - "micro_step": optimizer_z3.micro_step, - "group_infos": group_infos - }) + optimizer_z3.zf_op.zenflow_adam_ctrl_submit(optimizer_z3.zf_ctrl.data_ptr(), now_state, + optimizer_z3.micro_step + 1, lr, beta1, beta2, eps, weight_decay, + bias_correction) def wait_last_update_and_copy(optimizer_z3, timer_names): - if not hasattr(optimizer_z3, 'parent_conn'): + if not getattr(optimizer_z3, 'process_optimizer_established', False): return if optimizer_z3.micro_step + 1 > optimizer_z3.full_warm_up_rounds and optimizer_z3.first_update_round_after_warmup: optimizer_z3.first_update_round_after_warmup = False return - msg = optimizer_z3.parent_conn.recv() - assert msg["type"] == "done", "Optimizer process did not finish stepping correctly." + optimizer_z3.zf_op.zenflow_adam_ctrl_wait(optimizer_z3.zf_ctrl.data_ptr()) for sub_group_id, group in enumerate(optimizer_z3.fp16_groups): if optimizer_z3.fp16_partitioned_groups_flat[sub_group_id] is not None: diff --git a/deepspeed/runtime/zenflow/zenflow_utils.py b/deepspeed/runtime/zenflow/zenflow_utils.py index 6cef4cc16d8e..c567166dc0bb 100644 --- a/deepspeed/runtime/zenflow/zenflow_utils.py +++ b/deepspeed/runtime/zenflow/zenflow_utils.py @@ -57,48 +57,6 @@ def disable_accelerator(): accelerator._initialized = True -def zenflow_optimizer_process(pipe, param_groups, shared_overlap_grad_map, shared_stale_param_map, zf_affinity): - disable_accelerator() - - current_process = psutil.Process() - current_process.cpu_affinity(zf_affinity) - os.environ['OMP_NUM_THREADS'] = str(len(zf_affinity)) - - from deepspeed.ops.adam import ZenFlowCPUAdam - optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True) - - pipe.send({"type": "ready"}) - - # TODO: replace this with rpc - - while True: - cmd = pipe.recv() - if cmd["type"] == "step": - now_state = cmd["now_state"] - micro_step = cmd["micro_step"] - group_infos = cmd["group_infos"] - - for group_no, group_info in enumerate(group_infos): - original_param_groups = optimizer.param_groups - optimizer.param_groups = [original_param_groups[group_no]] - group = optimizer.param_groups[0] - - for param_idx, param in enumerate(group["params"]): - key = (group_no, param_idx) - if key in shared_overlap_grad_map: - param.overlap_grad = shared_overlap_grad_map[key] - if key in shared_stale_param_map: - param.stale_param = shared_stale_param_map[key] - - optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info) - - optimizer.param_groups = original_param_groups - - pipe.send({"type": "done"}) - elif cmd["type"] == "exit": - break - - def all_tensors_equal(tensor_list): first_tensor = tensor_list[0] for tensor in tensor_list[1:]: @@ -141,10 +99,10 @@ def _compute_zf_pt_affinity(zf_optimizer): def zenflow_optimizer_process_native(groups, ctrl, ready, zf_affinity, adamw_mode): - """ZeRO stage 1/2 optimizer process. Builds the native ZenFlowAdam pinned pool and runs - the worker loop driven by the shared-memory control block (no pickling pipe). The Adam - state is allocated here, in this process pinned to the optimizer cores, so it is - NUMA-local to the pool -- which is what makes a separate process worthwhile over an + """ZenFlow overlapped optimizer process (ZeRO stage 1/2/3). Builds the native ZenFlowAdam + pinned pool and runs the worker loop driven by the shared-memory control block (no pickling + pipe). The Adam state is allocated here, in this process pinned to the optimizer cores, so + it is NUMA-local to the pool -- which is what makes a separate process worthwhile over an in-process thread for large, memory-bandwidth-bound updates.""" disable_accelerator() current_process = psutil.Process() @@ -169,20 +127,27 @@ def zenflow_optimizer_process_native(groups, ctrl, ready, zf_affinity, adamw_mod def _start_native_optimizer(zf_optimizer): - """ZeRO stage 1/2 overlapped optimizer: run the native ZenFlowAdam in a separate process, - coordinated through a shared-memory semaphore control block instead of a pickling pipe. - Keeps the isolation of a separate process (NUMA-local state, no contention with the - training thread) while removing the per-step Python/IPC overhead of the old subprocess.""" + """ZenFlow overlapped optimizer (ZeRO stage 1/2/3): run the native ZenFlowAdam in a + separate process, coordinated through a shared-memory semaphore control block instead of a + pickling pipe. Keeps the isolation of a separate process (NUMA-local state, no contention + with the training thread) while removing the per-step Python/IPC overhead of the old + subprocess.""" from multiprocessing import get_context from deepspeed.ops.op_builder import CPUAdamBuilder op = CPUAdamBuilder().load() zf_optimizer.zf_op = op + # Stage 3 steps each flattened sub-group partition; stage 1/2 steps one flat partition per + # param group. Both carry overlap_grad double buffers and a stale snapshot. + if zf_optimizer.zf_stage3: + params = list(zf_optimizer.fp32_partitioned_groups_flat) + else: + params = [group["params"][0] for group in zf_optimizer.optimizer.param_groups] + # Share the tensors the optimizer process reads/writes; the Adam state stays process-local. groups = [] - for group in zf_optimizer.optimizer.param_groups: - param = group["params"][0] + for param in params: param.data.share_memory_() if not hasattr(param, "stale_param"): param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device) @@ -212,60 +177,4 @@ def _start_native_optimizer(zf_optimizer): def start_optimizer_process(zf_optimizer): - if not zf_optimizer.zf_stage3: - _start_native_optimizer(zf_optimizer) - return - - from multiprocessing import Pipe, get_context, Manager - - ctx = get_context("spawn") - zf_optimizer.parent_conn, zf_optimizer.child_conn = Pipe() - - manager = Manager() - zf_optimizer.shared_overlap_grad_map = manager.dict() - zf_optimizer.shared_stale_param_map = manager.dict() - - if zf_optimizer.zf_stage3: - params_iter = [((group_no, 0), param) - for group_no, param in enumerate(zf_optimizer.fp32_partitioned_groups_flat)] - else: - params_iter = [((group_no, param_idx), param) - for group_no, group in enumerate(zf_optimizer.optimizer.param_groups) - for param_idx, param in enumerate(group["params"])] - - for key, param in params_iter: - param.data.share_memory_() - - if not hasattr(param, "stale_param"): - param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device) - param.stale_param.data.share_memory_() - zf_optimizer.shared_stale_param_map[key] = param.stale_param - - if getattr(param, "overlap_grad", None) is not None: - param.overlap_grad[0].data.share_memory_() - param.overlap_grad[1].data.share_memory_() - zf_optimizer.shared_overlap_grad_map[key] = param.overlap_grad - - param_groups_data = ([{ - "params": [param] - } for param in zf_optimizer.fp32_partitioned_groups_flat] - if zf_optimizer.zf_stage3 else zf_optimizer.optimizer.param_groups) - - current_process = psutil.Process() - zf_affinity, pt_affinity = _compute_zf_pt_affinity(zf_optimizer) - - zf_optimizer.process = ctx.Process( - target=zenflow_optimizer_process, - args=(zf_optimizer.child_conn, param_groups_data, zf_optimizer.shared_overlap_grad_map, - zf_optimizer.shared_stale_param_map, zf_affinity), - ) - zf_optimizer.process.daemon = True - zf_optimizer.process.start() - - current_process.cpu_affinity(pt_affinity) - os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity)) - - msg = zf_optimizer.parent_conn.recv() - assert msg["type"] == "ready", "Optimizer process did not initialize correctly." - - zf_optimizer.process_optimizer_established = True + _start_native_optimizer(zf_optimizer) From 60181d99003fe054f0f57b7395a1ba6e88cea785 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 10 Jun 2026 17:23:51 +0000 Subject: [PATCH 08/13] Fail fast if the ZenFlow optimizer process does not start The training process waited unbounded on the optimizer process's ready signal. If that process crashed during initialization (for example a SIGBUS when /dev/shm is exhausted, or a bad spawn), the training process blocked forever on the first step's wait with no indication of what went wrong. Bound the wait and raise a clear error if the optimizer process never signals ready, so the failure surfaces instead of hanging. Verified at scale: ZeRO stage 1/2/3 overlap trains 0.5B and 1.5B parameter models on 1 and 2 GPUs (the optimizer process registers the flattened partitions, signals ready, and steps to finite loss). Signed-off-by: Tingfeng Lan --- deepspeed/runtime/zenflow/zenflow_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zenflow/zenflow_utils.py b/deepspeed/runtime/zenflow/zenflow_utils.py index c567166dc0bb..b80814b50ce0 100644 --- a/deepspeed/runtime/zenflow/zenflow_utils.py +++ b/deepspeed/runtime/zenflow/zenflow_utils.py @@ -167,7 +167,13 @@ def _start_native_optimizer(zf_optimizer): proc = ctx.Process(target=zenflow_optimizer_process_native, args=(groups, ctrl, ready, zf_affinity, True)) proc.daemon = True proc.start() - ready.wait() + # Wait for the optimizer process to finish building its pool and registering tensors. + # If it crashed during init (e.g. it never signals), fail loudly instead of blocking the + # training process forever on the first step's wait. + if not ready.wait(timeout=600): + proc.terminate() + raise RuntimeError("ZenFlow optimizer process failed to become ready (it likely crashed " + "during initialization; check the optimizer process traceback above)") zf_optimizer.process = proc psutil.Process().cpu_affinity(pt_affinity) From 59daa9e16a6341a8f939d6e6e675e090fc95cd3c Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 10 Jun 2026 21:22:25 +0000 Subject: [PATCH 09/13] Stream ZenFlow optimizer copyback in chunks to bound the GPU memory peak When the overlapped CPU optimizer finishes, the updated fp32 master partition is copied back to its GPU bit16 partition via bit16.copy_(fp32.to(device)). The .to(device) first materializes the entire fp32 partition on the GPU -- a transient spike of ~2x the bit16 partition (measured ~2944 MiB for a 0.75B-param partition) stacked on top of the model, which is exactly the memory CPU offload is meant to save. Stream the copy in fixed-size chunks so only one chunk's fp32 staging tensor is resident at a time; the transient peak drops to the chunk size (measured ~256 MiB) and the bit16 result is unchanged. End-to-end throughput is unaffected. Signed-off-by: Tingfeng Lan --- .../runtime/zenflow/zenflow_stage_1_and_2.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py index 41a61c9f4da5..a8190e8e0c51 100644 --- a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py @@ -34,6 +34,11 @@ ] INITIAL_MICRO_STEP_ID = -1 +# Number of elements copied per chunk when streaming the updated fp32 master partition +# back to the GPU bit16 partition. Bounds the transient fp32 staging tensor on the GPU +# (chunk * 4 bytes) instead of materializing the whole partition at once. +ZENFLOW_COPYBACK_CHUNK_NUMEL = 32 * 1024 * 1024 + SELECTIVE_OPTIMIZER_UPDATE_TIMER = 'selective_optimizer_update' SELECTIVE_OPTIMIZER_PROCESS_TIMER = 'selective_optimizer_process' SELECTIVE_OPTIMIZER_STEP_TIMER = 'selective_optimizer_step' @@ -683,7 +688,7 @@ def wait_last_update_and_copy(self): bit16_partitions = self.parallel_partitioned_bit16_groups[i] fp32_partition = self.optimizer.param_groups[i]['params'][0].stale_param.data self.timers(OPTIMIZER_TRANSMIT_TIMER).start() - bit16_partitions[partition_id].data.copy_(fp32_partition.to(get_accelerator().current_device_name()).data) + self._copyback_fp32_partition_to_bit16(fp32_partition, bit16_partitions[partition_id].data) self.timers(OPTIMIZER_TRANSMIT_TIMER).stop() see_memory_usage('After optimizer before all-gather') @@ -709,6 +714,24 @@ def wait_last_update_and_copy(self): self.timers.log(OPTIMIZER_TIMERS) see_memory_usage('After zero_optimizer step') + def _copyback_fp32_partition_to_bit16(self, fp32_partition, bit16_partition): + """Stream the updated fp32 master partition back to its GPU bit16 partition in chunks. + + The straightforward ``bit16.copy_(fp32.to(device))`` first materializes the whole + fp32 partition on the GPU, a transient spike of ~2x the bit16 partition (~3GB for a + 0.75B-param partition) stacked on top of the model -- exactly the memory the offload + is meant to save. Copying chunk by chunk keeps only one chunk's fp32 staging tensor + resident, so the peak drops to the chunk size; the bit16 result is unchanged. + """ + device = get_accelerator().current_device_name() + fp32_flat = fp32_partition.view(-1) + bit16_flat = bit16_partition.view(-1) + numel = fp32_flat.numel() + for offset in range(0, numel, ZENFLOW_COPYBACK_CHUNK_NUMEL): + end = min(offset + ZENFLOW_COPYBACK_CHUNK_NUMEL, numel) + gpu_chunk = fp32_flat[offset:end].to(device, non_blocking=True) + bit16_flat[offset:end].copy_(gpu_chunk) + def zenflow_cpu_optimizer_step(self, now_state, scaled_global_grad_norm): if not self.process_optimizer_established: From 2f9590bb8199c706a2483e18fb37aa3dc5fb752d Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 10 Jun 2026 21:22:44 +0000 Subject: [PATCH 10/13] Remove ZenFlow's superseded in-process overlapped optimizer path ZenFlow's overlapped optimizer now always runs in a dedicated process driven by a shared-memory semaphore control block (ZenFlowAdam::run_worker). The earlier in-process variant -- a background dispatcher thread with submit_step/wait_step, exposed as zenflow_adam_submit/wait(handle) and ZenFlowCPUAdam._parallel_step -- was kept only as a unit-test driver and is no longer reachable in production. Remove it: drop the dispatcher thread and its sync state from ZenFlowAdam, delete the handle-based submit/wait bindings and _parallel_step, and delete the TestZenFlowAdamNative test. With the in-process submit/wait gone, the cross-process control-block ops reclaim the plain names zenflow_adam_submit/wait. The fused adam_update_multi op (still used by the worker kernel and its own tests) is kept. No functional change to the production cross-process path; cross-process and fused unit tests and a stage 1/2 end-to-end run remain bit-identical. Signed-off-by: Tingfeng Lan --- csrc/adam/cpu_adam.cpp | 28 ++-- csrc/adam/cpu_adam_impl.cpp | 152 +++--------------- csrc/includes/cpu_adam.h | 39 ++--- deepspeed/ops/adam/zenflow_cpu_adam.py | 80 +-------- deepspeed/runtime/zenflow/engine_stage3.py | 7 +- .../runtime/zenflow/zenflow_stage_1_and_2.py | 6 +- deepspeed/runtime/zenflow/zenflow_utils.py | 19 +-- tests/unit/ops/adam/test_cpu_adam.py | 112 +------------ 8 files changed, 64 insertions(+), 379 deletions(-) diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index b7484bc7f5c9..4c14d0552e63 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -32,27 +32,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); - // ZenFlowAdam: in-process overlapped CPU Adam. wait/submit/destroy release the GIL - // so the optimizer thread overlaps the Python training thread. + // ZenFlowAdam: the native CPU Adam backing ZenFlow's overlapped optimizer step. create / + // register_group / destroy set up the handle-indexed pinned pool (used by the worker process). m.def("zenflow_adam_create", &zenflow_adam_create, "ZenFlowAdam create (C++)"); m.def("zenflow_adam_register_group", &zenflow_adam_register_group, "ZenFlowAdam register a parameter group (C++)"); - m.def("zenflow_adam_submit", - &zenflow_adam_submit, - "ZenFlowAdam submit an overlapped step (C++)", - pybind11::call_guard()); - m.def("zenflow_adam_wait", - &zenflow_adam_wait, - "ZenFlowAdam wait for a submitted step (C++)", - pybind11::call_guard()); m.def("zenflow_adam_destroy", &zenflow_adam_destroy, "ZenFlowAdam destroy (C++)", pybind11::call_guard()); #if defined(__linux__) - // Cross-process driver (optimizer in a separate process, shared-memory semaphore control). + // The optimizer runs in a separate process, coordinated through a shared-memory semaphore + // control block. submit/wait/run_worker release the GIL so the optimizer process overlaps + // the Python training thread. m.def( "zenflow_adam_ctrl_size", &zenflow_adam_ctrl_size, "ZenFlowAdam control block size (C++)"); m.def("zenflow_adam_ctrl_init", &zenflow_adam_ctrl_init, "ZenFlowAdam control init (C++)"); @@ -60,13 +54,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) &zenflow_adam_run_worker, "ZenFlowAdam optimizer-process worker loop (C++)", pybind11::call_guard()); - m.def("zenflow_adam_ctrl_submit", - &zenflow_adam_ctrl_submit, - "ZenFlowAdam cross-process submit (C++)", + m.def("zenflow_adam_submit", + &zenflow_adam_submit, + "ZenFlowAdam submit an overlapped step (C++)", pybind11::call_guard()); - m.def("zenflow_adam_ctrl_wait", - &zenflow_adam_ctrl_wait, - "ZenFlowAdam cross-process wait (C++)", + m.def("zenflow_adam_wait", + &zenflow_adam_wait, + "ZenFlowAdam wait for a submitted step (C++)", pybind11::call_guard()); m.def("zenflow_adam_ctrl_exit", &zenflow_adam_ctrl_exit, diff --git a/csrc/adam/cpu_adam_impl.cpp b/csrc/adam/cpu_adam_impl.cpp index cef0a8d0ed52..8d2643739de8 100644 --- a/csrc/adam/cpu_adam_impl.cpp +++ b/csrc/adam/cpu_adam_impl.cpp @@ -411,15 +411,13 @@ int destroy_adam_optimizer(int optimizer_id) } // --------------------------------------------------------------------------- -// ZenFlowAdam: in-process, GIL-released CPU Adam for ZenFlow's overlapped step. +// ZenFlowAdam: the native CPU Adam that backs ZenFlow's overlapped optimizer step. // -// Replaces the multiprocessing optimizer subprocess. The optimizer step runs on -// a background dispatcher thread; the heavy per-element math is fanned out to a -// pool of worker threads pinned to ZenFlow's dedicated cores, each running its -// element slice through the serial (parallel=false) kernel. Because the workers -// hold no GIL while computing, the Python training thread keeps running. Since -// everything lives in one process the optimizer touches the same tensors the -// main thread holds -- no shared memory, pipe, or per-step rebinding. +// The optimizer runs in a dedicated process (see zenflow_utils.start_optimizer_process): +// run_worker() blocks on a shared-memory control block and, for each requested step, fans +// the heavy per-element math out to a pool of worker threads pinned to ZenFlow's dedicated +// cores, each running its element slice through the serial (parallel=false) kernel. The +// Adam state lives in that process, NUMA-local to the pool. // --------------------------------------------------------------------------- // A persistent pool of threads pinned to a fixed core set. parallel_for() splits @@ -562,12 +560,9 @@ struct ZenControl { class ZenFlowAdam { public: ZenFlowAdam(int optimizer_id, std::vector zf_affinity) : opt_id_(optimizer_id) - { - pool_ = std::make_unique(zf_affinity); - dispatcher_ = std::thread(&ZenFlowAdam::dispatcher_main, this); - } + { pool_ = std::make_unique(zf_affinity); } - ~ZenFlowAdam() { shutdown(); } + ~ZenFlowAdam() = default; void register_group(torch::Tensor param, torch::Tensor grad0, @@ -591,59 +586,6 @@ class ZenFlowAdam { groups_.push_back(std::move(g)); } - // Hand a step to the dispatcher and return immediately (non-blocking). - void submit_step(int now_state, - int64_t step, - std::vector lr, - std::vector beta1, - std::vector beta2, - std::vector eps, - std::vector weight_decay, - std::vector bias_correction) - { - const size_t ng = groups_.size(); - TORCH_CHECK(lr.size() == ng && beta1.size() == ng && beta2.size() == ng && - eps.size() == ng && weight_decay.size() == ng && - bias_correction.size() == ng, - "ZenFlowAdam::submit_step: hyperparameter length must match group count"); - std::vector hps(ng); - for (size_t g = 0; g < ng; ++g) { - hps[g] = {lr[g], beta1[g], beta2[g], eps[g], weight_decay[g], (bool)bias_correction[g]}; - } - { - std::lock_guard lk(mtx_); - TORCH_CHECK(!has_work_, - "ZenFlowAdam::submit_step called before previous step was consumed"); - now_state_ = now_state; - step_ = step; - hps_ = std::move(hps); - has_work_ = true; - } - cv_.notify_all(); - } - - // Block until one submitted step has completed. Uses a completion counter so a - // skipped wait (the engine's first post-warmup round) does not desync: each - // wait consumes exactly one completion, like draining one message from the pipe. - void wait_step() - { - std::unique_lock lk(mtx_); - cv_.wait(lk, [this] { return completed_ > waited_; }); - ++waited_; - } - - void shutdown() - { - { - std::lock_guard lk(mtx_); - if (exit_) return; - exit_ = true; - } - cv_.notify_all(); - if (dispatcher_.joinable()) dispatcher_.join(); - pool_.reset(); - } - #if defined(__linux__) // Process-mode driver: run in the optimizer process, block on the shared-memory control // block, and run each requested step on the pinned pool. Returns on the exit command. @@ -670,30 +612,6 @@ class ZenFlowAdam { #endif private: - void dispatcher_main() - { - while (true) { - int now_state; - int64_t step; - std::vector hps; - { - std::unique_lock lk(mtx_); - cv_.wait(lk, [this] { return has_work_ || exit_; }); - if (exit_) return; - now_state = now_state_; - step = step_; - hps = hps_; - has_work_ = false; - } - run_step(now_state, step, hps); - { - std::lock_guard lk(mtx_); - ++completed_; - } - cv_.notify_all(); - } - } - void run_step(int now_state, int64_t step, const std::vector& hps) { auto opt = std::static_pointer_cast(s_optimizers[opt_id_]); @@ -736,17 +654,6 @@ class ZenFlowAdam { int opt_id_; std::vector groups_; std::unique_ptr pool_; - std::thread dispatcher_; - - std::mutex mtx_; - std::condition_variable cv_; - bool has_work_ = false; - bool exit_ = false; - int now_state_ = 0; - int64_t step_ = 0; - std::vector hps_; - uint64_t completed_ = 0; - uint64_t waited_ = 0; }; // Handle-indexed registry, mirroring s_optimizers, so the Python side refers to a @@ -775,29 +682,10 @@ void zenflow_adam_register_group(int handle, param, grad0, grad1, exp_avg0, exp_avg1, exp_avg_sq0, exp_avg_sq1, stale); } -void zenflow_adam_submit(int handle, - int now_state, - int64_t step, - std::vector lr, - std::vector beta1, - std::vector beta2, - std::vector eps, - std::vector weight_decay, - std::vector bias_correction) -{ - s_zenflow_adams.at(handle)->submit_step( - now_state, step, lr, beta1, beta2, eps, weight_decay, bias_correction); -} - -void zenflow_adam_wait(int handle) { s_zenflow_adams.at(handle)->wait_step(); } - void zenflow_adam_destroy(int handle) { - auto it = s_zenflow_adams.find(handle); - if (it != s_zenflow_adams.end()) { - it->second->shutdown(); - s_zenflow_adams.erase(it); - } + // Erasing the unique_ptr runs ~ZenFlowAdam, which tears down the pinned pool. + s_zenflow_adams.erase(handle); } #if defined(__linux__) @@ -819,15 +707,15 @@ void zenflow_adam_ctrl_init(uintptr_t control_ptr, int num_groups) void zenflow_adam_run_worker(int handle, uintptr_t control_ptr) { s_zenflow_adams.at(handle)->run_worker(reinterpret_cast(control_ptr)); } -void zenflow_adam_ctrl_submit(uintptr_t control_ptr, - int now_state, - int64_t step, - std::vector lr, - std::vector beta1, - std::vector beta2, - std::vector eps, - std::vector weight_decay, - std::vector bias_correction) +void zenflow_adam_submit(uintptr_t control_ptr, + int now_state, + int64_t step, + std::vector lr, + std::vector beta1, + std::vector beta2, + std::vector eps, + std::vector weight_decay, + std::vector bias_correction) { auto* ctrl = reinterpret_cast(control_ptr); const int ng = (int)lr.size(); @@ -845,7 +733,7 @@ void zenflow_adam_ctrl_submit(uintptr_t control_ptr, sem_post(&ctrl->cmd_ready); // release: hyperparameters above are visible to the worker } -void zenflow_adam_ctrl_wait(uintptr_t control_ptr) +void zenflow_adam_wait(uintptr_t control_ptr) { auto* ctrl = reinterpret_cast(control_ptr); while (sem_wait(&ctrl->done) != 0) {} // retry on EINTR diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index 06cc8d02ed10..b85799da6e9f 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -254,8 +254,9 @@ int ds_adam_rollback(int optimizer_id, int destroy_adam_optimizer(int optimizer_id); -// ZenFlowAdam: in-process, GIL-released overlapped CPU Adam for ZenFlow. The handle -// indexes a background dispatcher + pinned thread pool that drives the step. +// ZenFlowAdam: the native CPU Adam backing ZenFlow's overlapped optimizer step. The handle +// indexes a pinned thread pool; the optimizer runs in a dedicated process (run_worker) and +// is driven from the main process through the shared-memory control block below. int zenflow_adam_create(int optimizer_id, std::vector zf_affinity); void zenflow_adam_register_group(int handle, @@ -268,7 +269,17 @@ void zenflow_adam_register_group(int handle, torch::Tensor exp_avg_sq1, torch::Tensor stale); -void zenflow_adam_submit(int handle, +void zenflow_adam_destroy(int handle); + +#if defined(__linux__) +// The optimizer runs in a separate process and coordinates with the main process through two +// process-shared semaphores in a shared-memory control block. ctrl_size/ctrl_init/ctrl_exit +// set it up and tear it down; the worker process loops in run_worker; the main process drives +// each step with submit (non-blocking) / wait. +int64_t zenflow_adam_ctrl_size(); +void zenflow_adam_ctrl_init(uintptr_t control_ptr, int num_groups); +void zenflow_adam_run_worker(int handle, uintptr_t control_ptr); +void zenflow_adam_submit(uintptr_t control_ptr, int now_state, int64_t step, std::vector lr, @@ -277,26 +288,6 @@ void zenflow_adam_submit(int handle, std::vector eps, std::vector weight_decay, std::vector bias_correction); - -void zenflow_adam_wait(int handle); - -void zenflow_adam_destroy(int handle); - -#if defined(__linux__) -// Cross-process driver: the optimizer runs in a separate process and coordinates with the -// main process through two process-shared semaphores in a shared-memory control block. -int64_t zenflow_adam_ctrl_size(); -void zenflow_adam_ctrl_init(uintptr_t control_ptr, int num_groups); -void zenflow_adam_run_worker(int handle, uintptr_t control_ptr); -void zenflow_adam_ctrl_submit(uintptr_t control_ptr, - int now_state, - int64_t step, - std::vector lr, - std::vector beta1, - std::vector beta2, - std::vector eps, - std::vector weight_decay, - std::vector bias_correction); -void zenflow_adam_ctrl_wait(uintptr_t control_ptr); +void zenflow_adam_wait(uintptr_t control_ptr); void zenflow_adam_ctrl_exit(uintptr_t control_ptr); #endif diff --git a/deepspeed/ops/adam/zenflow_cpu_adam.py b/deepspeed/ops/adam/zenflow_cpu_adam.py index 5b8ab17622f7..bb62f3893ee5 100644 --- a/deepspeed/ops/adam/zenflow_cpu_adam.py +++ b/deepspeed/ops/adam/zenflow_cpu_adam.py @@ -12,12 +12,11 @@ class ZenFlowCPUAdam(DeepSpeedCPUAdam): def __init__(self, *args, overlap_step=False, **kwargs): super(ZenFlowCPUAdam, self).__init__(*args, **kwargs) self.overlap_step = overlap_step + # In the overlapped path the optimizer step is driven natively in the ZenFlow optimizer + # process (see ZenFlowAdam / zenflow_utils.start_optimizer_process), so this object's own + # step() is unused there. Only the sequential (non-overlap) offload path steps here. if not self.overlap_step: - print("ZenFlowCPUAdam initialized with normal step.") self.step = self._sequential_step - else: - print("ZenFlowCPUAdam initialized with overlap step.") - self.step = self._parallel_step @torch.no_grad() def _sequential_step(self, step_id, closure=None): @@ -76,76 +75,3 @@ def _sequential_step(self, step_id, closure=None): group['weight_decay'], group['bias_correction'], p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq']) return loss - - @torch.no_grad() - def _parallel_step(self, step_id, now_state, group_info, closure=None): - """Update the model parameters. - - .. note:: - This method will be called internally by ZeRO-Offload. DeepSpeed - users should still use ``engine.step()`` as shown in the - `Getting Started - `_ guide. - - Args: - closure (callable, optional): closure to compute the loss. - Defaults to ``None``. - - Returns: - loss: if ``closure`` is provided. Otherwise ``None``. - """ - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - # intended device for step - device = torch.device('cpu') - - # Collect the per-group tensors and drive the whole group through a single fused - # native call. This keeps the per-parameter loop in C++, avoiding one - # Python<->C++ crossing per parameter, and lets the stale snapshot be written - # natively (no Python-side clone()). - params = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - stale_params = [] - - for group_id, group in enumerate(self.param_groups): - for param_id, p in enumerate(group['params']): - assert p.data.is_shared(), "param.data must be in shared memory" - if not hasattr(p, 'overlap_grad'): - continue - - assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \ - "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config." - - state = self.state[p] - # State initialization - if len(state) == 0: - state['step'] = 0 - - #use full precision by default unless self.fp32_optimizer_states is off - state_dtype = torch.float if self.fp32_optimizer_states else p.dtype - exp_avg = torch.zeros_like(p.data, dtype=state_dtype, device=device) - exp_avg_sq = torch.zeros_like(p.data, dtype=state_dtype, device=device) - state['exp_avg'] = [exp_avg, exp_avg.clone()] - state['exp_avg_sq'] = [exp_avg_sq, exp_avg_sq.clone()] - - state['step'] = step_id - params.append(p.data) - grads.append(p.overlap_grad[now_state].data) - exp_avgs.append(state['exp_avg'][now_state]) - exp_avg_sqs.append(state['exp_avg_sq'][now_state]) - stale_params.append(p.stale_param.data) - - if not params: - return loss - - beta1, beta2 = group_info['betas'] - self.ds_opt_adam.adam_update_multi(self.opt_id, step_id, group_info['lr'], beta1, beta2, group_info['eps'], - group_info['weight_decay'], group_info['bias_correction'], params, grads, - exp_avgs, exp_avg_sqs, stale_params) - return loss diff --git a/deepspeed/runtime/zenflow/engine_stage3.py b/deepspeed/runtime/zenflow/engine_stage3.py index 2539c6e4dfb8..c2d2b51870e1 100644 --- a/deepspeed/runtime/zenflow/engine_stage3.py +++ b/deepspeed/runtime/zenflow/engine_stage3.py @@ -554,9 +554,8 @@ def zenflow_cpu_optimizer_overlap_step(optimizer_z3, now_state, scaled_global_gr weight_decay.append(pg["weight_decay"]) bias_correction.append(1 if pg["bias_correction"] else 0) - optimizer_z3.zf_op.zenflow_adam_ctrl_submit(optimizer_z3.zf_ctrl.data_ptr(), now_state, - optimizer_z3.micro_step + 1, lr, beta1, beta2, eps, weight_decay, - bias_correction) + optimizer_z3.zf_op.zenflow_adam_submit(optimizer_z3.zf_ctrl.data_ptr(), now_state, optimizer_z3.micro_step + 1, lr, + beta1, beta2, eps, weight_decay, bias_correction) def wait_last_update_and_copy(optimizer_z3, timer_names): @@ -568,7 +567,7 @@ def wait_last_update_and_copy(optimizer_z3, timer_names): optimizer_z3.first_update_round_after_warmup = False return - optimizer_z3.zf_op.zenflow_adam_ctrl_wait(optimizer_z3.zf_ctrl.data_ptr()) + optimizer_z3.zf_op.zenflow_adam_wait(optimizer_z3.zf_ctrl.data_ptr()) for sub_group_id, group in enumerate(optimizer_z3.fp16_groups): if optimizer_z3.fp16_partitioned_groups_flat[sub_group_id] is not None: diff --git a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py index a8190e8e0c51..79924ead162b 100644 --- a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py @@ -680,7 +680,7 @@ def wait_last_update_and_copy(self): return self.timers(OPTIMIZER_RECV_PARAMS_TIMER).start() - self.zf_op.zenflow_adam_ctrl_wait(self.zf_ctrl.data_ptr()) + self.zf_op.zenflow_adam_wait(self.zf_ctrl.data_ptr()) self.timers(OPTIMIZER_RECV_PARAMS_TIMER).stop() for i, group in enumerate(self.bit16_groups): @@ -750,8 +750,8 @@ def zenflow_cpu_optimizer_step(self, now_state, scaled_global_grad_norm): weight_decay.append(pg["weight_decay"]) bias_correction.append(1 if pg["bias_correction"] else 0) - self.zf_op.zenflow_adam_ctrl_submit(self.zf_ctrl.data_ptr(), now_state, self.micro_step + 1, lr, beta1, beta2, - eps, weight_decay, bias_correction) + self.zf_op.zenflow_adam_submit(self.zf_ctrl.data_ptr(), now_state, self.micro_step + 1, lr, beta1, beta2, eps, + weight_decay, bias_correction) def step(self, closure=None): """ diff --git a/deepspeed/runtime/zenflow/zenflow_utils.py b/deepspeed/runtime/zenflow/zenflow_utils.py index b80814b50ce0..f530b7e783d1 100644 --- a/deepspeed/runtime/zenflow/zenflow_utils.py +++ b/deepspeed/runtime/zenflow/zenflow_utils.py @@ -98,7 +98,7 @@ def _compute_zf_pt_affinity(zf_optimizer): return zf_affinity, pt_affinity -def zenflow_optimizer_process_native(groups, ctrl, ready, zf_affinity, adamw_mode): +def zenflow_optimizer_process(groups, ctrl, ready, zf_affinity, adamw_mode): """ZenFlow overlapped optimizer process (ZeRO stage 1/2/3). Builds the native ZenFlowAdam pinned pool and runs the worker loop driven by the shared-memory control block (no pickling pipe). The Adam state is allocated here, in this process pinned to the optimizer cores, so @@ -126,12 +126,11 @@ def zenflow_optimizer_process_native(groups, ctrl, ready, zf_affinity, adamw_mod op.destroy_adam(0) -def _start_native_optimizer(zf_optimizer): - """ZenFlow overlapped optimizer (ZeRO stage 1/2/3): run the native ZenFlowAdam in a - separate process, coordinated through a shared-memory semaphore control block instead of a - pickling pipe. Keeps the isolation of a separate process (NUMA-local state, no contention - with the training thread) while removing the per-step Python/IPC overhead of the old - subprocess.""" +def start_optimizer_process(zf_optimizer): + """Start ZenFlow's overlapped optimizer (ZeRO stage 1/2/3) in a dedicated process, + coordinated through a shared-memory semaphore control block. A separate process keeps the + Adam state NUMA-local to the optimizer cores and free of contention with the training + thread, while the native control block avoids per-step Python/IPC overhead.""" from multiprocessing import get_context from deepspeed.ops.op_builder import CPUAdamBuilder @@ -164,7 +163,7 @@ def _start_native_optimizer(zf_optimizer): ctx = get_context("spawn") ready = ctx.Event() - proc = ctx.Process(target=zenflow_optimizer_process_native, args=(groups, ctrl, ready, zf_affinity, True)) + proc = ctx.Process(target=zenflow_optimizer_process, args=(groups, ctrl, ready, zf_affinity, True)) proc.daemon = True proc.start() # Wait for the optimizer process to finish building its pool and registering tensors. @@ -180,7 +179,3 @@ def _start_native_optimizer(zf_optimizer): os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity)) zf_optimizer.process_optimizer_established = True - - -def start_optimizer_process(zf_optimizer): - _start_native_optimizer(zf_optimizer) diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index 09c785046563..0082d4097d29 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -314,114 +314,6 @@ def test_multi_without_stale(self): ds_opt_adam.destroy_adam(opt_id) -class TestZenFlowAdamNative(DistributedTest): - """ZenFlowAdam (in-process background thread + pinned pool, sliced serial kernel) - must produce the same update as the reference fused path, with the alternating - double-buffered grads/moments that ZenFlow's overlap uses.""" - world_size = 1 - reuse_dist_env = True - requires_cuda_env = False - if not get_accelerator().is_available(): - init_distributed = False - set_dist_env = False - - @pytest.mark.parametrize('dtype', [torch.float, torch.bfloat16], ids=["fp32", "bf16"]) - def test_matches_reference(self, dtype): - import os - ds = CPUAdamBuilder().load() - lr, beta1, beta2, eps, weight_decay = 1e-3, 0.9, 0.999, 1e-8, 0.0 - # Sizes that exercise the multi-thread slicing: smaller than the pool, not a - # multiple of it, and large. - sizes = [3, 1000, 100003] - - opt_zf, opt_ref = 5, 6 - ds.create_adam(opt_zf, lr, beta1, beta2, eps, weight_decay, True, False) - ds.create_adam(opt_ref, lr, beta1, beta2, eps, weight_decay, True, False) - - affinity = list(range(min(4, os.cpu_count() or 1))) - handle = ds.zenflow_adam_create(opt_zf, affinity) - - torch.manual_seed(0) - # ZenFlowAdam state (double-buffered) and the reference mirror of it. - params_zf = [torch.randn(n, dtype=dtype) for n in sizes] - params_ref = [p.clone() for p in params_zf] - grad = [[torch.zeros(n, dtype=dtype) for n in sizes] for _ in range(2)] - ea = [[torch.zeros(n) for n in sizes] for _ in range(2)] - eq = [[torch.zeros(n) for n in sizes] for _ in range(2)] - stale = [torch.zeros(n, dtype=dtype) for n in sizes] - ea_ref = [[t.clone() for t in ea[s]] for s in range(2)] - eq_ref = [[t.clone() for t in eq[s]] for s in range(2)] - stale_ref = [t.clone() for t in stale] - - for i in range(len(sizes)): - ds.zenflow_adam_register_group(handle, params_zf[i], grad[0][i], grad[1][i], ea[0][i], ea[1][i], eq[0][i], - eq[1][i], stale[i]) - - try: - for step in range(1, 6): - now = step & 1 - grads = [torch.randn(n, dtype=dtype) for n in sizes] - for i in range(len(sizes)): - grad[now][i].copy_(grads[i]) - - ds.zenflow_adam_submit(handle, now, step, [lr] * len(sizes), [beta1] * len(sizes), - [beta2] * len(sizes), [eps] * len(sizes), [weight_decay] * len(sizes), - [1] * len(sizes)) - ds.zenflow_adam_wait(handle) - - ds.adam_update_multi(opt_ref, step, lr, beta1, beta2, eps, weight_decay, True, params_ref, - [g.clone() for g in grads], ea_ref[now], eq_ref[now], stale_ref) - - for i in range(len(sizes)): - assert torch.equal(params_zf[i], params_ref[i]), f"param mismatch size {sizes[i]} step {step}" - assert torch.equal(ea[now][i], ea_ref[now][i]), f"exp_avg mismatch size {sizes[i]} step {step}" - assert torch.equal(eq[now][i], eq_ref[now][i]), f"exp_avg_sq mismatch size {sizes[i]}" - assert torch.equal(stale[i], stale_ref[i]), f"stale mismatch size {sizes[i]} step {step}" - finally: - ds.zenflow_adam_destroy(handle) - ds.destroy_adam(opt_zf) - ds.destroy_adam(opt_ref) - - def test_pipelined_submit_wait(self): - """Mirror the engine's pipeline: warmup does submit-then-wait, steady state does - wait-then-submit (each wait drains the *previous* submit), leaving one undrained - completion that destroy() cleans up. Must not hang or desync.""" - import os - ds = CPUAdamBuilder().load() - lr, beta1, beta2, eps, wd = 1e-3, 0.9, 0.999, 1e-8, 0.0 - n = 1024 - opt_id = 7 - ds.create_adam(opt_id, lr, beta1, beta2, eps, wd, True, False) - handle = ds.zenflow_adam_create(opt_id, list(range(min(4, os.cpu_count() or 1)))) - - param = torch.randn(n) - g = [torch.zeros(n), torch.zeros(n)] - ea = [torch.zeros(n), torch.zeros(n)] - eq = [torch.zeros(n), torch.zeros(n)] - stale = torch.zeros(n) - ds.zenflow_adam_register_group(handle, param, g[0], g[1], ea[0], ea[1], eq[0], eq[1], stale) - - def submit(now, step): - g[now].copy_(torch.randn(n)) - ds.zenflow_adam_submit(handle, now, step, [lr], [beta1], [beta2], [eps], [wd], [1]) - - try: - # warmup: submit then wait (no overlap) - submit(1, 1) - ds.zenflow_adam_wait(handle) - # steady: the first post-warmup wait is skipped, so this round is submit-only, - # and every later wait drains the submit from the previous round. - submit(0, 2) - for step in range(3, 8): - ds.zenflow_adam_wait(handle) - submit(step & 1, step) - ds.zenflow_adam_wait(handle) # drain the last submitted step - assert torch.all(torch.isfinite(param)) - finally: - ds.zenflow_adam_destroy(handle) - ds.destroy_adam(opt_id) - - def _zenflow_adam_proc_worker(param, g0, g1, ea0, ea1, eq0, eq1, stale, ctrl, ready, affinity): op = CPUAdamBuilder().load() op.create_adam(0, 1e-3, 0.9, 0.999, 1e-8, 0.0, True, False) @@ -476,8 +368,8 @@ def test_zenflow_adam_cross_process(): now = step & 1 grad = torch.randn(n) g[now].copy_(grad) - op.zenflow_adam_ctrl_submit(ctrl.data_ptr(), now, step, [lr], [beta1], [beta2], [eps], [wd], [1]) - op.zenflow_adam_ctrl_wait(ctrl.data_ptr()) + op.zenflow_adam_submit(ctrl.data_ptr(), now, step, [lr], [beta1], [beta2], [eps], [wd], [1]) + op.zenflow_adam_wait(ctrl.data_ptr()) op.adam_update_multi(1, step, lr, beta1, beta2, eps, wd, True, [p_ref], [grad.clone()], [ea_ref[now]], [eq_ref[now]], [st_ref]) assert torch.equal(param, p_ref), f"param mismatch step {step}" From 13ce892f2c9f39a57fe6c832145655eae88adae4 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 10 Jun 2026 21:22:51 +0000 Subject: [PATCH 11/13] Recognize ZenFlowCPUAdam as a supported ZeRO optimizer is_zero_supported_optimizer matches the optimizer type exactly, so ZenFlowCPUAdam (a DeepSpeedCPUAdam subclass used by ZenFlow's CPU offload) was treated as untested and required zero_allow_untested_optimizer: true in every ZenFlow config. Add it to ZERO_SUPPORTED_OPTIMIZERS so ZenFlow runs without that flag. Signed-off-by: Tingfeng Lan --- deepspeed/runtime/zero/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 139419563352..acbce2a8a41d 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -10,7 +10,7 @@ import torch from deepspeed import comm as dist from deepspeed.utils import logger -from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.ops.adam import DeepSpeedCPUAdam, ZenFlowCPUAdam from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad from deepspeed.ops.adam import FusedAdam from deepspeed.ops.lion import DeepSpeedCPULion, FusedLion @@ -43,8 +43,8 @@ class ZeRORuntimeException(Exception): ZERO_SUPPORTED_OPTIMIZERS = [ - torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad, DeepSpeedCPUAdagrad, - DeepSpeedCPULion, FusedLion + torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, ZenFlowCPUAdam, torch.optim.Adagrad, + DeepSpeedCPUAdagrad, DeepSpeedCPULion, FusedLion ] # Add MuonWithAuxAdam to supported list if muon is installed From 1c5822128a2831935eb4b5f9b18c04f230c695bd Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 10 Jun 2026 21:42:12 +0000 Subject: [PATCH 12/13] Apply clang-format formatting to ZenFlowAdam definitions CI's clang-format (18.1.3) expands the single-line constructor and zenflow_adam_run_worker bodies to multi-line; match it. Signed-off-by: Tingfeng Lan --- csrc/adam/cpu_adam_impl.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/csrc/adam/cpu_adam_impl.cpp b/csrc/adam/cpu_adam_impl.cpp index 8d2643739de8..4ab59a8dd438 100644 --- a/csrc/adam/cpu_adam_impl.cpp +++ b/csrc/adam/cpu_adam_impl.cpp @@ -560,7 +560,9 @@ struct ZenControl { class ZenFlowAdam { public: ZenFlowAdam(int optimizer_id, std::vector zf_affinity) : opt_id_(optimizer_id) - { pool_ = std::make_unique(zf_affinity); } + { + pool_ = std::make_unique(zf_affinity); + } ~ZenFlowAdam() = default; @@ -705,7 +707,9 @@ void zenflow_adam_ctrl_init(uintptr_t control_ptr, int num_groups) // Called in the optimizer process; blocks running steps until the exit command. void zenflow_adam_run_worker(int handle, uintptr_t control_ptr) -{ s_zenflow_adams.at(handle)->run_worker(reinterpret_cast(control_ptr)); } +{ + s_zenflow_adams.at(handle)->run_worker(reinterpret_cast(control_ptr)); +} void zenflow_adam_submit(uintptr_t control_ptr, int now_state, From 3a0d10ad14a8df462c733e7100086064e3a7a4f6 Mon Sep 17 00:00:00 2001 From: Tingfeng Lan Date: Wed, 10 Jun 2026 21:52:18 +0000 Subject: [PATCH 13/13] Fail loudly if the ZenFlow optimizer process dies mid-step If the optimizer process exited after signalling ready but before posting a completion (e.g. an OOM or TORCH_CHECK in run_step), the training side blocked forever on the done semaphore, hanging the whole distributed job -- unlike the old Pipe path, which surfaced a closed-pipe error. Make zenflow_adam_wait a bounded wait (sem_timedwait) returning whether a completion was consumed. The training side (ZeRO stage 1/2 and 3) now loops on it and, on each timeout, checks the optimizer process is still alive, raising a clear error instead of hanging if it died. Normal steps are unaffected (the wait returns as soon as the worker posts done). Signed-off-by: Tingfeng Lan --- csrc/adam/cpu_adam_impl.cpp | 22 +++++++++++++++++-- csrc/includes/cpu_adam.h | 2 +- deepspeed/runtime/zenflow/engine_stage3.py | 13 +++++++++-- .../runtime/zenflow/zenflow_stage_1_and_2.py | 18 +++++++++++++-- deepspeed/runtime/zenflow/zenflow_utils.py | 5 +++++ tests/unit/ops/adam/test_cpu_adam.py | 5 ++++- 6 files changed, 57 insertions(+), 8 deletions(-) diff --git a/csrc/adam/cpu_adam_impl.cpp b/csrc/adam/cpu_adam_impl.cpp index 4ab59a8dd438..02f2e7773111 100644 --- a/csrc/adam/cpu_adam_impl.cpp +++ b/csrc/adam/cpu_adam_impl.cpp @@ -23,6 +23,8 @@ #include #include #include +#include +#include #endif using namespace std::string_literals; @@ -737,10 +739,26 @@ void zenflow_adam_submit(uintptr_t control_ptr, sem_post(&ctrl->cmd_ready); // release: hyperparameters above are visible to the worker } -void zenflow_adam_wait(uintptr_t control_ptr) +// Wait up to timeout_s for the optimizer process to post one completion. Returns true if a +// completion was consumed, false on timeout -- so the training side can re-check that the +// optimizer process is still alive and fail loudly instead of blocking forever if the process +// died mid-step (e.g. an OOM or TORCH_CHECK in run_step after it signalled ready). +bool zenflow_adam_wait(uintptr_t control_ptr, double timeout_s) { auto* ctrl = reinterpret_cast(control_ptr); - while (sem_wait(&ctrl->done) != 0) {} // retry on EINTR + struct timespec deadline; + clock_gettime(CLOCK_REALTIME, &deadline); + deadline.tv_sec += (time_t)timeout_s; + deadline.tv_nsec += (long)((timeout_s - (double)(time_t)timeout_s) * 1e9); + if (deadline.tv_nsec >= 1000000000L) { + deadline.tv_sec += 1; + deadline.tv_nsec -= 1000000000L; + } + while (sem_timedwait(&ctrl->done, &deadline) != 0) { + if (errno == EINTR) continue; // retry on signal + return false; // timed out (or error): caller re-checks process liveness + } + return true; } void zenflow_adam_ctrl_exit(uintptr_t control_ptr) diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index b85799da6e9f..bee80d7a6f34 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -288,6 +288,6 @@ void zenflow_adam_submit(uintptr_t control_ptr, std::vector eps, std::vector weight_decay, std::vector bias_correction); -void zenflow_adam_wait(uintptr_t control_ptr); +bool zenflow_adam_wait(uintptr_t control_ptr, double timeout_s); void zenflow_adam_ctrl_exit(uintptr_t control_ptr); #endif diff --git a/deepspeed/runtime/zenflow/engine_stage3.py b/deepspeed/runtime/zenflow/engine_stage3.py index c2d2b51870e1..c6d749a8fbee 100644 --- a/deepspeed/runtime/zenflow/engine_stage3.py +++ b/deepspeed/runtime/zenflow/engine_stage3.py @@ -14,7 +14,7 @@ from typing import List from deepspeed.accelerator import get_accelerator from typing import TYPE_CHECKING -from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process +from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process, ZENFLOW_OPTIMIZER_WAIT_POLL_SECONDS if TYPE_CHECKING: from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 @@ -567,7 +567,16 @@ def wait_last_update_and_copy(optimizer_z3, timer_names): optimizer_z3.first_update_round_after_warmup = False return - optimizer_z3.zf_op.zenflow_adam_wait(optimizer_z3.zf_ctrl.data_ptr()) + # Wake periodically to check the optimizer process is alive: if it died mid-step, fail loudly + # here instead of blocking this rank (and the whole job) forever on a semaphore it will never + # post. + while not optimizer_z3.zf_op.zenflow_adam_wait(optimizer_z3.zf_ctrl.data_ptr(), + ZENFLOW_OPTIMIZER_WAIT_POLL_SECONDS): + proc = getattr(optimizer_z3, 'process', None) + if proc is not None and not proc.is_alive(): + raise RuntimeError("ZenFlow optimizer process exited during a step (likely an error or OOM in the " + "optimizer process -- check its traceback above) instead of completing the " + "update. Aborting to avoid hanging distributed training.") for sub_group_id, group in enumerate(optimizer_z3.fp16_groups): if optimizer_z3.fp16_partitioned_groups_flat[sub_group_id] is not None: diff --git a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py index 79924ead162b..48921a9f1880 100644 --- a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py @@ -7,7 +7,7 @@ from deepspeed import comm as dist from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer -from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process +from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process, ZENFLOW_OPTIMIZER_WAIT_POLL_SECONDS from deepspeed.runtime.utils import (see_memory_usage) from deepspeed.ops.adam import ZenFlowSelectiveAdamW @@ -670,6 +670,20 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param): dest_tensor.copy_(src_tensor, non_blocking=True) param.grad = None #offload only + def _wait_for_optimizer_process(self): + """Block until the optimizer process signals the submitted step is done. + + The wait wakes up periodically to check the optimizer process is still alive: if it died + mid-step (e.g. an OOM or assertion in the native worker after it signalled ready), fail + loudly here instead of blocking this rank -- and the whole distributed job -- forever on + a semaphore the dead process will never post.""" + while not self.zf_op.zenflow_adam_wait(self.zf_ctrl.data_ptr(), ZENFLOW_OPTIMIZER_WAIT_POLL_SECONDS): + proc = getattr(self, 'process', None) + if proc is not None and not proc.is_alive(): + raise RuntimeError("ZenFlow optimizer process exited during a step (likely an error or OOM in " + "the optimizer process -- check its traceback above) instead of completing " + "the update. Aborting to avoid hanging distributed training.") + def wait_last_update_and_copy(self): if not getattr(self, 'process_optimizer_established', False): @@ -680,7 +694,7 @@ def wait_last_update_and_copy(self): return self.timers(OPTIMIZER_RECV_PARAMS_TIMER).start() - self.zf_op.zenflow_adam_wait(self.zf_ctrl.data_ptr()) + self._wait_for_optimizer_process() self.timers(OPTIMIZER_RECV_PARAMS_TIMER).stop() for i, group in enumerate(self.bit16_groups): diff --git a/deepspeed/runtime/zenflow/zenflow_utils.py b/deepspeed/runtime/zenflow/zenflow_utils.py index f530b7e783d1..7bd3cba51a7c 100644 --- a/deepspeed/runtime/zenflow/zenflow_utils.py +++ b/deepspeed/runtime/zenflow/zenflow_utils.py @@ -10,6 +10,11 @@ from deepspeed import comm as dist from deepspeed.accelerator import get_accelerator +# How long the training side blocks on a single semaphore wait for the optimizer process before +# waking up to check that the process is still alive. A normal step completes far sooner; this +# only bounds how long we hang if the optimizer process dies mid-step. +ZENFLOW_OPTIMIZER_WAIT_POLL_SECONDS = 60 + def _flatten_dense_tensors(tensors): """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index 0082d4097d29..39414a3aad61 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -364,12 +364,15 @@ def test_zenflow_adam_cross_process(): proc.start() try: assert ready.wait(timeout=60), "optimizer process did not start" + # With no step submitted yet, a bounded wait must time out (return False) rather than + # block -- this is what lets the training side notice a dead optimizer process. + assert op.zenflow_adam_wait(ctrl.data_ptr(), 0.05) is False, "wait should time out when no step is pending" for step in range(1, 6): now = step & 1 grad = torch.randn(n) g[now].copy_(grad) op.zenflow_adam_submit(ctrl.data_ptr(), now, step, [lr], [beta1], [beta2], [eps], [wd], [1]) - op.zenflow_adam_wait(ctrl.data_ptr()) + assert op.zenflow_adam_wait(ctrl.data_ptr(), 60.0), f"wait timed out step {step}" op.adam_update_multi(1, step, lr, beta1, beta2, eps, wd, True, [p_ref], [grad.clone()], [ea_ref[now]], [eq_ref[now]], [st_ref]) assert torch.equal(param, p_ref), f"param mismatch step {step}"