diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index f4c242ff9229..4c14d0552e63 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -7,8 +7,64 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + using namespace pybind11::literals; m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)"); + // `parallel` defaults to true (OpenMP across elements, as before). ZenFlow's native + // pinned thread pool sets it false so each pool thread runs its slice serially. + m.def("adam_update_multi", + &ds_adam_step_multi, + "DeepSpeed CPU Adam fused multi-tensor update (C++)", + "optimizer_id"_a, + "step"_a, + "lr"_a, + "beta1"_a, + "beta2"_a, + "epsilon"_a, + "weight_decay"_a, + "bias_correction"_a, + "params"_a, + "grads"_a, + "exp_avgs"_a, + "exp_avg_sqs"_a, + "stale_params"_a, + "parallel"_a = true); m.def("adam_rollback", &ds_adam_rollback, "DeepSpeed CPU Adam rollback (C++)"); m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); + + // ZenFlowAdam: the native CPU Adam backing ZenFlow's overlapped optimizer step. create / + // register_group / destroy set up the handle-indexed pinned pool (used by the worker process). + m.def("zenflow_adam_create", &zenflow_adam_create, "ZenFlowAdam create (C++)"); + m.def("zenflow_adam_register_group", + &zenflow_adam_register_group, + "ZenFlowAdam register a parameter group (C++)"); + m.def("zenflow_adam_destroy", + &zenflow_adam_destroy, + "ZenFlowAdam destroy (C++)", + pybind11::call_guard()); + +#if defined(__linux__) + // The optimizer runs in a separate process, coordinated through a shared-memory semaphore + // control block. submit/wait/run_worker release the GIL so the optimizer process overlaps + // the Python training thread. + m.def( + "zenflow_adam_ctrl_size", &zenflow_adam_ctrl_size, "ZenFlowAdam control block size (C++)"); + m.def("zenflow_adam_ctrl_init", &zenflow_adam_ctrl_init, "ZenFlowAdam control init (C++)"); + m.def("zenflow_adam_run_worker", + &zenflow_adam_run_worker, + "ZenFlowAdam optimizer-process worker loop (C++)", + pybind11::call_guard()); + m.def("zenflow_adam_submit", + &zenflow_adam_submit, + "ZenFlowAdam submit an overlapped step (C++)", + pybind11::call_guard()); + m.def("zenflow_adam_wait", + &zenflow_adam_wait, + "ZenFlowAdam wait for a submitted step (C++)", + pybind11::call_guard()); + m.def("zenflow_adam_ctrl_exit", + &zenflow_adam_ctrl_exit, + "ZenFlowAdam cross-process exit (C++)", + pybind11::call_guard()); +#endif } diff --git a/csrc/adam/cpu_adam_impl.cpp b/csrc/adam/cpu_adam_impl.cpp index 1f2b8cf0df47..02f2e7773111 100644 --- a/csrc/adam/cpu_adam_impl.cpp +++ b/csrc/adam/cpu_adam_impl.cpp @@ -4,14 +4,28 @@ // DeepSpeed Team #include +#include #include +#include +#include +#include #include #include #include #include +#include +#include #include #include +#include #include "cpu_adam.h" +#if defined(__linux__) +#include +#include +#include +#include +#include +#endif using namespace std::string_literals; static std::unordered_map> s_optimizers; @@ -23,11 +37,12 @@ void Adam_Optimizer::Step_1(ds_params_precision_t* _params, ds_params_precision_t* grads, ds_state_precision_t* _exp_avg, ds_state_precision_t* _exp_avg_sq, - size_t _param_size) + size_t _param_size, + bool parallel) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size); + Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, parallel); #endif if (_param_size > rounded_size) { float betta1_minus1 = 1 - _betta1; @@ -40,7 +55,7 @@ void Adam_Optimizer::Step_1(ds_params_precision_t* _params, size_t copy_size = TILE; if ((t + TILE) > _param_size) copy_size = _param_size - t; size_t offset = copy_size + t; -#pragma omp parallel for +#pragma omp parallel for if (parallel) for (size_t k = t; k < offset; k++) { float grad = (float)grads[k]; float param = (float)_params[k]; @@ -72,18 +87,20 @@ void Adam_Optimizer::Step_4(ds_params_precision_t* _params, ds_params_precision_t* grads, ds_state_precision_t* _exp_avg, ds_state_precision_t* _exp_avg_sq, - size_t _param_size) + size_t _param_size, + bool parallel) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size); + Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, parallel); #endif if (_param_size > rounded_size) Step_1((_params + rounded_size), (grads + rounded_size), (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), - (_param_size - rounded_size)); + (_param_size - rounded_size), + parallel); } int create_adam_optimizer(int optimizer_id, @@ -131,18 +148,20 @@ void Adam_Optimizer::Step_8(ds_params_precision_t* _params, ds_params_precision_t* grads, ds_state_precision_t* _exp_avg, ds_state_precision_t* _exp_avg_sq, - size_t _param_size) + size_t _param_size, + bool parallel) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size); + Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, parallel); #endif if (_param_size > rounded_size) Step_4((_params + rounded_size), (grads + rounded_size), (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), - (_param_size - rounded_size)); + (_param_size - rounded_size), + parallel); } template @@ -151,17 +170,20 @@ void step_invoker(std::shared_ptr opt, void* grads, void* _exp_avg, void* _exp_avg_sq, - size_t _param_size) + size_t _param_size, + bool parallel) { opt->Step_8((ds_params_precision_t*)(_params), (ds_params_precision_t*)(grads), (ds_state_precision_t*)(_exp_avg), (ds_state_precision_t*)(_exp_avg_sq), - _param_size); + _param_size, + parallel); } -std::map, - std::function, void*, void*, void*, void*, size_t)>> +std::map< + std::tuple, + std::function, void*, void*, void*, void*, size_t, bool)>> invokers; // Fill map with template functions for each type @@ -188,7 +210,8 @@ void invoke(std::shared_ptr opt, torch::Tensor& grads, torch::Tensor& exp_avg, torch::Tensor& exp_avg_sq, - size_t param_size) + size_t param_size, + bool parallel = true) { c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype()); c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg.options().dtype()); @@ -205,7 +228,8 @@ void invoke(std::shared_ptr opt, grads.data_ptr(), exp_avg.data_ptr(), exp_avg_sq.data_ptr(), - param_size); + param_size, + parallel); } int ds_adam_step(int optimizer_id, @@ -236,6 +260,55 @@ int ds_adam_step(int optimizer_id, return 0; } +// Fused multi-tensor variant used by ZenFlow's overlapped optimizer. Driving the +// per-parameter loop in C++ avoids one Python<->C++ crossing (and one OpenMP region +// spawn) per parameter, which dominates on ZeRO Stage 1/2 where a group holds many +// small parameters. When stale_params is non-empty, the post-update parameter is +// snapshotted into it here, replacing the Python-side clone()+copy that ZenFlow used +// to keep a stale copy for the GPU sync. +int ds_adam_step_multi(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + std::vector& params, + std::vector& grads, + std::vector& exp_avgs, + std::vector& exp_avg_sqs, + std::vector& stale_params, + bool parallel) +{ + const size_t num_tensors = params.size(); + TORCH_CHECK(grads.size() == num_tensors && exp_avgs.size() == num_tensors && + exp_avg_sqs.size() == num_tensors, + "ds_adam_step_multi: params/grads/exp_avgs/exp_avg_sqs length mismatch"); + const bool has_stale = !stale_params.empty(); + TORCH_CHECK(!has_stale || stale_params.size() == num_tensors, + "ds_adam_step_multi: stale_params length mismatch"); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + // All tensors share one optimizer step, so advance bias-correction state once. + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, epsilon, weight_decay, bias_correction); + + for (size_t i = 0; i < num_tensors; ++i) { + auto params_c = params[i].contiguous(); + auto grads_c = grads[i].contiguous(); + auto exp_avg_c = exp_avgs[i].contiguous(); + auto exp_avg_sq_c = exp_avg_sqs[i].contiguous(); + + invoke(opt, params_c, grads_c, exp_avg_c, exp_avg_sq_c, params_c.numel(), parallel); + + if (has_stale) { stale_params[i].copy_(params_c); } + } + + return 0; +} + void adamw_rollback_inplace(float* params, const float* grads, float* momentum, @@ -338,3 +411,360 @@ int destroy_adam_optimizer(int optimizer_id) return 0; } + +// --------------------------------------------------------------------------- +// ZenFlowAdam: the native CPU Adam that backs ZenFlow's overlapped optimizer step. +// +// The optimizer runs in a dedicated process (see zenflow_utils.start_optimizer_process): +// run_worker() blocks on a shared-memory control block and, for each requested step, fans +// the heavy per-element math out to a pool of worker threads pinned to ZenFlow's dedicated +// cores, each running its element slice through the serial (parallel=false) kernel. The +// Adam state lives in that process, NUMA-local to the pool. +// --------------------------------------------------------------------------- + +// A persistent pool of threads pinned to a fixed core set. parallel_for() splits +// [0, total) into one contiguous chunk per thread and blocks until all finish. +class PinnedThreadPool { +public: + explicit PinnedThreadPool(const std::vector& affinity) + { + n_ = std::max(1, affinity.size()); + for (size_t i = 0; i < n_; ++i) { + int core = affinity.empty() ? -1 : affinity[i % affinity.size()]; + threads_.emplace_back([this, i, core] { worker(i, core); }); + } + } + + ~PinnedThreadPool() + { + { + std::lock_guard lk(m_); + stop_ = true; + ++gen_; + } + cv_start_.notify_all(); + for (auto& t : threads_) t.join(); + } + + size_t size() const { return n_; } + + // Split [0, total) into one chunk per thread. Chunk boundaries are rounded up to a + // multiple of `align` so each slice's AVX/scalar split lines up with the whole-tensor + // kernel's split -- otherwise an element could be computed by AVX (FMA) in one layout + // and the scalar tail (mul+add) in another, which differ in the last bit. + void parallel_for(size_t total, size_t align, std::function fn) + { + { + std::unique_lock lk(m_); + fn_ = std::move(fn); + total_ = total; + align_ = std::max(1, align); + done_count_ = 0; + ++gen_; + } + cv_start_.notify_all(); + std::unique_lock lk(m_); + cv_done_.wait(lk, [this] { return done_count_ == n_; }); + } + +private: + void worker(size_t tid, int core) + { +#if defined(__linux__) + if (core >= 0) { + cpu_set_t set; + CPU_ZERO(&set); + CPU_SET(core, &set); + pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &set); + } +#endif + long seen = 0; + while (true) { + std::function fn; + size_t total = 0; + size_t align = 1; + { + std::unique_lock lk(m_); + cv_start_.wait(lk, [this, seen] { return gen_ != seen; }); + seen = gen_; + if (stop_) return; + fn = fn_; + total = total_; + align = align_; + } + size_t chunk = (total + n_ - 1) / n_; + chunk = ((chunk + align - 1) / align) * align; // round up to SIMD-block alignment + size_t begin = std::min(tid * chunk, total); + size_t end = std::min(begin + chunk, total); + if (end > begin) fn(begin, end); + { + std::lock_guard lk(m_); + ++done_count_; + if (done_count_ == n_) cv_done_.notify_one(); + } + } + } + + size_t n_; + std::vector threads_; + std::mutex m_; + std::condition_variable cv_start_, cv_done_; + std::function fn_; + size_t total_ = 0; + size_t align_ = 1; + size_t done_count_ = 0; + long gen_ = 0; + bool stop_ = false; +}; + +// SIMD block the Adam AVX kernel rounds to (Step_8 => span 8). Slicing on multiples of +// this keeps each slice's AVX/scalar boundary identical to the whole-tensor kernel. +#if defined(__AVX512__) or defined(__AVX256__) +static constexpr size_t kZenAdamAlign = SIMD_WIDTH * 8; +#else +static constexpr size_t kZenAdamAlign = 1; +#endif + +struct ZenHP { + float lr, beta1, beta2, eps, weight_decay; + bool bias_correction; +}; + +struct ZenGroup { + torch::Tensor param; + torch::Tensor grad[2]; + torch::Tensor exp_avg[2]; + torch::Tensor exp_avg_sq[2]; + torch::Tensor stale; // may be undefined -> stale snapshot skipped +}; + +#if defined(__linux__) +// Control block placed in a shared-memory buffer (a shared torch tensor's storage) so the +// main process and the optimizer process coordinate through two process-shared semaphores +// instead of a pickling pipe. The main process writes a command + per-group hyperparameters +// and posts cmd_ready; the worker runs the step and posts done. `done` is a counting +// semaphore, so a skipped wait (the engine's post-warmup transition) is drained later. +static constexpr int ZEN_MAX_GROUPS = 1024; +enum { ZEN_CMD_STEP = 0, ZEN_CMD_EXIT = 1 }; + +struct ZenControl { + sem_t cmd_ready; + sem_t done; + int cmd; + int now_state; + int64_t step; + int num_groups; + float hp[ZEN_MAX_GROUPS * 5]; // lr, beta1, beta2, eps, weight_decay per group + uint8_t bias_correction[ZEN_MAX_GROUPS]; +}; +#endif + +class ZenFlowAdam { +public: + ZenFlowAdam(int optimizer_id, std::vector zf_affinity) : opt_id_(optimizer_id) + { + pool_ = std::make_unique(zf_affinity); + } + + ~ZenFlowAdam() = default; + + void register_group(torch::Tensor param, + torch::Tensor grad0, + torch::Tensor grad1, + torch::Tensor exp_avg0, + torch::Tensor exp_avg1, + torch::Tensor exp_avg_sq0, + torch::Tensor exp_avg_sq1, + torch::Tensor stale) + { + TORCH_CHECK(param.is_contiguous(), "ZenFlowAdam: param must be contiguous"); + ZenGroup g; + g.param = param; + g.grad[0] = grad0; + g.grad[1] = grad1; + g.exp_avg[0] = exp_avg0; + g.exp_avg[1] = exp_avg1; + g.exp_avg_sq[0] = exp_avg_sq0; + g.exp_avg_sq[1] = exp_avg_sq1; + g.stale = stale; + groups_.push_back(std::move(g)); + } + +#if defined(__linux__) + // Process-mode driver: run in the optimizer process, block on the shared-memory control + // block, and run each requested step on the pinned pool. Returns on the exit command. + void run_worker(void* control_ptr) + { + ZenControl* ctrl = reinterpret_cast(control_ptr); + while (true) { + while (sem_wait(&ctrl->cmd_ready) != 0) {} // retry on EINTR + if (ctrl->cmd == ZEN_CMD_EXIT) break; + const int ng = ctrl->num_groups; + std::vector hps(ng); + for (int g = 0; g < ng; ++g) { + hps[g] = {ctrl->hp[g * 5 + 0], + ctrl->hp[g * 5 + 1], + ctrl->hp[g * 5 + 2], + ctrl->hp[g * 5 + 3], + ctrl->hp[g * 5 + 4], + (bool)ctrl->bias_correction[g]}; + } + run_step(ctrl->now_state, ctrl->step, hps); + sem_post(&ctrl->done); + } + } +#endif + +private: + void run_step(int now_state, int64_t step, const std::vector& hps) + { + auto opt = std::static_pointer_cast(s_optimizers[opt_id_]); + for (size_t g = 0; g < groups_.size(); ++g) { + const ZenHP& hp = hps[g]; + // Groups share one Adam_Optimizer; advance its bias-correction state for + // this group before the pool reads it (pool is idle here -> no race). + opt->IncrementStep(step, hp.beta1, hp.beta2); + opt->update_state(hp.lr, hp.eps, hp.weight_decay, hp.bias_correction); + + ZenGroup& grp = groups_[g]; + torch::Tensor& P = grp.param; + torch::Tensor& G = grp.grad[now_state]; + torch::Tensor& M = grp.exp_avg[now_state]; + torch::Tensor& V = grp.exp_avg_sq[now_state]; + + auto it = invokers.find(std::tuple(P.scalar_type(), M.scalar_type())); + TORCH_CHECK(it != invokers.end(), + "ZenFlowAdam: unsupported param/state dtype combination"); + auto fn = it->second; + + char* pp = static_cast(P.data_ptr()); + char* gp = static_cast(G.data_ptr()); + char* mp = static_cast(M.data_ptr()); + char* vp = static_cast(V.data_ptr()); + char* sp = grp.stale.defined() ? static_cast(grp.stale.data_ptr()) : nullptr; + const size_t pe = P.element_size(); + const size_t se = M.element_size(); + const size_t numel = P.numel(); + + pool_->parallel_for(numel, kZenAdamAlign, [=](size_t b, size_t e) { + const size_t len = e - b; + // parallel=false: each pinned thread runs its slice serially. + fn(opt, pp + b * pe, gp + b * pe, mp + b * se, vp + b * se, len, false); + if (sp) std::memcpy(sp + b * pe, pp + b * pe, len * pe); + }); + } + } + + int opt_id_; + std::vector groups_; + std::unique_ptr pool_; +}; + +// Handle-indexed registry, mirroring s_optimizers, so the Python side refers to a +// ZenFlowAdam by an int handle and the class itself stays encapsulated here. +static std::unordered_map> s_zenflow_adams; +static int s_next_zenflow_id = 0; + +int zenflow_adam_create(int optimizer_id, std::vector zf_affinity) +{ + int handle = s_next_zenflow_id++; + s_zenflow_adams[handle] = std::make_unique(optimizer_id, std::move(zf_affinity)); + return handle; +} + +void zenflow_adam_register_group(int handle, + torch::Tensor param, + torch::Tensor grad0, + torch::Tensor grad1, + torch::Tensor exp_avg0, + torch::Tensor exp_avg1, + torch::Tensor exp_avg_sq0, + torch::Tensor exp_avg_sq1, + torch::Tensor stale) +{ + s_zenflow_adams.at(handle)->register_group( + param, grad0, grad1, exp_avg0, exp_avg1, exp_avg_sq0, exp_avg_sq1, stale); +} + +void zenflow_adam_destroy(int handle) +{ + // Erasing the unique_ptr runs ~ZenFlowAdam, which tears down the pinned pool. + s_zenflow_adams.erase(handle); +} + +#if defined(__linux__) +// Size (bytes) the shared control tensor must hold. +int64_t zenflow_adam_ctrl_size() { return (int64_t)sizeof(ZenControl); } + +// Called once by the main process before spawning the optimizer process. +void zenflow_adam_ctrl_init(uintptr_t control_ptr, int num_groups) +{ + TORCH_CHECK(num_groups <= ZEN_MAX_GROUPS, "ZenFlowAdam: too many param groups"); + auto* ctrl = reinterpret_cast(control_ptr); + ctrl->num_groups = num_groups; + ctrl->cmd = ZEN_CMD_STEP; + sem_init(&ctrl->cmd_ready, /*pshared=*/1, 0); + sem_init(&ctrl->done, /*pshared=*/1, 0); +} + +// Called in the optimizer process; blocks running steps until the exit command. +void zenflow_adam_run_worker(int handle, uintptr_t control_ptr) +{ + s_zenflow_adams.at(handle)->run_worker(reinterpret_cast(control_ptr)); +} + +void zenflow_adam_submit(uintptr_t control_ptr, + int now_state, + int64_t step, + std::vector lr, + std::vector beta1, + std::vector beta2, + std::vector eps, + std::vector weight_decay, + std::vector bias_correction) +{ + auto* ctrl = reinterpret_cast(control_ptr); + const int ng = (int)lr.size(); + for (int g = 0; g < ng; ++g) { + ctrl->hp[g * 5 + 0] = lr[g]; + ctrl->hp[g * 5 + 1] = beta1[g]; + ctrl->hp[g * 5 + 2] = beta2[g]; + ctrl->hp[g * 5 + 3] = eps[g]; + ctrl->hp[g * 5 + 4] = weight_decay[g]; + ctrl->bias_correction[g] = bias_correction[g]; + } + ctrl->now_state = now_state; + ctrl->step = step; + ctrl->cmd = ZEN_CMD_STEP; + sem_post(&ctrl->cmd_ready); // release: hyperparameters above are visible to the worker +} + +// Wait up to timeout_s for the optimizer process to post one completion. Returns true if a +// completion was consumed, false on timeout -- so the training side can re-check that the +// optimizer process is still alive and fail loudly instead of blocking forever if the process +// died mid-step (e.g. an OOM or TORCH_CHECK in run_step after it signalled ready). +bool zenflow_adam_wait(uintptr_t control_ptr, double timeout_s) +{ + auto* ctrl = reinterpret_cast(control_ptr); + struct timespec deadline; + clock_gettime(CLOCK_REALTIME, &deadline); + deadline.tv_sec += (time_t)timeout_s; + deadline.tv_nsec += (long)((timeout_s - (double)(time_t)timeout_s) * 1e9); + if (deadline.tv_nsec >= 1000000000L) { + deadline.tv_sec += 1; + deadline.tv_nsec -= 1000000000L; + } + while (sem_timedwait(&ctrl->done, &deadline) != 0) { + if (errno == EINTR) continue; // retry on signal + return false; // timed out (or error): caller re-checks process liveness + } + return true; +} + +void zenflow_adam_ctrl_exit(uintptr_t control_ptr) +{ + auto* ctrl = reinterpret_cast(control_ptr); + ctrl->cmd = ZEN_CMD_EXIT; + sem_post(&ctrl->cmd_ready); +} +#endif diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index f07a14e08438..bee80d7a6f34 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "simd.h" #define STEP(SPAN) \ @@ -19,7 +20,8 @@ ds_params_precision_t* grads, \ ds_state_precision_t* _exp_avg, \ ds_state_precision_t* _exp_avg_sq, \ - size_t _param_size); + size_t _param_size, \ + bool parallel = true); class Adam_Optimizer { public: @@ -49,7 +51,8 @@ class Adam_Optimizer { ds_params_precision_t* grads, ds_state_precision_t* _exp_avg, ds_state_precision_t* _exp_avg_sq, - size_t param_size); + size_t param_size, + bool parallel = true); #endif STEP(1) STEP(4) @@ -115,7 +118,8 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, ds_params_precision_t* grads, ds_state_precision_t* _exp_avg, ds_state_precision_t* _exp_avg_sq, - size_t _param_size) + size_t _param_size, + bool parallel) { #if !defined(__AVX512__) if (std::is_same_v || @@ -156,7 +160,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, size_t copy_size = TILE; if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; size_t offset = copy_size + t; -#pragma omp parallel for +#pragma omp parallel for if (parallel) for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { AVX_Data grad_4[span]; simd_load(grad_4, grads + i); @@ -220,6 +224,21 @@ int ds_adam_step(int optimizer_id, torch::Tensor& exp_avg, torch::Tensor& exp_avg_sq); +int ds_adam_step_multi(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + std::vector& params, + std::vector& grads, + std::vector& exp_avgs, + std::vector& exp_avg_sqs, + std::vector& stale_params, + bool parallel = true); + int ds_adam_rollback(int optimizer_id, size_t step, float lr, @@ -234,3 +253,41 @@ int ds_adam_rollback(int optimizer_id, torch::Tensor& exp_avg_sq); int destroy_adam_optimizer(int optimizer_id); + +// ZenFlowAdam: the native CPU Adam backing ZenFlow's overlapped optimizer step. The handle +// indexes a pinned thread pool; the optimizer runs in a dedicated process (run_worker) and +// is driven from the main process through the shared-memory control block below. +int zenflow_adam_create(int optimizer_id, std::vector zf_affinity); + +void zenflow_adam_register_group(int handle, + torch::Tensor param, + torch::Tensor grad0, + torch::Tensor grad1, + torch::Tensor exp_avg0, + torch::Tensor exp_avg1, + torch::Tensor exp_avg_sq0, + torch::Tensor exp_avg_sq1, + torch::Tensor stale); + +void zenflow_adam_destroy(int handle); + +#if defined(__linux__) +// The optimizer runs in a separate process and coordinates with the main process through two +// process-shared semaphores in a shared-memory control block. ctrl_size/ctrl_init/ctrl_exit +// set it up and tear it down; the worker process loops in run_worker; the main process drives +// each step with submit (non-blocking) / wait. +int64_t zenflow_adam_ctrl_size(); +void zenflow_adam_ctrl_init(uintptr_t control_ptr, int num_groups); +void zenflow_adam_run_worker(int handle, uintptr_t control_ptr); +void zenflow_adam_submit(uintptr_t control_ptr, + int now_state, + int64_t step, + std::vector lr, + std::vector beta1, + std::vector beta2, + std::vector eps, + std::vector weight_decay, + std::vector bias_correction); +bool zenflow_adam_wait(uintptr_t control_ptr, double timeout_s); +void zenflow_adam_ctrl_exit(uintptr_t control_ptr); +#endif diff --git a/deepspeed/ops/adam/zenflow_cpu_adam.py b/deepspeed/ops/adam/zenflow_cpu_adam.py index 0809d7a0f7e0..bb62f3893ee5 100644 --- a/deepspeed/ops/adam/zenflow_cpu_adam.py +++ b/deepspeed/ops/adam/zenflow_cpu_adam.py @@ -12,12 +12,11 @@ class ZenFlowCPUAdam(DeepSpeedCPUAdam): def __init__(self, *args, overlap_step=False, **kwargs): super(ZenFlowCPUAdam, self).__init__(*args, **kwargs) self.overlap_step = overlap_step + # In the overlapped path the optimizer step is driven natively in the ZenFlow optimizer + # process (see ZenFlowAdam / zenflow_utils.start_optimizer_process), so this object's own + # step() is unused there. Only the sequential (non-overlap) offload path steps here. if not self.overlap_step: - print("ZenFlowCPUAdam initialized with normal step.") self.step = self._sequential_step - else: - print("ZenFlowCPUAdam initialized with overlap step.") - self.step = self._parallel_step @torch.no_grad() def _sequential_step(self, step_id, closure=None): @@ -76,63 +75,3 @@ def _sequential_step(self, step_id, closure=None): group['weight_decay'], group['bias_correction'], p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq']) return loss - - @torch.no_grad() - def _parallel_step(self, step_id, now_state, group_info, closure=None): - """Update the model parameters. - - .. note:: - This method will be called internally by ZeRO-Offload. DeepSpeed - users should still use ``engine.step()`` as shown in the - `Getting Started - `_ guide. - - Args: - closure (callable, optional): closure to compute the loss. - Defaults to ``None``. - - Returns: - loss: if ``closure`` is provided. Otherwise ``None``. - """ - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - # intended device for step - device = torch.device('cpu') - - stale_param = None - - for group_id, group in enumerate(self.param_groups): - for param_id, p in enumerate(group['params']): - assert p.data.is_shared(), "param.data must be in shared memory" - if not hasattr(p, 'overlap_grad'): - continue - - assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \ - "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config." - - state = self.state[p] - # State initialization - if len(state) == 0: - #print(f'group {group_id} param {param_id} = {p.numel()}') - # print("creating", flush=True) - state['step'] = 0 - - #use full precision by default unless self.fp32_optimizer_states is off - state_dtype = torch.float if self.fp32_optimizer_states else p.dtype - exp_avg = torch.zeros_like(p.data, dtype=state_dtype, device=device) - exp_avg_sq = torch.zeros_like(p.data, dtype=state_dtype, device=device) - state['exp_avg'] = [exp_avg, exp_avg.clone()] - state['exp_avg_sq'] = [exp_avg_sq, exp_avg_sq.clone()] - - state['step'] = step_id - beta1, beta2 = group_info['betas'] - self.ds_opt_adam.adam_update(self.opt_id, state['step'], group_info['lr'], beta1, beta2, - group_info['eps'], group_info['weight_decay'], - group_info['bias_correction'], p.data, p.overlap_grad[now_state].data, - state['exp_avg'][now_state], state['exp_avg_sq'][now_state]) - p.stale_param.data.copy_(p.data.clone()) - return loss diff --git a/deepspeed/runtime/zenflow/engine_stage3.py b/deepspeed/runtime/zenflow/engine_stage3.py index 0d95be9f2f8a..c6d749a8fbee 100644 --- a/deepspeed/runtime/zenflow/engine_stage3.py +++ b/deepspeed/runtime/zenflow/engine_stage3.py @@ -14,7 +14,7 @@ from typing import List from deepspeed.accelerator import get_accelerator from typing import TYPE_CHECKING -from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process +from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process, ZENFLOW_OPTIMIZER_WAIT_POLL_SECONDS if TYPE_CHECKING: from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 @@ -542,40 +542,41 @@ def zenflow_cpu_optimizer_overlap_step(optimizer_z3, now_state, scaled_global_gr if not optimizer_z3.process_optimizer_established: optimizer_z3.start_optimizer_process() - group_infos = [] + lr, beta1, beta2, eps, weight_decay, bias_correction = [], [], [], [], [], [] for group_no, group in enumerate(optimizer_z3.fp16_groups): optimizer_z3.unscale_and_clip_grads(group_no, scaled_global_grad_norm, now_state) param_group_id = optimizer_z3.sub_group_to_group_id[group_no] + pg = optimizer_z3.optimizer.param_groups[param_group_id] + lr.append(pg["lr"]) + beta1.append(pg["betas"][0]) + beta2.append(pg["betas"][1]) + eps.append(pg["eps"]) + weight_decay.append(pg["weight_decay"]) + bias_correction.append(1 if pg["bias_correction"] else 0) - group_info = { - "lr": optimizer_z3.optimizer.param_groups[param_group_id]["lr"], - "betas": optimizer_z3.optimizer.param_groups[param_group_id]["betas"], - "eps": optimizer_z3.optimizer.param_groups[param_group_id]["eps"], - "weight_decay": optimizer_z3.optimizer.param_groups[param_group_id]["weight_decay"], - "bias_correction": optimizer_z3.optimizer.param_groups[param_group_id]["bias_correction"], - } - - group_infos.append(group_info) - - optimizer_z3.parent_conn.send({ - "type": "step", - "now_state": now_state, - "micro_step": optimizer_z3.micro_step, - "group_infos": group_infos - }) + optimizer_z3.zf_op.zenflow_adam_submit(optimizer_z3.zf_ctrl.data_ptr(), now_state, optimizer_z3.micro_step + 1, lr, + beta1, beta2, eps, weight_decay, bias_correction) def wait_last_update_and_copy(optimizer_z3, timer_names): - if not hasattr(optimizer_z3, 'parent_conn'): + if not getattr(optimizer_z3, 'process_optimizer_established', False): return if optimizer_z3.micro_step + 1 > optimizer_z3.full_warm_up_rounds and optimizer_z3.first_update_round_after_warmup: optimizer_z3.first_update_round_after_warmup = False return - msg = optimizer_z3.parent_conn.recv() - assert msg["type"] == "done", "Optimizer process did not finish stepping correctly." + # Wake periodically to check the optimizer process is alive: if it died mid-step, fail loudly + # here instead of blocking this rank (and the whole job) forever on a semaphore it will never + # post. + while not optimizer_z3.zf_op.zenflow_adam_wait(optimizer_z3.zf_ctrl.data_ptr(), + ZENFLOW_OPTIMIZER_WAIT_POLL_SECONDS): + proc = getattr(optimizer_z3, 'process', None) + if proc is not None and not proc.is_alive(): + raise RuntimeError("ZenFlow optimizer process exited during a step (likely an error or OOM in the " + "optimizer process -- check its traceback above) instead of completing the " + "update. Aborting to avoid hanging distributed training.") for sub_group_id, group in enumerate(optimizer_z3.fp16_groups): if optimizer_z3.fp16_partitioned_groups_flat[sub_group_id] is not None: diff --git a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py index 2f5e423f1320..48921a9f1880 100644 --- a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py @@ -7,7 +7,7 @@ from deepspeed import comm as dist from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer -from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process +from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process, ZENFLOW_OPTIMIZER_WAIT_POLL_SECONDS from deepspeed.runtime.utils import (see_memory_usage) from deepspeed.ops.adam import ZenFlowSelectiveAdamW @@ -34,6 +34,11 @@ ] INITIAL_MICRO_STEP_ID = -1 +# Number of elements copied per chunk when streaming the updated fp32 master partition +# back to the GPU bit16 partition. Bounds the transient fp32 staging tensor on the GPU +# (chunk * 4 bytes) instead of materializing the whole partition at once. +ZENFLOW_COPYBACK_CHUNK_NUMEL = 32 * 1024 * 1024 + SELECTIVE_OPTIMIZER_UPDATE_TIMER = 'selective_optimizer_update' SELECTIVE_OPTIMIZER_PROCESS_TIMER = 'selective_optimizer_process' SELECTIVE_OPTIMIZER_STEP_TIMER = 'selective_optimizer_step' @@ -538,15 +543,15 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt self._process_selected_fp32_groups_grad(tensor, curr_selected_reduce_size, communication_data_type) self.timers(SELECTIVE_OPTIMIZER_PROCESS_TIMER).stop() - def backward(self, loss, retain_graph=False): - """ - :attr:`backward` performs the following steps: + def backward_prologue(self): + """Prepare ZenFlow's per-microbatch state before the backward pass. - 1. fp32_loss = loss.float() - 2. scaled_loss = fp32_loss*loss_scale - 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves + Called by the engine at the start of each backward. Advances the + micro-step counter and, on an auto-update step, refreshes the + update-interval bookkeeping. At a selection boundary, resyncs the fp32 + master partition from the bit16 weights and clears the selective + optimizer's moments so the next top-k update starts clean. """ - self.backward_prologue() self.micro_step += 1 if self.auto_update: @@ -565,16 +570,6 @@ def backward(self, loss, retain_graph=False): self.selective_optimizer.clear_selected_mv() self.timers(SELECTIVE_OPTIMIZER_SYNC_TIMER).stop() - self.enter_backward() - if self.custom_loss_scaler: - scaled_loss = self.external_loss_scale * loss - scaled_loss.backward(retain_graph=retain_graph) - else: - self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) - - self.backward_epilogue() - self.exit_backward() - def log_selective_optimizer_timers(self): self.timers.log(SELECTIVE_OPTIMIZER_TIMERS) @@ -675,9 +670,23 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param): dest_tensor.copy_(src_tensor, non_blocking=True) param.grad = None #offload only + def _wait_for_optimizer_process(self): + """Block until the optimizer process signals the submitted step is done. + + The wait wakes up periodically to check the optimizer process is still alive: if it died + mid-step (e.g. an OOM or assertion in the native worker after it signalled ready), fail + loudly here instead of blocking this rank -- and the whole distributed job -- forever on + a semaphore the dead process will never post.""" + while not self.zf_op.zenflow_adam_wait(self.zf_ctrl.data_ptr(), ZENFLOW_OPTIMIZER_WAIT_POLL_SECONDS): + proc = getattr(self, 'process', None) + if proc is not None and not proc.is_alive(): + raise RuntimeError("ZenFlow optimizer process exited during a step (likely an error or OOM in " + "the optimizer process -- check its traceback above) instead of completing " + "the update. Aborting to avoid hanging distributed training.") + def wait_last_update_and_copy(self): - if not hasattr(self, 'parent_conn'): + if not getattr(self, 'process_optimizer_established', False): return if self.micro_step + 1 > self.full_warm_up_rounds and self.first_update_round_after_warmup: @@ -685,8 +694,7 @@ def wait_last_update_and_copy(self): return self.timers(OPTIMIZER_RECV_PARAMS_TIMER).start() - msg = self.parent_conn.recv() - assert msg["type"] == "done", "Optimizer process did not finish stepping correctly." + self._wait_for_optimizer_process() self.timers(OPTIMIZER_RECV_PARAMS_TIMER).stop() for i, group in enumerate(self.bit16_groups): @@ -694,7 +702,7 @@ def wait_last_update_and_copy(self): bit16_partitions = self.parallel_partitioned_bit16_groups[i] fp32_partition = self.optimizer.param_groups[i]['params'][0].stale_param.data self.timers(OPTIMIZER_TRANSMIT_TIMER).start() - bit16_partitions[partition_id].data.copy_(fp32_partition.to(get_accelerator().current_device_name()).data) + self._copyback_fp32_partition_to_bit16(fp32_partition, bit16_partitions[partition_id].data) self.timers(OPTIMIZER_TRANSMIT_TIMER).stop() see_memory_usage('After optimizer before all-gather') @@ -720,32 +728,44 @@ def wait_last_update_and_copy(self): self.timers.log(OPTIMIZER_TIMERS) see_memory_usage('After zero_optimizer step') + def _copyback_fp32_partition_to_bit16(self, fp32_partition, bit16_partition): + """Stream the updated fp32 master partition back to its GPU bit16 partition in chunks. + + The straightforward ``bit16.copy_(fp32.to(device))`` first materializes the whole + fp32 partition on the GPU, a transient spike of ~2x the bit16 partition (~3GB for a + 0.75B-param partition) stacked on top of the model -- exactly the memory the offload + is meant to save. Copying chunk by chunk keeps only one chunk's fp32 staging tensor + resident, so the peak drops to the chunk size; the bit16 result is unchanged. + """ + device = get_accelerator().current_device_name() + fp32_flat = fp32_partition.view(-1) + bit16_flat = bit16_partition.view(-1) + numel = fp32_flat.numel() + for offset in range(0, numel, ZENFLOW_COPYBACK_CHUNK_NUMEL): + end = min(offset + ZENFLOW_COPYBACK_CHUNK_NUMEL, numel) + gpu_chunk = fp32_flat[offset:end].to(device, non_blocking=True) + bit16_flat[offset:end].copy_(gpu_chunk) + def zenflow_cpu_optimizer_step(self, now_state, scaled_global_grad_norm): if not self.process_optimizer_established: self.start_optimizer_process() - group_infos = [] + lr, beta1, beta2, eps, weight_decay, bias_correction = [], [], [], [], [], [] for group_no, group in enumerate(self.bit16_groups): single_grad_partition = self.single_partition_of_fp32_groups[group_no].overlap_grad[now_state] self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) - group_info = { - "lr": self.optimizer.param_groups[group_no]["lr"], - "betas": self.optimizer.param_groups[group_no]["betas"], - "eps": self.optimizer.param_groups[group_no]["eps"], - "weight_decay": self.optimizer.param_groups[group_no]["weight_decay"], - "bias_correction": self.optimizer.param_groups[group_no]["bias_correction"], - } - - group_infos.append(group_info) - - self.parent_conn.send({ - "type": "step", - "now_state": now_state, - "micro_step": self.micro_step, - "group_infos": group_infos - }) + pg = self.optimizer.param_groups[group_no] + lr.append(pg["lr"]) + beta1.append(pg["betas"][0]) + beta2.append(pg["betas"][1]) + eps.append(pg["eps"]) + weight_decay.append(pg["weight_decay"]) + bias_correction.append(1 if pg["bias_correction"] else 0) + + self.zf_op.zenflow_adam_submit(self.zf_ctrl.data_ptr(), now_state, self.micro_step + 1, lr, beta1, beta2, eps, + weight_decay, bias_correction) def step(self, closure=None): """ diff --git a/deepspeed/runtime/zenflow/zenflow_utils.py b/deepspeed/runtime/zenflow/zenflow_utils.py index f238b3626506..7bd3cba51a7c 100644 --- a/deepspeed/runtime/zenflow/zenflow_utils.py +++ b/deepspeed/runtime/zenflow/zenflow_utils.py @@ -10,6 +10,11 @@ from deepspeed import comm as dist from deepspeed.accelerator import get_accelerator +# How long the training side blocks on a single semaphore wait for the optimizer process before +# waking up to check that the process is still alive. A normal step completes far sooner; this +# only bounds how long we hang if the optimizer process dies mid-step. +ZENFLOW_OPTIMIZER_WAIT_POLL_SECONDS = 60 + def _flatten_dense_tensors(tensors): """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of @@ -57,48 +62,6 @@ def disable_accelerator(): accelerator._initialized = True -def zenflow_optimizer_process(pipe, param_groups, shared_overlap_grad_map, shared_stale_param_map, zf_affinity): - disable_accelerator() - - current_process = psutil.Process() - current_process.cpu_affinity(zf_affinity) - os.environ['OMP_NUM_THREADS'] = str(len(zf_affinity)) - - from deepspeed.ops.adam import ZenFlowCPUAdam - optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True) - - pipe.send({"type": "ready"}) - - # TODO: replace this with rpc - - while True: - cmd = pipe.recv() - if cmd["type"] == "step": - now_state = cmd["now_state"] - micro_step = cmd["micro_step"] - group_infos = cmd["group_infos"] - - for group_no, group_info in enumerate(group_infos): - original_param_groups = optimizer.param_groups - optimizer.param_groups = [original_param_groups[group_no]] - group = optimizer.param_groups[0] - - for param_idx, param in enumerate(group["params"]): - key = (group_no, param_idx) - if key in shared_overlap_grad_map: - param.overlap_grad = shared_overlap_grad_map[key] - if key in shared_stale_param_map: - param.stale_param = shared_stale_param_map[key] - - optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info) - - optimizer.param_groups = original_param_groups - - pipe.send({"type": "done"}) - elif cmd["type"] == "exit": - break - - def all_tensors_equal(tensor_list): first_tensor = tensor_list[0] for tensor in tensor_list[1:]: @@ -107,47 +70,14 @@ def all_tensors_equal(tensor_list): return True -def start_optimizer_process(zf_optimizer): - from multiprocessing import Pipe, get_context, Manager - - ctx = get_context("spawn") - zf_optimizer.parent_conn, zf_optimizer.child_conn = Pipe() - - manager = Manager() - zf_optimizer.shared_overlap_grad_map = manager.dict() - zf_optimizer.shared_stale_param_map = manager.dict() - - if zf_optimizer.zf_stage3: - params_iter = [((group_no, 0), param) - for group_no, param in enumerate(zf_optimizer.fp32_partitioned_groups_flat)] - else: - params_iter = [((group_no, param_idx), param) - for group_no, group in enumerate(zf_optimizer.optimizer.param_groups) - for param_idx, param in enumerate(group["params"])] - - for key, param in params_iter: - param.data.share_memory_() - - if not hasattr(param, "stale_param"): - param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device) - param.stale_param.data.share_memory_() - zf_optimizer.shared_stale_param_map[key] = param.stale_param - - if getattr(param, "overlap_grad", None) is not None: - param.overlap_grad[0].data.share_memory_() - param.overlap_grad[1].data.share_memory_() - zf_optimizer.shared_overlap_grad_map[key] = param.overlap_grad - - param_groups_data = ([{ - "params": [param] - } for param in zf_optimizer.fp32_partitioned_groups_flat] - if zf_optimizer.zf_stage3 else zf_optimizer.optimizer.param_groups) - +def _compute_zf_pt_affinity(zf_optimizer): + """Split this rank's cores into a ZenFlow-optimizer set and a training (PyTorch) set. + When every rank reports the same affinity the launcher did not bind workers, so do a + soft per-rank bind first, then carve off pt_reserved_cores_perc for training.""" curr_rank = dist.get_rank() total_rank = dist.get_world_size() - current_process = psutil.Process() - current_affinity = current_process.cpu_affinity() + current_affinity = psutil.Process().cpu_affinity() all_affinities = [ torch.zeros(len(current_affinity), dtype=type(current_affinity[0]), @@ -157,15 +87,12 @@ def start_optimizer_process(zf_optimizer): all_affinities, torch.tensor(current_affinity, dtype=type(current_affinity[0]), device=get_accelerator().current_device_name())) - # When affinity across all ranks are the same, the workers are not binded. Do a soft bind here if all_tensors_equal(all_affinities): num_phy_cores = psutil.cpu_count(logical=False) available_phy_cores = [i for i in current_affinity if i < num_phy_cores] - num_available_phy_cores = len(available_phy_cores) - my_rank = curr_rank - my_size = total_rank - cores_per_rank = num_available_phy_cores // my_size - current_affinity = available_phy_cores[my_rank * cores_per_rank:(my_rank + 1) * cores_per_rank] + cores_per_rank = len(available_phy_cores) // total_rank + current_affinity = available_phy_cores[curr_rank * cores_per_rank:(curr_rank + 1) * cores_per_rank] + pt_num_cores = math.ceil(zf_optimizer.pt_reserved_cores_perc * len(current_affinity)) if pt_num_cores > 0 and pt_num_cores < len(current_affinity): zf_affinity = current_affinity[pt_num_cores:] @@ -173,19 +100,87 @@ def start_optimizer_process(zf_optimizer): else: zf_affinity = current_affinity pt_affinity = current_affinity + return zf_affinity, pt_affinity - zf_optimizer.process = ctx.Process( - target=zenflow_optimizer_process, - args=(zf_optimizer.child_conn, param_groups_data, zf_optimizer.shared_overlap_grad_map, - zf_optimizer.shared_stale_param_map, zf_affinity), - ) - zf_optimizer.process.daemon = True - zf_optimizer.process.start() - current_process.cpu_affinity(pt_affinity) - os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity)) +def zenflow_optimizer_process(groups, ctrl, ready, zf_affinity, adamw_mode): + """ZenFlow overlapped optimizer process (ZeRO stage 1/2/3). Builds the native ZenFlowAdam + pinned pool and runs the worker loop driven by the shared-memory control block (no pickling + pipe). The Adam state is allocated here, in this process pinned to the optimizer cores, so + it is NUMA-local to the pool -- which is what makes a separate process worthwhile over an + in-process thread for large, memory-bandwidth-bound updates.""" + disable_accelerator() + current_process = psutil.Process() + current_process.cpu_affinity(zf_affinity) + os.environ['OMP_NUM_THREADS'] = str(len(zf_affinity)) + + from deepspeed.ops.op_builder import CPUAdamBuilder + op = CPUAdamBuilder().load() + op.create_adam(0, 1e-3, 0.9, 0.999, 1e-8, 0.0, adamw_mode, False) + handle = op.zenflow_adam_create(0, list(zf_affinity)) + for param, overlap_grad0, overlap_grad1, stale in groups: + exp_avg0 = torch.zeros_like(param) + exp_avg1 = torch.zeros_like(param) + exp_avg_sq0 = torch.zeros_like(param) + exp_avg_sq1 = torch.zeros_like(param) + op.zenflow_adam_register_group(handle, param, overlap_grad0, overlap_grad1, exp_avg0, exp_avg1, exp_avg_sq0, + exp_avg_sq1, stale) + ready.set() + op.zenflow_adam_run_worker(handle, ctrl.data_ptr()) + op.zenflow_adam_destroy(handle) + op.destroy_adam(0) - msg = zf_optimizer.parent_conn.recv() - assert msg["type"] == "ready", "Optimizer process did not initialize correctly." + +def start_optimizer_process(zf_optimizer): + """Start ZenFlow's overlapped optimizer (ZeRO stage 1/2/3) in a dedicated process, + coordinated through a shared-memory semaphore control block. A separate process keeps the + Adam state NUMA-local to the optimizer cores and free of contention with the training + thread, while the native control block avoids per-step Python/IPC overhead.""" + from multiprocessing import get_context + from deepspeed.ops.op_builder import CPUAdamBuilder + + op = CPUAdamBuilder().load() + zf_optimizer.zf_op = op + + # Stage 3 steps each flattened sub-group partition; stage 1/2 steps one flat partition per + # param group. Both carry overlap_grad double buffers and a stale snapshot. + if zf_optimizer.zf_stage3: + params = list(zf_optimizer.fp32_partitioned_groups_flat) + else: + params = [group["params"][0] for group in zf_optimizer.optimizer.param_groups] + + # Share the tensors the optimizer process reads/writes; the Adam state stays process-local. + groups = [] + for param in params: + param.data.share_memory_() + if not hasattr(param, "stale_param"): + param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device) + param.stale_param.data.share_memory_() + param.overlap_grad[0].data.share_memory_() + param.overlap_grad[1].data.share_memory_() + groups.append((param.data, param.overlap_grad[0].data, param.overlap_grad[1].data, param.stale_param.data)) + + ctrl = torch.zeros(op.zenflow_adam_ctrl_size(), dtype=torch.uint8).share_memory_() + op.zenflow_adam_ctrl_init(ctrl.data_ptr(), len(groups)) + zf_optimizer.zf_ctrl = ctrl + + zf_affinity, pt_affinity = _compute_zf_pt_affinity(zf_optimizer) + + ctx = get_context("spawn") + ready = ctx.Event() + proc = ctx.Process(target=zenflow_optimizer_process, args=(groups, ctrl, ready, zf_affinity, True)) + proc.daemon = True + proc.start() + # Wait for the optimizer process to finish building its pool and registering tensors. + # If it crashed during init (e.g. it never signals), fail loudly instead of blocking the + # training process forever on the first step's wait. + if not ready.wait(timeout=600): + proc.terminate() + raise RuntimeError("ZenFlow optimizer process failed to become ready (it likely crashed " + "during initialization; check the optimizer process traceback above)") + zf_optimizer.process = proc + + psutil.Process().cpu_affinity(pt_affinity) + os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity)) zf_optimizer.process_optimizer_established = True diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 139419563352..acbce2a8a41d 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -10,7 +10,7 @@ import torch from deepspeed import comm as dist from deepspeed.utils import logger -from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.ops.adam import DeepSpeedCPUAdam, ZenFlowCPUAdam from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad from deepspeed.ops.adam import FusedAdam from deepspeed.ops.lion import DeepSpeedCPULion, FusedLion @@ -43,8 +43,8 @@ class ZeRORuntimeException(Exception): ZERO_SUPPORTED_OPTIMIZERS = [ - torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad, DeepSpeedCPUAdagrad, - DeepSpeedCPULion, FusedLion + torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, ZenFlowCPUAdam, torch.optim.Adagrad, + DeepSpeedCPUAdagrad, DeepSpeedCPULion, FusedLion ] # Add MuonWithAuxAdam to supported list if muon is installed diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index 003a6f8f6a46..39414a3aad61 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -3,6 +3,8 @@ # DeepSpeed Team +import os +import sys import torch import numpy as np import pytest @@ -175,6 +177,217 @@ def test_bf16_optimizer_states_match_fp32(self, model_size): check_equal(param_fp32_states.float().norm(), param_bf16_states.float().norm(), atol=tolerance) +class TestCPUAdamFusedMultiTensor(DistributedTest): + """adam_update_multi (fused multi-tensor, used by ZenFlow overlap) must match a + per-parameter sequence of adam_update bit-for-bit, and write the post-update + parameter snapshot into the stale buffer.""" + world_size = 1 + reuse_dist_env = True + requires_cuda_env = False + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + @pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"]) + def test_multi_matches_single(self, dtype): + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-adam with half precision not supported on AMD CPUs") + + ds_opt_adam = CPUAdamBuilder().load() + + lr, beta1, beta2, eps, weight_decay = 1e-3, 0.9, 0.999, 1e-8, 0.0 + adamw_mode, bias_correction = True, True + # Mixed sizes (including ones that don't divide the SIMD width) exercise both the + # vectorized and scalar tails inside the fused C++ loop. + sizes = [64, 22, 1024, 1048576] + + opt_single, opt_multi = 0, 1 + ds_opt_adam.create_adam(opt_single, lr, beta1, beta2, eps, weight_decay, adamw_mode, False) + ds_opt_adam.create_adam(opt_multi, lr, beta1, beta2, eps, weight_decay, adamw_mode, False) + + torch.manual_seed(0) + params_single = [torch.randn(n, dtype=dtype) for n in sizes] + params_multi = [p.clone() for p in params_single] + exp_avg_single = [torch.zeros(n, dtype=torch.float) for n in sizes] + exp_avg_sq_single = [torch.zeros(n, dtype=torch.float) for n in sizes] + exp_avg_multi = [torch.zeros(n, dtype=torch.float) for n in sizes] + exp_avg_sq_multi = [torch.zeros(n, dtype=torch.float) for n in sizes] + stale_multi = [torch.zeros(n, dtype=dtype) for n in sizes] + + try: + for step in range(1, 6): + grads = [torch.randn(n, dtype=dtype) for n in sizes] + + for i in range(len(sizes)): + ds_opt_adam.adam_update(opt_single, step, lr, beta1, beta2, eps, weight_decay, bias_correction, + params_single[i], grads[i].clone(), exp_avg_single[i], + exp_avg_sq_single[i]) + + ds_opt_adam.adam_update_multi(opt_multi, step, lr, beta1, beta2, eps, weight_decay, bias_correction, + params_multi, [g.clone() for g in grads], exp_avg_multi, + exp_avg_sq_multi, stale_multi) + + for i in range(len(sizes)): + assert torch.equal(params_single[i], params_multi[i]), f"param mismatch at size {sizes[i]}" + assert torch.equal(exp_avg_single[i], exp_avg_multi[i]), f"exp_avg mismatch at size {sizes[i]}" + assert torch.equal(exp_avg_sq_single[i], + exp_avg_sq_multi[i]), f"exp_avg_sq mismatch at size {sizes[i]}" + # stale must hold the post-update parameter snapshot + assert torch.equal(stale_multi[i], params_multi[i]), f"stale mismatch at size {sizes[i]}" + finally: + ds_opt_adam.destroy_adam(opt_single) + ds_opt_adam.destroy_adam(opt_multi) + + @pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"]) + def test_serial_matches_parallel(self, dtype): + """The serial kernel path (parallel=False, used by ZenFlow's pinned thread pool) + must match the OpenMP path (parallel=True) bit-for-bit.""" + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-adam with half precision not supported on AMD CPUs") + + ds_opt_adam = CPUAdamBuilder().load() + lr, beta1, beta2, eps, weight_decay = 1e-3, 0.9, 0.999, 1e-8, 0.0 + sizes = [64, 22, 1024, 1048576] + + opt_par, opt_ser = 3, 4 + ds_opt_adam.create_adam(opt_par, lr, beta1, beta2, eps, weight_decay, True, False) + ds_opt_adam.create_adam(opt_ser, lr, beta1, beta2, eps, weight_decay, True, False) + + torch.manual_seed(0) + params_par = [torch.randn(n, dtype=dtype) for n in sizes] + params_ser = [p.clone() for p in params_par] + ea_par = [torch.zeros(n) for n in sizes] + eq_par = [torch.zeros(n) for n in sizes] + ea_ser = [torch.zeros(n) for n in sizes] + eq_ser = [torch.zeros(n) for n in sizes] + + try: + for step in range(1, 4): + grads = [torch.randn(n, dtype=dtype) for n in sizes] + ds_opt_adam.adam_update_multi(opt_par, + step, + lr, + beta1, + beta2, + eps, + weight_decay, + True, + params_par, [g.clone() for g in grads], + ea_par, + eq_par, [], + parallel=True) + ds_opt_adam.adam_update_multi(opt_ser, + step, + lr, + beta1, + beta2, + eps, + weight_decay, + True, + params_ser, [g.clone() for g in grads], + ea_ser, + eq_ser, [], + parallel=False) + for i in range(len(sizes)): + assert torch.equal(params_par[i], params_ser[i]), f"param mismatch at size {sizes[i]}" + assert torch.equal(ea_par[i], ea_ser[i]), f"exp_avg mismatch at size {sizes[i]}" + assert torch.equal(eq_par[i], eq_ser[i]), f"exp_avg_sq mismatch at size {sizes[i]}" + finally: + ds_opt_adam.destroy_adam(opt_par) + ds_opt_adam.destroy_adam(opt_ser) + + def test_multi_without_stale(self): + """An empty stale list is allowed and simply skips the snapshot.""" + ds_opt_adam = CPUAdamBuilder().load() + opt_id = 2 + ds_opt_adam.create_adam(opt_id, 1e-3, 0.9, 0.999, 1e-8, 0.0, True, False) + try: + params = [torch.randn(64, dtype=torch.float)] + grads = [torch.randn(64, dtype=torch.float)] + exp_avg = [torch.zeros(64, dtype=torch.float)] + exp_avg_sq = [torch.zeros(64, dtype=torch.float)] + before = params[0].clone() + ds_opt_adam.adam_update_multi(opt_id, 1, 1e-3, 0.9, 0.999, 1e-8, 0.0, True, params, grads, exp_avg, + exp_avg_sq, []) + assert not torch.equal(params[0], before), "params should be updated even without stale buffers" + finally: + ds_opt_adam.destroy_adam(opt_id) + + +def _zenflow_adam_proc_worker(param, g0, g1, ea0, ea1, eq0, eq1, stale, ctrl, ready, affinity): + op = CPUAdamBuilder().load() + op.create_adam(0, 1e-3, 0.9, 0.999, 1e-8, 0.0, True, False) + handle = op.zenflow_adam_create(0, affinity) + op.zenflow_adam_register_group(handle, param, g0, g1, ea0, ea1, eq0, eq1, stale) + ready.set() + op.zenflow_adam_run_worker(handle, ctrl.data_ptr()) # blocks until the exit command + op.zenflow_adam_destroy(handle) + op.destroy_adam(0) + + +@pytest.mark.skipif(not sys.platform.startswith("linux"), reason="cross-process ZenFlowAdam is Linux-only") +def test_zenflow_adam_cross_process(): + """The optimizer-process driver (shared-memory semaphore control + native worker, the + production path for ZenFlow stage 1/2 overlap) must match the fused reference bit-for-bit + with alternating double buffers. Run as a plain test, not DistributedTest, so the pytest + process (non-daemonic) can spawn the optimizer process.""" + import torch.multiprocessing as mp + + op = CPUAdamBuilder().load() + if not hasattr(op, "zenflow_adam_ctrl_size"): + pytest.skip("cross-process ZenFlowAdam not available in this build") + + lr, beta1, beta2, eps, wd = 1e-3, 0.9, 0.999, 1e-8, 0.0 + n = 100003 # non-SIMD-aligned, exercises the scalar tail + affinity = list(range(min(4, os.cpu_count() or 1))) + + ctrl = torch.zeros(op.zenflow_adam_ctrl_size(), dtype=torch.uint8).share_memory_() + op.zenflow_adam_ctrl_init(ctrl.data_ptr(), 1) + + torch.manual_seed(0) + param = torch.randn(n).share_memory_() + g = [torch.zeros(n).share_memory_(), torch.zeros(n).share_memory_()] + ea = [torch.zeros(n).share_memory_(), torch.zeros(n).share_memory_()] + eq = [torch.zeros(n).share_memory_(), torch.zeros(n).share_memory_()] + stale = torch.zeros(n).share_memory_() + + op.create_adam(1, lr, beta1, beta2, eps, wd, True, False) + p_ref = param.clone() + ea_ref = [ea[0].clone(), ea[1].clone()] + eq_ref = [eq[0].clone(), eq[1].clone()] + st_ref = stale.clone() + + ctx = mp.get_context("spawn") + ready = ctx.Event() + proc = ctx.Process(target=_zenflow_adam_proc_worker, + args=(param, g[0], g[1], ea[0], ea[1], eq[0], eq[1], stale, ctrl, ready, affinity)) + proc.start() + try: + assert ready.wait(timeout=60), "optimizer process did not start" + # With no step submitted yet, a bounded wait must time out (return False) rather than + # block -- this is what lets the training side notice a dead optimizer process. + assert op.zenflow_adam_wait(ctrl.data_ptr(), 0.05) is False, "wait should time out when no step is pending" + for step in range(1, 6): + now = step & 1 + grad = torch.randn(n) + g[now].copy_(grad) + op.zenflow_adam_submit(ctrl.data_ptr(), now, step, [lr], [beta1], [beta2], [eps], [wd], [1]) + assert op.zenflow_adam_wait(ctrl.data_ptr(), 60.0), f"wait timed out step {step}" + op.adam_update_multi(1, step, lr, beta1, beta2, eps, wd, True, [p_ref], [grad.clone()], [ea_ref[now]], + [eq_ref[now]], [st_ref]) + assert torch.equal(param, p_ref), f"param mismatch step {step}" + assert torch.equal(ea[now], ea_ref[now]), f"exp_avg mismatch step {step}" + assert torch.equal(eq[now], eq_ref[now]), f"exp_avg_sq mismatch step {step}" + assert torch.equal(stale, st_ref), f"stale mismatch step {step}" + op.zenflow_adam_ctrl_exit(ctrl.data_ptr()) + proc.join(timeout=10) + finally: + if proc.is_alive(): + proc.terminate() + proc.join(timeout=5) + op.destroy_adam(1) + + class TestCPUAdamGPUError(DistributedTest): def test_cpu_adam_gpu_error(self):