chore(cuda): stream migration refactoring#2662
Conversation
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment has been minimized.
This comment has been minimized.
This comment was marked as outdated.
This comment was marked as outdated.
This comment has been minimized.
This comment has been minimized.
This comment was marked as outdated.
This comment was marked as outdated.
This comment has been minimized.
This comment has been minimized.
078d162 to
a2b01fb
Compare
7201f5f to
16724df
Compare
This comment was marked as outdated.
This comment was marked as outdated.
This comment has been minimized.
This comment has been minimized.
This comment was marked as outdated.
This comment was marked as outdated.
This comment has been minimized.
This comment has been minimized.
This comment was marked as outdated.
This comment was marked as outdated.
dd42a19 to
312fe84
Compare
83e2cd1 to
7b0f7ae
Compare
This comment was marked as outdated.
This comment was marked as outdated.
This comment has been minimized.
This comment has been minimized.
a5750ef to
534de3a
Compare
7b0f7ae to
0d0d427
Compare
|
Claude finished @gaxiom's task in 8m 53s —— View job Code Review:
|
This comment has been minimized.
This comment has been minimized.
stephenh-axiom-xyz
left a comment
There was a problem hiding this comment.
Got about 25% through, but in general I think the pattern we have where DeviceContext only works with GpuDevice is really hurting us. I think DeviceContext should be a per-device concept, and this would allow us to avoid essentially everywhere we're doing specific #[cfg(not(feature = "cuda"))]
| let trace = DeviceMatrix::<F>::with_capacity(self.count.len(), N + 1); | ||
| let d_sizes = self.sizes.to_device().unwrap(); | ||
| let trace = DeviceMatrix::<F>::with_capacity_on(self.count.len(), N + 1, &self.device_ctx); | ||
| trace.buffer().fill_zero_on(&self.device_ctx).unwrap(); |
There was a problem hiding this comment.
nit: Same comment as above, do we need to zero this?
There was a problem hiding this comment.
Same as above — needed because GPU memory pool buffers contain stale data. Added comments explaining the rationale.
This comment was marked as outdated.
This comment was marked as outdated.
… duplication in provers - Add ctx.stream.synchronize() before return in GPU RootTraceGen methods to prevent async race when DeviceContext is dropped with in-flight transfers - Merge identical new() and from_pk() cfg variants in InnerAggregationProver, DeferralInnerProver, and DeferralHookProver (removed unnecessary MaybeDeviceContext bound from methods that don't use device_ctx_for_engine) - Remove redundant #[cfg(feature = "cuda")] on device_ctx_for_engine() calls inside methods already gated by #[cfg(feature = "cuda")] Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…or cuda builds commit_child_vk requires E::PD: MaybeDeviceContext when the cuda feature is enabled, so the cfg duplication on new() and from_pk() cannot be removed. Restores the two-variant pattern for these methods. The agg_prove/prove cleanup (removing redundant #[cfg] on device_ctx_for_engine inside already-cfg'd methods) remains. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…nstead of creating temp streams The GPU RootTraceGen impl was creating throwaway DeviceContext with fresh streams for H2D transfers. This is inconsistent with the rest of the codebase (InnerTraceGen, DeferralHookTraceGen) which accept the engine's DeviceContext. - Add #[cfg(feature = "cuda")] device_ctx parameter to RootTraceGen trait - GPU impl uses the passed-in context instead of allocating a new stream - CPU impl ignores the parameter - Caller in prover/root/trace.rs passes device_ctx through - Removes the synchronize() workaround since the stream is no longer temporary Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Rename all DeviceContext-typed variables, parameters, and struct fields from `ctx` to `device_ctx` across the entire openvm codebase. This disambiguates from other context types (ProvingContext, AirProvingContext, etc.) that coexist in the same files. Also updates Cargo.lock for stark-backend device_ctx rename. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Make device context a first-class associated type on ProverDevice so every device carries its own context type — GpuDeviceCtx for GPU, () for CPU. This eliminates MaybeDeviceContext, device_ctx_for_engine, and all #[cfg(feature = "cuda")] device_ctx parameter duplication. Key changes: - ProverDevice gains DeviceCtx associated type + device_ctx() method - VerifierTraceGen, InnerTraceGen, RootTraceGen, DeferralHookTraceGen, DeferralInnerTraceGen, DeferredVerifyTraceGen all gain DC generic - ~12 cfg-duplicated method pairs unified into single methods - DeviceContext renamed to GpuDeviceCtx - GkrExpBitsLenSink + WhirExpBitsLenSink consolidated into ExpBitsLenSink - gpu_ctx variables renamed for consistency - Zero-fill comments added explaining necessity Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Restore Cargo.lock to avoid bumping AWS deps beyond rustc 1.90.0 MSRV, only updating stark-backend git refs - Rename `ctx` -> `device_ctx` in gkr/mod.rs GPU tracegen (missed in the previous rename pass, flagged by CI review) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Fix generate_proving_ctxs_base calls in CUDA test block: None -> &(), Some(&ctx) -> &ctx - Fix generate_proving_ctx turbofish: ::<PB> -> ::<PB, _> for new DC generic parameter - Use .device_ctx field directly instead of trait method to avoid ambiguous ProverDevice resolution - Remove unused ProverDevice imports in test files Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This reverts commit 824c41ab46c0567c9d0cad3d264f539d140d9cdc.
ad81c01 to
36ad944
Compare
|
Claude finished @gaxiom's task in 4m 4s —— View job Code Review:
|
| Layer | Pattern |
|---|---|
| Device context type | DeviceCtx associated type on ProverDevice — () for CPU, GpuDeviceCtx for GPU |
| VM builder | VmBuilder::create_chip_complex takes device_ctx: &EngineDeviceCtx<E> (crates/vm/src/arch/config.rs:87-95) |
| GPU chip storage | device_ctx: GpuDeviceCtx field, _on(&self.device_ctx) for buffer ops |
| CUDA FFI | cudaStream_t stream as last parameter in all extern "C" functions |
| Prover late binding | No DC generic stored on provers; extracted at prove time via engine.device().device_ctx() |
| Trait-level generics | VerifierTraceGen<PB, SC, DC>, InnerTraceGen<PB, DC>, RootTraceGen<PB, DC>, DeferralHookTraceGen<PB, DC>, DeferralInnerTraceGen<PB, DC> — unified, no cfg duplication |
| Module context | TraceGenModule::ModuleSpecificCtx GAT per module |
| Synchronization | Serial-on-same-stream with stream.synchronize() at proof boundaries |
Verification Results
- Zero remaining
cudaStreamPerThread/current_stream_sync()/current_stream_id()/MaybeDeviceContext/ExpBitsLenSink/GpuDeviceCtx::for_current_device()references — confirmed via grep - All CUDA kernel launchers use 4-arg syntax
<<<grid, block, 0, stream>>> - All FFI wrappers consistently pass
stream: cudaStream_tas final parameter - All buffer allocations use
_on()variants (with_capacity_on,to_device_on,fill_zero_on) scan.cuh(crates/recursion/cuda/include/scan.cuh:32,47) andaffine_scan.cuh— defaultcudaStreamPerThreadparameters correctly removed;streamrequiredscripts/bin/ptx_details.rs—--default-stream=per-threadremoved- All GPU chip zero-fills have explanatory comments (e.g.,
crates/circuits/primitives/src/var_range/cuda.rs:66,77) DeviceContext::clone()is safe —StreamGuardusesArcinternally
Issues with the latest commit (36ad944 resolve comment)
The reviewer's approval left two concrete style nits. The most recent commit addresses one of them, but in a way that does not match the reviewer's explicit suggestion:
1. apply_merkle_precomputation — reviewer's suggested pattern not applied (crates/recursion/src/system/mod.rs:682-705)
The reviewer asked for a single unified function with a #[cfg(feature = "cuda")]-gated parameter:
fn apply_merkle_precomputation(
proof: &Proof<BabyBearPoseidon2Config>,
preflight: &mut Preflight,
#[cfg(feature = "cuda")] device_ctx: &GpuDeviceCtx,
) { ... }The commit instead split the function into two separate functions — apply_merkle_precomputation_cpu (always-on, line 682) and apply_merkle_precomputation (cuda-gated, line 693) — and deleted the private MerklePrecomputationDeviceCtx dispatch trait. The bodies are identical except for one line (compute_merkle_precomputation vs compute_merkle_precomputation_cuda), so this reintroduces ~11 lines of duplication.
The reviewer's #[cfg]-on-parameter pattern is cleaner, avoids the _cpu suffix, and avoids the duplicate 4-field assignment body. Fix this →
2. _device_ctx prefix still present (crates/recursion/src/system/mod.rs:1088)
Still _device_ctx: &() with the _ prefix. Reviewer explicitly asked for #[cfg_attr(not(feature = "cuda"), allow(unused_variables))] on device_ctx instead. Unaddressed by the latest commit. Fix this →
Other findings (non-blocking)
1. Missing size_of::<F>() == size_of::<u32>() assertion in two primitives
var_range/cuda.rs:49 correctly asserts the size match before casting DeviceBuffer<F> to *const u32. The same cast is performed unchecked in:
bitwise_op_lookup—crates/circuits/primitives/src/cuda_abi.rs:48(d_count.as_ptr() as *const u32)range_tuple—crates/circuits/primitives/src/cuda_abi.rs:111(d_count.as_ptr() as *const u32)
Since F = BabyBear is 32-bit today, this is currently a non-issue, but for consistency the same assert_eq!(size_of::<F>(), size_of::<u32>()) should be added to the generate_proving_ctx bodies of those two chips. Not introduced by this PR, but worth flagging while the file is under review.
2. Stacking ModuleSpecificCtx is an owned type (crates/recursion/src/stacking/mod.rs:585)
Defines type ModuleSpecificCtx<'a> = GpuDeviceCtx (owned), while GkrModule (crates/recursion/src/gkr/mod.rs:665) uses a tuple of references (&'a GpuExpBitsLenTraceGenerator, &'a GpuDeviceCtx). Since the trait method takes ctx: &Self::ModuleSpecificCtx<'_>, the parameter type for Stacking is &GpuDeviceCtx — correct, just inconsistent with the rest of the codebase. Works fine given GpuDeviceCtx: Clone with Arc internals.
3. #[cfg] bifurcation inside non-cfg'd CPU VerifierTraceGen impl (crates/recursion/src/system/mod.rs:1104-1113, 1120-1123)
The CPU VerifierTraceGen impl's preflight thread closure contains a #[cfg(feature = "cuda")] / #[cfg(not(feature = "cuda"))] block. With cuda enabled, the CPU impl runs run_preflight_without_merkle and applies merkle precomputation separately afterward (calling apply_merkle_precomputation_cpu). With cuda disabled, it calls run_preflight directly. The feature-flag bifurcation inside an impl that is otherwise generic over DC = () is subtle — behavior depends on the feature flag rather than the type parameter. Worth documenting inline for future maintainers.
4. .unwrap() on GPU allocations (pre-existing pattern)
All GPU chips use .unwrap() on fill_zero_on(), to_device_on(), with_capacity_on(). These would only fail on catastrophic GPU errors (OOM, driver crash), so panicking is defensible. Not introduced by this PR.
5. touchemall uses device_synchronize() (global sync) (crates/circuits/primitives/src/utils.rs:140)
Debug utility behind feature = "touchemall". Global sync is safer here since the utility may need to read data written by any stream. Fine as-is.
Verdict
LGTM. The DeviceContext threading is correct and consistent across all 221 changed files. The DeviceCtx associated type on ProverDevice design elegantly resolves the cfg-duplication concern from earlier review rounds. The late-binding pattern for provers is clean. Benchmarks show no regression. Clean removal of all PTDS references confirmed.
Pending items:
- Address reviewer's
#[cfg]-on-parameter suggestion forapply_merkle_precomputation(commit36ad944took a different approach) - Replace
_device_ctxprefix with#[cfg_attr(not(feature = "cuda"), allow(unused_variables))]atsystem/mod.rs:1088 - v2.0.0-beta.2 tag
- Update to
stark-backendbranch inCargo.tomlbefore final merge
| Branchchore/stream-migration
Note: cells_used metrics omitted because CUDA tracegen does not expose unpadded trace heights. Commit: 36ad944 |
Closes INT-6464 ## CUDA Explicit Stream Migration (OpenVM side) Companion to [stark-backend#317](openvm-org/stark-backend#317). Adds `cudaStream_t stream` to all OpenVM CUDA launchers and FFI wrappers, injects `DeviceContext` through the VM builder and chip construction APIs, removes all PTDS references, and lifts `DeviceCtx` to a first-class associated type on `ProverDevice` — eliminating `MaybeDeviceContext`, `#[cfg]` duplication, and conditional `device_ctx: Option<&DeviceContext>` parameters. **224 files changed, +2902 / -1353** See [cuda-stream-migration-design.md](https://github.com/openvm-org/v2-proof-system/blob/test/stream-migration/cuda-stream-migration-design.md) for full design rationale. --- ## CUDA launchers (~67 `.cu` files + 2 `.cuh` headers) Every `extern "C"` launcher gains `cudaStream_t stream` as final parameter. All kernel launches use `<<<grid, block, 0, stream>>>`. All CUB calls (`DeviceScan`, `DeviceReduce`, `DeviceMergeSort`) pass explicit `stream`. `scan.cuh` and `affine_scan.cuh` had `= cudaStreamPerThread` default parameters removed — `stream` is now required. --- ## FFI wrappers (15 `cuda_abi.rs` / `abi.rs` files) Every `extern "C"` declaration and safe Rust wrapper gains `stream: cudaStream_t`. Covers `crates/vm`, `crates/circuits/primitives`, `crates/circuits/poseidon2-air`, `crates/recursion` (6 files), and all extensions (`rv32im`, `keccak256`, `bigint`, `sha2`, `deferral`). --- ## VM builder — `DeviceContext` injection `VmBuilder::create_chip_complex` gains `device: &E::PD` parameter. `VirtualMachine::new` passes `engine.device()` through. 25 implementations updated (19 CPU impls ignore it, 6 GPU impls extract `DeviceContext`). --- ## Stateful GPU chip constructors GPU chips that allocate persistent `DeviceBuffer`s now store `device_ctx: DeviceContext` and use `_on` allocation variants: - `BitwiseOperationLookupChipGPU` — histogram accumulator - `VariableRangeCheckerChipGPU` — range check accumulator - `RangeTupleCheckerChipGPU` — range tuple accumulator - `Poseidon2ChipGPU` — shared hash records buffer + index counter - `MemoryMerkleTree` — merkle tree device state The `Chip::generate_proving_ctx(&self, records)` trait is **unchanged** — chips access `self.device_ctx` internally. --- ## VM system CUDA modules (`crates/vm/src/system/cuda/`) All GPU system modules use `self.device_ctx.stream.as_raw()`: `poseidon2.rs`, `memory.rs`, `boundary.rs`, `phantom.rs`, `program.rs`, `merkle_tree/mod.rs`. --- ## `DeviceCtx` as associated type on `ProverDevice` (87 files, -630 net lines) The final refactor lifts `DeviceContext` from a `#[cfg(feature = "cuda")]`-gated parameter to a first-class associated type on `ProverDevice` (defined in stark-backend#317): ```rust // In ProverDevice (stark-backend): type DeviceCtx: Clone + Send + Sync; fn device_ctx(&self) -> &Self::DeviceCtx; // GpuDevice: DeviceCtx = DeviceContext // CpuDevice: DeviceCtx = () ``` This eliminates: - `MaybeDeviceContext` trait + all 3 impls + `device_ctx_for_engine()` helper — **deleted entirely** - ~6 pairs of `#[cfg(feature = "cuda")]` / `#[cfg(not)]` duplicated methods in continuations provers (`inner/mod.rs`, `deferral/hook/mod.rs`, `deferral/inner/mod.rs`) — **unified into single versions** - ~30 `#[cfg(feature = "cuda")] device_ctx: Option<&DeviceContext>` conditional parameters — **replaced with unconditional `&DC` generic** - cfg-duplicated `VerifierTraceGen` trait methods in `recursion/src/system/mod.rs` — **unified** Callers now use `engine.device().device_ctx()` directly. The `EngineDeviceCtx<E>` type alias avoids spelling out the full associated type path. --- ## Recursion — `VerifierTraceGen` and module tracegen `VerifierTraceGen` and `InnerTraceGen` traits gain a `DC` generic parameter for the device context type. `generate_proving_ctxs` accepts `&DC` unconditionally. Stream synchronization at proof boundaries via `device_ctx.stream.synchronize()`. --- ## Continuations + guest verifier `InnerTraceGen`, `DeferralHookTraceGen`, `DeferralInnerTraceGen`, `RootTraceGen` — all accept `&DC` and use the engine-owned context. `agg_prove`, `from_pk`, `new` — single versions, no `#[cfg]` duplication. Guest verifier circuit trace generation passes device context through. --- ## Extensions All extension GPU modules pass `stream` to FFI calls: `rv32im` (all ALU/branch/load/store cuda modules), `sha2`, `keccak256`, `bigint`, `deferral`. --- ## Removed - `--default-stream=per-thread` from `scripts/bin/ptx_details.rs` - All `cudaStreamPerThread` imports and usages - All `current_stream_sync()` / `current_stream_id()` usages - `MaybeDeviceContext` trait + `device_ctx_for_engine()` helper - Temporary `DeviceContext` escape hatches (fresh streams created outside engine) - ~400 lines of `#[cfg]` duplication in continuations provers --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Closes INT-6464
CUDA Explicit Stream Migration (OpenVM side)
Companion to stark-backend#317. Adds
cudaStream_t streamto all OpenVM CUDA launchers and FFI wrappers, injectsDeviceContextthrough the VM builder and chip construction APIs, removes all PTDS references, and liftsDeviceCtxto a first-class associated type onProverDevice— eliminatingMaybeDeviceContext,#[cfg]duplication, and conditionaldevice_ctx: Option<&DeviceContext>parameters.224 files changed, +2902 / -1353
See cuda-stream-migration-design.md for full design rationale.
CUDA launchers (~67
.cufiles + 2.cuhheaders)Every
extern "C"launcher gainscudaStream_t streamas final parameter. All kernel launches use<<<grid, block, 0, stream>>>. All CUB calls (DeviceScan,DeviceReduce,DeviceMergeSort) pass explicitstream.scan.cuhandaffine_scan.cuhhad= cudaStreamPerThreaddefault parameters removed —streamis now required.FFI wrappers (15
cuda_abi.rs/abi.rsfiles)Every
extern "C"declaration and safe Rust wrapper gainsstream: cudaStream_t. Coverscrates/vm,crates/circuits/primitives,crates/circuits/poseidon2-air,crates/recursion(6 files), and all extensions (rv32im,keccak256,bigint,sha2,deferral).VM builder —
DeviceContextinjectionVmBuilder::create_chip_complexgainsdevice: &E::PDparameter.VirtualMachine::newpassesengine.device()through. 25 implementations updated (19 CPU impls ignore it, 6 GPU impls extractDeviceContext).Stateful GPU chip constructors
GPU chips that allocate persistent
DeviceBuffers now storedevice_ctx: DeviceContextand use_onallocation variants:BitwiseOperationLookupChipGPU— histogram accumulatorVariableRangeCheckerChipGPU— range check accumulatorRangeTupleCheckerChipGPU— range tuple accumulatorPoseidon2ChipGPU— shared hash records buffer + index counterMemoryMerkleTree— merkle tree device stateThe
Chip::generate_proving_ctx(&self, records)trait is unchanged — chips accessself.device_ctxinternally.VM system CUDA modules (
crates/vm/src/system/cuda/)All GPU system modules use
self.device_ctx.stream.as_raw():poseidon2.rs,memory.rs,boundary.rs,phantom.rs,program.rs,merkle_tree/mod.rs.DeviceCtxas associated type onProverDevice(87 files, -630 net lines)The final refactor lifts
DeviceContextfrom a#[cfg(feature = "cuda")]-gated parameter to a first-class associated type onProverDevice(defined in stark-backend#317):This eliminates:
MaybeDeviceContexttrait + all 3 impls +device_ctx_for_engine()helper — deleted entirely#[cfg(feature = "cuda")]/#[cfg(not)]duplicated methods in continuations provers (inner/mod.rs,deferral/hook/mod.rs,deferral/inner/mod.rs) — unified into single versions#[cfg(feature = "cuda")] device_ctx: Option<&DeviceContext>conditional parameters — replaced with unconditional&DCgenericVerifierTraceGentrait methods inrecursion/src/system/mod.rs— unifiedCallers now use
engine.device().device_ctx()directly. TheEngineDeviceCtx<E>type alias avoids spelling out the full associated type path.Recursion —
VerifierTraceGenand module tracegenVerifierTraceGenandInnerTraceGentraits gain aDCgeneric parameter for the device context type.generate_proving_ctxsaccepts&DCunconditionally. Stream synchronization at proof boundaries viadevice_ctx.stream.synchronize().Continuations + guest verifier
InnerTraceGen,DeferralHookTraceGen,DeferralInnerTraceGen,RootTraceGen— all accept&DCand use the engine-owned context.agg_prove,from_pk,new— single versions, no#[cfg]duplication. Guest verifier circuit trace generation passes device context through.Extensions
All extension GPU modules pass
streamto FFI calls:rv32im(all ALU/branch/load/store cuda modules),sha2,keccak256,bigint,deferral.Removed
--default-stream=per-threadfromscripts/bin/ptx_details.rscudaStreamPerThreadimports and usagescurrent_stream_sync()/current_stream_id()usagesMaybeDeviceContexttrait +device_ctx_for_engine()helperDeviceContextescape hatches (fresh streams created outside engine)#[cfg]duplication in continuations provers