[CUDA Plugin EP] Expose kernel sync stream for scratch allocation#29244
[CUDA Plugin EP] Expose kernel sync stream for scratch allocation#29244tianleiwu wants to merge 4 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
This PR extends the kernel-context C API to expose the framework OrtSyncStream* for the current kernel invocation, and updates the CUDA plugin EP to use that stream for stream-aware scratch allocation bookkeeping so it can safely advertise concurrent Session::Run() when supported by the host runtime.
Changes:
- Adds
OrtApi::KernelContext_GetSyncStream(plus C++ and adapter wrappers) to retrieve the framework stream wrapper associated with a kernel context. - Updates the CUDA plugin kernel adapter to associate scratch/workspace allocations with the framework stream (instead of a null stream tag).
- Re-enables CUDA plugin EP concurrent-run support when the host runtime supports the new API and unified-stream mode is not forced.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
include/onnxruntime/core/session/onnxruntime_c_api.h |
Adds the public C API entry and documentation for KernelContext_GetSyncStream. |
onnxruntime/core/session/ort_apis.h |
Declares the new OrtApis implementation entry point. |
onnxruntime/core/session/onnxruntime_c_api.cc |
Wires the new function pointer into the OrtApi table. |
onnxruntime/core/session/custom_ops.cc |
Implements KernelContext_GetSyncStream by returning the kernel’s framework compute stream wrapper. |
include/onnxruntime/core/session/onnxruntime_cxx_api.h |
Adds Ort::KernelContext::GetSyncStream() declaration. |
include/onnxruntime/core/session/onnxruntime_cxx_inline.h |
Implements the C++ wrapper calling into the C API. |
include/onnxruntime/ep/adapter/op_kernel.h |
Adds version-gated adapter access to GetSyncStream() for plugin kernels. |
onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h |
Tracks/uses the framework stream wrapper for stream-aware scratch allocation tagging. |
onnxruntime/core/providers/cuda/plugin/cuda_ep.cc |
Gates IsConcurrentRunSupported on API availability and unified-stream configuration. |
onnxruntime/test/shared_lib/custom_op_utils.cc |
Extends shared-lib custom-op tests to exercise GetSyncStream(). |
docs/cuda_plugin_ep/cuda_plugin_ep_design.md |
Updates plugin design docs for the new gated stream-tagged scratch capability. |
docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md |
Updates CUDA graph docs to reflect stream-tagged scratch allocation and concurrent-run conditions. |
docs/cuda_plugin_ep/arena_allocator_migration_design.md |
Updates allocator migration design docs to reflect stream-tagged scratch allocation and compatibility behavior. |
e3d1cc9 to
39c5fe2
Compare
| inline void* GetComputeStream(OpKernelContext* ctx) const { | ||
| return ctx->GetGPUComputeStream(); | ||
| void* cuda_stream = ctx->GetGPUComputeStream(); | ||
| cuda_plugin::detail::RegisterFrameworkStreamForCudaStream(cuda_stream, ctx->GetSyncStream()); | ||
| return cuda_stream; |
There was a problem hiding this comment.
Good catch — fixed. Stream(ctx) now calls RegisterFrameworkStreamForCudaStream(cuda_stream, ctx->GetSyncStream()) for the current Compute call, so kernels that call Stream(ctx) and then GetTransientScratchBuffer()/GetScratchBuffer(..., nullptr) before ever calling GetComputeStream()/GetOrtStream() (e.g. conv algo search) get correctly stream-tagged scratch instead of a stale/null framework stream.
| inline void RegisterFrameworkStreamForCudaStream(void* cuda_stream, OrtSyncStream* framework_stream) { | ||
| current_cuda_stream = cuda_stream; | ||
| current_framework_stream = reinterpret_cast<onnxruntime::Stream*>(framework_stream); | ||
|
|
||
| if (current_framework_stream == nullptr) { | ||
| return; | ||
| } | ||
|
|
||
| stream_to_framework_stream[current_framework_stream] = current_framework_stream; | ||
|
|
||
| if (cuda_stream != nullptr) { | ||
| stream_to_framework_stream[cuda_stream] = current_framework_stream; | ||
| } | ||
| } |
There was a problem hiding this comment.
Fixed. Removed the stream_to_framework_stream[current_framework_stream] = current_framework_stream self-entry. GetFrameworkStreamForStreamArg already handles stream == current_framework_stream directly, so the entry was unused and only risked unbounded thread-local map growth and retaining framework stream pointers past the Session::Run() teardown lifetime. The map now keys only off raw cudaStream_t handles.
Description
This PR adds a kernel-context C API accessor for the framework
OrtSyncStream*and uses it in the CUDA plugin EP so scratch allocations can be tagged with the actual compute stream selected for the kernel. It is stacked on #29221 and turns the previously documented concurrent multi-stream limitation into a gated capability: older runtimes keep the conservative fallback, while runtimes with the new API can safely advertise concurrent runs when EP-level unified stream mode is not forced.Summary of Changes
Public API and Adapters
include/onnxruntime/core/session/onnxruntime_c_api.hKernelContext_GetSyncStreamto expose the borrowed framework stream for stream-aware allocation and synchronization bookkeeping.onnxruntime/core/session/custom_ops.ccOpKernelContext::GetComputeStream()inside ORT core.onnxruntime/core/session/ort_apis.handonnxruntime/core/session/onnxruntime_c_api.ccinclude/onnxruntime/core/session/onnxruntime_cxx_api.handinclude/onnxruntime/core/session/onnxruntime_cxx_inline.hOrt::KernelContext::GetSyncStream()wrapper.include/onnxruntime/ep/adapter/op_kernel.hCUDA Plugin EP
OrtStreamAdapterstream arguments.KernelContext_GetSyncStreamis available and EP-level unified stream mode is not forced.Tests and Docs
Ort::KernelContext::GetSyncStream().Why a C API is needed
The implementation of
KernelContext_GetSyncStreamis intentionally small, but the API boundary is the important part. ORT core can safely castOrtKernelContext*back toonnxruntime::OpKernelContext*because it owns both the opaque C handle and the private C++ implementation. A plugin kernel should not perform that cast directly: it would make the plugin depend on ORT-core private C++ layout, vtables, and exact build compatibility.The new API keeps that private cast inside ORT core and gives plugin kernels a stable ABI entry point:
This also lets the plugin use runtime version gating. When loaded by an older ORT runtime that does not expose the API, the adapter returns null, scratch allocation uses the conservative fallback, and concurrent runs are not advertised.
Testing
lintrunner -aninja -C build/cu130_plugin/Debug onnxruntime_providers_cuda_pluginninja -C build/cu130_plugin/Debug onnxruntime_shared_lib_testcd build/cu130_plugin/Debug && ./onnxruntime_shared_lib_test --gtest_filter=CApiTest.custom_op_handler --gtest_color=noChecklist