Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions src/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ ggml_backend_t Backend::handle() const {
return impl_ ? impl_->backend : nullptr;
}

bool Backend::is_gpu() const {
return impl_ && impl_->use_sched;
}

void Backend::register_input(ggml_tensor* t, const void* host, size_t nbytes) {
impl_->pending.push_back({t, host, nbytes});
}
Expand Down Expand Up @@ -387,4 +391,154 @@ void weight_to_host_f32(const ModelLoader& ml, const char* name, std::vector<flo
ggml_backend_tensor_get(t, out.data(), 0, ggml_nbytes(t));
}

// ---------------------------------------------------------------------------
// ReplayGraph: build a graph once, recompute it many times on the persistent
// backend/gallocr WITHOUT re-running ggml_init / gallocr_alloc / ggml_free per
// call. Keeping the ggml context (and thus cgraph->nodes[0]) alive across calls
// is what lets ggml-cuda's CUDA-graph capture warm up and replay — the direct
// analogue of megapar's torch.cuda.CUDAGraph replay. See backend.hpp.
// ---------------------------------------------------------------------------
ReplayGraph::ReplayGraph(Backend& backend,
const std::function<ggml_tensor*(ggml_context*)>& build)
: backend_(backend) {
// Metadata-only context, kept alive for this ReplayGraph's lifetime.
struct ggml_init_params params = {
/* .mem_size = */ ggml_tensor_overhead() * kGraphSize +
ggml_graph_overhead_custom(kGraphSize, false),
/* .mem_buffer = */ nullptr,
/* .no_alloc = */ true,
};
ctx_ = ggml_init(params);
assert(ctx_ && "ReplayGraph: ggml_init failed");

// Drive add_graph_input()/capture registrations to this Backend for the
// build call (same mechanism Backend::compute uses).
Backend::Impl* impl = backend_.impl_;
impl->pending.clear();
impl->captures.clear();
Backend* prev_active = t_active;
t_active = &backend_;
out_ = build(ctx_);
t_active = prev_active;
assert(out_ && "ReplayGraph: build() returned null output tensor");

// Record the input-tensor handles (registration order) BEFORE clearing
// pending, so set_input() can re-feed them later.
inputs_.reserve(impl->pending.size());
for (const PendingInput& pi : impl->pending) inputs_.push_back(pi.tensor);
impl->pending.clear();
// Likewise record capture tensors + dst (compute_with_captures re-fills them
// each step; they must NOT be cleared like Backend::compute does).
captures_.reserve(impl->captures.size());
for (const PendingCapture& pc : impl->captures)
captures_.push_back({pc.tensor, pc.dst});
impl->captures.clear();

// Mark the output AND every capture so the gallocr keeps them, then expand
// the forward graph over captures first (robust if the output's subgraph
// does not reach them, mirroring Backend::compute).
ggml_set_output(out_);
for (const auto& cap : captures_) ggml_set_output(cap.first);
gf_ = ggml_new_graph_custom(ctx_, kGraphSize, false);
for (const auto& cap : captures_) ggml_build_forward_expand(gf_, cap.first);
ggml_build_forward_expand(gf_, out_);

// Allocate the graph ONCE now (persistent gallocr / sched) so the input
// tensors get ->buffer/->data set BEFORE the first set_input() call, and so
// the steady-state replay path does NOT re-plan the allocation per step.
// The persistent gallocr keeps the buffer for a stable graph shape across
// graph_compute calls, so one alloc here covers every replay.
if (!alloc_internal()) {
assert(false && "ReplayGraph: initial graph allocation failed");
}
}

ReplayGraph::~ReplayGraph() {
if (ctx_) ggml_free(ctx_);
}

void ReplayGraph::set_input(size_t i, const void* host, size_t nbytes) {
assert(i < inputs_.size() && "ReplayGraph::set_input index out of range");
ggml_tensor* t = inputs_[i];
// The input tensor was marked ggml_set_input() during build() and allocated
// by the gallocr (in the constructor); its ->buffer/->data are set, so a
// tensor_set feeds the new step's data in place.
ggml_backend_tensor_set(t, host, 0, nbytes);
}

// Allocate gf_ on the persistent gallocr (or sched fallback). Returns true on
// success and records need_sched_ so compute() knows which compute path to use.
bool ReplayGraph::alloc_internal() {
Backend::Impl* impl = backend_.impl_;
need_sched_ = false;
if (impl->use_sched) {
const int n_nodes = ggml_graph_n_nodes(gf_);
for (int i = 0; i < n_nodes; ++i) {
if (!ggml_backend_supports_op(impl->backend, ggml_graph_node(gf_, i))) {
need_sched_ = true;
break;
}
}
}
if (need_sched_) {
if (!impl->sched) {
ggml_backend_t backs[2] = { impl->backend, impl->cpu_backend };
impl->sched = ggml_backend_sched_new(
backs, nullptr, 2, kGraphSize, false, true);
}
ggml_backend_sched_reset(impl->sched);
bool ok = ggml_backend_sched_alloc_graph(impl->sched, gf_);
if (!ok) PK_LOG("ReplayGraph: ggml_backend_sched_alloc_graph failed");
return ok;
}
if (!impl->galloc) {
impl->galloc = ggml_gallocr_new(
ggml_backend_get_default_buffer_type(impl->backend));
if (!impl->galloc) { PK_LOG("ReplayGraph: ggml_gallocr_new failed"); return false; }
}
bool ok = ggml_gallocr_alloc_graph(impl->galloc, gf_);
if (!ok) PK_LOG("ReplayGraph: ggml_gallocr_alloc_graph failed");
else g_last_graph_alloc_bytes = ggml_gallocr_get_buffer_size(impl->galloc, 0);
return ok;
}

bool ReplayGraph::compute(std::vector<float>& out) {
Backend::Impl* impl = backend_.impl_;
if (!impl || !impl->backend) {
PK_LOG("ReplayGraph::compute called on an uninitialised backend");
return false;
}

// Recompute the SAME already-allocated graph. The persistent gallocr keeps
// the buffer for this stable graph shape, so no per-step alloc is needed:
// set_input() wrote the new step's inputs into the persistent input tensors,
// and graph_compute reads them in place. Keeping the ggml context (and thus
// cgraph->nodes[0]) stable across calls is what lets ggml-cuda capture +
// replay the per-step work (megapar's win).
enum ggml_status status = need_sched_
? ggml_backend_sched_graph_compute(impl->sched, gf_)
: ggml_backend_graph_compute(impl->backend, gf_);
if (status != GGML_STATUS_SUCCESS) {
PK_LOG("ReplayGraph: ggml_backend_graph_compute failed (status=%d)", (int)status);
return false;
}

size_t n = (size_t)ggml_nelements(out_);
out.resize(n);
ggml_backend_tensor_get(out_, out.data(), 0, n * sizeof(float));
return true;
}

bool ReplayGraph::compute_with_captures(std::vector<float>& out) {
if (!compute(out)) return false;
// Re-fill every capture's stable dst vector from the persistent capture
// tensors (the same graph_compute just wrote them in place).
for (const auto& cap : captures_) {
size_t cn = (size_t)ggml_nelements(cap.first);
cap.second->resize(cn);
ggml_backend_tensor_get(cap.first, cap.second->data(), 0, cn * sizeof(float));
}
return true;
}

} // namespace pk
89 changes: 89 additions & 0 deletions src/backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

struct ggml_context;
struct ggml_tensor;
struct ggml_cgraph;
struct ggml_backend;
typedef struct ggml_backend* ggml_backend_t;

Expand Down Expand Up @@ -55,6 +56,12 @@ class Backend {
// registry device name for a GPU backend, e.g. the CUDA device name).
const char* device_name() const { return device_name_.c_str(); }

// True iff the active backend is a non-CPU (GPU/IGPU) device. The graph-
// capture/replay decode optimisation only helps GPU (it is launch-overhead
// bound there); on CPU the per-step set_input + capture-readback overhead
// it adds is a net regression, so callers gate on this.
bool is_gpu() const;

// The underlying CPU ggml backend. Exposed so the loader can give its weight
// tensors a backend buffer over the SAME backend graphs run on (see
// ModelLoader::realize_weights). Any CPU buffer is compatible with the CPU
Expand Down Expand Up @@ -95,6 +102,8 @@ class Backend {
Impl* impl_;
int n_threads_ = 1;
std::string device_name_ = "cpu";

friend class ReplayGraph;
};

// Register a host-backed graph input for the currently-active Backend::compute
Expand Down Expand Up @@ -149,4 +158,84 @@ void ensure_weights_realized(const ModelLoader& ml);
// (preprocessing, batch-norm folding) — NOT for graph leaves (use clone_weight).
void weight_to_host_f32(const ModelLoader& ml, const char* name, std::vector<float>& out);

// A graph built ONCE and recomputed many times, keeping the SAME ggml context /
// cgraph / input tensors alive across calls. The C++ analogue of megapar's
// CUDA-graph step capture, realized through ggml's OWN capture:
//
// ggml-cuda keys its internal CUDA graph on `cgraph->nodes[0]` (a tensor pointer
// owned by the compute context). Backend::compute does ggml_init + ggml_free
// EVERY call, so every per-step graph gets a NEW context -> NEW node pointers ->
// a different key -> CUDA-graph capture NEVER warms up and every tiny per-step
// op is launched directly (the launch-overhead regime that dominates the GPU
// transducer decode loop). ReplayGraph keeps the context + cgraph alive across
// calls, so nodes[0] is a STABLE pointer: ggml-cuda warms up after one extra
// direct eval and then replays the captured graph, collapsing the per-step
// launch storm. On CPU this still wins by skipping the per-call ggml_init /
// gallocr re-plan / ggml_free.
//
// The replayed graph's input tensors live in the (persistent) ggml context; the
// caller feeds fresh data into them each step via the returned input handles.
//
// Usage:
// ReplayGraph rg(backend, [&](ggml_context* ctx){
// ggml_tensor* a = pk::graph_input_tensor(ctx, ...);
// return some_op(ctx, a, weight);
// });
// // build() recorded `a`'s handle; feed it:
// rg.set_input(0, host_a, nbytes_a);
// std::vector<float> out; rg.compute(out);
class ReplayGraph {
public:
// Build the graph now: runs `build(ctx)` in a no_alloc context (same
// contract as Backend::compute's build lambda — register host inputs via
// pk::add_graph_input / pk::graph_input_tensor), allocates the result via
// the backend's persistent gallocr, and remembers the input-tensor handles
// (in registration order) for set_input(). `backend` must outlive this
// object (it owns the gallocr the graph is allocated in).
ReplayGraph(Backend& backend,
const std::function<ggml_tensor*(ggml_context*)>& build);
~ReplayGraph();

ReplayGraph(const ReplayGraph&) = delete;
ReplayGraph& operator=(const ReplayGraph&) = delete;

// Feed `nbytes` from `host` into input #`i` (the i-th tensor registered
// during build). The data lands in the persistent input tensor; it must
// stay valid until compute() returns.
void set_input(size_t i, const void* host, size_t nbytes);

// Recompute the graph (with whatever data set_input wrote) and read the
// output tensor's f32 contents into `out`. Returns true on success.
bool compute(std::vector<float>& out);

// Number of input tensors registered during build().
size_t n_inputs() const { return inputs_.size(); }

// Recompute the graph and read BOTH the output tensor (into `out`) and every
// tensor registered via pk::capture_graph_output() during build() (into the
// caller's stable dst vectors). The captures are remembered across calls
// (unlike Backend::compute, which clears them), so the same dst vectors are
// re-filled each step. Used by the prediction net to pull each layer's new
// (h', c') state out of the replayed graph.
bool compute_with_captures(std::vector<float>& out);

private:
Backend& backend_;
ggml_context* ctx_ = nullptr;
ggml_cgraph* gf_ = nullptr;
ggml_tensor* out_ = nullptr;
// Input tensors recorded in build() order; set_input(i) writes into these.
std::vector<ggml_tensor*> inputs_;
// Capture tensors + their stable dst vectors (recorded from
// pk::capture_graph_output() during build). compute_with_captures() re-fills
// each dst after every replay.
std::vector<std::pair<ggml_tensor*, std::vector<float>*>> captures_;
// Whether gf_ was allocated via the sched fallback (true) or the fast
// gallocr path (false). Set once in alloc_internal(); read in compute().
bool need_sched_ = false;

// Allocate gf_ on the persistent gallocr / sched (called once in the ctor).
bool alloc_internal();
};

} // namespace pk
Loading