diff --git a/.github/workflows/linux_cuda_plugin_ci.yml b/.github/workflows/linux_cuda_plugin_ci.yml index e88c6beff5280..2369af53621b2 100644 --- a/.github/workflows/linux_cuda_plugin_ci.yml +++ b/.github/workflows/linux_cuda_plugin_ci.yml @@ -144,10 +144,14 @@ jobs: # --- Run the CUDA plugin EP C++ GoogleTest binary --- # onnxruntime_provider_test is built into the artifact and links the plugin tests - # (gated by ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP). The user-stream + CUDA graph test - # registers the plugin .so via GetSharedLibraryFileName("onnxruntime_providers_cuda_plugin"), - # which returns the platform-specific filename without a directory component. Run from - # /build/Release/Release so that filename resolves to the plugin .so built there. + # (gated by ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP). These tests register the plugin .so via + # GetSharedLibraryFileName("onnxruntime_providers_cuda_plugin"), which returns the + # platform-specific filename without a directory component. Run from /build/Release/Release + # so that filename resolves to the plugin .so built there. + # The filter covers every CUDA plugin EP suite linked into this binary: + # CudaPlugin* -> CudaPluginUserStreamGraphTest, CudaPluginArenaTest, + # CudaPluginPartitioningTest, CudaPluginProfilingTest + # CudaResourcePartitioning* -> CudaResourcePartitioningTest - name: Run CUDA Plugin EP C++ Tests run: | docker run --rm --gpus all \ @@ -163,5 +167,5 @@ jobs: cd /build/Release/Release ls -la onnxruntime_provider_test libonnxruntime_providers_cuda_plugin.so - ./onnxruntime_provider_test --gtest_filter='CudaPluginUserStreamGraphTest.*' + ./onnxruntime_provider_test --gtest_filter='CudaPlugin*:CudaResourcePartitioning*' " diff --git a/docs/cuda_plugin_ep/arena_allocator_migration_design.md b/docs/cuda_plugin_ep/arena_allocator_migration_design.md index f082b444e10b0..d4ff21e713f85 100644 --- a/docs/cuda_plugin_ep/arena_allocator_migration_design.md +++ b/docs/cuda_plugin_ep/arena_allocator_migration_design.md @@ -62,16 +62,16 @@ if (!factory.arena_allocator_) { **Stream-aware allocation.** `ArenaImpl::AllocOnStream(size, stream)` tracks which chunks are assigned to which stream. `ResetChunksUsingStream(stream_impl)` is called from `OrtSyncStreamImpl::OnSessionRunEnd` to release chunk-to-stream assignments when a session run completes. -**Kernel-side consumption of the arena.** Migrated CUDA kernels obtain scratch/workspace memory from this arena through `CudaKernel::GetScratchBuffer`, which calls `Info().GetAllocator(OrtMemTypeDefault)`. Inside the plugin build that allocator is exposed to internal code as an `IAllocatorWrappingOrtAllocator` (`include/onnxruntime/ep/adapter/allocator.h`), which implements `IsStreamAware()`/`AllocOnStream()` by forwarding to the underlying `OrtAllocator`'s `AllocOnStream` (ORT ≥ 1.23), falling back to plain `Alloc` otherwise. The plugin `GetScratchBuffer` deliberately passes a **null stream** to the arena rather than forwarding the kernel's compute stream. A plugin kernel only has the raw `cudaStream_t` (via `KernelContext::GetGPUComputeStream`), not the framework `OrtSyncStream*` that the stream-aware arena persists in each chunk (`chunk->stream`) and later dereferences through the EP stream API (`SyncStream_GetImpl`/`SyncStream_GetSyncId`). Synthesizing a temporary framework `Stream` wrapper over the raw handle would be unsafe: it would dangle once `GetScratchBuffer` returns while the arena still holds the pointer, and it would be type-confused (a framework `Stream*` reinterpreted as an `OrtSyncStream*` that ORT never created for this stream). With a null stream the arena tracks scratch chunks as freely reusable (the same semantics as a plain non-stream-aware BFC arena). This is still what keeps scratch allocations served from already-reserved chunks during CUDA graph capture — capture stability comes from chunk reuse, not from stream tagging — and it is safe for the CUDA graph path, which runs on a single unified stream. +**Kernel-side consumption of the arena.** Migrated CUDA kernels obtain scratch/workspace memory from this arena through `CudaKernel::GetScratchBuffer`, which calls `Info().GetAllocator(OrtMemTypeDefault)`. Inside the plugin build that allocator is exposed to internal code as an `IAllocatorWrappingOrtAllocator` (`include/onnxruntime/ep/adapter/allocator.h`), which implements `IsStreamAware()`/`AllocOnStream()` by forwarding to the underlying `OrtAllocator`'s `AllocOnStream` (ORT >= 1.23), falling back to plain `Alloc` otherwise. `GetScratchBuffer` uses the framework `OrtSyncStream*` exposed through `KernelContext_GetSyncStream` to stream-tag scratch chunks, while kernels continue to use the raw `cudaStream_t` from `KernelContext_GetGPUComputeStream` for launches and library handles. This keeps allocation bookkeeping on the same framework stream wrapper that the arena stores in `chunk->stream` and later queries through the EP stream API (`SyncStream_GetImpl`/`SyncStream_GetSyncId`). If the negotiated ORT API version does not include `KernelContext_GetSyncStream`, the adapter falls back to a null stream tag and the EP does not advertise concurrent run support. -#### Scratch buffer stream tagging — limitation and future work +#### Scratch buffer stream tagging -A common review question is: *"Passing a null stream to the scratch allocator looks wrong — won't it cause a synchronization issue? Shouldn't the scratch buffer use the same stream as the kernel?"* The short answer is that, for the path this code targets, it is correct and safe. The longer answer clarifies what the `stream` argument actually does and why forwarding the real stream is not currently possible. +A common review question is: *"Shouldn't the scratch buffer use the same stream as the kernel?"* The short answer is yes for concurrent multi-stream runs, but the allocator must receive the framework stream wrapper, not the raw CUDA handle. -- **The `stream` argument is bookkeeping, not execution.** The stream passed to a stream-aware arena's `AllocOnStream()` is only metadata the arena uses to decide whether a *freed* chunk may be reused on a *different* stream without an intervening synchronization. It does **not** change where the kernel runs: the returned buffer is always consumed by the kernel on its real compute stream. So a null tag does not move work onto the default stream or skip any required sync — it only relaxes cross-stream chunk reuse. -- **Why null is safe here.** The scratch routing targets serialized runs and the CUDA graph path, which runs on a single **unified stream** when graph capture and a user compute stream are combined. On a single stream, alloc -> use -> free -> reuse are implicitly ordered by the stream itself, so there is never a second stream that could reuse a chunk while the first is still using it. A null-tagged ("freely reusable") chunk behaves exactly like a plain non-stream-aware BFC arena chunk, which is the correct behavior for one stream. Because null-tagged chunks are not safe for overlapping runs on different CUDA streams, the CUDA plugin EP does not advertise concurrent `Session::Run()` support until scratch chunks can be properly stream-tagged. -- **Why we cannot forward the real stream today (C-API limitation).** The stream-aware arena needs the framework `OrtSyncStream*` (`struct OrtSyncStream : public onnxruntime::Stream` in `core/framework/plugin_ep_stream.h`) — the ORT-core wrapper it stored in `chunk->stream`. A plugin kernel only has the raw `cudaStream_t`. `CudaSyncStream::FromCudaStream()` can recover the plugin-side `CudaSyncStream` (an `OrtSyncStreamImpl`), but that is a *different* object from the ORT-core `OrtSyncStream*` the arena expects; passing it (or a stack-allocated shim over the raw handle) would be both dangling and type-confused. -- **Future work.** To properly stream-tag scratch chunks — which only becomes necessary if this path is extended to support concurrent multi-stream runs sharing one arena — ORT needs new C-API surface to expose the framework `OrtSyncStream*` (or its sync-id) to plugin kernels at dispatch time (e.g. via `KernelContext`). Until then, the null-stream tag is the correct and intentional choice. The matching code comment lives in `CudaKernel::GetScratchBuffer` (`onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h`). +- **The `stream` argument is bookkeeping, not execution.** The stream passed to a stream-aware arena's `AllocOnStream()` is only metadata the arena uses to decide whether a *freed* chunk may be reused on a *different* stream without an intervening synchronization. It does **not** change where the kernel runs: the returned buffer is consumed by the kernel on its raw CUDA compute stream. +- **Raw CUDA stream and framework stream are different objects.** `KernelContext_GetGPUComputeStream` returns the raw `cudaStream_t` used for CUDA calls. The stream-aware arena needs the framework `OrtSyncStream*` (`struct OrtSyncStream : public onnxruntime::Stream` in `core/framework/plugin_ep_stream.h`) because that stable wrapper is what it persists in each chunk. `CudaSyncStream::FromCudaStream()` can recover the plugin-side `CudaSyncStream` (`OrtSyncStreamImpl`), but that is not the ORT-core `OrtSyncStream*` the arena expects. +- **How the plugin bridges them.** `KernelContext_GetSyncStream` exposes the framework stream for the current kernel dispatch. The CUDA plugin adapter records the mapping from raw `cudaStream_t` to framework stream when migrated kernels call `GetComputeStream(ctx)`, and `GetScratchBuffer` uses the framework stream for `AllocOnStream`. This preserves the existing migrated-kernel pattern while making scratch chunks safe for cross-stream reuse decisions. +- **Compatibility fallback.** When the negotiated ORT API version does not include `KernelContext_GetSyncStream`, scratch allocations use a null stream tag. A null tag is correct for serialized runs and single-unified-stream CUDA graph capture, but it is not safe for overlapping runs on different CUDA streams, so the plugin EP only advertises concurrent `Session::Run()` when `KernelContext_GetSyncStream` is available. diff --git a/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md b/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md index 1cac9464430dc..9035ce91bb3bb 100644 --- a/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md +++ b/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md @@ -92,8 +92,9 @@ A natural question when reading `GetPerThreadContext()` is why `use_external_str | `onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc` | Added `InitHandlesWithExternalStream()`, updated destructor for `owns_stream_` | | `onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h` | Added `InitHandlesWithExternalStream()` declaration, `owns_stream_` member | | `onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc` | Added config parsing for `enable_cuda_graph` and `min_num_runs_before_cuda_graph_capture`; removed the validation that rejected `user_compute_stream` + `enable_cuda_graph` (the combination is now supported) | -| `onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h` | `CudaKernel::GetScratchBuffer` now allocates through `Info().GetAllocator()` (the EP arena) with a null stream, instead of issuing a raw `cudaMallocAsync`/`cudaMalloc` per call, so scratch allocations are served from already-reserved arena chunks during capture | -| `include/onnxruntime/ep/adapter/allocator.h` | Implemented `IAllocatorWrappingOrtAllocator::IsStreamAware`/`AllocOnStream` (previously `ORT_NOT_IMPLEMENTED`) so plugin adapters can forward stream-aware allocations when a framework stream is available; `GetScratchBuffer` still passes a null stream until plugin kernels can receive a stable framework `OrtSyncStream*` | +| `onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h` | `CudaKernel::GetScratchBuffer` now allocates through `Info().GetAllocator()` (the EP arena) and stream-tags scratch chunks with the framework stream exposed by `KernelContext_GetSyncStream`, instead of issuing a raw `cudaMallocAsync`/`cudaMalloc` per call | +| `include/onnxruntime/ep/adapter/allocator.h` | Implemented `IAllocatorWrappingOrtAllocator::IsStreamAware`/`AllocOnStream` (previously `ORT_NOT_IMPLEMENTED`) so plugin adapters can forward stream-aware allocations when a framework stream is available | +| `include/onnxruntime/core/session/onnxruntime_c_api.h` | Added `KernelContext_GetSyncStream` so plugin kernels can obtain the framework `OrtSyncStream*` for stream-aware allocation bookkeeping while still using `KernelContext_GetGPUComputeStream` for raw CUDA work | | `include/onnxruntime/core/session/onnxruntime_ep_c_api.h` | Added `IsGraphCaptureEnabled`, `IsGraphCaptured`, `ReplayGraph`, `GetGraphCaptureNodeAssignmentPolicy` callbacks and `OrtGraphCaptureNodeAssignmentPolicy` enum to `OrtEp` | | `include/onnxruntime/core/framework/execution_provider.h` | Added `GetGraphCaptureNodeAssignmentPolicy()` virtual to `IExecutionProvider` | | `onnxruntime/core/session/inference_session.cc` | Replaced hard-coded EP name list with policy-driven graph capture validation loop; added bounded recursion via `RunImpl()` with `kMaxGraphCaptureWarmupRuns`; graph-enabled runs now reacquire stream collections through ORT core's thread-affine pool across internal warm-up/capture recursion | @@ -128,7 +129,7 @@ CUDA graph capture requires that all memory allocations happen during warmup, no **Arena integration details (now implemented):** - Default CUDA device allocations come from the plugin-hosted arena (`CudaArenaAllocator`). During warmup runs, the arena grows to accommodate all needed chunks; during capture and replay, the same chunks are reused without `cudaMalloc` calls. -- Kernel scratch/workspace allocations (`CudaKernel::GetScratchBuffer`) also flow through the EP arena via `Info().GetAllocator()`, rather than issuing a fresh `cudaMallocAsync`/`cudaMalloc` per call. After warmup the arena has reached its steady-state working set, so the capture run serves every scratch request from an already-reserved chunk and the device free-memory footprint stays stable across the capture window. This is what makes the `cudaMemGetInfo` allocation-during-capture detector pass for graphs that use scratch buffers, and it matches the bundled CUDA EP (which also obtains scratch from `Info().GetAllocator()`). `GetScratchBuffer` passes a **null stream** to the arena. This is *not* a synchronization bug: the `stream` argument is only bookkeeping metadata the stream-aware arena uses to decide when a freed chunk may be reused on a *different* stream without a sync - it does not change where the kernel runs (the buffer is still consumed on the real compute stream). In a serialized run (and within one graph-capture run), alloc/free/reuse are implicitly ordered on that stream, so a null-tagged ("freely reusable") chunk is correct and safe. It is also currently the only safe option, because a plugin kernel only has the raw `cudaStream_t` (`KernelContext::GetGPUComputeStream`), not the framework `OrtSyncStream*` the stream-aware arena persists per chunk and later dereferences through the EP stream API; note that the ORT-core `OrtSyncStream` (`struct OrtSyncStream : public onnxruntime::Stream`) is a different object from the plugin's `CudaSyncStream` (an `OrtSyncStreamImpl`). Synthesizing a temporary `Stream*` over the raw handle would dangle after `GetScratchBuffer` returns and be type-confused, so scratch chunks are tracked with a null stream (freely reusable, like a plain BFC arena). Capture stability comes from chunk reuse, not stream tagging. Properly stream-tagging scratch chunks (required before this path can support concurrent multi-stream runs) is **future work** that requires new C-API surface to expose the framework `OrtSyncStream*` to plugin kernels — see [arena_allocator_migration_design.md](arena_allocator_migration_design.md) ("Scratch buffer stream tagging — limitation and future work"). +- Kernel scratch/workspace allocations (`CudaKernel::GetScratchBuffer`) also flow through the EP arena via `Info().GetAllocator()`, rather than issuing a fresh `cudaMallocAsync`/`cudaMalloc` per call. After warmup the arena has reached its steady-state working set, so the capture run serves every scratch request from an already-reserved chunk and the device free-memory footprint stays stable across the capture window. This is what makes the `cudaMemGetInfo` allocation-during-capture detector pass for graphs that use scratch buffers, and it matches the bundled CUDA EP (which also obtains scratch from `Info().GetAllocator()`). `GetScratchBuffer` stream-tags scratch chunks with the framework `OrtSyncStream*` exposed by `KernelContext_GetSyncStream`. The raw `cudaStream_t` from `KernelContext_GetGPUComputeStream` is still used for CUDA launches and library calls; the framework stream is used only for the arena's cross-stream reuse bookkeeping. - When `arena.use_cuda_mempool=1` is configured, CUDA device allocations come from `CudaMempoolOrtAllocator`, which wraps `cudaMallocFromPoolAsync`/`cudaFreeAsync`. These async allocation/free operations are CUDA-graph-safe since CUDA 11.4+ and become part of the captured graph topology. - Pinned allocations are also arena-backed, but remain non-stream-aware. - The graph stream created by `CudaEp::PerThreadContext` flows through `CudaSyncStream::InitHandlesWithExternalStream()` so stream-aware arena allocation uses the same `cudaStream_t` during warm-up, capture, and replay. @@ -137,12 +138,12 @@ CUDA graph capture requires that all memory allocations happen during warmup, no ### Concurrent Run Support -Concurrent `Session::Run()` is intentionally **not** advertised by the CUDA plugin EP while migrated kernels route scratch/workspace allocations through the EP arena with a null stream tag. +Concurrent `Session::Run()` is advertised by the CUDA plugin EP when the host ORT runtime exposes `KernelContext_GetSyncStream` and the session is not forced into EP-level unified-stream mode. - `CudaEp::PerThreadContext` still owns graph stream, graph manager, warm-up run counts, and memory watermark state per thread. This keeps graph bookkeeping thread-local and avoids sharing captured graph executables across threads. -- However, plugin kernels currently receive only the raw `cudaStream_t` (`KernelContext::GetGPUComputeStream`), not the framework `OrtSyncStream*` that the stream-aware arena stores in each chunk and later uses for safe cross-stream reuse checks. -- Because `CudaKernel::GetScratchBuffer` cannot safely provide that framework stream, it passes a null stream tag. Null-tagged scratch chunks are freely reusable, which is safe for serialized runs and single-unified-stream graph capture but unsafe for overlapping runs on different CUDA streams. -- Therefore `CudaEp::IsConcurrentRunSupportedImpl()` returns false. Re-enabling concurrent multi-stream runs is future work and requires new C-API surface to expose a stable framework stream (or equivalent sync id) to plugin kernels so scratch chunks can be properly stream-tagged. +- Plugin kernels now obtain the framework `OrtSyncStream*` through `KernelContext_GetSyncStream` and use it only for scratch/workspace allocation bookkeeping. CUDA work still launches on the raw `cudaStream_t` from `KernelContext_GetGPUComputeStream`. +- Stream-tagged scratch chunks let the shared arena apply its normal cross-stream reuse rules for overlapping runs on different CUDA streams. +- When the negotiated ORT API version does not include `KernelContext_GetSyncStream`, `CudaKernel::GetScratchBuffer` falls back to a null stream tag and `CudaEp::IsConcurrentRunSupportedImpl()` returns false. ## Verification @@ -159,4 +160,3 @@ Concurrent `Session::Run()` is intentionally **not** advertised by the CUDA plug ## Future Work 1. **Profiling integration**: CUDA graph replay currently bypasses the CUDA plugin EP profiler path because the CUDA plugin EP does not yet implement `OrtEp::CreateProfiler`. Wiring graph replay into that path is future work. -2. **Stream-tagged scratch allocations**: `CudaKernel::GetScratchBuffer` passes a null stream to the EP arena because plugin kernels cannot currently obtain the framework `OrtSyncStream*` the stream-aware arena needs (they only have the raw `cudaStream_t`). This is correct and safe for serialized runs and within one graph-capture run, but it is why the EP does not advertise concurrent `Session::Run()` support. Supporting concurrent multi-stream runs that share one arena would require new C-API surface to expose the framework `OrtSyncStream*` (or its sync-id) to plugin kernels so scratch chunks can be properly stream-tagged. See [arena_allocator_migration_design.md](arena_allocator_migration_design.md) ("Scratch buffer stream tagging — limitation and future work"). diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md index 438fb8606fc09..8f9a0d388d1af 100644 --- a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -97,10 +97,11 @@ Because the plugin binary may load into an older runtime, every `OrtApi`/`OrtEpA | API surface | Newest `\since` used | Representative functions | | --- | --- | --- | | `OrtApi` — direct calls (`ort_api_.*`, `Ort::GetApi().*`) | **1.23** | `SyncStream_GetHandle`, `GetTensorSizeInBytes`, `GetRunConfigEntry`, `CreateMemoryInfo_V2`, `Graph_GetNumNodes`/`Graph_GetNodes` (older: `CreateStatus`, `Logger_LogMessage`, `*KeyValuePairs`, `HardwareDevice_*`, `MemoryInfoGet*`, `GetSessionConfigEntry`) | +| `OrtApi` — optional gated kernel-context capability | **1.28** | `KernelContext_GetSyncStream` (called from the adapter only when `CurrentOrtApiVersion() >= 28`; otherwise scratch allocation uses a null stream tag and concurrent run support is not advertised) | | `OrtEpApi` — direct calls (`ep_api_.*`, `Ort::GetEpApi().*`) | **1.24** | `CreateKernelRegistry`, `KernelRegistry_AddKernel`, `ReleaseKernelRegistry`, `CreateIfKernel`/`CreateLoopKernel`/`CreateScanKernel`, `EpGraphSupportInfo_LookUpKernel` (older: `MemoryDevice_*`, `MemoryInfo_GetMemoryDevice`, `SyncStream_*`, `EpDevice_AddAllocatorInfo`, `EpGraphSupportInfo_AddSingleNode`, `CreateEpDevice`/`ReleaseEpDevice`) | | EP profiler API (only when built with `ENABLE_CUDA_PROFILING`) | **1.25** | `CreateProfilingEvent`, `ProfilingEventsContainer_AddEvents`, `ReleaseProfilingEvent` (called from `cuda_profiler_plugin.cc` via the `Ort::ProfilingEvent` / `Ort::UnownedProfilingEventsContainer` wrappers) | -`provider_api_shims.cc` uses only internal helpers (`GetEnvironmentVar`, `MLFloat16` conversions), and the plugin uses no Model Editor, Model Package, or Compile API. **Apart from the optional EP profiler, every API the plugin calls is `\since 1.24` or older**, so the true compatibility floor is `1.24.4`. +`provider_api_shims.cc` uses only internal helpers (`GetEnvironmentVar`, `MLFloat16` conversions), and the plugin uses no Model Editor, Model Package, or Compile API. **Apart from optional gated capabilities such as EP profiling and stream-tagged scratch allocation, every API the plugin calls is `\since 1.24` or older**, so the true compatibility floor is `1.24.4`. **Defensive capability gating.** Reading a struct field is safe because the field is append-only and ORT only reads fields it knows about. The real hazard is *calling* an `OrtApi`/`OrtEpApi` function that the (possibly older) runtime does not provide. The correct guard for that is the runtime API version, `onnxruntime::ep::CurrentOrtApiVersion()`, not `ort_version_supported`. The `CudaEp` constructor (`cuda_ep.cc`) therefore reads `const uint32_t ort_version = onnxruntime::ep::CurrentOrtApiVersion();` and only installs an `OrtEp` callback when that runtime version is new enough to provide both the callback field and every API its implementation calls: @@ -113,7 +114,9 @@ Because the plugin binary may load into an older runtime, every `OrtApi`/`OrtEpA All other `OrtEp` and `OrtEpFactory` callbacks are `\since 1.24` or older and are installed unconditionally. Gating `CreateProfiler` is what makes the three `\since 1.25` profiler functions unreachable on an older runtime: when the profiler is never created, ORT never drives the `OrtEpProfilerImpl` callbacks that call them. -The gates use **graceful degradation rather than throwing**: the gated callbacks are all optional capabilities (per-run sync, EP-level GPU profiling, CUDA-graph capture/replay, device-memory budgeting), so disabling them on an older runtime still yields a fully functional EP — inference runs, just without that specific feature. This was validated by loading the plugin (built against the latest headers) into both the latest runtime (full test suite passes) and an `onnxruntime==1.24.4` runtime (the EP registers, enumerates devices, and runs inference correctly with the newer callbacks left null). +`KernelContext_GetSyncStream` is guarded at the adapter call site rather than through an `OrtEp` callback field: `OpKernelContext::GetSyncStream()` returns null when `CurrentOrtApiVersion() < 28`, and `CudaEp::IsConcurrentRunSupportedImpl()` only advertises concurrent runs when that API is available. Older runtimes therefore keep the previous serialized-run behavior while still using the same plugin binary. + +The gates use **graceful degradation rather than throwing**: the gated callbacks and adapter capabilities are optional features (per-run sync, EP-level GPU profiling, CUDA-graph capture/replay, device-memory budgeting, stream-tagged scratch for concurrent runs), so disabling them on an older runtime still yields a fully functional EP — inference runs, just without that specific feature. This was validated by loading the plugin (built against the latest headers) into both the latest runtime (full test suite passes) and an `onnxruntime==1.24.4` runtime (the EP registers, enumerates devices, and runs inference correctly with the newer callbacks left null). --- @@ -463,14 +466,14 @@ The NHWC rollout is effectively in a "runtime enabled, cleanup remaining" state: Migrated kernels need a valid device allocator in two places: scratch/workspace buffers during `Compute()`, and one-time weight conversion or packing during `PrePack()`. Both now resolve the allocator the same way the bundled CUDA EP does, through the kernel's own `OpKernelInfo`. -- **Scratch buffers.** `CudaKernel::GetScratchBuffer` allocates through `Info().GetAllocator(OrtMemTypeDefault)` (the EP arena) with a null stream tag, instead of issuing a raw `cudaMallocAsync`/`cudaMalloc` per call. The adapter `OpKernelInfo::GetAllocator` resolves the EP's default-memory (device) allocator and is always valid for a migrated kernel, so no plugin-only scratch path is needed. Routing through the arena is also what keeps the device free-memory footprint stable during CUDA graph capture (see [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md#arena-allocator-integration)). The null stream tag is intentional: plugin kernels only have the raw `cudaStream_t`, not the framework `OrtSyncStream*` that the stream-aware arena persists in chunks for safe cross-stream reuse. +- **Scratch buffers.** `CudaKernel::GetScratchBuffer` allocates through `Info().GetAllocator(OrtMemTypeDefault)` (the EP arena) and stream-tags scratch chunks with the framework `OrtSyncStream*` from `KernelContext_GetSyncStream`, instead of issuing a raw `cudaMallocAsync`/`cudaMalloc` per call. The adapter `OpKernelInfo::GetAllocator` resolves the EP's default-memory (device) allocator and is always valid for a migrated kernel, so no plugin-only scratch path is needed. Routing through the arena is also what keeps the device free-memory footprint stable during CUDA graph capture (see [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md#arena-allocator-integration)). CUDA launches still use the raw `cudaStream_t` from `KernelContext_GetGPUComputeStream`; the framework stream is used only for stream-aware arena bookkeeping. - **PrePack.** The framework prepack loop (`SessionState::PrepackConstantInitializedTensors`) resolves the allocator with `GetInitializerAllocator(kernel->Info().GetDevice(OrtMemTypeDefault))`, a session map keyed by device. For a plugin EP registered as a separate library, that device-keyed lookup can miss and return null. The loop now falls back to `kernel->Info().GetAllocator(OrtMemTypeDefault)` when the lookup is null, so every `PrePack` implementation receives a valid allocator at the single framework call site. This replaces the earlier approach of adding a per-kernel `if (!alloc) alloc = Info().GetAllocator(...)` guard to each prepacking op (which only covered the few ops that were touched and risked missing future ones). The fallback is behavior-neutral for in-tree EPs, whose device-keyed lookup already succeeds, and it does **not** force `is_packed`/`prepacked_weights` handling \u2014 ops such as `QMoE` and `MatMulNBits` still set `is_packed = true` and populate prepacked weights normally. -The enabling adapter change is in [`include/onnxruntime/ep/adapter/allocator.h`](../../include/onnxruntime/ep/adapter/allocator.h): `IAllocatorWrappingOrtAllocator` now implements `IsStreamAware()`/`AllocOnStream()` (previously `ORT_NOT_IMPLEMENTED`) by forwarding to the underlying `OrtAllocator`'s `AllocOnStream` when it is available (ORT >= 1.23), falling back to plain `Alloc` otherwise. `GetScratchBuffer` does not use that stream-aware path yet because the plugin kernel layer cannot safely provide the framework `OrtSyncStream*`; stream-tagged scratch allocation is future work and is documented in [arena_allocator_migration_design.md](arena_allocator_migration_design.md#scratch-buffer-stream-tagging--limitation-and-future-work). +The enabling adapter changes are in [`include/onnxruntime/ep/adapter/allocator.h`](../../include/onnxruntime/ep/adapter/allocator.h) and [`include/onnxruntime/ep/adapter/op_kernel.h`](../../include/onnxruntime/ep/adapter/op_kernel.h): `IAllocatorWrappingOrtAllocator` implements `IsStreamAware()`/`AllocOnStream()` by forwarding to the underlying `OrtAllocator`'s `AllocOnStream` when it is available (ORT >= 1.23), and `OpKernelContext::GetSyncStream()` exposes the framework stream when the negotiated ORT API version includes `KernelContext_GetSyncStream`. The CUDA plugin uses that framework stream for `GetScratchBuffer`; if it is unavailable, allocation falls back to a null stream tag and concurrent `Session::Run()` is not advertised. ### 5.4 CUDA Graph Support -CUDA Graph capture/replay is fully implemented for the plugin EP, including arena integration (both default BFC arena and CUDA native mempool), multi-graph via annotation IDs with different input shapes, and combining a caller-supplied `user_compute_stream` with capture/replay. Concurrent `Session::Run()` is intentionally not advertised while scratch allocations are null-stream-tagged; supporting concurrent multi-stream runs requires future C-API work to expose a stable framework stream or sync id to plugin kernels. The full design — plugin-side implementation, per-thread isolation, arena integration, capture flow, and user-stream mode — is in [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md). This section documents only the framework-level and C API changes that affect the broader ORT architecture. +CUDA Graph capture/replay is fully implemented for the plugin EP, including arena integration (both default BFC arena and CUDA native mempool), multi-graph via annotation IDs with different input shapes, and combining a caller-supplied `user_compute_stream` with capture/replay. Concurrent `Session::Run()` is supported when the host runtime exposes `KernelContext_GetSyncStream` and the session is not forced into EP-level unified-stream mode. The full design — plugin-side implementation, per-thread isolation, arena integration, capture flow, and user-stream mode — is in [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md). This section documents only the framework-level and C API changes that affect the broader ORT architecture. #### 5.4.1 OrtEp C API Extensions (v1.26) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 5dd53a8cf45c0..a73868527771a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7496,6 +7496,27 @@ struct OrtApi { * \since Version 1.28. */ ORT_API_T(OrtExperimentalFnPtr, GetExperimentalFunction, _In_ const char* name); + + /** \brief Get the framework synchronization stream associated with a kernel context. + * + * This returns the framework stream wrapper for the execution provider stream used by this kernel invocation. + * It is intended for APIs that need a stable framework stream object for stream-aware allocation and + * synchronization bookkeeping. Use KernelContext_GetGPUComputeStream when launching native GPU work. + * + * \param[in] context OrtKernelContext instance. + * \param[out] out Returns the framework synchronization stream, or nullptr if the kernel has no stream. + * Do not free or mutate the returned pointer. It is owned by the underlying session. + * The pointer may be stored and used for stream-aware allocation and synchronization + * bookkeeping beyond the Compute call (e.g. an allocator may persist it in arena + * chunks); it remains valid until the owning Session::Run() completes its teardown. + * Do not retain or dereference it after the run that produced this kernel context ends. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.28. + */ + ORT_API2_STATUS(KernelContext_GetSyncStream, _In_ const OrtKernelContext* context, + _Outptr_result_maybenull_ OrtSyncStream** out); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 4798d3d4ad1b8..55a4e36167e86 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -3020,6 +3020,7 @@ struct KernelContext { UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; UnownedValue GetOutput(size_t index, const std::vector& dims) const; void* GetGPUComputeStream() const; + OrtSyncStream* GetSyncStream() const; Logger GetLogger() const; Ort::Allocator GetAllocator(const OrtMemoryInfo& memory_info) const; OrtKernelContext* GetOrtKernelContext() const { return ctx_; } diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index d7439e7b356c6..ed3abc0961be6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2876,6 +2876,12 @@ inline void* KernelContext::GetGPUComputeStream() const { return out; } +inline OrtSyncStream* KernelContext::GetSyncStream() const { + OrtSyncStream* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetSyncStream(ctx_, &out)); + return out; +} + inline Ort::Allocator KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const { OrtAllocator* out = nullptr; Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out)); diff --git a/include/onnxruntime/ep/adapter/op_kernel.h b/include/onnxruntime/ep/adapter/op_kernel.h index 27a46cc10e306..1f103b64a443e 100644 --- a/include/onnxruntime/ep/adapter/op_kernel.h +++ b/include/onnxruntime/ep/adapter/op_kernel.h @@ -164,6 +164,14 @@ struct OpKernelContext { void* GetGPUComputeStream() const { return context_.GetGPUComputeStream(); } + OrtSyncStream* GetSyncStream() const { + static constexpr uint32_t kOrtKernelContextGetSyncStreamMinVersion = 28; + if (CurrentOrtApiVersion() < kOrtKernelContextGetSyncStreamMinVersion) { + return nullptr; + } + + return context_.GetSyncStream(); + } private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernelContext); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index 13f7bfd7a40cf..73fa92d19cd1f 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -464,14 +464,12 @@ OrtStatus* ORT_API_CALL CudaEp::IsConcurrentRunSupportedImpl( return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "is_supported must not be null."); } - // Plugin kernels currently expose only the raw cudaStream_t to GetScratchBuffer(), not the - // framework OrtSyncStream* that the stream-aware arena needs to tag scratch chunks by stream. - // Scratch chunks are therefore allocated with a null stream tag and can be reused freely. That is - // safe when runs are serialized, but it is not safe to advertise concurrent Session::Run(): two - // runs on different CUDA streams could reuse the same scratch chunk while earlier work is still - // in flight. Re-enable concurrent runs only after the plugin kernel layer can pass a stable - // framework stream (or equivalent sync id) to the arena. - *is_supported = false; + auto* ep = static_cast(this_ptr); + // Concurrent runs require stream-tagged scratch allocations. The plugin kernel adapter can tag + // scratch chunks only when the hosting ORT runtime exposes KernelContext_GetSyncStream. + static constexpr uint32_t kOrtKernelContextGetSyncStreamMinVersion = 28; + *is_supported = !ep->config_.use_ep_level_unified_stream && + ::onnxruntime::ep::CurrentOrtApiVersion() >= kOrtKernelContextGetSyncStreamMinVersion; return nullptr; } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index f134c599d5b46..06fe635e35716 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -16,6 +16,7 @@ #pragma once #include +#include #include "core/common/status.h" #include "core/common/narrow.h" @@ -53,6 +54,81 @@ namespace onnxruntime { struct CudaStream; +namespace cuda_plugin { +namespace detail { +inline thread_local std::unordered_map stream_to_framework_stream; +inline thread_local void* current_cuda_stream = nullptr; +inline thread_local onnxruntime::Stream* current_framework_stream = nullptr; + +inline void RegisterFrameworkStreamForCudaStream(void* cuda_stream, OrtSyncStream* framework_stream) { + current_cuda_stream = cuda_stream; + current_framework_stream = reinterpret_cast(framework_stream); + + if (current_framework_stream == nullptr) { + return; + } + + // Map only from the raw cudaStream_t handle to the current framework stream. The framework + // stream is already handled directly by GetFrameworkStreamForStreamArg, so we deliberately do + // not insert a framework_stream -> framework_stream entry: it would be unused and would grow the + // thread-local map without bound while retaining framework stream pointers past the + // Session::Run() teardown lifetime documented for KernelContext_GetSyncStream. + if (cuda_stream != nullptr) { + stream_to_framework_stream[cuda_stream] = current_framework_stream; + } +} + +inline onnxruntime::Stream* GetFrameworkStreamForStreamArg(void* stream) { + // A null stream argument means "the compute stream of the current Compute call". This is the + // form used by GetTransientScratchBuffer and legacy GetScratchBuffer(..., nullptr). Map it to + // the framework stream registered for this call so scratch chunks are still stream-tagged even + // when the kernel runs on a non-default CUDA stream (where current_cuda_stream is non-null and a + // nullptr arg would otherwise miss the map lookup and fall back to a null stream tag). + // + // current_framework_stream is scoped to a single CudaKernel::Compute invocation by + // ComputeStreamScope (see below). Outside any Compute call it is nullptr, so allocations made + // from kernel constructors (which also call GetScratchBuffer(..., nullptr)) fall back to the + // non-stream-tagged path instead of inheriting a stale framework stream pointer whose lifetime + // ended with a previous Session::Run(). + if (stream == nullptr || stream == current_cuda_stream || stream == current_framework_stream) { + return current_framework_stream; + } + + auto it = stream_to_framework_stream.find(stream); + return it == stream_to_framework_stream.end() ? nullptr : it->second; +} + +// RAII guard that scopes the thread-local "current Compute call" framework stream to the lifetime +// of a single CudaKernel::Compute invocation on a worker thread. +// +// On entry it clears current_cuda_stream/current_framework_stream so that scratch allocated before +// the kernel registers its stream (via Stream(ctx)/GetComputeStream(ctx)/GetOrtStream(ctx)), or via +// a nullptr stream argument, does not inherit a stale framework stream left over from a previous +// Compute call on this worker thread. On exit it restores the previous values, which keeps nested +// Compute calls (a kernel that invokes another kernel's Compute) correct and leaves the per-thread +// "current" stream cleared once the outermost Compute returns. The borrowed framework stream is +// only valid until its owning Session::Run() completes teardown, so it must not outlive the call. +struct ComputeStreamScope { + ComputeStreamScope() + : saved_cuda_stream_(current_cuda_stream), + saved_framework_stream_(current_framework_stream) { + current_cuda_stream = nullptr; + current_framework_stream = nullptr; + } + ~ComputeStreamScope() { + current_cuda_stream = saved_cuda_stream_; + current_framework_stream = saved_framework_stream_; + } + + private: + void* saved_cuda_stream_; + onnxruntime::Stream* saved_framework_stream_; + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ComputeStreamScope); +}; +} // namespace detail +} // namespace cuda_plugin + // Lightweight Stream shim for plugin build: wraps a raw cudaStream_t as a // framework-compatible Stream* that can be passed to _impl.cu functions which // call stream->GetHandle(). Stack-allocated; does NOT own the stream. @@ -70,6 +146,11 @@ class OrtStreamAdapter { explicit OrtStreamAdapter(void* cuda_stream_handle) : plugin_stream_shim_(cuda_stream_handle), stream_(&plugin_stream_shim_) {} + OrtStreamAdapter(void* cuda_stream_handle, OrtSyncStream* framework_stream) + : plugin_stream_shim_(cuda_stream_handle), + stream_(framework_stream == nullptr ? static_cast(&plugin_stream_shim_) + : reinterpret_cast(framework_stream)) {} + onnxruntime::Stream* get() const { return stream_; } operator onnxruntime::Stream*() const { return stream_; } @@ -83,6 +164,10 @@ class OrtStreamAdapter { explicit OrtStreamAdapter(void* cuda_stream_handle) : stream_(static_cast(cuda_stream_handle)) {} + OrtStreamAdapter(void* cuda_stream_handle, OrtSyncStream* framework_stream) + : stream_(framework_stream == nullptr ? static_cast(cuda_stream_handle) + : reinterpret_cast(framework_stream)) {} + onnxruntime::Stream* get() const { return stream_; } operator onnxruntime::Stream*() const { return stream_; } @@ -868,6 +953,11 @@ class CudaKernel : public OpKernel { } virtual ~CudaKernel() = default; Status Compute(OpKernelContext* ctx) const { + // Scope the thread-local "current Compute call" framework stream to this invocation so that + // scratch tagged via a nullptr stream argument never inherits a stale framework stream from a + // previous Compute call (or leaks one to a later kernel constructor) on this worker thread. + cuda_plugin::detail::ComputeStreamScope compute_stream_scope; + // Ensure the correct CUDA device is active for this kernel. // Worker threads default to device 0; sessions on device > 0 need an // explicit cudaSetDevice. Skip during CUDA graph capture because @@ -903,17 +993,27 @@ class CudaKernel : public OpKernel { cudaStream_t Stream(OpKernelContext* ctx) const { if (!ctx) return nullptr; - return static_cast(ctx->GetGPUComputeStream()); + // Register the framework sync stream for this Compute call so that scratch allocated via + // GetTransientScratchBuffer()/GetScratchBuffer(..., nullptr) is still stream-tagged for kernels + // that call Stream(ctx) before GetComputeStream()/GetOrtStream() (e.g. conv algo search). + void* cuda_stream = ctx->GetGPUComputeStream(); + cuda_plugin::detail::RegisterFrameworkStreamForCudaStream(cuda_stream, ctx->GetSyncStream()); + return static_cast(cuda_stream); } // Returns an opaque stream pointer for passing to GetScratchBuffer/AddDeferredReleaseCPUPtr/CopyToGpu. // Returns void* for dual-build compatibility: framework wraps Stream*, plugin wraps cudaStream_t. inline void* GetComputeStream(OpKernelContext* ctx) const { - return ctx->GetGPUComputeStream(); + void* cuda_stream = ctx->GetGPUComputeStream(); + cuda_plugin::detail::RegisterFrameworkStreamForCudaStream(cuda_stream, ctx->GetSyncStream()); + return cuda_stream; } inline onnxruntime::OrtStreamAdapter GetOrtStream(OpKernelContext* ctx) const { - return onnxruntime::OrtStreamAdapter(GetComputeStream(ctx)); + void* cuda_stream = ctx->GetGPUComputeStream(); + OrtSyncStream* framework_stream = ctx->GetSyncStream(); + cuda_plugin::detail::RegisterFrameworkStreamForCudaStream(cuda_stream, framework_stream); + return onnxruntime::OrtStreamAdapter(cuda_stream, framework_stream); } static cudnnHandle_t GetCudnnHandle(cudaStream_t s) { @@ -1023,7 +1123,7 @@ class CudaKernel : public OpKernel { template using IAllocatorUniquePtr = std::unique_ptr>; template - inline IAllocatorUniquePtr GetScratchBuffer(size_t cnt, void* /*stream*/) const { + inline IAllocatorUniquePtr GetScratchBuffer(size_t cnt, void* stream) const { if (cnt == 0) return IAllocatorUniquePtr(nullptr, [](T*) {}); // Route kernel scratch/workspace allocations through the EP allocator @@ -1047,36 +1147,18 @@ class CudaKernel : public OpKernel { // enforced inside MakeUniquePtr via ValidatedCalcMemSizeForArray (it throws // on cnt * sizeof(T) overflow). // - // The compute stream is intentionally NOT forwarded to the allocator here. This is a - // bookkeeping decision, NOT a synchronization bug: the `stream` argument to a stream-aware - // arena is only metadata used to decide when a freed chunk may be reused on a *different* - // stream without an intervening sync. It does not change where the kernel runs - the returned - // buffer is still consumed by the kernel on the real compute stream. In a serialized run (and - // within one graph-capture run), alloc/free/reuse ordering is implicit on that stream, so there - // is no cross-stream chunk to race on. Tagging chunks with a null stream (freely reusable, the - // same semantics as a plain non-stream-aware BFC arena) is therefore correct and safe as long - // as the EP does not advertise concurrent Session::Run() support. - // - // It is also currently the only safe option, because of a C-API type constraint: a plugin - // kernel only has the raw cudaStream_t (KernelContext::GetGPUComputeStream), not the framework - // OrtSyncStream* that the stream-aware arena persists in each chunk (CudaArena stores - // `chunk->stream` and later dereferences it through the EP stream API, e.g. - // SyncStream_GetImpl/SyncStream_GetSyncId). Note that OrtSyncStream (the ORT-core wrapper, - // `struct OrtSyncStream : public onnxruntime::Stream`) is a DIFFERENT object from the plugin's - // CudaSyncStream (an OrtSyncStreamImpl); CudaSyncStream::FromCudaStream() recovers the latter, - // not the former. Wrapping the raw handle in a temporary framework Stream shim and passing it - // down would be unsafe on two counts: (1) the shim is stack-allocated and would dangle after - // this function returns while the arena still holds the pointer, and (2) it is type-confused — - // the arena would reinterpret a framework Stream* as an OrtSyncStream* that was never created - // by ORT for this stream. - // - // Properly stream-tagging scratch chunks (needed before this path can support concurrent - // multi-stream runs) requires new C-API surface to expose the framework OrtSyncStream* to - // plugin kernels. See docs/cuda_plugin_ep/arena_allocator_migration_design.md ("Scratch buffer - // stream tagging") for the limitation and future work. + // The `stream` argument is the raw cudaStream_t used by migrated CUDA kernels, or a Stream* + // from OrtStreamAdapter in code paths that need stream->GetHandle(). Stream-aware arena + // allocation needs the stable framework Stream* wrapper instead, because the arena stores it + // in each chunk and later queries sync ids through the EP stream API. Stream(ctx), + // GetComputeStream(ctx) and GetOrtStream(ctx) record the mapping from both argument forms to + // the framework stream for the current Compute call. + // If the negotiated ORT API version does not include KernelContext_GetSyncStream, the lookup + // returns null and allocation falls back to the non-stream-tagged path. + auto* framework_stream = cuda_plugin::detail::GetFrameworkStreamForStreamArg(stream); return ::onnxruntime::IAllocator::MakeUniquePtr( Info().GetAllocator(OrtMemType::OrtMemTypeDefault), cnt, /*use_reserve*/ false, - /*stream*/ nullptr); + framework_stream); } template inline IAllocatorUniquePtr GetTransientScratchBuffer(size_t cnt) const { diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 2c8b81e4ffefe..89969172c1bdc 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -206,6 +206,15 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetGPUComputeStream, _In_ const OrtKe }); }; +ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetSyncStream, _In_ const OrtKernelContext* context, + _Outptr_result_maybenull_ OrtSyncStream** out) { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { + auto* stream = reinterpret_cast(context)->GetComputeStream(); + *out = reinterpret_cast(stream); + return nullptr; + }); +}; + ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetAllocator, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out) { return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index a663d209cfa53..22df898ca3227 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4916,6 +4916,8 @@ static constexpr OrtApi ort_api_1_to_28 = { // End of Version 27 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::GetExperimentalFunction, + + &OrtApis::KernelContext_GetSyncStream, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 61ece2dd9a682..e747d0d0ab2d8 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -196,6 +196,7 @@ ORT_API_STATUS_IMPL(KernelContext_GetInputCount, _In_ const OrtKernelContext* co ORT_API_STATUS_IMPL(KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); ORT_API_STATUS_IMPL(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out); ORT_API_STATUS_IMPL(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Out_ OrtValue** out); +ORT_API_STATUS_IMPL(KernelContext_GetSyncStream, _In_ const OrtKernelContext* context, _Outptr_result_maybenull_ OrtSyncStream** out); // OrtTypeInfo methods ORT_API_STATUS_IMPL(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo*, _Out_ const char** const denotation, _Out_ size_t* len); diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc index d49faf3c90ea8..b75de767bb7f6 100644 --- a/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc +++ b/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc @@ -88,6 +88,12 @@ Ort::ConstEpDevice FindCudaPluginDevice(Ort::Env& env) { return Ort::ConstEpDevice{nullptr}; } +// Dummy external allocator callbacks. They are only used to make the external-allocator +// configuration non-null; the plugin EP rejects the combination with user_compute_stream +// before either is ever invoked. +void* DummyExternalAlloc(size_t /*size*/) { return nullptr; } +void DummyExternalFree(void* /*ptr*/) {} + } // namespace class CudaPluginUserStreamGraphTest : public ::testing::Test { @@ -129,6 +135,70 @@ class CudaPluginUserStreamGraphTest : public ::testing::Test { return so; } + // Allocate device input/output, bind them, and run `iterations` times on `stream`, verifying + // Y = X * W each run. The input is uploaded once up front and then left constant: when CUDA graph + // capture is enabled, issuing host->device work on the stream immediately before the capture run + // would interfere with cudaStreamBeginCapture, so the buffers are populated and synchronized + // before any capture happens. When `graph_ids` is non-empty, run i sets gpu_graph_id to + // graph_ids[i % size] to exercise CUDA graph annotation-id switching. mul_1.onnx computes + // Y = X * W with W = [1..6] (shape 3x2). + void RunAndVerifyOnStream(Ort::Session& session, cudaStream_t stream, int iterations, + const std::vector& graph_ids = {}) { + auto device_memory_info = cuda_device_.GetMemoryInfo(OrtDeviceMemoryType_DEFAULT); + auto allocator = ort_env->GetSharedAllocator(device_memory_info); + ASSERT_NE(allocator, nullptr); + + constexpr size_t kNumElements = 6; + constexpr size_t kBytes = kNumElements * sizeof(float); + const std::array shape = {3, 2}; + const std::array w_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + const std::array x_values = {2.0f, 3.0f, 5.0f, 7.0f, 11.0f, 13.0f}; + + // Fixed device buffers so captured CUDA graphs keep valid IO addresses across replays. + void* input_gpu = allocator.Alloc(kBytes); + void* output_gpu = allocator.Alloc(kBytes); + ASSERT_NE(input_gpu, nullptr); + ASSERT_NE(output_gpu, nullptr); + + // Populate the input once and synchronize, so no host-issued work is pending on `stream` + // when graph capture begins on a later run. + ASSERT_EQ(cudaSuccess, + cudaMemcpyAsync(input_gpu, x_values.data(), kBytes, cudaMemcpyHostToDevice, stream)); + ASSERT_EQ(cudaSuccess, cudaStreamSynchronize(stream)); + + Ort::Value input_tensor = Ort::Value::CreateTensor( + device_memory_info, reinterpret_cast(input_gpu), kNumElements, + shape.data(), shape.size()); + Ort::Value output_tensor = Ort::Value::CreateTensor( + device_memory_info, reinterpret_cast(output_gpu), kNumElements, + shape.data(), shape.size()); + + Ort::IoBinding binding(session); + binding.BindInput("X", input_tensor); + binding.BindOutput("Y", output_tensor); + + for (int i = 0; i < iterations; ++i) { + Ort::RunOptions run_options; + if (!graph_ids.empty()) { + run_options.AddConfigEntry("gpu_graph_id", graph_ids[i % graph_ids.size()].c_str()); + } + session.Run(run_options, binding); + + // Kernels run on `stream`; wait for them before copying the result back. + ASSERT_EQ(cudaSuccess, cudaStreamSynchronize(stream)); + std::array y{}; + ASSERT_EQ(cudaSuccess, cudaMemcpy(y.data(), output_gpu, kBytes, cudaMemcpyDeviceToHost)); + for (size_t j = 0; j < kNumElements; ++j) { + EXPECT_FLOAT_EQ(y[j], x_values[j] * w_values[j]) << "mismatch at iteration " << i << " index " << j; + } + } + + binding.ClearBoundInputs(); + binding.ClearBoundOutputs(); + allocator.Free(input_gpu); + allocator.Free(output_gpu); + } + std::unique_ptr registration_; Ort::ConstEpDevice cuda_device_{nullptr}; }; @@ -234,6 +304,69 @@ TEST_F(CudaPluginUserStreamGraphTest, CaptureAndReplayOnUserStream) { ASSERT_EQ(cudaSuccess, cudaStreamDestroy(user_stream)); } +// Negative: a user_compute_stream combined with an external GPU allocator +// (gpu_external_alloc/gpu_external_free) is not supported and must be rejected at session +// creation with an error rather than silently ignored. +TEST_F(CudaPluginUserStreamGraphTest, RejectsUserStreamWithExternalAllocator) { + cudaStream_t user_stream = nullptr; + ASSERT_EQ(cudaSuccess, cudaStreamCreate(&user_stream)); + + Ort::SessionOptions so; + std::unordered_map provider_options = { + {"user_compute_stream", std::to_string(reinterpret_cast(user_stream))}, + {"gpu_external_alloc", std::to_string(reinterpret_cast(&DummyExternalAlloc))}, + {"gpu_external_free", std::to_string(reinterpret_cast(&DummyExternalFree))}, + }; + so.AppendExecutionProvider_V2(*ort_env, {cuda_device_}, provider_options); + + EXPECT_THROW( + { + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), so); + (void)session; + }, + Ort::Exception); + + ASSERT_EQ(cudaSuccess, cudaStreamDestroy(user_stream)); +} + +// Edge case: cudaStream_t(0) (the CUDA default stream) is a valid user-provided stream. Because +// user_compute_stream parses to nullptr, the caller must set has_user_compute_stream explicitly, +// otherwise the stream would be treated as "not provided". Session creation must succeed and +// inference must run correctly on the default stream. +// +// Note: CUDA graph capture is intentionally NOT enabled here. The legacy default stream (stream 0) +// cannot be captured (cudaStreamBeginCapture returns cudaErrorStreamCaptureUnsupported), so this +// test exercises only that stream 0 is honored as the compute stream for non-graph execution. +TEST_F(CudaPluginUserStreamGraphTest, DefaultStreamAsUserStream) { + Ort::SessionOptions so; + std::unordered_map provider_options = { + {"has_user_compute_stream", "1"}, + {"user_compute_stream", "0"}, + }; + so.AppendExecutionProvider_V2(*ort_env, {cuda_device_}, provider_options); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), so); + + // Run several iterations on the default stream (stream 0) and verify correctness. + RunAndVerifyOnStream(session, /*stream=*/nullptr, /*iterations=*/4); +} + +// Switching the CUDA graph annotation id (gpu_graph_id) between runs while using a user stream +// must capture/replay a distinct graph per id without crashing and keep producing correct results. +TEST_F(CudaPluginUserStreamGraphTest, GraphAnnotationIdSwitchingWithUserStream) { + cudaStream_t user_stream = nullptr; + ASSERT_EQ(cudaSuccess, cudaStreamCreate(&user_stream)); + + Ort::SessionOptions so = CreateUserStreamGraphSessionOptions(user_stream); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), so); + + // Alternate between annotation ids "1" and "2". With min_num_runs_before_cuda_graph_capture == 2, + // 8 iterations let each id accumulate warmup runs, capture, and then replay on the user stream. + RunAndVerifyOnStream(session, user_stream, /*iterations=*/8, /*graph_ids=*/{"1", "2"}); + + ASSERT_EQ(cudaSuccess, cudaStreamDestroy(user_stream)); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/shared_lib/custom_op_utils.cc b/onnxruntime/test/shared_lib/custom_op_utils.cc index 53745cae9d803..a58f82deab4ff 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.cc +++ b/onnxruntime/test/shared_lib/custom_op_utils.cc @@ -54,8 +54,11 @@ void MyCustomKernel::Compute(OrtKernelContext* context) { EXPECT_NE(allocated, nullptr) << "KernelContext_GetAllocator() can successfully allocate some memory"; allocator.Free(allocated); + OrtSyncStream* sync_stream = ctx.GetSyncStream(); + // Do computation #ifdef USE_CUDA + EXPECT_NE(sync_stream, nullptr) << "KernelContext_GetSyncStream() returns the kernel compute stream"; // Launch on stream 0 or user provided stream void* stream; Ort::ThrowOnError(ort_.KernelContext_GetGPUComputeStream(context, &stream)); @@ -70,6 +73,7 @@ void MyCustomKernel::Compute(OrtKernelContext* context) { // and use the same compute stream to launch the custom op. // Here, an example for (1) is shown (See test_inference.cc to see how this custom op is used.) #else + EXPECT_EQ(sync_stream, nullptr) << "CPU custom ops do not have a compute stream"; ORT_UNUSED_PARAMETER(ort_); for (int64_t i = 0; i < size; i++) { out[i] = X[i] + Y[i];