From de6e6fe109867b2d5f096a3594e2f9188ce21b38 Mon Sep 17 00:00:00 2001 From: Will Killian Date: Thu, 25 Jun 2026 13:13:08 -0400 Subject: [PATCH] feat(plugin): add native Rust plugin SDK Signed-off-by: Will Killian --- .github/ci-path-filters.yml | 12 + .github/workflows/ci.yaml | 2 + .gitlab-ci.yml | 2 +- Cargo.lock | 9 + Cargo.toml | 2 + RELEASING.md | 30 +- crates/cli/src/main.rs | 52 + crates/cli/tests/coverage/main_tests.rs | 18 +- crates/cli/tests/coverage/setup_tests.rs | 32 +- crates/plugin/Cargo.toml | 18 + crates/plugin/src/lib.rs | 2755 ++++++++++ crates/plugin/tests/typed_callbacks.rs | 4516 +++++++++++++++++ examples/rust-native-plugin/.gitignore | 5 + examples/rust-native-plugin/Cargo.toml | 19 + examples/rust-native-plugin/README.md | 78 + examples/rust-native-plugin/relay-plugin.toml | 24 + examples/rust-native-plugin/src/lib.rs | 351 ++ justfile | 74 + 18 files changed, 7950 insertions(+), 49 deletions(-) create mode 100644 crates/plugin/Cargo.toml create mode 100644 crates/plugin/src/lib.rs create mode 100644 crates/plugin/tests/typed_callbacks.rs create mode 100644 examples/rust-native-plugin/.gitignore create mode 100644 examples/rust-native-plugin/Cargo.toml create mode 100644 examples/rust-native-plugin/README.md create mode 100644 examples/rust-native-plugin/relay-plugin.toml create mode 100644 examples/rust-native-plugin/src/lib.rs diff --git a/.github/ci-path-filters.yml b/.github/ci-path-filters.yml index 3b79c81d..5b3e7d5a 100644 --- a/.github/ci-path-filters.yml +++ b/.github/ci-path-filters.yml @@ -12,6 +12,8 @@ shared: - 'crates/adaptive/src/**' - 'crates/core/Cargo.toml' - 'crates/core/src/**' + - 'crates/types/Cargo.toml' + - 'crates/types/src/**' - 'justfile' - 'rust-toolchain.toml' @@ -24,6 +26,10 @@ rust_package: - 'crates/cli/src/**' - 'crates/core/Cargo.toml' - 'crates/core/src/**' + - 'crates/plugin/Cargo.toml' + - 'crates/plugin/src/**' + - 'crates/types/Cargo.toml' + - 'crates/types/src/**' - 'crates/ffi/Cargo.toml' - 'crates/ffi/build.rs' - 'crates/ffi/cbindgen.toml' @@ -39,6 +45,8 @@ node_package: - 'crates/adaptive/src/**' - 'crates/core/Cargo.toml' - 'crates/core/src/**' + - 'crates/types/Cargo.toml' + - 'crates/types/src/**' - 'crates/node/Cargo.toml' - 'crates/node/build.rs' - 'crates/node/package.json' @@ -59,6 +67,8 @@ python_package: - 'crates/adaptive/src/**' - 'crates/core/Cargo.toml' - 'crates/core/src/**' + - 'crates/types/Cargo.toml' + - 'crates/types/src/**' - 'crates/python/Cargo.toml' - 'crates/python/src/**' - 'justfile' @@ -88,6 +98,8 @@ wasm_package: - 'crates/adaptive/src/**' - 'crates/core/Cargo.toml' - 'crates/core/src/**' + - 'crates/types/Cargo.toml' + - 'crates/types/src/**' - 'crates/wasm/Cargo.toml' - 'crates/wasm/scripts/**' - 'crates/wasm/src/**' diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 17dc3cbc..308f4589 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -374,8 +374,10 @@ jobs: set -euo pipefail version="${{ github.ref_name }}" packages=( + nemo-relay-types nemo-relay nemo-relay-adaptive + nemo-relay-plugin nemo-relay-pii-redaction nemo-relay-ffi nemo-relay-cli diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index d0251b5e..6abdcffc 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -262,7 +262,7 @@ publish:artifactory:cargo: artifactory = { index = "sparse+${NEMO_RELAY_CI_ARTIFACTORY_CARGO_URL}" } EOF export CARGO_REGISTRIES_ARTIFACTORY_TOKEN="Bearer ${NEMO_RELAY_CI_ARTIFACTORY_KEY}" - export NEMO_RELAY_ARTIFACTORY_CRATE_DIRS="core adaptive pii-redaction ffi cli" + export NEMO_RELAY_ARTIFACTORY_CRATE_DIRS="types core adaptive plugin pii-redaction ffi cli" crates="$( uv run --no-project python - <<'PY' diff --git a/Cargo.lock b/Cargo.lock index 92151cb9..9c38de95 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1470,6 +1470,15 @@ dependencies = [ "tokio", ] +[[package]] +name = "nemo-relay-plugin" +version = "0.5.0" +dependencies = [ + "nemo-relay-types", + "serde", + "serde_json", +] + [[package]] name = "nemo-relay-python" version = "0.5.0" diff --git a/Cargo.toml b/Cargo.toml index f4ecf923..355411fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "crates/core", "crates/types", + "crates/plugin", "crates/adaptive", "crates/pii-redaction", "crates/cli", @@ -26,6 +27,7 @@ repository = "https://github.com/NVIDIA/NeMo-Relay" [workspace.dependencies] nemo-relay = { version = "0.5.0", path = "crates/core", default-features = false } nemo-relay-types = { version = "0.5.0", path = "crates/types" } +nemo-relay-plugin = { version = "0.5.0", path = "crates/plugin" } nemo-relay-adaptive = { version = "0.5.0", path = "crates/adaptive" } nemo-relay-pii-redaction = { version = "0.5.0", path = "crates/pii-redaction" } nemo-relay-ffi = { version = "0.5.0", path = "crates/ffi" } diff --git a/RELEASING.md b/RELEASING.md index 0b829dc2..e711940a 100644 --- a/RELEASING.md +++ b/RELEASING.md @@ -30,7 +30,7 @@ The release pipeline publishes these package surfaces from a tag push: | Ecosystem | Published Surface | |---|---| -| crates.io | `nemo-relay`, `nemo-relay-adaptive`, `nemo-relay-pii-redaction`, `nemo-relay-ffi`, `nemo-relay-cli` | +| crates.io | `nemo-relay-types`, `nemo-relay`, `nemo-relay-adaptive`, `nemo-relay-plugin`, `nemo-relay-pii-redaction`, `nemo-relay-ffi`, `nemo-relay-cli` | | PyPI | `nemo-relay` | | npm | `nemo-relay-node`, `nemo-relay-openclaw`, `nemo-relay-wasm` | | GitHub Releases | CLI binaries and `SHA256SUMS` | @@ -50,7 +50,8 @@ NeMo Relay versions are anchored on the workspace SemVer in the repository root - The root `Cargo.toml` `workspace.package.version` is the canonical release version for the Rust workspace. -- The root `Cargo.toml` `workspace.dependencies` entries for `nemo-relay`, +- The root `Cargo.toml` `workspace.dependencies` entries for + `nemo-relay-types`, `nemo-relay`, `nemo-relay-plugin`, `nemo-relay-adaptive`, `nemo-relay-pii-redaction`, `nemo-relay-ffi`, and `nemo-relay-cli` must stay aligned with that same version. - `crates/node/package.json` carries the base npm version for the Node.js @@ -132,9 +133,9 @@ Before you create a release tag, confirm the following: 3. The working tree you use for local validation is clean or disposable. 4. Registry credentials and repository settings are in place: - GitHub Actions `id-token: write` access for the top-level crates.io publish job - - crates.io trusted publishers for `nemo-relay`, `nemo-relay-adaptive`, - `nemo-relay-pii-redaction`, `nemo-relay-ffi`, and `nemo-relay-cli` are - configured for the top-level + - crates.io trusted publishers for `nemo-relay-types`, `nemo-relay`, + `nemo-relay-adaptive`, `nemo-relay-plugin`, `nemo-relay-pii-redaction`, + `nemo-relay-ffi`, and `nemo-relay-cli` are configured for the top-level [`.github/workflows/ci.yaml`](.github/workflows/ci.yaml) workflow - GitHub Actions `id-token: write` access is available for the top-level npm publish job - GitHub Actions `id-token: write` access for the top-level PyPI publish job @@ -154,8 +155,9 @@ The helper updates: 1. The root [`Cargo.toml`](Cargo.toml) workspace version. 2. The root [`Cargo.toml`](Cargo.toml) `workspace.dependencies` versions for - `nemo-relay`, `nemo-relay-adaptive`, `nemo-relay-pii-redaction`, - `nemo-relay-ffi`, and `nemo-relay-cli`. + `nemo-relay-types`, `nemo-relay`, `nemo-relay-plugin`, + `nemo-relay-adaptive`, `nemo-relay-pii-redaction`, `nemo-relay-ffi`, and + `nemo-relay-cli`. 3. [`crates/node/package.json`](crates/node/package.json) and the `crates/node` entry in the root [`package-lock.json`](package-lock.json) to the same release version. @@ -193,6 +195,7 @@ If you want to validate the packaging recipes before pushing a tag, run: ```bash just --set output_dir "$PWD/target/release-artifacts" --set ref_name 0.1.0 package-node just --set output_dir "$PWD/target/release-artifacts" --set ref_name 0.1.0 package-openclaw +just --set output_dir "$PWD/target/release-artifacts" --set ref_name 0.1.0 package-rust just --set output_dir "$PWD/target/release-artifacts" --set ref_name 0.1.0 package-python just --set output_dir "$PWD/target/release-artifacts" --set ref_name 0.1.0 package-wasm ``` @@ -234,6 +237,7 @@ The release pipeline then: 2. Runs the required repository checks, language test jobs, and Fern documentation validation. 3. Builds publishable package artifacts with the exact tag version: + - `package-rust` packs the published Rust crates for local validation. - `package-node` packs the npm Node.js package. - `package-openclaw` packs the npm OpenClaw plugin package. - `package-python` builds platform wheels. @@ -243,9 +247,10 @@ The release pipeline then: 4. Publishes packages from the top-level workflow after the reusable packaging jobs complete: - `publish-rust` stamps Cargo workspace versions from the release tag, then - runs `cargo publish --package` for `nemo-relay`, `nemo-relay-adaptive`, - `nemo-relay-pii-redaction`, `nemo-relay-ffi`, and `nemo-relay-cli` - through trusted publishing from the top-level workflow + runs `cargo publish --package` for `nemo-relay-types`, `nemo-relay`, + `nemo-relay-adaptive`, `nemo-relay-plugin`, `nemo-relay-pii-redaction`, + `nemo-relay-ffi`, and `nemo-relay-cli` through trusted publishing from + the top-level workflow - `publish-python` uploads the wheel artifacts to PyPI with trusted publishing from the top-level workflow - `publish-npm` publishes the Node.js, OpenClaw plugin, and WebAssembly npm @@ -310,8 +315,9 @@ for that tag. After the release is live, verify: -1. The `nemo-relay`, `nemo-relay-adaptive`, `nemo-relay-pii-redaction`, - `nemo-relay-ffi`, and `nemo-relay-cli` crates are visible on crates.io. +1. The `nemo-relay-types`, `nemo-relay`, `nemo-relay-adaptive`, + `nemo-relay-plugin`, `nemo-relay-pii-redaction`, `nemo-relay-ffi`, and + `nemo-relay-cli` crates are visible on crates.io. 2. The `nemo-relay` wheel is visible on PyPI. 3. The `nemo-relay-node`, `nemo-relay-openclaw`, and `nemo-relay-wasm` packages are visible on npm. diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 41ec0ee3..41349982 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -209,9 +209,61 @@ async fn run_default(server_args: &ServerArgs) -> Result, + prev: Option, + } + + impl CwdTestScope { + pub(crate) fn locked() -> Self { + Self { + _guard: lock_cwd(), + prev: None, + } + } + + pub(crate) fn enter(path: &std::path::Path) -> Self { + let guard = lock_cwd(); + let prev = std::env::current_dir().unwrap(); + std::env::set_current_dir(path).unwrap(); + Self { + _guard: guard, + prev: Some(prev), + } + } + } + + impl Drop for CwdTestScope { + fn drop(&mut self) { + if let Some(prev) = &self.prev + && let Err(error) = std::env::set_current_dir(prev) + { + CWD_RESTORE_FAILED.store(true, std::sync::atomic::Ordering::SeqCst); + if std::thread::panicking() { + eprintln!("failed to restore current_dir to {prev:?}: {error}"); + } else { + panic!("failed to restore current_dir to {prev:?}: {error}"); + } + } + } + } + + pub(crate) static CWD_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); + static CWD_RESTORE_FAILED: std::sync::atomic::AtomicBool = + std::sync::atomic::AtomicBool::new(false); pub(crate) static ENV_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); pub(crate) static PLUGIN_CONFIG_TEST_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(()); + + fn lock_cwd() -> std::sync::MutexGuard<'static, ()> { + let guard = CWD_TEST_LOCK.lock().expect("CWD_TEST_LOCK poisoned"); + assert!( + !CWD_RESTORE_FAILED.load(std::sync::atomic::Ordering::SeqCst), + "current_dir restore failed in a previous test; aborting to prevent cross-test contamination", + ); + guard + } } #[cfg(test)] diff --git a/crates/cli/tests/coverage/main_tests.rs b/crates/cli/tests/coverage/main_tests.rs index 437d3987..9dd10b5b 100644 --- a/crates/cli/tests/coverage/main_tests.rs +++ b/crates/cli/tests/coverage/main_tests.rs @@ -12,6 +12,7 @@ use crate::config::{ }; struct EnvScope { + _cwd_guard: Option, _guard: std::sync::MutexGuard<'static, ()>, values: Vec<(&'static str, Option)>, } @@ -20,13 +21,19 @@ impl EnvScope { fn hermetic(temp: &tempfile::TempDir) -> Self { let xdg = temp.path().join("xdg"); std::fs::create_dir_all(&xdg).unwrap(); - Self::set(&[ - ("HOME", Some(temp.path().as_os_str())), - ("XDG_CONFIG_HOME", Some(xdg.as_os_str())), - ]) + Self::set_with_cwd_guard( + &[ + ("HOME", Some(temp.path().as_os_str())), + ("XDG_CONFIG_HOME", Some(xdg.as_os_str())), + ], + Some(crate::test_support::CwdTestScope::locked()), + ) } - fn set(values: &[(&'static str, Option<&std::ffi::OsStr>)]) -> Self { + fn set_with_cwd_guard( + values: &[(&'static str, Option<&std::ffi::OsStr>)], + cwd_guard: Option, + ) -> Self { let guard = crate::test_support::ENV_TEST_LOCK .lock() .unwrap_or_else(|error| error.into_inner()); @@ -43,6 +50,7 @@ impl EnvScope { } } Self { + _cwd_guard: cwd_guard, _guard: guard, values: previous, } diff --git a/crates/cli/tests/coverage/setup_tests.rs b/crates/cli/tests/coverage/setup_tests.rs index 370c0d0c..b156b61c 100644 --- a/crates/cli/tests/coverage/setup_tests.rs +++ b/crates/cli/tests/coverage/setup_tests.rs @@ -2,16 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 use super::*; +use crate::test_support::CwdTestScope as CwdScope; use std::ffi::OsString; use std::path::PathBuf; -use std::sync::{Mutex, OnceLock}; - -// Current-directory changes are process-wide, so tests that enter a temp workspace -// must run serially with respect to each other. -fn cwd_lock() -> &'static Mutex<()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| Mutex::new(())) -} // Tests that exercise the global-config write path clear `$XDG_CONFIG_HOME` // because CI runners commonly set it to a real `/home/runner/.config` path. @@ -47,29 +40,6 @@ impl Drop for XdgScope { } } -struct CwdScope { - _guard: std::sync::MutexGuard<'static, ()>, - prev: PathBuf, -} - -impl CwdScope { - fn enter(path: &std::path::Path) -> Self { - let guard = cwd_lock().lock().unwrap_or_else(|e| e.into_inner()); - let prev = std::env::current_dir().unwrap(); - std::env::set_current_dir(path).unwrap(); - Self { - _guard: guard, - prev, - } - } -} - -impl Drop for CwdScope { - fn drop(&mut self) { - std::env::set_current_dir(&self.prev).unwrap(); - } -} - struct EnvScope { _guard: std::sync::MutexGuard<'static, ()>, values: Vec<(&'static str, Option)>, diff --git a/crates/plugin/Cargo.toml b/crates/plugin/Cargo.toml new file mode 100644 index 00000000..78c2cabb --- /dev/null +++ b/crates/plugin/Cargo.toml @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "nemo-relay-plugin" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Rust plugin authoring SDK and stable native plugin ABI for NeMo Relay." + +[lints] +workspace = true + +[dependencies] +nemo-relay-types.workspace = true +serde = { version = "1", features = ["derive"] } +serde_json = "1" diff --git a/crates/plugin/src/lib.rs b/crates/plugin/src/lib.rs new file mode 100644 index 00000000..1d764b85 --- /dev/null +++ b/crates/plugin/src/lib.rs @@ -0,0 +1,2755 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#![deny(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links)] + +//! Stable native plugin ABI and Rust authoring helpers for NeMo Relay. +//! +//! This crate intentionally does not depend on the `nemo-relay` runtime crate. +//! Native plugins built with it communicate with a host through versioned +//! C-compatible tables and host-owned string handles. + +use std::ffi::{c_char, c_void}; +use std::marker::{PhantomData, PhantomPinned}; +use std::panic::{AssertUnwindSafe, catch_unwind}; +use std::ptr; +use std::sync::Mutex; + +pub use nemo_relay_types::Json; +pub use nemo_relay_types::api::event::{Event, ScopeCategory}; +pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest}; +pub use nemo_relay_types::api::scope::{HandleAttributes, ScopeAttributes, ScopeType}; +pub use nemo_relay_types::api::tool::ToolAttributes; +pub use nemo_relay_types::codec::request::AnnotatedLlmRequest; +pub use nemo_relay_types::codec::response::AnnotatedLlmResponse; +pub use nemo_relay_types::plugin::{ConfigDiagnostic, DiagnosticLevel}; +use serde::{Serialize, de::DeserializeOwned}; +use serde_json::Map; + +/// Native plugin ABI version supported by this crate. +pub const NEMO_RELAY_NATIVE_ABI_VERSION: u32 = 1; + +/// Status codes returned by stable native ABI functions. +#[repr(i32)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum NemoRelayStatus { + /// Operation completed successfully. + Ok = 0, + /// A resource with the given name already exists. + AlreadyExists = 1, + /// The requested resource was not found. + NotFound = 2, + /// The scope stack is empty. + ScopeStackEmpty = 3, + /// A guardrail rejected the operation. + GuardrailRejected = 4, + /// An internal runtime error occurred. + Internal = 5, + /// A required pointer argument was null. + NullPointer = 6, + /// A JSON string argument could not be parsed. + InvalidJson = 7, + /// A string argument contained invalid UTF-8. + InvalidUtf8 = 8, + /// A function argument had an invalid value. + InvalidArg = 9, + /// A stream reached end-of-stream and has no chunk to return. + StreamEnd = 10, +} + +/// Opaque host-owned UTF-8 string or JSON byte buffer. +#[repr(C)] +pub struct NemoRelayNativeString { + _private: [u8; 0], + _marker: PhantomData<(*mut u8, PhantomPinned)>, +} + +/// Opaque plugin registration context borrowed from the host during registration. +#[repr(C)] +pub struct NemoRelayNativePluginContext { + _private: [u8; 0], + _marker: PhantomData<(*mut u8, PhantomPinned)>, +} + +/// Opaque host-owned scope handle. +#[repr(C)] +pub struct NemoRelayNativeScopeHandle { + _private: [u8; 0], + _marker: PhantomData<(*mut u8, PhantomPinned)>, +} + +/// Opaque host-owned scope stack handle. +#[repr(C)] +pub struct NemoRelayNativeScopeStack { + _private: [u8; 0], + _marker: PhantomData<(*mut u8, PhantomPinned)>, +} + +/// Opaque host-owned captured scope-stack binding. +#[repr(C)] +pub struct NemoRelayNativeScopeStackBinding { + _private: [u8; 0], + _marker: PhantomData<(*mut u8, PhantomPinned)>, +} + +/// Scope category used by native plugins when opening scopes. +#[repr(i32)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum NemoRelayNativeScopeType { + /// Top-level agent scope. + Agent = 0, + /// Generic function scope. + Function = 1, + /// Tool invocation scope. + Tool = 2, + /// LLM call scope. + Llm = 3, + /// Retriever scope. + Retriever = 4, + /// Embedder scope. + Embedder = 5, + /// Reranker scope. + Reranker = 6, + /// Guardrail evaluation scope. + Guardrail = 7, + /// Evaluator scope. + Evaluator = 8, + /// User-defined custom scope. + Custom = 9, + /// Unknown or unspecified scope type. + Unknown = 10, +} + +/// Optional destructor for user data captured by native callbacks. +pub type NemoRelayNativeFreeFn = Option; + +/// Native callback executed while a host scope stack is temporarily active. +pub type NemoRelayNativeWithScopeStackCb = + unsafe extern "C" fn(user_data: *mut c_void) -> NemoRelayStatus; + +/// Runtime-provided continuation for tool execution intercepts. +pub type NemoRelayNativeToolNextFn = unsafe extern "C" fn( + args_json: *const NemoRelayNativeString, + next_ctx: *mut c_void, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus; + +/// Runtime-provided continuation for LLM execution intercepts. +pub type NemoRelayNativeLlmNextFn = unsafe extern "C" fn( + request_json: *const NemoRelayNativeString, + next_ctx: *mut c_void, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus; + +/// Native stream poll callback. +/// +/// Return [`NemoRelayStatus::Ok`] with `out_json` set for one chunk, +/// [`NemoRelayStatus::StreamEnd`] with `out_json` null at end of stream, or an +/// error status for stream failure. +pub type NemoRelayNativeLlmStreamPollFn = unsafe extern "C" fn( + user_data: *mut c_void, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus; + +/// Optional native stream cancellation callback. +pub type NemoRelayNativeLlmStreamCancelFn = + Option NemoRelayStatus>; + +/// Optional native stream destructor callback. +pub type NemoRelayNativeLlmStreamDropFn = Option; + +/// Native LLM JSON stream handle table. +#[repr(C)] +pub struct NemoRelayNativeLlmStreamV1 { + /// Size of this struct as seen by the producer. + pub struct_size: usize, + /// Stream state passed back to poll/cancel/drop callbacks. + pub user_data: *mut c_void, + /// Polls the next stream chunk. + pub next: Option, + /// Cancels an in-flight stream when a consumer stops before stream end. + pub cancel: NemoRelayNativeLlmStreamCancelFn, + /// Drops stream state after stream completion, error, or cancellation. + pub drop: NemoRelayNativeLlmStreamDropFn, +} + +impl Default for NemoRelayNativeLlmStreamV1 { + fn default() -> Self { + Self { + struct_size: std::mem::size_of::(), + user_data: ptr::null_mut(), + next: None, + cancel: None, + drop: None, + } + } +} + +/// Runtime-provided continuation for LLM stream execution intercepts. +pub type NemoRelayNativeLlmStreamNextFn = unsafe extern "C" fn( + request_json: *const NemoRelayNativeString, + next_ctx: *mut c_void, + out_stream: *mut NemoRelayNativeLlmStreamV1, +) -> NemoRelayStatus; + +/// Native event subscriber callback. +pub type NemoRelayNativeEventSubscriberCb = unsafe extern "C" fn( + user_data: *mut c_void, + event_json: *const NemoRelayNativeString, +) -> NemoRelayStatus; + +/// Native JSON transform callback for tool request/response sanitizers and tool request intercepts. +pub type NemoRelayNativeToolJsonCb = unsafe extern "C" fn( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + payload_json: *const NemoRelayNativeString, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus; + +/// Native tool conditional-execution callback. +pub type NemoRelayNativeToolConditionalCb = unsafe extern "C" fn( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + args_json: *const NemoRelayNativeString, + out_reason: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus; + +/// Native tool execution intercept callback. +pub type NemoRelayNativeToolExecutionCb = unsafe extern "C" fn( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + args_json: *const NemoRelayNativeString, + next_fn: NemoRelayNativeToolNextFn, + next_ctx: *mut c_void, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus; + +/// Native LLM request transform callback for request sanitizers. +pub type NemoRelayNativeLlmRequestCb = unsafe extern "C" fn( + user_data: *mut c_void, + request_json: *const NemoRelayNativeString, + out_request_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus; + +/// Native JSON transform callback for LLM response sanitizers. +pub type NemoRelayNativeJsonCb = unsafe extern "C" fn( + user_data: *mut c_void, + payload_json: *const NemoRelayNativeString, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus; + +/// Native LLM conditional-execution callback. +pub type NemoRelayNativeLlmConditionalCb = unsafe extern "C" fn( + user_data: *mut c_void, + request_json: *const NemoRelayNativeString, + out_reason: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus; + +/// Native LLM request intercept callback. +pub type NemoRelayNativeLlmRequestInterceptCb = unsafe extern "C" fn( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + request_json: *const NemoRelayNativeString, + annotated_json: *const NemoRelayNativeString, + out_request_json: *mut *mut NemoRelayNativeString, + out_annotated_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus; + +/// Native LLM execution intercept callback. +pub type NemoRelayNativeLlmExecutionCb = unsafe extern "C" fn( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + request_json: *const NemoRelayNativeString, + next_fn: NemoRelayNativeLlmNextFn, + next_ctx: *mut c_void, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus; + +/// Native LLM stream execution intercept callback. +pub type NemoRelayNativeLlmStreamExecutionCb = unsafe extern "C" fn( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + request_json: *const NemoRelayNativeString, + next_fn: NemoRelayNativeLlmStreamNextFn, + next_ctx: *mut c_void, + out_stream: *mut NemoRelayNativeLlmStreamV1, +) -> NemoRelayStatus; + +/// Native plugin validation callback. +pub type NemoRelayNativePluginValidateFn = unsafe extern "C" fn( + user_data: *mut c_void, + plugin_config_json: *const NemoRelayNativeString, + out_diagnostics_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus; + +/// Native plugin registration callback. +pub type NemoRelayNativePluginRegisterFn = unsafe extern "C" fn( + user_data: *mut c_void, + plugin_config_json: *const NemoRelayNativeString, + ctx: *mut NemoRelayNativePluginContext, +) -> NemoRelayStatus; + +/// Native plugin drop callback. +pub type NemoRelayNativePluginDropFn = Option; + +/// Versioned host API table passed to native plugin entry symbols. +#[repr(C)] +#[derive(Clone, Copy)] +pub struct NemoRelayNativeHostApiV1 { + /// ABI version implemented by this table. + pub abi_version: u32, + /// Size of this struct as seen by the host. + pub struct_size: usize, + /// Null-terminated host Relay version string. + pub relay_version: *const c_char, + /// Allocates a host-owned string from UTF-8 bytes. + pub string_new: unsafe extern "C" fn( + data: *const u8, + len: usize, + out: *mut *mut NemoRelayNativeString, + ) -> NemoRelayStatus, + /// Returns the string data pointer for a host-owned string. + pub string_data: unsafe extern "C" fn(value: *const NemoRelayNativeString) -> *const u8, + /// Returns the byte length for a host-owned string. + pub string_len: unsafe extern "C" fn(value: *const NemoRelayNativeString) -> usize, + /// Frees a host-owned string. + pub string_free: unsafe extern "C" fn(value: *mut NemoRelayNativeString), + /// Clears the host thread-local native ABI error message. + pub last_error_clear: unsafe extern "C" fn(), + /// Sets the host thread-local native ABI error message. + pub last_error_set: unsafe extern "C" fn(message: *const NemoRelayNativeString), + /// Registers an event subscriber through the plugin context. + pub plugin_context_register_subscriber: unsafe extern "C" fn( + ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + cb: NemoRelayNativeEventSubscriberCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus, + /// Registers a tool sanitize-request guardrail through the plugin context. + pub plugin_context_register_tool_sanitize_request_guardrail: + unsafe extern "C" fn( + ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeToolJsonCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus, + /// Registers a tool sanitize-response guardrail through the plugin context. + pub plugin_context_register_tool_sanitize_response_guardrail: + unsafe extern "C" fn( + ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeToolJsonCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus, + /// Registers a tool conditional-execution guardrail through the plugin context. + pub plugin_context_register_tool_conditional_execution_guardrail: + unsafe extern "C" fn( + ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeToolConditionalCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus, + /// Registers a tool request intercept through the plugin context. + pub plugin_context_register_tool_request_intercept: unsafe extern "C" fn( + ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + break_chain: bool, + cb: NemoRelayNativeToolJsonCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) + -> NemoRelayStatus, + /// Registers a tool execution intercept through the plugin context. + pub plugin_context_register_tool_execution_intercept: unsafe extern "C" fn( + ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeToolExecutionCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) + -> NemoRelayStatus, + /// Registers an LLM sanitize-request guardrail through the plugin context. + pub plugin_context_register_llm_sanitize_request_guardrail: + unsafe extern "C" fn( + ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeLlmRequestCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus, + /// Registers an LLM sanitize-response guardrail through the plugin context. + pub plugin_context_register_llm_sanitize_response_guardrail: + unsafe extern "C" fn( + ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeJsonCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus, + /// Registers an LLM conditional-execution guardrail through the plugin context. + pub plugin_context_register_llm_conditional_execution_guardrail: + unsafe extern "C" fn( + ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeLlmConditionalCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus, + /// Registers an LLM request intercept through the plugin context. + pub plugin_context_register_llm_request_intercept: unsafe extern "C" fn( + ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + break_chain: bool, + cb: NemoRelayNativeLlmRequestInterceptCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus, + /// Registers an LLM execution intercept through the plugin context. + pub plugin_context_register_llm_execution_intercept: unsafe extern "C" fn( + ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeLlmExecutionCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) + -> NemoRelayStatus, + /// Registers an LLM stream execution intercept through the plugin context. + pub plugin_context_register_llm_stream_execution_intercept: + unsafe extern "C" fn( + ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeLlmStreamExecutionCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus, + /// Frees a host-owned scope handle. + pub scope_handle_free: unsafe extern "C" fn(handle: *mut NemoRelayNativeScopeHandle), + /// Retrieves the current scope handle from the active stack. + pub scope_get_current: + unsafe extern "C" fn(out: *mut *mut NemoRelayNativeScopeHandle) -> NemoRelayStatus, + /// Pushes a scope, emits its start event, and returns its handle. + pub scope_push: unsafe extern "C" fn( + name: *const NemoRelayNativeString, + scope_type: NemoRelayNativeScopeType, + parent: *const NemoRelayNativeScopeHandle, + attributes: u32, + data_json: *const NemoRelayNativeString, + metadata_json: *const NemoRelayNativeString, + input_json: *const NemoRelayNativeString, + timestamp_unix_micros: *const i64, + out: *mut *mut NemoRelayNativeScopeHandle, + ) -> NemoRelayStatus, + /// Pops a scope handle, emits its end event, and clears scope-local registrations. + pub scope_pop: unsafe extern "C" fn( + handle: *const NemoRelayNativeScopeHandle, + output_json: *const NemoRelayNativeString, + metadata_json: *const NemoRelayNativeString, + timestamp_unix_micros: *const i64, + ) -> NemoRelayStatus, + /// Emits a mark event under the current or provided parent scope. + pub emit_mark: unsafe extern "C" fn( + name: *const NemoRelayNativeString, + parent: *const NemoRelayNativeScopeHandle, + data_json: *const NemoRelayNativeString, + metadata_json: *const NemoRelayNativeString, + timestamp_unix_micros: *const i64, + ) -> NemoRelayStatus, + /// Creates a new independent scope stack with its own root scope. + pub scope_stack_create: + unsafe extern "C" fn(out: *mut *mut NemoRelayNativeScopeStack) -> NemoRelayStatus, + /// Frees a host-owned scope stack handle. + pub scope_stack_free: unsafe extern "C" fn(stack: *mut NemoRelayNativeScopeStack), + /// Binds a scope stack to the current OS thread. + pub scope_stack_set_thread: + unsafe extern "C" fn(stack: *const NemoRelayNativeScopeStack) -> NemoRelayStatus, + /// Captures the current thread-local scope-stack binding. + pub scope_stack_capture_thread: + unsafe extern "C" fn(out: *mut *mut NemoRelayNativeScopeStackBinding) -> NemoRelayStatus, + /// Restores and frees a captured thread-local scope-stack binding. + pub scope_stack_restore_thread: + unsafe extern "C" fn(binding: *mut NemoRelayNativeScopeStackBinding) -> NemoRelayStatus, + /// Frees a captured thread-local binding without restoring it. + pub scope_stack_binding_free: + unsafe extern "C" fn(binding: *mut NemoRelayNativeScopeStackBinding), + /// Returns whether the current context has an explicitly active scope stack. + pub scope_stack_active: unsafe extern "C" fn() -> bool, + /// Runs a callback with the provided scope stack visible to host runtime APIs. + pub scope_stack_with_current: unsafe extern "C" fn( + stack: *const NemoRelayNativeScopeStack, + cb: NemoRelayNativeWithScopeStackCb, + user_data: *mut c_void, + ) -> NemoRelayStatus, +} + +// The host API table is immutable after construction. Function pointers and +// the null-terminated version string pointer are safe to share across threads. +unsafe impl Send for NemoRelayNativeHostApiV1 {} +unsafe impl Sync for NemoRelayNativeHostApiV1 {} + +/// Versioned plugin descriptor returned by native plugin entry symbols. +#[repr(C)] +pub struct NemoRelayNativePluginV1 { + /// Size of this struct as seen by the plugin. + pub struct_size: usize, + /// Host-owned plugin kind string. + pub plugin_kind: *mut NemoRelayNativeString, + /// Whether this plugin kind supports multiple configured components. + pub allows_multiple_components: bool, + /// Plugin-owned state pointer passed to callbacks. + pub user_data: *mut c_void, + /// Optional validation callback. + pub validate: Option, + /// Required registration callback. + pub register: Option, + /// Optional plugin-owned state destructor. + pub drop: NemoRelayNativePluginDropFn, +} + +impl Default for NemoRelayNativePluginV1 { + fn default() -> Self { + Self { + struct_size: std::mem::size_of::(), + plugin_kind: ptr::null_mut(), + allows_multiple_components: true, + user_data: ptr::null_mut(), + validate: None, + register: None, + drop: None, + } + } +} + +/// Native entry symbol type loaded by the host. +pub type NemoRelayNativePluginEntry = unsafe extern "C" fn( + host: *const NemoRelayNativeHostApiV1, + out: *mut NemoRelayNativePluginV1, +) -> NemoRelayStatus; + +/// Result type used by the Rust native plugin SDK. +pub type Result = std::result::Result; + +/// Synchronous JSON chunk stream used by native LLM stream intercept helpers. +pub type LlmJsonStream = Box> + Send>; + +/// Cloneable high-level runtime handle for host APIs available to native plugins. +#[derive(Clone)] +pub struct PluginRuntime { + host: NemoRelayNativeHostApiV1, +} + +impl PluginRuntime { + /// Creates a runtime handle from the host ABI table. + pub fn new(host: &NemoRelayNativeHostApiV1) -> Self { + Self { host: *host } + } + + /// Returns the underlying host ABI table. + pub fn host_api(&self) -> &NemoRelayNativeHostApiV1 { + &self.host + } + + /// Retrieves the current scope handle. + pub fn current_scope(&self) -> Result> { + current_scope(&self.host) + } + + /// Pushes a scope and emits its start event. + pub fn push_scope( + &self, + name: &str, + scope_type: ScopeType, + data: Option<&Json>, + metadata: Option<&Json>, + input: Option<&Json>, + ) -> Result> { + push_scope(&self.host, name, scope_type.into(), data, metadata, input) + } + + /// Pops a scope and emits its end event. + pub fn pop_scope( + &self, + handle: &ScopeHandle<'_>, + output: Option<&Json>, + metadata: Option<&Json>, + ) -> Result<()> { + pop_scope(&self.host, handle, output, metadata) + } + + /// Opens a scope that is popped automatically when the guard is closed or dropped. + pub fn scope( + &self, + name: &str, + scope_type: ScopeType, + data: Option<&Json>, + metadata: Option<&Json>, + input: Option<&Json>, + ) -> Result> { + let handle = self.push_scope(name, scope_type, data, metadata, input)?; + Ok(ScopeGuard { + runtime: self, + handle: Some(handle), + }) + } + + /// Emits a mark event under the current scope. + pub fn emit_mark( + &self, + name: &str, + data: Option<&Json>, + metadata: Option<&Json>, + ) -> Result<()> { + emit_mark(&self.host, name, data, metadata) + } + + /// Creates a new independent scope stack. + pub fn create_scope_stack(&self) -> Result> { + create_scope_stack(&self.host) + } + + /// Captures the current thread-local scope-stack binding. + pub fn capture_scope_stack_thread(&self) -> Result> { + capture_scope_stack_thread(&self.host) + } + + /// Returns whether the current context has an explicitly active scope stack. + pub fn scope_stack_active(&self) -> bool { + unsafe { (self.host.scope_stack_active)() } + } + + /// Binds `stack` to the current OS thread until the returned guard is dropped. + pub fn bind_scope_stack_thread<'a>( + &'a self, + stack: &'a ScopeStack<'a>, + ) -> Result> { + let previous = self.capture_scope_stack_thread()?; + let status = stack.set_thread(); + if status == NemoRelayStatus::Ok { + Ok(ThreadScopeStackGuard { + previous: Some(previous), + }) + } else { + let _ = previous.restore(); + Err(format!("scope_stack_set_thread failed: {status:?}")) + } + } +} + +impl From for NemoRelayNativeScopeType { + fn from(value: ScopeType) -> Self { + match value { + ScopeType::Agent => Self::Agent, + ScopeType::Function => Self::Function, + ScopeType::Tool => Self::Tool, + ScopeType::Llm => Self::Llm, + ScopeType::Retriever => Self::Retriever, + ScopeType::Embedder => Self::Embedder, + ScopeType::Reranker => Self::Reranker, + ScopeType::Guardrail => Self::Guardrail, + ScopeType::Evaluator => Self::Evaluator, + ScopeType::Custom => Self::Custom, + ScopeType::Unknown => Self::Unknown, + } + } +} + +/// RAII guard for a host scope opened by [`PluginRuntime::scope`]. +pub struct ScopeGuard<'a> { + runtime: &'a PluginRuntime, + handle: Option>, +} + +impl<'a> ScopeGuard<'a> { + /// Returns the active scope handle. + pub fn handle(&self) -> Option<&ScopeHandle<'a>> { + self.handle.as_ref() + } + + /// Pops the scope with optional output and metadata. + pub fn close(&mut self, output: Option<&Json>, metadata: Option<&Json>) -> Result<()> { + let Some(handle) = self.handle.as_ref() else { + return Ok(()); + }; + self.runtime.pop_scope(handle, output, metadata)?; + self.handle.take(); + Ok(()) + } +} + +impl Drop for ScopeGuard<'_> { + fn drop(&mut self) { + if let Some(handle) = self.handle.take() { + let _ = self.runtime.pop_scope(&handle, None, None); + } + } +} + +/// RAII guard that restores the previous thread-local scope stack on drop. +pub struct ThreadScopeStackGuard<'a> { + previous: Option>, +} + +impl ThreadScopeStackGuard<'_> { + /// Restores the previous thread-local scope stack immediately. + pub fn restore(mut self) -> Result<()> { + let Some(previous) = self.previous.take() else { + return Ok(()); + }; + let status = previous.restore(); + if status == NemoRelayStatus::Ok { + Ok(()) + } else { + Err(format!("scope_stack_restore_thread failed: {status:?}")) + } + } +} + +impl Drop for ThreadScopeStackGuard<'_> { + fn drop(&mut self) { + if let Some(previous) = self.previous.take() { + let _ = previous.restore(); + } + } +} + +/// Typed continuation passed to tool execution intercepts. +pub struct ToolNext<'a> { + host: &'a NemoRelayNativeHostApiV1, + next_fn: NemoRelayNativeToolNextFn, + next_ctx: *mut c_void, +} + +impl ToolNext<'_> { + /// Continues the tool execution chain with replacement arguments. + pub fn call(&self, args: Json) -> Result { + let args = HostString::from_json(self.host, &args) + .ok_or_else(|| "failed to allocate tool next args".to_string())?; + let mut out = ptr::null_mut(); + let status = unsafe { (self.next_fn)(args.as_ptr(), self.next_ctx, &mut out) }; + if status != NemoRelayStatus::Ok { + return Err(format!("tool next failed: {status:?}")); + } + if out.is_null() { + return Err("tool next returned null output".into()); + } + let result = read_json_value(self.host, out, "tool next result"); + unsafe { (self.host.string_free)(out) }; + result.map_err(|status| format!("tool next returned invalid JSON: {status:?}")) + } +} + +/// Typed continuation passed to LLM execution intercepts. +pub struct LlmNext<'a> { + host: &'a NemoRelayNativeHostApiV1, + next_fn: NemoRelayNativeLlmNextFn, + next_ctx: *mut c_void, +} + +impl LlmNext<'_> { + /// Continues the LLM execution chain with a replacement request. + pub fn call(&self, request: LlmRequest) -> Result { + let request = HostString::from_json(self.host, &request) + .ok_or_else(|| "failed to allocate LLM next request".to_string())?; + let mut out = ptr::null_mut(); + let status = unsafe { (self.next_fn)(request.as_ptr(), self.next_ctx, &mut out) }; + if status != NemoRelayStatus::Ok { + return Err(format!("llm next failed: {status:?}")); + } + if out.is_null() { + return Err("llm next returned null output".into()); + } + let result = read_json_value(self.host, out, "llm next result"); + unsafe { (self.host.string_free)(out) }; + result.map_err(|status| format!("llm next returned invalid JSON: {status:?}")) + } +} + +/// Typed continuation passed to LLM stream execution intercepts. +pub struct LlmStreamNext<'a> { + host: &'a NemoRelayNativeHostApiV1, + next_fn: NemoRelayNativeLlmStreamNextFn, + next_ctx: *mut c_void, +} + +impl LlmStreamNext<'_> { + /// Continues the LLM stream execution chain with a replacement request. + pub fn call(&self, request: LlmRequest) -> Result { + let request = HostString::from_json(self.host, &request) + .ok_or_else(|| "failed to allocate LLM stream next request".to_string())?; + let mut raw = NemoRelayNativeLlmStreamV1::default(); + let status = unsafe { (self.next_fn)(request.as_ptr(), self.next_ctx, &mut raw) }; + if status != NemoRelayStatus::Ok { + return Err(format!("llm stream next failed: {status:?}")); + } + unsafe { LlmStream::from_raw(self.host, raw) } + } +} + +/// Host- or plugin-owned stream returned across the native LLM stream ABI. +pub struct LlmStream { + host: NemoRelayNativeHostApiV1, + raw: NemoRelayNativeLlmStreamV1, + finished: bool, +} + +// The host ABI table is Send, and stream ownership is exclusive through this wrapper. +unsafe impl Send for LlmStream {} + +impl LlmStream { + /// Creates a typed stream wrapper from a raw stream table. + /// + /// # Safety + /// `raw` must contain callbacks and `user_data` produced by the same host + /// and must not be used again after it is moved into this wrapper. + pub unsafe fn from_raw( + host: &NemoRelayNativeHostApiV1, + mut raw: NemoRelayNativeLlmStreamV1, + ) -> Result { + let expected_size = std::mem::size_of::(); + if raw.struct_size != expected_size { + if raw.struct_size >= expected_size { + unsafe { drop_raw_llm_stream(&mut raw) }; + } + return Err(format!( + "unsupported LLM stream struct size: {}", + raw.struct_size + )); + } + if raw.next.is_none() { + unsafe { drop_raw_llm_stream(&mut raw) }; + return Err("LLM stream next callback was null".into()); + } + Ok(Self { + host: *host, + raw, + finished: false, + }) + } + + /// Polls the next stream chunk. + pub fn next_chunk(&mut self) -> Result> { + if self.finished { + return Ok(None); + } + let Some(next) = self.raw.next else { + self.finished = true; + return Err("LLM stream next callback was null".into()); + }; + let mut out = ptr::null_mut(); + let status = unsafe { next(self.raw.user_data, &mut out) }; + match status { + NemoRelayStatus::Ok => { + if out.is_null() { + self.finished = true; + return Err("LLM stream returned null chunk".into()); + } + let result = read_json_value(&self.host, out, "LLM stream chunk"); + unsafe { (self.host.string_free)(out) }; + match result { + Ok(chunk) => Ok(Some(chunk)), + Err(status) => { + self.finished = true; + Err(format!("LLM stream returned invalid JSON: {status:?}")) + } + } + } + NemoRelayStatus::StreamEnd => { + if !out.is_null() { + unsafe { (self.host.string_free)(out) }; + } + self.finished = true; + Ok(None) + } + other => { + if !out.is_null() { + unsafe { (self.host.string_free)(out) }; + } + self.finished = true; + Err(format!("LLM stream failed: {other:?}")) + } + } + } + + /// Cancels the stream if it has not reached end-of-stream. + pub fn cancel(&mut self) -> Result<()> { + if self.finished { + return Ok(()); + } + if let Some(cancel) = self.raw.cancel { + let status = unsafe { cancel(self.raw.user_data) }; + if status != NemoRelayStatus::Ok { + return Err(format!("LLM stream cancel failed: {status:?}")); + } + } + self.finished = true; + Ok(()) + } +} + +impl Iterator for LlmStream { + type Item = Result; + + fn next(&mut self) -> Option { + match self.next_chunk() { + Ok(Some(chunk)) => Some(Ok(chunk)), + Ok(None) => None, + Err(message) => Some(Err(message)), + } + } +} + +unsafe fn drop_raw_llm_stream(raw: &mut NemoRelayNativeLlmStreamV1) { + if let Some(drop_fn) = raw.drop.take() { + unsafe { drop_fn(raw.user_data) }; + } + raw.user_data = ptr::null_mut(); +} + +impl Drop for LlmStream { + fn drop(&mut self) { + if !self.finished { + if let Some(cancel) = self.raw.cancel { + let _ = unsafe { cancel(self.raw.user_data) }; + } + self.finished = true; + } + unsafe { drop_raw_llm_stream(&mut self.raw) }; + } +} + +/// Host-owned scope handle returned by native scope APIs. +pub struct ScopeHandle<'a> { + host: &'a NemoRelayNativeHostApiV1, + ptr: *mut NemoRelayNativeScopeHandle, +} + +impl<'a> ScopeHandle<'a> { + /// Returns the raw ABI pointer. + pub fn as_ptr(&self) -> *const NemoRelayNativeScopeHandle { + self.ptr + } +} + +impl Drop for ScopeHandle<'_> { + fn drop(&mut self) { + unsafe { (self.host.scope_handle_free)(self.ptr) }; + } +} + +/// Host-owned isolated scope stack returned by native scope-stack APIs. +pub struct ScopeStack<'a> { + host: &'a NemoRelayNativeHostApiV1, + ptr: *mut NemoRelayNativeScopeStack, +} + +impl<'a> ScopeStack<'a> { + /// Returns the raw ABI pointer. + pub fn as_ptr(&self) -> *const NemoRelayNativeScopeStack { + self.ptr + } + + fn set_thread(&self) -> NemoRelayStatus { + unsafe { (self.host.scope_stack_set_thread)(self.ptr) } + } + + /// Executes `f` while this stack is visible to host runtime APIs. + pub fn with_current(&self, f: F) -> Result<()> + where + F: FnOnce() -> Result<()>, + { + struct State { + f: Option, + error: Option, + } + + unsafe extern "C" fn trampoline(user_data: *mut c_void) -> NemoRelayStatus + where + F: FnOnce() -> Result<()>, + { + if user_data.is_null() { + return NemoRelayStatus::NullPointer; + } + let state = unsafe { &mut *(user_data as *mut State) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let Some(f) = state.f.take() else { + return Err("scope-stack callback was already consumed".to_string()); + }; + f() + })); + match result { + Ok(Ok(())) => NemoRelayStatus::Ok, + Ok(Err(message)) => { + state.error = Some(message); + NemoRelayStatus::Internal + } + Err(_) => { + state.error = Some("scope-stack callback panicked".into()); + NemoRelayStatus::Internal + } + } + } + + let mut state = State { + f: Some(f), + error: None, + }; + let status = unsafe { + (self.host.scope_stack_with_current)( + self.ptr, + trampoline::, + (&mut state as *mut State<_>).cast(), + ) + }; + if status == NemoRelayStatus::Ok { + Ok(()) + } else { + Err(state + .error + .unwrap_or_else(|| format!("scope_stack_with_current failed: {status:?}"))) + } + } +} + +impl Drop for ScopeStack<'_> { + fn drop(&mut self) { + unsafe { (self.host.scope_stack_free)(self.ptr) }; + } +} + +/// Captured thread-local scope-stack binding. +pub struct ScopeStackBinding<'a> { + host: &'a NemoRelayNativeHostApiV1, + ptr: *mut NemoRelayNativeScopeStackBinding, +} + +impl<'a> ScopeStackBinding<'a> { + /// Restores and consumes this binding. + pub fn restore(mut self) -> NemoRelayStatus { + let ptr = std::mem::replace(&mut self.ptr, ptr::null_mut()); + unsafe { (self.host.scope_stack_restore_thread)(ptr) } + } +} + +impl Drop for ScopeStackBinding<'_> { + fn drop(&mut self) { + if !self.ptr.is_null() { + unsafe { (self.host.scope_stack_binding_free)(self.ptr) }; + } + } +} + +/// Retrieves the current scope handle. +pub fn current_scope(host: &NemoRelayNativeHostApiV1) -> Result> { + let mut out = ptr::null_mut(); + let status = unsafe { (host.scope_get_current)(&mut out) }; + if status == NemoRelayStatus::Ok && !out.is_null() { + Ok(ScopeHandle { host, ptr: out }) + } else { + Err(format!("scope_get_current failed: {status:?}")) + } +} + +/// Pushes a scope and emits its start event. +pub fn push_scope<'a>( + host: &'a NemoRelayNativeHostApiV1, + name: &str, + scope_type: NemoRelayNativeScopeType, + data: Option<&Json>, + metadata: Option<&Json>, + input: Option<&Json>, +) -> Result> { + let name = + HostString::new(host, name).ok_or_else(|| "failed to allocate scope name".to_string())?; + let data = OptionalHostJson::new(host, data)?; + let metadata = OptionalHostJson::new(host, metadata)?; + let input = OptionalHostJson::new(host, input)?; + let mut out = ptr::null_mut(); + let status = unsafe { + (host.scope_push)( + name.as_ptr(), + scope_type, + ptr::null(), + 0, + data.as_ptr(), + metadata.as_ptr(), + input.as_ptr(), + ptr::null(), + &mut out, + ) + }; + if status == NemoRelayStatus::Ok && !out.is_null() { + Ok(ScopeHandle { host, ptr: out }) + } else { + Err(format!("scope_push failed: {status:?}")) + } +} + +/// Pops a scope and emits its end event. +pub fn pop_scope( + host: &NemoRelayNativeHostApiV1, + handle: &ScopeHandle<'_>, + output: Option<&Json>, + metadata: Option<&Json>, +) -> Result<()> { + let output = OptionalHostJson::new(host, output)?; + let metadata = OptionalHostJson::new(host, metadata)?; + let status = unsafe { + (host.scope_pop)( + handle.as_ptr(), + output.as_ptr(), + metadata.as_ptr(), + ptr::null(), + ) + }; + if status == NemoRelayStatus::Ok { + Ok(()) + } else { + Err(format!("scope_pop failed: {status:?}")) + } +} + +/// Emits a mark event under the current scope. +pub fn emit_mark( + host: &NemoRelayNativeHostApiV1, + name: &str, + data: Option<&Json>, + metadata: Option<&Json>, +) -> Result<()> { + let name = + HostString::new(host, name).ok_or_else(|| "failed to allocate mark name".to_string())?; + let data = OptionalHostJson::new(host, data)?; + let metadata = OptionalHostJson::new(host, metadata)?; + let status = unsafe { + (host.emit_mark)( + name.as_ptr(), + ptr::null(), + data.as_ptr(), + metadata.as_ptr(), + ptr::null(), + ) + }; + if status == NemoRelayStatus::Ok { + Ok(()) + } else { + Err(format!("emit_mark failed: {status:?}")) + } +} + +/// Creates a new independent scope stack. +pub fn create_scope_stack(host: &NemoRelayNativeHostApiV1) -> Result> { + let mut out = ptr::null_mut(); + let status = unsafe { (host.scope_stack_create)(&mut out) }; + if status == NemoRelayStatus::Ok && !out.is_null() { + Ok(ScopeStack { host, ptr: out }) + } else { + Err(format!("scope_stack_create failed: {status:?}")) + } +} + +/// Captures the current thread-local scope-stack binding. +pub fn capture_scope_stack_thread( + host: &NemoRelayNativeHostApiV1, +) -> Result> { + let mut out = ptr::null_mut(); + let status = unsafe { (host.scope_stack_capture_thread)(&mut out) }; + if status == NemoRelayStatus::Ok && !out.is_null() { + Ok(ScopeStackBinding { host, ptr: out }) + } else { + Err(format!("scope_stack_capture_thread failed: {status:?}")) + } +} + +/// Trait implemented by Rust native plugins. +pub trait NativePlugin: Send + 'static { + /// Returns the stable plugin kind. + fn plugin_kind(&self) -> &str; + + /// Returns whether the plugin allows multiple configured components. + fn allows_multiple_components(&self) -> bool { + true + } + + /// Validates one component-local JSON config object. + fn validate(&self, _plugin_config: &Map) -> Vec { + vec![] + } + + /// Registers runtime behavior through the component-scoped plugin context. + fn register( + &mut self, + plugin_config: &Map, + ctx: &mut PluginContext<'_>, + ) -> Result<()>; +} + +/// Borrowed safe wrapper around a host plugin registration context. +pub struct PluginContext<'a> { + host: &'a NemoRelayNativeHostApiV1, + raw: *mut NemoRelayNativePluginContext, +} + +#[allow(clippy::not_unsafe_ptr_arg_deref)] +impl<'a> PluginContext<'a> { + /// Creates a plugin context wrapper from raw ABI parts. + /// + /// # Safety + /// `host` and `raw` must remain valid for the lifetime of this wrapper. + pub unsafe fn from_raw( + host: &'a NemoRelayNativeHostApiV1, + raw: *mut NemoRelayNativePluginContext, + ) -> Self { + Self { host, raw } + } + + /// Returns the host ABI table backing this registration context. + pub fn host_api(&self) -> &'a NemoRelayNativeHostApiV1 { + self.host + } + + /// Returns a cloneable high-level runtime handle. + pub fn runtime(&self) -> PluginRuntime { + PluginRuntime::new(self.host) + } + + /// Registers a typed event subscriber callback. + pub fn register_subscriber(&mut self, name: &str, callback: F) -> Result<()> + where + F: Fn(&Event) + Send + Sync + 'static, + { + let user_data = typed_callback_user_data(self.host, callback); + let status = unsafe { + self.register_subscriber_raw( + name, + typed_subscriber_trampoline::, + user_data, + Some(drop_typed_callback::), + ) + }; + finish_typed_registration::(self.host, status, user_data, "subscriber") + } + + /// Registers a typed tool sanitize-request guardrail. + pub fn register_tool_sanitize_request_guardrail( + &mut self, + name: &str, + priority: i32, + callback: F, + ) -> Result<()> + where + F: Fn(&str, Json) -> Json + Send + Sync + 'static, + { + let user_data = typed_callback_user_data(self.host, callback); + let status = unsafe { + self.register_tool_sanitize_request_guardrail_raw( + name, + priority, + typed_tool_sanitize_trampoline::, + user_data, + Some(drop_typed_callback::), + ) + }; + finish_typed_registration::( + self.host, + status, + user_data, + "tool sanitize request guardrail", + ) + } + + /// Registers a typed tool sanitize-response guardrail. + pub fn register_tool_sanitize_response_guardrail( + &mut self, + name: &str, + priority: i32, + callback: F, + ) -> Result<()> + where + F: Fn(&str, Json) -> Json + Send + Sync + 'static, + { + let user_data = typed_callback_user_data(self.host, callback); + let status = unsafe { + self.register_tool_sanitize_response_guardrail_raw( + name, + priority, + typed_tool_sanitize_trampoline::, + user_data, + Some(drop_typed_callback::), + ) + }; + finish_typed_registration::( + self.host, + status, + user_data, + "tool sanitize response guardrail", + ) + } + + /// Registers a typed tool conditional-execution guardrail. + pub fn register_tool_conditional_execution_guardrail( + &mut self, + name: &str, + priority: i32, + callback: F, + ) -> Result<()> + where + F: Fn(&str, &Json) -> Result> + Send + Sync + 'static, + { + let user_data = typed_callback_user_data(self.host, callback); + let status = unsafe { + self.register_tool_conditional_execution_guardrail_raw( + name, + priority, + typed_tool_conditional_trampoline::, + user_data, + Some(drop_typed_callback::), + ) + }; + finish_typed_registration::( + self.host, + status, + user_data, + "tool conditional execution guardrail", + ) + } + + /// Registers a typed tool request intercept. + pub fn register_tool_request_intercept( + &mut self, + name: &str, + priority: i32, + break_chain: bool, + callback: F, + ) -> Result<()> + where + F: Fn(&str, Json) -> Result + Send + Sync + 'static, + { + let user_data = typed_callback_user_data(self.host, callback); + let status = unsafe { + self.register_tool_request_intercept_raw( + name, + priority, + break_chain, + typed_tool_intercept_trampoline::, + user_data, + Some(drop_typed_callback::), + ) + }; + finish_typed_registration::(self.host, status, user_data, "tool request intercept") + } + + /// Registers a typed tool execution intercept. + pub fn register_tool_execution_intercept( + &mut self, + name: &str, + priority: i32, + callback: F, + ) -> Result<()> + where + F: for<'next> Fn(&str, Json, ToolNext<'next>) -> Result + Send + Sync + 'static, + { + let user_data = typed_callback_user_data(self.host, callback); + let status = unsafe { + self.register_tool_execution_intercept_raw( + name, + priority, + typed_tool_execution_trampoline::, + user_data, + Some(drop_typed_callback::), + ) + }; + finish_typed_registration::(self.host, status, user_data, "tool execution intercept") + } + + /// Registers a typed LLM sanitize-request guardrail. + pub fn register_llm_sanitize_request_guardrail( + &mut self, + name: &str, + priority: i32, + callback: F, + ) -> Result<()> + where + F: Fn(LlmRequest) -> LlmRequest + Send + Sync + 'static, + { + let user_data = typed_callback_user_data(self.host, callback); + let status = unsafe { + self.register_llm_sanitize_request_guardrail_raw( + name, + priority, + typed_llm_sanitize_request_trampoline::, + user_data, + Some(drop_typed_callback::), + ) + }; + finish_typed_registration::( + self.host, + status, + user_data, + "llm sanitize request guardrail", + ) + } + + /// Registers a typed LLM sanitize-response guardrail. + pub fn register_llm_sanitize_response_guardrail( + &mut self, + name: &str, + priority: i32, + callback: F, + ) -> Result<()> + where + F: Fn(Json) -> Json + Send + Sync + 'static, + { + let user_data = typed_callback_user_data(self.host, callback); + let status = unsafe { + self.register_llm_sanitize_response_guardrail_raw( + name, + priority, + typed_llm_sanitize_response_trampoline::, + user_data, + Some(drop_typed_callback::), + ) + }; + finish_typed_registration::( + self.host, + status, + user_data, + "llm sanitize response guardrail", + ) + } + + /// Registers a typed LLM conditional-execution guardrail. + pub fn register_llm_conditional_execution_guardrail( + &mut self, + name: &str, + priority: i32, + callback: F, + ) -> Result<()> + where + F: Fn(&LlmRequest) -> Result> + Send + Sync + 'static, + { + let user_data = typed_callback_user_data(self.host, callback); + let status = unsafe { + self.register_llm_conditional_execution_guardrail_raw( + name, + priority, + typed_llm_conditional_trampoline::, + user_data, + Some(drop_typed_callback::), + ) + }; + finish_typed_registration::( + self.host, + status, + user_data, + "llm conditional execution guardrail", + ) + } + + /// Registers a typed LLM request intercept. + pub fn register_llm_request_intercept( + &mut self, + name: &str, + priority: i32, + break_chain: bool, + callback: F, + ) -> Result<()> + where + F: Fn( + &str, + LlmRequest, + Option, + ) -> Result<(LlmRequest, Option)> + + Send + + Sync + + 'static, + { + let user_data = typed_callback_user_data(self.host, callback); + let status = unsafe { + self.register_llm_request_intercept_raw( + name, + priority, + break_chain, + typed_llm_request_intercept_trampoline::, + user_data, + Some(drop_typed_callback::), + ) + }; + finish_typed_registration::(self.host, status, user_data, "llm request intercept") + } + + /// Registers a typed LLM execution intercept. + pub fn register_llm_execution_intercept( + &mut self, + name: &str, + priority: i32, + callback: F, + ) -> Result<()> + where + F: for<'next> Fn(&str, LlmRequest, LlmNext<'next>) -> Result + Send + Sync + 'static, + { + let user_data = typed_callback_user_data(self.host, callback); + let status = unsafe { + self.register_llm_execution_intercept_raw( + name, + priority, + typed_llm_execution_trampoline::, + user_data, + Some(drop_typed_callback::), + ) + }; + finish_typed_registration::(self.host, status, user_data, "llm execution intercept") + } + + /// Registers a typed LLM stream execution intercept. + /// + /// Native ABI v1 represents stream execution as one JSON result. The host + /// wraps that result as a one-chunk stream. + pub fn register_llm_stream_execution_intercept( + &mut self, + name: &str, + priority: i32, + callback: F, + ) -> Result<()> + where + F: for<'next> Fn(&str, LlmRequest, LlmStreamNext<'next>) -> Result + + Send + + Sync + + 'static, + { + let user_data = typed_callback_user_data(self.host, callback); + let status = unsafe { + self.register_llm_stream_execution_intercept_raw( + name, + priority, + typed_llm_stream_execution_trampoline::, + user_data, + Some(drop_typed_callback::), + ) + }; + finish_typed_registration::( + self.host, + status, + user_data, + "llm stream execution intercept", + ) + } + + /// Registers a raw event subscriber callback. + /// + /// # Safety + /// `cb`, `user_data`, and `free_fn` must remain valid for every host + /// callback invocation until the host deregisters the callback or calls + /// `free_fn`. `free_fn` must match the allocation behind `user_data`. + pub unsafe fn register_subscriber_raw( + &mut self, + name: &str, + cb: NemoRelayNativeEventSubscriberCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus { + self.with_name(name, |host, name| unsafe { + (host.plugin_context_register_subscriber)(self.raw, name, cb, user_data, free_fn) + }) + } + + /// Registers a raw tool sanitize-request guardrail callback. + /// + /// # Safety + /// `cb`, `user_data`, and `free_fn` must remain valid for every host + /// callback invocation until the host deregisters the callback or calls + /// `free_fn`. `free_fn` must match the allocation behind `user_data`. + pub unsafe fn register_tool_sanitize_request_guardrail_raw( + &mut self, + name: &str, + priority: i32, + cb: NemoRelayNativeToolJsonCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus { + self.with_name(name, |host, name| unsafe { + (host.plugin_context_register_tool_sanitize_request_guardrail)( + self.raw, name, priority, cb, user_data, free_fn, + ) + }) + } + + /// Registers a raw tool sanitize-response guardrail callback. + /// + /// # Safety + /// `cb`, `user_data`, and `free_fn` must remain valid for every host + /// callback invocation until the host deregisters the callback or calls + /// `free_fn`. `free_fn` must match the allocation behind `user_data`. + pub unsafe fn register_tool_sanitize_response_guardrail_raw( + &mut self, + name: &str, + priority: i32, + cb: NemoRelayNativeToolJsonCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus { + self.with_name(name, |host, name| unsafe { + (host.plugin_context_register_tool_sanitize_response_guardrail)( + self.raw, name, priority, cb, user_data, free_fn, + ) + }) + } + + /// Registers a raw tool conditional-execution guardrail callback. + /// + /// # Safety + /// `cb`, `user_data`, and `free_fn` must remain valid for every host + /// callback invocation until the host deregisters the callback or calls + /// `free_fn`. `free_fn` must match the allocation behind `user_data`. + pub unsafe fn register_tool_conditional_execution_guardrail_raw( + &mut self, + name: &str, + priority: i32, + cb: NemoRelayNativeToolConditionalCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus { + self.with_name(name, |host, name| unsafe { + (host.plugin_context_register_tool_conditional_execution_guardrail)( + self.raw, name, priority, cb, user_data, free_fn, + ) + }) + } + + /// Registers a raw tool request intercept callback. + /// + /// # Safety + /// `cb`, `user_data`, and `free_fn` must remain valid for every host + /// callback invocation until the host deregisters the callback or calls + /// `free_fn`. `free_fn` must match the allocation behind `user_data`. + pub unsafe fn register_tool_request_intercept_raw( + &mut self, + name: &str, + priority: i32, + break_chain: bool, + cb: NemoRelayNativeToolJsonCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus { + self.with_name(name, |host, name| unsafe { + (host.plugin_context_register_tool_request_intercept)( + self.raw, + name, + priority, + break_chain, + cb, + user_data, + free_fn, + ) + }) + } + + /// Registers a raw tool execution intercept callback. + /// + /// # Safety + /// `cb`, `user_data`, and `free_fn` must remain valid for every host + /// callback invocation until the host deregisters the callback or calls + /// `free_fn`. `free_fn` must match the allocation behind `user_data`. + pub unsafe fn register_tool_execution_intercept_raw( + &mut self, + name: &str, + priority: i32, + cb: NemoRelayNativeToolExecutionCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus { + self.with_name(name, |host, name| unsafe { + (host.plugin_context_register_tool_execution_intercept)( + self.raw, name, priority, cb, user_data, free_fn, + ) + }) + } + + /// Registers a raw LLM sanitize-request guardrail callback. + /// + /// # Safety + /// `cb`, `user_data`, and `free_fn` must remain valid for every host + /// callback invocation until the host deregisters the callback or calls + /// `free_fn`. `free_fn` must match the allocation behind `user_data`. + pub unsafe fn register_llm_sanitize_request_guardrail_raw( + &mut self, + name: &str, + priority: i32, + cb: NemoRelayNativeLlmRequestCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus { + self.with_name(name, |host, name| unsafe { + (host.plugin_context_register_llm_sanitize_request_guardrail)( + self.raw, name, priority, cb, user_data, free_fn, + ) + }) + } + + /// Registers a raw LLM sanitize-response guardrail callback. + /// + /// # Safety + /// `cb`, `user_data`, and `free_fn` must remain valid for every host + /// callback invocation until the host deregisters the callback or calls + /// `free_fn`. `free_fn` must match the allocation behind `user_data`. + pub unsafe fn register_llm_sanitize_response_guardrail_raw( + &mut self, + name: &str, + priority: i32, + cb: NemoRelayNativeJsonCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus { + self.with_name(name, |host, name| unsafe { + (host.plugin_context_register_llm_sanitize_response_guardrail)( + self.raw, name, priority, cb, user_data, free_fn, + ) + }) + } + + /// Registers a raw LLM conditional-execution guardrail callback. + /// + /// # Safety + /// `cb`, `user_data`, and `free_fn` must remain valid for every host + /// callback invocation until the host deregisters the callback or calls + /// `free_fn`. `free_fn` must match the allocation behind `user_data`. + pub unsafe fn register_llm_conditional_execution_guardrail_raw( + &mut self, + name: &str, + priority: i32, + cb: NemoRelayNativeLlmConditionalCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus { + self.with_name(name, |host, name| unsafe { + (host.plugin_context_register_llm_conditional_execution_guardrail)( + self.raw, name, priority, cb, user_data, free_fn, + ) + }) + } + + /// Registers a raw LLM request intercept callback. + /// + /// # Safety + /// `cb`, `user_data`, and `free_fn` must remain valid for every host + /// callback invocation until the host deregisters the callback or calls + /// `free_fn`. `free_fn` must match the allocation behind `user_data`. + pub unsafe fn register_llm_request_intercept_raw( + &mut self, + name: &str, + priority: i32, + break_chain: bool, + cb: NemoRelayNativeLlmRequestInterceptCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus { + self.with_name(name, |host, name| unsafe { + (host.plugin_context_register_llm_request_intercept)( + self.raw, + name, + priority, + break_chain, + cb, + user_data, + free_fn, + ) + }) + } + + /// Registers a raw LLM execution intercept callback. + /// + /// # Safety + /// `cb`, `user_data`, and `free_fn` must remain valid for every host + /// callback invocation until the host deregisters the callback or calls + /// `free_fn`. `free_fn` must match the allocation behind `user_data`. + pub unsafe fn register_llm_execution_intercept_raw( + &mut self, + name: &str, + priority: i32, + cb: NemoRelayNativeLlmExecutionCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus { + self.with_name(name, |host, name| unsafe { + (host.plugin_context_register_llm_execution_intercept)( + self.raw, name, priority, cb, user_data, free_fn, + ) + }) + } + + /// Registers a raw LLM stream execution intercept callback. + /// + /// # Safety + /// `cb`, `user_data`, and `free_fn` must remain valid for every host + /// callback invocation until the host deregisters the callback or calls + /// `free_fn`. `free_fn` must match the allocation behind `user_data`. + pub unsafe fn register_llm_stream_execution_intercept_raw( + &mut self, + name: &str, + priority: i32, + cb: NemoRelayNativeLlmStreamExecutionCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, + ) -> NemoRelayStatus { + self.with_name(name, |host, name| unsafe { + (host.plugin_context_register_llm_stream_execution_intercept)( + self.raw, name, priority, cb, user_data, free_fn, + ) + }) + } + + fn with_name( + &self, + name: &str, + f: impl FnOnce(&NemoRelayNativeHostApiV1, *const NemoRelayNativeString) -> NemoRelayStatus, + ) -> NemoRelayStatus { + let name = match HostString::try_new(self.host, name) { + Ok(name) => name, + Err(status) => return status, + }; + f(self.host, name.as_ptr()) + } +} + +struct TypedCallback { + host: NemoRelayNativeHostApiV1, + callback: F, +} + +fn typed_callback_user_data(host: &NemoRelayNativeHostApiV1, callback: F) -> *mut c_void { + Box::into_raw(Box::new(TypedCallback { + host: *host, + callback, + })) as *mut c_void +} + +unsafe extern "C" fn drop_typed_callback(user_data: *mut c_void) { + if !user_data.is_null() { + let callback = unsafe { Box::from_raw(user_data as *mut TypedCallback) }; + let host = callback.host; + if catch_unwind(AssertUnwindSafe(|| drop(callback))).is_err() { + set_last_error(&host, "native plugin typed callback state drop panicked"); + } + } +} + +fn finish_typed_registration( + host: &NemoRelayNativeHostApiV1, + status: NemoRelayStatus, + user_data: *mut c_void, + label: &str, +) -> Result<()> { + if status == NemoRelayStatus::Ok { + Ok(()) + } else { + unsafe { drop_typed_callback::(user_data) }; + Err(status_error(host, status, label)) + } +} + +fn status_error(host: &NemoRelayNativeHostApiV1, status: NemoRelayStatus, label: &str) -> String { + match status { + NemoRelayStatus::Ok => format!("{label} succeeded"), + other => { + set_last_error(host, &format!("{label} failed: {other:?}")); + format!("{label} failed: {other:?}") + } + } +} + +fn callback_error(host: &NemoRelayNativeHostApiV1, message: String) -> NemoRelayStatus { + set_last_error(host, &message); + NemoRelayStatus::Internal +} + +fn callback_panic(host: &NemoRelayNativeHostApiV1, label: &str) -> NemoRelayStatus { + set_last_error(host, &format!("{label} panicked")); + NemoRelayStatus::Internal +} + +unsafe extern "C" fn typed_subscriber_trampoline( + user_data: *mut c_void, + event_json: *const NemoRelayNativeString, +) -> NemoRelayStatus +where + F: Fn(&Event) + Send + Sync + 'static, +{ + if user_data.is_null() { + return NemoRelayStatus::NullPointer; + } + let state = unsafe { &*(user_data as *const TypedCallback) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let event: Event = read_json_value(&state.host, event_json, "event")?; + (state.callback)(&event); + Ok::<_, NemoRelayStatus>(()) + })); + match result { + Ok(Ok(())) => NemoRelayStatus::Ok, + Ok(Err(status)) => status, + Err(_) => callback_panic(&state.host, "subscriber callback"), + } +} + +unsafe extern "C" fn typed_tool_sanitize_trampoline( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + payload_json: *const NemoRelayNativeString, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus +where + F: Fn(&str, Json) -> Json + Send + Sync + 'static, +{ + if user_data.is_null() || out_json.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out_json = ptr::null_mut() }; + let state = unsafe { &*(user_data as *const TypedCallback) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let name = read_required_host_string(&state.host, name, "tool name")?; + let payload: Json = read_json_value(&state.host, payload_json, "tool payload")?; + let output = (state.callback)(&name, payload); + Ok::<_, NemoRelayStatus>(write_json(&state.host, &output, out_json)) + })); + match result { + Ok(Ok(status)) => status, + Ok(Err(status)) => status, + Err(_) => callback_panic(&state.host, "tool sanitize callback"), + } +} + +unsafe extern "C" fn typed_tool_intercept_trampoline( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + payload_json: *const NemoRelayNativeString, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus +where + F: Fn(&str, Json) -> Result + Send + Sync + 'static, +{ + if user_data.is_null() || out_json.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out_json = ptr::null_mut() }; + let state = unsafe { &*(user_data as *const TypedCallback) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let name = read_required_host_string(&state.host, name, "tool name")?; + let payload: Json = read_json_value(&state.host, payload_json, "tool payload")?; + match (state.callback)(&name, payload) { + Ok(output) => Ok::<_, NemoRelayStatus>(write_json(&state.host, &output, out_json)), + Err(message) => Ok(callback_error(&state.host, message)), + } + })); + match result { + Ok(Ok(status)) => status, + Ok(Err(status)) => status, + Err(_) => callback_panic(&state.host, "tool intercept callback"), + } +} + +unsafe extern "C" fn typed_tool_conditional_trampoline( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + args_json: *const NemoRelayNativeString, + out_reason: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus +where + F: Fn(&str, &Json) -> Result> + Send + Sync + 'static, +{ + if user_data.is_null() || out_reason.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out_reason = ptr::null_mut() }; + let state = unsafe { &*(user_data as *const TypedCallback) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let name = read_required_host_string(&state.host, name, "tool name")?; + let args: Json = read_json_value(&state.host, args_json, "tool args")?; + match (state.callback)(&name, &args) { + Ok(Some(reason)) => { + let reason = + HostString::new(&state.host, &reason).ok_or(NemoRelayStatus::Internal)?; + unsafe { *out_reason = reason.ptr }; + std::mem::forget(reason); + Ok(NemoRelayStatus::Ok) + } + Ok(None) => { + unsafe { *out_reason = ptr::null_mut() }; + Ok(NemoRelayStatus::Ok) + } + Err(message) => Ok(callback_error(&state.host, message)), + } + })); + match result { + Ok(Ok(status)) => status, + Ok(Err(status)) => status, + Err(_) => callback_panic(&state.host, "tool conditional callback"), + } +} + +unsafe extern "C" fn typed_tool_execution_trampoline( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + args_json: *const NemoRelayNativeString, + next_fn: NemoRelayNativeToolNextFn, + next_ctx: *mut c_void, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus +where + F: for<'next> Fn(&str, Json, ToolNext<'next>) -> Result + Send + Sync + 'static, +{ + if user_data.is_null() || out_json.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out_json = ptr::null_mut() }; + let state = unsafe { &*(user_data as *const TypedCallback) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let name = read_required_host_string(&state.host, name, "tool name")?; + let args: Json = read_json_value(&state.host, args_json, "tool args")?; + let next = ToolNext { + host: &state.host, + next_fn, + next_ctx, + }; + match (state.callback)(&name, args, next) { + Ok(output) => Ok::<_, NemoRelayStatus>(write_json(&state.host, &output, out_json)), + Err(message) => Ok(callback_error(&state.host, message)), + } + })); + match result { + Ok(Ok(status)) => status, + Ok(Err(status)) => status, + Err(_) => callback_panic(&state.host, "tool execution callback"), + } +} + +unsafe extern "C" fn typed_llm_sanitize_request_trampoline( + user_data: *mut c_void, + request_json: *const NemoRelayNativeString, + out_request_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus +where + F: Fn(LlmRequest) -> LlmRequest + Send + Sync + 'static, +{ + if user_data.is_null() || out_request_json.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out_request_json = ptr::null_mut() }; + let state = unsafe { &*(user_data as *const TypedCallback) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let request: LlmRequest = read_json_value(&state.host, request_json, "LLM request")?; + let output = (state.callback)(request); + Ok::<_, NemoRelayStatus>(write_json(&state.host, &output, out_request_json)) + })); + match result { + Ok(Ok(status)) => status, + Ok(Err(status)) => status, + Err(_) => callback_panic(&state.host, "LLM sanitize request callback"), + } +} + +unsafe extern "C" fn typed_llm_sanitize_response_trampoline( + user_data: *mut c_void, + payload_json: *const NemoRelayNativeString, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus +where + F: Fn(Json) -> Json + Send + Sync + 'static, +{ + if user_data.is_null() || out_json.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out_json = ptr::null_mut() }; + let state = unsafe { &*(user_data as *const TypedCallback) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let payload: Json = read_json_value(&state.host, payload_json, "LLM response")?; + let output = (state.callback)(payload); + Ok::<_, NemoRelayStatus>(write_json(&state.host, &output, out_json)) + })); + match result { + Ok(Ok(status)) => status, + Ok(Err(status)) => status, + Err(_) => callback_panic(&state.host, "LLM sanitize response callback"), + } +} + +unsafe extern "C" fn typed_llm_conditional_trampoline( + user_data: *mut c_void, + request_json: *const NemoRelayNativeString, + out_reason: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus +where + F: Fn(&LlmRequest) -> Result> + Send + Sync + 'static, +{ + if user_data.is_null() || out_reason.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out_reason = ptr::null_mut() }; + let state = unsafe { &*(user_data as *const TypedCallback) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let request: LlmRequest = read_json_value(&state.host, request_json, "LLM request")?; + match (state.callback)(&request) { + Ok(Some(reason)) => { + let reason = + HostString::new(&state.host, &reason).ok_or(NemoRelayStatus::Internal)?; + unsafe { *out_reason = reason.ptr }; + std::mem::forget(reason); + Ok(NemoRelayStatus::Ok) + } + Ok(None) => { + unsafe { *out_reason = ptr::null_mut() }; + Ok(NemoRelayStatus::Ok) + } + Err(message) => Ok(callback_error(&state.host, message)), + } + })); + match result { + Ok(Ok(status)) => status, + Ok(Err(status)) => status, + Err(_) => callback_panic(&state.host, "LLM conditional callback"), + } +} + +unsafe extern "C" fn typed_llm_request_intercept_trampoline( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + request_json: *const NemoRelayNativeString, + annotated_json: *const NemoRelayNativeString, + out_request_json: *mut *mut NemoRelayNativeString, + out_annotated_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus +where + F: Fn( + &str, + LlmRequest, + Option, + ) -> Result<(LlmRequest, Option)> + + Send + + Sync + + 'static, +{ + if user_data.is_null() || out_request_json.is_null() || out_annotated_json.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { + *out_request_json = ptr::null_mut(); + *out_annotated_json = ptr::null_mut(); + } + let state = unsafe { &*(user_data as *const TypedCallback) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let name = read_required_host_string(&state.host, name, "LLM name")?; + let request: LlmRequest = read_json_value(&state.host, request_json, "LLM request")?; + let annotated: Option = + read_optional_json_value(&state.host, annotated_json, "annotated LLM request")?; + match (state.callback)(&name, request, annotated) { + Ok((request, annotated)) => { + let Some(request) = HostString::from_json(&state.host, &request) else { + set_last_error(&state.host, "failed to allocate LLM request output"); + return Ok(NemoRelayStatus::Internal); + }; + let annotated = match annotated.as_ref() { + Some(annotated) => { + let Some(annotated) = HostString::from_json(&state.host, annotated) else { + set_last_error( + &state.host, + "failed to allocate annotated LLM request output", + ); + return Ok(NemoRelayStatus::Internal); + }; + Some(annotated) + } + None => None, + }; + unsafe { + *out_request_json = request.ptr; + *out_annotated_json = annotated + .as_ref() + .map(|annotated| annotated.ptr) + .unwrap_or(ptr::null_mut()); + } + std::mem::forget(request); + if let Some(annotated) = annotated { + std::mem::forget(annotated); + } + Ok(NemoRelayStatus::Ok) + } + Err(message) => Ok(callback_error(&state.host, message)), + } + })); + match result { + Ok(Ok(status)) => status, + Ok(Err(status)) => status, + Err(_) => callback_panic(&state.host, "LLM request intercept callback"), + } +} + +unsafe extern "C" fn typed_llm_execution_trampoline( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + request_json: *const NemoRelayNativeString, + next_fn: NemoRelayNativeLlmNextFn, + next_ctx: *mut c_void, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus +where + F: for<'next> Fn(&str, LlmRequest, LlmNext<'next>) -> Result + Send + Sync + 'static, +{ + if user_data.is_null() || out_json.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out_json = ptr::null_mut() }; + let state = unsafe { &*(user_data as *const TypedCallback) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let name = read_required_host_string(&state.host, name, "LLM name")?; + let request: LlmRequest = read_json_value(&state.host, request_json, "LLM request")?; + let next = LlmNext { + host: &state.host, + next_fn, + next_ctx, + }; + match (state.callback)(&name, request, next) { + Ok(output) => Ok::<_, NemoRelayStatus>(write_json(&state.host, &output, out_json)), + Err(message) => Ok(callback_error(&state.host, message)), + } + })); + match result { + Ok(Ok(status)) => status, + Ok(Err(status)) => status, + Err(_) => callback_panic(&state.host, "LLM execution callback"), + } +} + +struct TypedLlmJsonStream { + host: NemoRelayNativeHostApiV1, + state: Mutex, +} + +struct TypedLlmJsonStreamState { + iter: LlmJsonStream, + finished: bool, +} + +fn native_stream_from_iter( + host: &NemoRelayNativeHostApiV1, + iter: LlmJsonStream, +) -> NemoRelayNativeLlmStreamV1 { + let state = Box::new(TypedLlmJsonStream { + host: *host, + state: Mutex::new(TypedLlmJsonStreamState { + iter, + finished: false, + }), + }); + NemoRelayNativeLlmStreamV1 { + struct_size: std::mem::size_of::(), + user_data: Box::into_raw(state).cast(), + next: Some(poll_typed_llm_json_stream), + cancel: Some(cancel_typed_llm_json_stream), + drop: Some(drop_typed_llm_json_stream), + } +} + +unsafe extern "C" fn poll_typed_llm_json_stream( + user_data: *mut c_void, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + if user_data.is_null() || out_json.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out_json = ptr::null_mut() }; + let stream = unsafe { &*(user_data as *const TypedLlmJsonStream) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let mut state = match stream.state.lock() { + Ok(state) => state, + Err(_) => { + set_last_error(&stream.host, "native plugin stream state lock poisoned"); + return NemoRelayStatus::Internal; + } + }; + if state.finished { + return NemoRelayStatus::StreamEnd; + } + match state.iter.next() { + Some(Ok(chunk)) => { + let status = write_json(&stream.host, &chunk, out_json); + if status != NemoRelayStatus::Ok { + state.finished = true; + } + status + } + Some(Err(message)) => { + state.finished = true; + callback_error(&stream.host, message) + } + None => { + state.finished = true; + NemoRelayStatus::StreamEnd + } + } + })); + result.unwrap_or_else(|_| callback_panic(&stream.host, "LLM stream callback")) +} + +unsafe extern "C" fn cancel_typed_llm_json_stream(user_data: *mut c_void) -> NemoRelayStatus { + if user_data.is_null() { + return NemoRelayStatus::NullPointer; + } + let stream = unsafe { &*(user_data as *const TypedLlmJsonStream) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let mut state = match stream.state.lock() { + Ok(state) => state, + Err(_) => { + set_last_error(&stream.host, "native plugin stream state lock poisoned"); + return NemoRelayStatus::Internal; + } + }; + state.finished = true; + NemoRelayStatus::Ok + })); + result.unwrap_or_else(|_| callback_panic(&stream.host, "LLM stream cancel callback")) +} + +unsafe extern "C" fn drop_typed_llm_json_stream(user_data: *mut c_void) { + if !user_data.is_null() { + let stream = unsafe { Box::from_raw(user_data as *mut TypedLlmJsonStream) }; + let host = stream.host; + if catch_unwind(AssertUnwindSafe(|| drop(stream))).is_err() { + set_last_error(&host, "native plugin LLM stream state drop panicked"); + } + } +} + +unsafe extern "C" fn typed_llm_stream_execution_trampoline( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + request_json: *const NemoRelayNativeString, + next_fn: NemoRelayNativeLlmStreamNextFn, + next_ctx: *mut c_void, + out_stream: *mut NemoRelayNativeLlmStreamV1, +) -> NemoRelayStatus +where + F: for<'next> Fn(&str, LlmRequest, LlmStreamNext<'next>) -> Result + + Send + + Sync + + 'static, +{ + if user_data.is_null() || out_stream.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out_stream = NemoRelayNativeLlmStreamV1::default() }; + let state = unsafe { &*(user_data as *const TypedCallback) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let name = read_required_host_string(&state.host, name, "LLM name")?; + let request: LlmRequest = read_json_value(&state.host, request_json, "LLM request")?; + let next = LlmStreamNext { + host: &state.host, + next_fn, + next_ctx, + }; + match (state.callback)(&name, request, next) { + Ok(stream) => { + unsafe { *out_stream = native_stream_from_iter(&state.host, stream) }; + Ok::<_, NemoRelayStatus>(NemoRelayStatus::Ok) + } + Err(message) => Ok(callback_error(&state.host, message)), + } + })); + match result { + Ok(Ok(status)) => status, + Ok(Err(status)) => status, + Err(_) => callback_panic(&state.host, "LLM stream execution callback"), + } +} + +struct HostString<'a> { + host: &'a NemoRelayNativeHostApiV1, + ptr: *mut NemoRelayNativeString, +} + +impl<'a> HostString<'a> { + fn try_new( + host: &'a NemoRelayNativeHostApiV1, + value: &str, + ) -> std::result::Result { + let mut out = ptr::null_mut(); + let status = unsafe { (host.string_new)(value.as_ptr(), value.len(), &mut out) }; + if status != NemoRelayStatus::Ok { + return Err(status); + } + if out.is_null() { + return Err(NemoRelayStatus::Internal); + } + Ok(Self { host, ptr: out }) + } + + fn new(host: &'a NemoRelayNativeHostApiV1, value: &str) -> Option { + Self::try_new(host, value).ok() + } + + fn from_json(host: &'a NemoRelayNativeHostApiV1, value: &T) -> Option { + serde_json::to_string(value) + .ok() + .and_then(|json| Self::new(host, &json)) + } + + fn as_ptr(&self) -> *const NemoRelayNativeString { + self.ptr + } +} + +impl Drop for HostString<'_> { + fn drop(&mut self) { + unsafe { (self.host.string_free)(self.ptr) }; + } +} + +struct OptionalHostJson<'a>(Option>); + +impl<'a> OptionalHostJson<'a> { + fn new(host: &'a NemoRelayNativeHostApiV1, value: Option<&Json>) -> Result { + match value { + Some(value) => HostString::from_json(host, value) + .map(|value| Self(Some(value))) + .ok_or_else(|| "failed to allocate JSON host string".into()), + None => Ok(Self(None)), + } + } + + fn as_ptr(&self) -> *const NemoRelayNativeString { + self.0 + .as_ref() + .map(HostString::as_ptr) + .unwrap_or(ptr::null()) + } +} + +struct PluginState

{ + host: NemoRelayNativeHostApiV1, + plugin: Mutex

, +} + +unsafe extern "C" fn drop_plugin_state(user_data: *mut c_void) { + if !user_data.is_null() { + let state = unsafe { Box::from_raw(user_data as *mut PluginState

) }; + let host = state.host; + if catch_unwind(AssertUnwindSafe(|| drop(state))).is_err() { + set_last_error(&host, "native plugin state drop panicked"); + } + } +} + +unsafe extern "C" fn validate_trampoline( + user_data: *mut c_void, + plugin_config_json: *const NemoRelayNativeString, + out_diagnostics_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + if user_data.is_null() || out_diagnostics_json.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out_diagnostics_json = ptr::null_mut() }; + let state = unsafe { &*(user_data as *const PluginState

) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let config = match read_json_object(&state.host, plugin_config_json) { + Ok(config) => config, + Err(status) => return status, + }; + let plugin = match state.plugin.lock() { + Ok(plugin) => plugin, + Err(_) => { + set_last_error(&state.host, "native plugin state lock poisoned"); + return NemoRelayStatus::Internal; + } + }; + let diagnostics = plugin.validate(&config); + write_json(&state.host, &diagnostics, out_diagnostics_json) + })); + result.unwrap_or_else(|_| { + set_last_error(&state.host, "native plugin validate callback panicked"); + NemoRelayStatus::Internal + }) +} + +unsafe extern "C" fn register_trampoline( + user_data: *mut c_void, + plugin_config_json: *const NemoRelayNativeString, + ctx: *mut NemoRelayNativePluginContext, +) -> NemoRelayStatus { + if user_data.is_null() || ctx.is_null() { + return NemoRelayStatus::NullPointer; + } + let state = unsafe { &*(user_data as *const PluginState

) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let config = match read_json_object(&state.host, plugin_config_json) { + Ok(config) => config, + Err(status) => return status, + }; + let mut ctx = unsafe { PluginContext::from_raw(&state.host, ctx) }; + let mut plugin = match state.plugin.lock() { + Ok(plugin) => plugin, + Err(_) => { + set_last_error(&state.host, "native plugin state lock poisoned"); + return NemoRelayStatus::Internal; + } + }; + match plugin.register(&config, &mut ctx) { + Ok(()) => NemoRelayStatus::Ok, + Err(message) => { + set_last_error(&state.host, &message); + NemoRelayStatus::Internal + } + } + })); + result.unwrap_or_else(|_| { + set_last_error(&state.host, "native plugin register callback panicked"); + NemoRelayStatus::Internal + }) +} + +fn read_json_object( + host: &NemoRelayNativeHostApiV1, + value: *const NemoRelayNativeString, +) -> std::result::Result, NemoRelayStatus> { + let value: Json = read_json_value(host, value, "plugin config")?; + match value { + Json::Object(map) => Ok(map), + _ => { + set_last_error(host, "plugin config must be a JSON object"); + Err(NemoRelayStatus::InvalidJson) + } + } +} + +fn read_json_value( + host: &NemoRelayNativeHostApiV1, + value: *const NemoRelayNativeString, + label: &str, +) -> std::result::Result { + let text = read_required_host_string(host, value, label)?; + serde_json::from_str::(&text).map_err(|error| { + set_last_error(host, &format!("{label} was invalid JSON: {error}")); + NemoRelayStatus::InvalidJson + }) +} + +fn read_optional_json_value( + host: &NemoRelayNativeHostApiV1, + value: *const NemoRelayNativeString, + label: &str, +) -> std::result::Result, NemoRelayStatus> { + if value.is_null() { + Ok(None) + } else { + read_json_value(host, value, label).map(Some) + } +} + +enum HostStringReadError { + Null, + InvalidUtf8, +} + +fn read_required_host_string( + host: &NemoRelayNativeHostApiV1, + value: *const NemoRelayNativeString, + label: &str, +) -> std::result::Result { + match read_host_string(host, value) { + Ok(value) => Ok(value), + Err(HostStringReadError::Null) => { + set_last_error(host, &format!("{label} was null")); + Err(NemoRelayStatus::NullPointer) + } + Err(HostStringReadError::InvalidUtf8) => { + set_last_error(host, &format!("{label} contained invalid UTF-8")); + Err(NemoRelayStatus::InvalidUtf8) + } + } +} + +fn read_host_string( + host: &NemoRelayNativeHostApiV1, + value: *const NemoRelayNativeString, +) -> std::result::Result { + if value.is_null() { + return Err(HostStringReadError::Null); + } + let len = unsafe { (host.string_len)(value) }; + let data = unsafe { (host.string_data)(value) }; + if data.is_null() && len > 0 { + return Err(HostStringReadError::InvalidUtf8); + } + let bytes = if len == 0 { + &[][..] + } else { + unsafe { std::slice::from_raw_parts(data, len) } + }; + std::str::from_utf8(bytes) + .map(str::to_owned) + .map_err(|_| HostStringReadError::InvalidUtf8) +} + +fn write_json( + host: &NemoRelayNativeHostApiV1, + value: &T, + out: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + if out.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out = ptr::null_mut() }; + let json = match serde_json::to_value(value) { + Ok(value) => value, + Err(error) => { + set_last_error(host, &format!("failed to serialize JSON: {error}")); + return NemoRelayStatus::Internal; + } + }; + let Some(handle) = HostString::from_json(host, &json) else { + set_last_error(host, "failed to allocate host string"); + return NemoRelayStatus::Internal; + }; + unsafe { *out = handle.ptr }; + std::mem::forget(handle); + NemoRelayStatus::Ok +} + +fn set_last_error(host: &NemoRelayNativeHostApiV1, message: &str) { + if let Some(message) = HostString::new(host, message) { + unsafe { (host.last_error_set)(message.as_ptr()) }; + } +} + +/// Sets a host last-error message from generated entry symbols. +/// +/// # Safety +/// `host` must be null or point to a valid [`NemoRelayNativeHostApiV1`]. +#[doc(hidden)] +pub unsafe fn __set_last_error_from_entry(host: *const NemoRelayNativeHostApiV1, message: &str) { + if !host.is_null() { + set_last_error(unsafe { &*host }, message); + } +} + +/// Initializes a native plugin descriptor for a Rust SDK plugin value. +/// +/// # Safety +/// `host` must point to a valid [`NemoRelayNativeHostApiV1`] for the duration +/// of the call, and `out` must point to writable memory for one +/// [`NemoRelayNativePluginV1`] descriptor. +pub unsafe fn export_plugin( + host: *const NemoRelayNativeHostApiV1, + out: *mut NemoRelayNativePluginV1, + plugin: P, +) -> NemoRelayStatus { + if host.is_null() || out.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out = NemoRelayNativePluginV1::default() }; + let host_ref = unsafe { &*host }; + export_plugin_checked(host_ref, out, plugin) +} + +/// Initializes a native plugin descriptor from a constructor callback. +/// +/// # Safety +/// `host` must point to a valid [`NemoRelayNativeHostApiV1`] for the duration +/// of the call, and `out` must point to writable memory for one +/// [`NemoRelayNativePluginV1`] descriptor. +#[doc(hidden)] +pub unsafe fn __export_plugin_from_constructor( + host: *const NemoRelayNativeHostApiV1, + out: *mut NemoRelayNativePluginV1, + constructor: F, +) -> NemoRelayStatus +where + P: NativePlugin, + F: FnOnce() -> P, +{ + if host.is_null() || out.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out = NemoRelayNativePluginV1::default() }; + let host_ref = unsafe { &*host }; + if host_ref.abi_version != NEMO_RELAY_NATIVE_ABI_VERSION { + return NemoRelayStatus::InvalidArg; + } + if host_ref.struct_size < std::mem::size_of::() { + return NemoRelayStatus::InvalidArg; + } + + export_plugin_checked(host_ref, out, constructor()) +} + +fn export_plugin_checked( + host_ref: &NemoRelayNativeHostApiV1, + out: *mut NemoRelayNativePluginV1, + plugin: P, +) -> NemoRelayStatus { + if host_ref.abi_version != NEMO_RELAY_NATIVE_ABI_VERSION { + return NemoRelayStatus::InvalidArg; + } + if host_ref.struct_size < std::mem::size_of::() { + return NemoRelayStatus::InvalidArg; + } + + let kind = plugin.plugin_kind().to_owned(); + let allows_multiple_components = plugin.allows_multiple_components(); + let Some(kind_handle) = HostString::new(host_ref, &kind) else { + return NemoRelayStatus::Internal; + }; + let state = Box::new(PluginState { + host: *host_ref, + plugin: Mutex::new(plugin), + }); + unsafe { + *out = NemoRelayNativePluginV1 { + struct_size: std::mem::size_of::(), + plugin_kind: kind_handle.ptr, + allows_multiple_components, + user_data: Box::into_raw(state) as *mut c_void, + validate: Some(validate_trampoline::

), + register: Some(register_trampoline::

), + drop: Some(drop_plugin_state::

), + }; + } + std::mem::forget(kind_handle); + NemoRelayStatus::Ok +} + +/// Exports a concrete plugin constructor as a native plugin entry symbol body. +#[macro_export] +macro_rules! nemo_relay_plugin { + ($symbol:ident, $constructor:expr) => { + #[doc = "Native plugin entry symbol generated by `nemo_relay_plugin!`."] + #[unsafe(no_mangle)] + pub unsafe extern "C" fn $symbol( + host: *const $crate::NemoRelayNativeHostApiV1, + out: *mut $crate::NemoRelayNativePluginV1, + ) -> $crate::NemoRelayStatus { + match ::std::panic::catch_unwind(::std::panic::AssertUnwindSafe(|| unsafe { + $crate::__export_plugin_from_constructor(host, out, $constructor) + })) { + Ok(status) => status, + Err(_) => { + unsafe { + $crate::__set_last_error_from_entry( + host, + "native plugin entry callback panicked", + ) + }; + $crate::NemoRelayStatus::Internal + } + } + } + }; +} diff --git a/crates/plugin/tests/typed_callbacks.rs b/crates/plugin/tests/typed_callbacks.rs new file mode 100644 index 00000000..214827c1 --- /dev/null +++ b/crates/plugin/tests/typed_callbacks.rs @@ -0,0 +1,4516 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Public-API tests for typed native plugin callback registration. + +use std::collections::VecDeque; +use std::ffi::c_void; +use std::mem::{align_of, offset_of, size_of}; +use std::ptr::{self, NonNull}; +use std::sync::{ + Arc, Mutex, MutexGuard, + atomic::{AtomicUsize, Ordering}, +}; + +use nemo_relay_plugin::{ + AnnotatedLlmRequest, ConfigDiagnostic, DiagnosticLevel, Event, Json, LlmJsonStream, LlmNext, + LlmRequest, LlmStream, LlmStreamNext, NEMO_RELAY_NATIVE_ABI_VERSION, NativePlugin, + NemoRelayNativeEventSubscriberCb, NemoRelayNativeFreeFn, NemoRelayNativeHostApiV1, + NemoRelayNativeJsonCb, NemoRelayNativeLlmConditionalCb, NemoRelayNativeLlmExecutionCb, + NemoRelayNativeLlmRequestCb, NemoRelayNativeLlmRequestInterceptCb, + NemoRelayNativeLlmStreamExecutionCb, NemoRelayNativeLlmStreamV1, NemoRelayNativePluginContext, + NemoRelayNativePluginV1, NemoRelayNativeScopeHandle, NemoRelayNativeScopeStack, + NemoRelayNativeScopeStackBinding, NemoRelayNativeScopeType, NemoRelayNativeString, + NemoRelayNativeToolConditionalCb, NemoRelayNativeToolExecutionCb, NemoRelayNativeToolJsonCb, + NemoRelayNativeWithScopeStackCb, NemoRelayStatus, PluginContext, PluginRuntime, ScopeType, + ToolNext, +}; +use serde_json::{Map, json}; + +struct TestString(Vec); + +struct RegisteredSubscriber { + name: String, + cb: NemoRelayNativeEventSubscriberCb, + user_data: usize, + free_fn: NemoRelayNativeFreeFn, +} + +impl RegisteredSubscriber { + unsafe fn free(self) { + if let Some(free_fn) = self.free_fn { + unsafe { free_fn(self.user_data as *mut c_void) }; + } + } +} + +struct RegisteredToolJson { + name: String, + priority: i32, + break_chain: bool, + cb: NemoRelayNativeToolJsonCb, + user_data: usize, + free_fn: NemoRelayNativeFreeFn, +} + +impl RegisteredToolJson { + unsafe fn free(self) { + if let Some(free_fn) = self.free_fn { + unsafe { free_fn(self.user_data as *mut c_void) }; + } + } +} + +struct RegisteredToolConditional { + name: String, + priority: i32, + cb: NemoRelayNativeToolConditionalCb, + user_data: usize, + free_fn: NemoRelayNativeFreeFn, +} + +impl RegisteredToolConditional { + unsafe fn free(self) { + if let Some(free_fn) = self.free_fn { + unsafe { free_fn(self.user_data as *mut c_void) }; + } + } +} + +struct RegisteredToolExecution { + name: String, + priority: i32, + cb: NemoRelayNativeToolExecutionCb, + user_data: usize, + free_fn: NemoRelayNativeFreeFn, +} + +impl RegisteredToolExecution { + unsafe fn free(self) { + if let Some(free_fn) = self.free_fn { + unsafe { free_fn(self.user_data as *mut c_void) }; + } + } +} + +struct RegisteredLlmRequest { + name: String, + priority: i32, + cb: NemoRelayNativeLlmRequestCb, + user_data: usize, + free_fn: NemoRelayNativeFreeFn, +} + +impl RegisteredLlmRequest { + unsafe fn free(self) { + if let Some(free_fn) = self.free_fn { + unsafe { free_fn(self.user_data as *mut c_void) }; + } + } +} + +struct RegisteredLlmJson { + name: String, + priority: i32, + cb: NemoRelayNativeJsonCb, + user_data: usize, + free_fn: NemoRelayNativeFreeFn, +} + +impl RegisteredLlmJson { + unsafe fn free(self) { + if let Some(free_fn) = self.free_fn { + unsafe { free_fn(self.user_data as *mut c_void) }; + } + } +} + +struct RegisteredLlmConditional { + name: String, + priority: i32, + cb: NemoRelayNativeLlmConditionalCb, + user_data: usize, + free_fn: NemoRelayNativeFreeFn, +} + +impl RegisteredLlmConditional { + unsafe fn free(self) { + if let Some(free_fn) = self.free_fn { + unsafe { free_fn(self.user_data as *mut c_void) }; + } + } +} + +struct RegisteredLlmExecution { + name: String, + priority: i32, + cb: NemoRelayNativeLlmExecutionCb, + user_data: usize, + free_fn: NemoRelayNativeFreeFn, +} + +impl RegisteredLlmExecution { + unsafe fn free(self) { + if let Some(free_fn) = self.free_fn { + unsafe { free_fn(self.user_data as *mut c_void) }; + } + } +} + +struct RegisteredLlmStreamExecution { + name: String, + priority: i32, + cb: NemoRelayNativeLlmStreamExecutionCb, + user_data: usize, + free_fn: NemoRelayNativeFreeFn, +} + +impl RegisteredLlmStreamExecution { + unsafe fn free(self) { + if let Some(free_fn) = self.free_fn { + unsafe { free_fn(self.user_data as *mut c_void) }; + } + } +} + +struct RegisteredLlmRequestIntercept { + name: String, + priority: i32, + break_chain: bool, + cb: NemoRelayNativeLlmRequestInterceptCb, + user_data: usize, + free_fn: NemoRelayNativeFreeFn, +} + +impl RegisteredLlmRequestIntercept { + unsafe fn free(self) { + if let Some(free_fn) = self.free_fn { + unsafe { free_fn(self.user_data as *mut c_void) }; + } + } +} + +trait CapturedRegistration { + unsafe fn free(self); +} + +macro_rules! impl_captured_registration { + ($($ty:ty),+ $(,)?) => { + $( + impl CapturedRegistration for $ty { + unsafe fn free(self) { + unsafe { <$ty>::free(self) }; + } + } + )+ + }; +} + +impl_captured_registration!( + RegisteredSubscriber, + RegisteredToolJson, + RegisteredToolConditional, + RegisteredToolExecution, + RegisteredLlmRequest, + RegisteredLlmJson, + RegisteredLlmConditional, + RegisteredLlmExecution, + RegisteredLlmStreamExecution, + RegisteredLlmRequestIntercept, +); + +fn replace_registration(slot: &Mutex>, registration: T) { + let previous = { + let mut slot = slot.lock().unwrap(); + slot.replace(registration) + }; + if let Some(previous) = previous { + unsafe { previous.free() }; + } +} + +fn clear_registration(slot: &Mutex>) { + let registration = { + let mut slot = slot.lock().unwrap(); + slot.take() + }; + if let Some(registration) = registration { + unsafe { registration.free() }; + } +} + +static TEST_LOCK: Mutex<()> = Mutex::new(()); +static LAST_ERROR: Mutex> = Mutex::new(None); +static REGISTRATION_STATUS: Mutex = Mutex::new(NemoRelayStatus::Ok); +static STRING_NEW_REMAINING_SUCCESSES: Mutex> = Mutex::new(None); +static STRING_NEW_RETURNS_NULL: Mutex = Mutex::new(false); +static SCOPE_GET_CURRENT_STATUS: Mutex = Mutex::new(NemoRelayStatus::Ok); +static SCOPE_GET_CURRENT_RETURNS_NULL: Mutex = Mutex::new(false); +static SCOPE_PUSH_STATUS: Mutex = Mutex::new(NemoRelayStatus::Ok); +static SCOPE_PUSH_RETURNS_NULL: Mutex = Mutex::new(false); +static SCOPE_POP_STATUS: Mutex = Mutex::new(NemoRelayStatus::Ok); +static EMIT_MARK_STATUS: Mutex = Mutex::new(NemoRelayStatus::Ok); +static SCOPE_STACK_CREATE_STATUS: Mutex = Mutex::new(NemoRelayStatus::Ok); +static SCOPE_STACK_CREATE_RETURNS_NULL: Mutex = Mutex::new(false); +static SCOPE_STACK_SET_THREAD_STATUS: Mutex = Mutex::new(NemoRelayStatus::Ok); +static SCOPE_STACK_CAPTURE_THREAD_STATUS: Mutex = Mutex::new(NemoRelayStatus::Ok); +static SCOPE_STACK_CAPTURE_THREAD_RETURNS_NULL: Mutex = Mutex::new(false); +static SCOPE_STACK_RESTORE_THREAD_STATUS: Mutex = Mutex::new(NemoRelayStatus::Ok); +static SCOPE_STACK_WITH_CURRENT_STATUS: Mutex = Mutex::new(NemoRelayStatus::Ok); +static STRING_LIVE_COUNT: AtomicUsize = AtomicUsize::new(0); +static RUNTIME_CALLS: Mutex> = Mutex::new(Vec::new()); +static SCOPE_HANDLE_FREES: AtomicUsize = AtomicUsize::new(0); +static SCOPE_STACK_FREES: AtomicUsize = AtomicUsize::new(0); +static SCOPE_STACK_BINDING_FREES: AtomicUsize = AtomicUsize::new(0); +static SCOPE_STACK_BINDING_RESTORES: AtomicUsize = AtomicUsize::new(0); +static SUBSCRIBER_REGISTRATION: Mutex> = Mutex::new(None); +static TOOL_JSON_REGISTRATION: Mutex> = Mutex::new(None); +static TOOL_CONDITIONAL_REGISTRATION: Mutex> = Mutex::new(None); +static TOOL_EXECUTION_REGISTRATION: Mutex> = Mutex::new(None); +static LLM_REQUEST_REGISTRATION: Mutex> = Mutex::new(None); +static LLM_JSON_REGISTRATION: Mutex> = Mutex::new(None); +static LLM_CONDITIONAL_REGISTRATION: Mutex> = Mutex::new(None); +static LLM_EXECUTION_REGISTRATION: Mutex> = Mutex::new(None); +static LLM_STREAM_EXECUTION_REGISTRATION: Mutex> = + Mutex::new(None); +static LLM_REQUEST_INTERCEPT_REGISTRATION: Mutex> = + Mutex::new(None); + +#[test] +fn native_abi_v1_struct_sizes_are_self_describing() { + assert_eq!(NEMO_RELAY_NATIVE_ABI_VERSION, 1); + assert_eq!( + size_of::(), + test_host().struct_size + ); + assert_eq!( + size_of::(), + NemoRelayNativePluginV1::default().struct_size + ); + assert_eq!( + size_of::(), + NemoRelayNativeLlmStreamV1::default().struct_size + ); + assert_eq!(NemoRelayStatus::StreamEnd as i32, 10); + + #[cfg(target_pointer_width = "64")] + { + assert_eq!(align_of::(), 8); + assert_eq!(size_of::(), 272); + assert_eq!( + host_api_offsets(), + [ + 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, + 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 264, + ] + ); + assert_eq!(align_of::(), 8); + assert_eq!(size_of::(), 56); + assert_eq!(plugin_offsets(), [0, 8, 16, 24, 32, 40, 48]); + assert_eq!(align_of::(), 8); + assert_eq!(size_of::(), 40); + assert_eq!(stream_offsets(), [0, 8, 16, 24, 32]); + } + + #[cfg(target_pointer_width = "32")] + { + assert_eq!(align_of::(), 4); + assert_eq!(size_of::(), 136); + assert_eq!( + host_api_offsets(), + [ + 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80, + 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 132, + ] + ); + assert_eq!(align_of::(), 4); + assert_eq!(size_of::(), 28); + assert_eq!(plugin_offsets(), [0, 4, 8, 12, 16, 20, 24]); + assert_eq!(align_of::(), 4); + assert_eq!(size_of::(), 20); + assert_eq!(stream_offsets(), [0, 4, 8, 12, 16]); + } +} + +fn host_api_offsets() -> [usize; 34] { + [ + offset_of!(NemoRelayNativeHostApiV1, abi_version), + offset_of!(NemoRelayNativeHostApiV1, struct_size), + offset_of!(NemoRelayNativeHostApiV1, relay_version), + offset_of!(NemoRelayNativeHostApiV1, string_new), + offset_of!(NemoRelayNativeHostApiV1, string_data), + offset_of!(NemoRelayNativeHostApiV1, string_len), + offset_of!(NemoRelayNativeHostApiV1, string_free), + offset_of!(NemoRelayNativeHostApiV1, last_error_clear), + offset_of!(NemoRelayNativeHostApiV1, last_error_set), + offset_of!(NemoRelayNativeHostApiV1, plugin_context_register_subscriber), + offset_of!( + NemoRelayNativeHostApiV1, + plugin_context_register_tool_sanitize_request_guardrail + ), + offset_of!( + NemoRelayNativeHostApiV1, + plugin_context_register_tool_sanitize_response_guardrail + ), + offset_of!( + NemoRelayNativeHostApiV1, + plugin_context_register_tool_conditional_execution_guardrail + ), + offset_of!( + NemoRelayNativeHostApiV1, + plugin_context_register_tool_request_intercept + ), + offset_of!( + NemoRelayNativeHostApiV1, + plugin_context_register_tool_execution_intercept + ), + offset_of!( + NemoRelayNativeHostApiV1, + plugin_context_register_llm_sanitize_request_guardrail + ), + offset_of!( + NemoRelayNativeHostApiV1, + plugin_context_register_llm_sanitize_response_guardrail + ), + offset_of!( + NemoRelayNativeHostApiV1, + plugin_context_register_llm_conditional_execution_guardrail + ), + offset_of!( + NemoRelayNativeHostApiV1, + plugin_context_register_llm_request_intercept + ), + offset_of!( + NemoRelayNativeHostApiV1, + plugin_context_register_llm_execution_intercept + ), + offset_of!( + NemoRelayNativeHostApiV1, + plugin_context_register_llm_stream_execution_intercept + ), + offset_of!(NemoRelayNativeHostApiV1, scope_handle_free), + offset_of!(NemoRelayNativeHostApiV1, scope_get_current), + offset_of!(NemoRelayNativeHostApiV1, scope_push), + offset_of!(NemoRelayNativeHostApiV1, scope_pop), + offset_of!(NemoRelayNativeHostApiV1, emit_mark), + offset_of!(NemoRelayNativeHostApiV1, scope_stack_create), + offset_of!(NemoRelayNativeHostApiV1, scope_stack_free), + offset_of!(NemoRelayNativeHostApiV1, scope_stack_set_thread), + offset_of!(NemoRelayNativeHostApiV1, scope_stack_capture_thread), + offset_of!(NemoRelayNativeHostApiV1, scope_stack_restore_thread), + offset_of!(NemoRelayNativeHostApiV1, scope_stack_binding_free), + offset_of!(NemoRelayNativeHostApiV1, scope_stack_active), + offset_of!(NemoRelayNativeHostApiV1, scope_stack_with_current), + ] +} + +fn plugin_offsets() -> [usize; 7] { + [ + offset_of!(NemoRelayNativePluginV1, struct_size), + offset_of!(NemoRelayNativePluginV1, plugin_kind), + offset_of!(NemoRelayNativePluginV1, allows_multiple_components), + offset_of!(NemoRelayNativePluginV1, user_data), + offset_of!(NemoRelayNativePluginV1, validate), + offset_of!(NemoRelayNativePluginV1, register), + offset_of!(NemoRelayNativePluginV1, drop), + ] +} + +fn stream_offsets() -> [usize; 5] { + [ + offset_of!(NemoRelayNativeLlmStreamV1, struct_size), + offset_of!(NemoRelayNativeLlmStreamV1, user_data), + offset_of!(NemoRelayNativeLlmStreamV1, next), + offset_of!(NemoRelayNativeLlmStreamV1, cancel), + offset_of!(NemoRelayNativeLlmStreamV1, drop), + ] +} + +unsafe extern "C" fn test_string_new( + data: *const u8, + len: usize, + out: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + if out.is_null() || (data.is_null() && len > 0) { + return NemoRelayStatus::NullPointer; + } + { + let mut remaining = STRING_NEW_REMAINING_SUCCESSES.lock().unwrap(); + if let Some(remaining) = remaining.as_mut() { + if *remaining == 0 { + return NemoRelayStatus::Internal; + } + *remaining -= 1; + } + } + if *STRING_NEW_RETURNS_NULL.lock().unwrap() { + unsafe { *out = ptr::null_mut() }; + return NemoRelayStatus::Ok; + } + let bytes = if len == 0 { + Vec::new() + } else { + unsafe { std::slice::from_raw_parts(data, len) }.to_vec() + }; + unsafe { *out = Box::into_raw(Box::new(TestString(bytes))).cast() }; + STRING_LIVE_COUNT.fetch_add(1, Ordering::SeqCst); + NemoRelayStatus::Ok +} + +unsafe extern "C" fn test_string_data(value: *const NemoRelayNativeString) -> *const u8 { + if value.is_null() { + return ptr::null(); + } + unsafe { &*(value.cast::()) }.0.as_ptr() +} + +unsafe extern "C" fn test_string_len(value: *const NemoRelayNativeString) -> usize { + if value.is_null() { + return 0; + } + unsafe { &*(value.cast::()) }.0.len() +} + +unsafe extern "C" fn test_string_free(value: *mut NemoRelayNativeString) { + if !value.is_null() { + drop(unsafe { Box::from_raw(value.cast::()) }); + STRING_LIVE_COUNT.fetch_sub(1, Ordering::SeqCst); + } +} + +unsafe extern "C" fn test_last_error_clear() { + *LAST_ERROR.lock().unwrap() = None; +} + +unsafe extern "C" fn test_last_error_set(message: *const NemoRelayNativeString) { + let host = test_host(); + *LAST_ERROR.lock().unwrap() = read_host_string(&host, message); +} + +unsafe extern "C" fn capture_register_subscriber( + _ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + cb: NemoRelayNativeEventSubscriberCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, +) -> NemoRelayStatus { + let status = *REGISTRATION_STATUS.lock().unwrap(); + if status == NemoRelayStatus::Ok { + let host = test_host(); + let name = match required_host_string(&host, name) { + Ok(name) => name, + Err(status) => return status, + }; + replace_registration( + &SUBSCRIBER_REGISTRATION, + RegisteredSubscriber { + name, + cb, + user_data: user_data as usize, + free_fn, + }, + ); + } + status +} + +unsafe extern "C" fn capture_tool_json( + _ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeToolJsonCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, +) -> NemoRelayStatus { + let status = *REGISTRATION_STATUS.lock().unwrap(); + if status == NemoRelayStatus::Ok { + let host = test_host(); + let name = match required_host_string(&host, name) { + Ok(name) => name, + Err(status) => return status, + }; + replace_registration( + &TOOL_JSON_REGISTRATION, + RegisteredToolJson { + name, + priority, + break_chain: false, + cb, + user_data: user_data as usize, + free_fn, + }, + ); + } + status +} + +unsafe extern "C" fn passthrough_tool_json_cb( + _user_data: *mut c_void, + _name: *const NemoRelayNativeString, + _payload_json: *const NemoRelayNativeString, + _out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + NemoRelayStatus::Ok +} + +unsafe extern "C" fn capture_tool_conditional( + _ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeToolConditionalCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, +) -> NemoRelayStatus { + let status = *REGISTRATION_STATUS.lock().unwrap(); + if status == NemoRelayStatus::Ok { + let host = test_host(); + let name = match required_host_string(&host, name) { + Ok(name) => name, + Err(status) => return status, + }; + replace_registration( + &TOOL_CONDITIONAL_REGISTRATION, + RegisteredToolConditional { + name, + priority, + cb, + user_data: user_data as usize, + free_fn, + }, + ); + } + status +} + +unsafe extern "C" fn capture_tool_request_intercept( + _ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + break_chain: bool, + cb: NemoRelayNativeToolJsonCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, +) -> NemoRelayStatus { + let status = *REGISTRATION_STATUS.lock().unwrap(); + if status == NemoRelayStatus::Ok { + let host = test_host(); + let name = match required_host_string(&host, name) { + Ok(name) => name, + Err(status) => return status, + }; + replace_registration( + &TOOL_JSON_REGISTRATION, + RegisteredToolJson { + name, + priority, + break_chain, + cb, + user_data: user_data as usize, + free_fn, + }, + ); + } + status +} + +unsafe extern "C" fn capture_tool_execution( + _ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeToolExecutionCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, +) -> NemoRelayStatus { + let status = *REGISTRATION_STATUS.lock().unwrap(); + if status == NemoRelayStatus::Ok { + let host = test_host(); + let name = match required_host_string(&host, name) { + Ok(name) => name, + Err(status) => return status, + }; + replace_registration( + &TOOL_EXECUTION_REGISTRATION, + RegisteredToolExecution { + name, + priority, + cb, + user_data: user_data as usize, + free_fn, + }, + ); + } + status +} + +unsafe extern "C" fn capture_llm_request( + _ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeLlmRequestCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, +) -> NemoRelayStatus { + let status = *REGISTRATION_STATUS.lock().unwrap(); + if status == NemoRelayStatus::Ok { + let host = test_host(); + let name = match required_host_string(&host, name) { + Ok(name) => name, + Err(status) => return status, + }; + replace_registration( + &LLM_REQUEST_REGISTRATION, + RegisteredLlmRequest { + name, + priority, + cb, + user_data: user_data as usize, + free_fn, + }, + ); + } + status +} + +unsafe extern "C" fn capture_llm_json( + _ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeJsonCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, +) -> NemoRelayStatus { + let status = *REGISTRATION_STATUS.lock().unwrap(); + if status == NemoRelayStatus::Ok { + let host = test_host(); + let name = match required_host_string(&host, name) { + Ok(name) => name, + Err(status) => return status, + }; + replace_registration( + &LLM_JSON_REGISTRATION, + RegisteredLlmJson { + name, + priority, + cb, + user_data: user_data as usize, + free_fn, + }, + ); + } + status +} + +unsafe extern "C" fn capture_llm_conditional( + _ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeLlmConditionalCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, +) -> NemoRelayStatus { + let status = *REGISTRATION_STATUS.lock().unwrap(); + if status == NemoRelayStatus::Ok { + let host = test_host(); + let name = match required_host_string(&host, name) { + Ok(name) => name, + Err(status) => return status, + }; + replace_registration( + &LLM_CONDITIONAL_REGISTRATION, + RegisteredLlmConditional { + name, + priority, + cb, + user_data: user_data as usize, + free_fn, + }, + ); + } + status +} + +unsafe extern "C" fn capture_llm_request_intercept( + _ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + break_chain: bool, + cb: NemoRelayNativeLlmRequestInterceptCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, +) -> NemoRelayStatus { + let status = *REGISTRATION_STATUS.lock().unwrap(); + if status == NemoRelayStatus::Ok { + let host = test_host(); + let name = match required_host_string(&host, name) { + Ok(name) => name, + Err(status) => return status, + }; + replace_registration( + &LLM_REQUEST_INTERCEPT_REGISTRATION, + RegisteredLlmRequestIntercept { + name, + priority, + break_chain, + cb, + user_data: user_data as usize, + free_fn, + }, + ); + } + status +} + +unsafe extern "C" fn capture_llm_stream_execution( + _ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeLlmStreamExecutionCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, +) -> NemoRelayStatus { + let status = *REGISTRATION_STATUS.lock().unwrap(); + if status == NemoRelayStatus::Ok { + let host = test_host(); + let name = match required_host_string(&host, name) { + Ok(name) => name, + Err(status) => return status, + }; + replace_registration( + &LLM_STREAM_EXECUTION_REGISTRATION, + RegisteredLlmStreamExecution { + name, + priority, + cb, + user_data: user_data as usize, + free_fn, + }, + ); + } + status +} + +unsafe extern "C" fn capture_llm_execution( + _ctx: *mut NemoRelayNativePluginContext, + name: *const NemoRelayNativeString, + priority: i32, + cb: NemoRelayNativeLlmExecutionCb, + user_data: *mut c_void, + free_fn: NemoRelayNativeFreeFn, +) -> NemoRelayStatus { + let status = *REGISTRATION_STATUS.lock().unwrap(); + if status == NemoRelayStatus::Ok { + let host = test_host(); + let name = match required_host_string(&host, name) { + Ok(name) => name, + Err(status) => return status, + }; + replace_registration( + &LLM_EXECUTION_REGISTRATION, + RegisteredLlmExecution { + name, + priority, + cb, + user_data: user_data as usize, + free_fn, + }, + ); + } + status +} + +unsafe extern "C" fn capture_scope_get_current( + out: *mut *mut NemoRelayNativeScopeHandle, +) -> NemoRelayStatus { + if out.is_null() { + return NemoRelayStatus::NullPointer; + } + let status = *SCOPE_GET_CURRENT_STATUS.lock().unwrap(); + if status != NemoRelayStatus::Ok { + return status; + } + RUNTIME_CALLS.lock().unwrap().push("current_scope".into()); + if *SCOPE_GET_CURRENT_RETURNS_NULL.lock().unwrap() { + unsafe { *out = ptr::null_mut() }; + } else { + unsafe { *out = Box::into_raw(Box::new(0_u8)).cast() }; + } + NemoRelayStatus::Ok +} + +unsafe extern "C" fn capture_scope_push( + name: *const NemoRelayNativeString, + scope_type: NemoRelayNativeScopeType, + parent: *const NemoRelayNativeScopeHandle, + attributes: u32, + data_json: *const NemoRelayNativeString, + metadata_json: *const NemoRelayNativeString, + input_json: *const NemoRelayNativeString, + _timestamp_unix_micros: *const i64, + out: *mut *mut NemoRelayNativeScopeHandle, +) -> NemoRelayStatus { + if out.is_null() { + return NemoRelayStatus::NullPointer; + } + let status = *SCOPE_PUSH_STATUS.lock().unwrap(); + if status != NemoRelayStatus::Ok { + return status; + } + let host = test_host(); + let name = match required_host_string(&host, name) { + Ok(name) => name, + Err(status) => return status, + }; + let data = match optional_host_string(&host, data_json) { + Ok(data) => data, + Err(status) => return status, + }; + let metadata = match optional_host_string(&host, metadata_json) { + Ok(metadata) => metadata, + Err(status) => return status, + }; + let input = match optional_host_string(&host, input_json) { + Ok(input) => input, + Err(status) => return status, + }; + RUNTIME_CALLS.lock().unwrap().push(format!( + "push:{name}:{scope_type:?}:{attributes}:parent={}:data={data}:metadata={metadata}:input={input}", + !parent.is_null() + )); + if *SCOPE_PUSH_RETURNS_NULL.lock().unwrap() { + unsafe { *out = ptr::null_mut() }; + } else { + unsafe { *out = Box::into_raw(Box::new(0_u8)).cast() }; + } + NemoRelayStatus::Ok +} + +unsafe extern "C" fn capture_scope_pop( + handle: *const NemoRelayNativeScopeHandle, + output_json: *const NemoRelayNativeString, + metadata_json: *const NemoRelayNativeString, + _timestamp_unix_micros: *const i64, +) -> NemoRelayStatus { + if handle.is_null() { + return NemoRelayStatus::NullPointer; + } + let status = *SCOPE_POP_STATUS.lock().unwrap(); + if status != NemoRelayStatus::Ok { + return status; + } + let host = test_host(); + let output = match optional_host_string(&host, output_json) { + Ok(output) => output, + Err(status) => return status, + }; + let metadata = match optional_host_string(&host, metadata_json) { + Ok(metadata) => metadata, + Err(status) => return status, + }; + RUNTIME_CALLS + .lock() + .unwrap() + .push(format!("pop:output={output}:metadata={metadata}")); + NemoRelayStatus::Ok +} + +unsafe extern "C" fn capture_emit_mark( + name: *const NemoRelayNativeString, + parent: *const NemoRelayNativeScopeHandle, + data_json: *const NemoRelayNativeString, + metadata_json: *const NemoRelayNativeString, + _timestamp_unix_micros: *const i64, +) -> NemoRelayStatus { + let status = *EMIT_MARK_STATUS.lock().unwrap(); + if status != NemoRelayStatus::Ok { + return status; + } + let host = test_host(); + let name = match required_host_string(&host, name) { + Ok(name) => name, + Err(status) => return status, + }; + let data = match optional_host_string(&host, data_json) { + Ok(data) => data, + Err(status) => return status, + }; + let metadata = match optional_host_string(&host, metadata_json) { + Ok(metadata) => metadata, + Err(status) => return status, + }; + RUNTIME_CALLS.lock().unwrap().push(format!( + "mark:{name}:parent={}:data={data}:metadata={metadata}", + !parent.is_null() + )); + NemoRelayStatus::Ok +} + +unsafe extern "C" fn capture_scope_stack_create( + out: *mut *mut NemoRelayNativeScopeStack, +) -> NemoRelayStatus { + if out.is_null() { + return NemoRelayStatus::NullPointer; + } + let status = *SCOPE_STACK_CREATE_STATUS.lock().unwrap(); + if status != NemoRelayStatus::Ok { + return status; + } + RUNTIME_CALLS.lock().unwrap().push("stack_create".into()); + if *SCOPE_STACK_CREATE_RETURNS_NULL.lock().unwrap() { + unsafe { *out = ptr::null_mut() }; + } else { + unsafe { *out = Box::into_raw(Box::new(0_u8)).cast() }; + } + NemoRelayStatus::Ok +} + +unsafe extern "C" fn capture_scope_stack_set_thread( + stack: *const NemoRelayNativeScopeStack, +) -> NemoRelayStatus { + if stack.is_null() { + return NemoRelayStatus::NullPointer; + } + let status = *SCOPE_STACK_SET_THREAD_STATUS.lock().unwrap(); + if status != NemoRelayStatus::Ok { + return status; + } + RUNTIME_CALLS + .lock() + .unwrap() + .push("stack_set_thread".into()); + NemoRelayStatus::Ok +} + +unsafe extern "C" fn capture_scope_stack_capture_thread( + out: *mut *mut NemoRelayNativeScopeStackBinding, +) -> NemoRelayStatus { + if out.is_null() { + return NemoRelayStatus::NullPointer; + } + let status = *SCOPE_STACK_CAPTURE_THREAD_STATUS.lock().unwrap(); + if status != NemoRelayStatus::Ok { + return status; + } + RUNTIME_CALLS.lock().unwrap().push("stack_capture".into()); + if *SCOPE_STACK_CAPTURE_THREAD_RETURNS_NULL.lock().unwrap() { + unsafe { *out = ptr::null_mut() }; + } else { + unsafe { *out = Box::into_raw(Box::new(0_u8)).cast() }; + } + NemoRelayStatus::Ok +} + +unsafe extern "C" fn capture_scope_stack_restore_thread( + binding: *mut NemoRelayNativeScopeStackBinding, +) -> NemoRelayStatus { + if binding.is_null() { + return NemoRelayStatus::NullPointer; + } + let status = *SCOPE_STACK_RESTORE_THREAD_STATUS.lock().unwrap(); + RUNTIME_CALLS.lock().unwrap().push("stack_restore".into()); + unsafe { drop(Box::from_raw(binding.cast::())) }; + SCOPE_STACK_BINDING_RESTORES.fetch_add(1, Ordering::SeqCst); + status +} + +unsafe extern "C" fn capture_scope_stack_with_current( + stack: *const NemoRelayNativeScopeStack, + cb: NemoRelayNativeWithScopeStackCb, + user_data: *mut c_void, +) -> NemoRelayStatus { + if stack.is_null() { + return NemoRelayStatus::NullPointer; + } + let status = *SCOPE_STACK_WITH_CURRENT_STATUS.lock().unwrap(); + if status != NemoRelayStatus::Ok { + return status; + } + RUNTIME_CALLS + .lock() + .unwrap() + .push("stack_with_current".into()); + unsafe { cb(user_data) } +} + +unsafe extern "C" fn capture_scope_handle_free(handle: *mut NemoRelayNativeScopeHandle) { + if !handle.is_null() { + unsafe { drop(Box::from_raw(handle.cast::())) }; + SCOPE_HANDLE_FREES.fetch_add(1, Ordering::SeqCst); + } +} +unsafe extern "C" fn capture_scope_stack_free(stack: *mut NemoRelayNativeScopeStack) { + if !stack.is_null() { + unsafe { drop(Box::from_raw(stack.cast::())) }; + SCOPE_STACK_FREES.fetch_add(1, Ordering::SeqCst); + } +} +unsafe extern "C" fn capture_scope_stack_binding_free( + binding: *mut NemoRelayNativeScopeStackBinding, +) { + if !binding.is_null() { + unsafe { drop(Box::from_raw(binding.cast::())) }; + SCOPE_STACK_BINDING_FREES.fetch_add(1, Ordering::SeqCst); + } +} +unsafe extern "C" fn true_scope_stack_active() -> bool { + true +} + +fn test_host() -> NemoRelayNativeHostApiV1 { + NemoRelayNativeHostApiV1 { + abi_version: NEMO_RELAY_NATIVE_ABI_VERSION, + struct_size: std::mem::size_of::(), + relay_version: c"test".as_ptr(), + string_new: test_string_new, + string_data: test_string_data, + string_len: test_string_len, + string_free: test_string_free, + last_error_clear: test_last_error_clear, + last_error_set: test_last_error_set, + plugin_context_register_subscriber: capture_register_subscriber, + plugin_context_register_tool_sanitize_request_guardrail: capture_tool_json, + plugin_context_register_tool_sanitize_response_guardrail: capture_tool_json, + plugin_context_register_tool_conditional_execution_guardrail: capture_tool_conditional, + plugin_context_register_tool_request_intercept: capture_tool_request_intercept, + plugin_context_register_tool_execution_intercept: capture_tool_execution, + plugin_context_register_llm_sanitize_request_guardrail: capture_llm_request, + plugin_context_register_llm_sanitize_response_guardrail: capture_llm_json, + plugin_context_register_llm_conditional_execution_guardrail: capture_llm_conditional, + plugin_context_register_llm_request_intercept: capture_llm_request_intercept, + plugin_context_register_llm_execution_intercept: capture_llm_execution, + plugin_context_register_llm_stream_execution_intercept: capture_llm_stream_execution, + scope_handle_free: capture_scope_handle_free, + scope_get_current: capture_scope_get_current, + scope_push: capture_scope_push, + scope_pop: capture_scope_pop, + emit_mark: capture_emit_mark, + scope_stack_create: capture_scope_stack_create, + scope_stack_free: capture_scope_stack_free, + scope_stack_set_thread: capture_scope_stack_set_thread, + scope_stack_capture_thread: capture_scope_stack_capture_thread, + scope_stack_restore_thread: capture_scope_stack_restore_thread, + scope_stack_binding_free: capture_scope_stack_binding_free, + scope_stack_active: true_scope_stack_active, + scope_stack_with_current: capture_scope_stack_with_current, + } +} + +fn begin_test() -> MutexGuard<'static, ()> { + let guard = TEST_LOCK + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + reset_state(); + guard +} + +fn reset_state() { + clear_registration(&SUBSCRIBER_REGISTRATION); + clear_registration(&TOOL_JSON_REGISTRATION); + clear_registration(&TOOL_CONDITIONAL_REGISTRATION); + clear_registration(&TOOL_EXECUTION_REGISTRATION); + clear_registration(&LLM_REQUEST_REGISTRATION); + clear_registration(&LLM_JSON_REGISTRATION); + clear_registration(&LLM_CONDITIONAL_REGISTRATION); + clear_registration(&LLM_EXECUTION_REGISTRATION); + clear_registration(&LLM_STREAM_EXECUTION_REGISTRATION); + clear_registration(&LLM_REQUEST_INTERCEPT_REGISTRATION); + assert_eq!( + STRING_LIVE_COUNT.load(Ordering::SeqCst), + 0, + "previous test leaked host strings" + ); + *LAST_ERROR.lock().unwrap() = None; + *REGISTRATION_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = None; + *STRING_NEW_RETURNS_NULL.lock().unwrap() = false; + *SCOPE_GET_CURRENT_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + *SCOPE_GET_CURRENT_RETURNS_NULL.lock().unwrap() = false; + *SCOPE_PUSH_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + *SCOPE_PUSH_RETURNS_NULL.lock().unwrap() = false; + *SCOPE_POP_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + *EMIT_MARK_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + *SCOPE_STACK_CREATE_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + *SCOPE_STACK_CREATE_RETURNS_NULL.lock().unwrap() = false; + *SCOPE_STACK_SET_THREAD_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + *SCOPE_STACK_CAPTURE_THREAD_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + *SCOPE_STACK_CAPTURE_THREAD_RETURNS_NULL.lock().unwrap() = false; + *SCOPE_STACK_RESTORE_THREAD_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + *SCOPE_STACK_WITH_CURRENT_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + RUNTIME_CALLS.lock().unwrap().clear(); + SCOPE_HANDLE_FREES.store(0, Ordering::SeqCst); + SCOPE_STACK_FREES.store(0, Ordering::SeqCst); + SCOPE_STACK_BINDING_FREES.store(0, Ordering::SeqCst); + SCOPE_STACK_BINDING_RESTORES.store(0, Ordering::SeqCst); +} + +fn test_context(host: &NemoRelayNativeHostApiV1) -> PluginContext<'_> { + unsafe { + PluginContext::from_raw( + host, + NonNull::::dangling().as_ptr(), + ) + } +} + +fn read_host_string( + host: &NemoRelayNativeHostApiV1, + value: *const NemoRelayNativeString, +) -> Option { + if value.is_null() { + return None; + } + let data = unsafe { (host.string_data)(value) }; + let len = unsafe { (host.string_len)(value) }; + if data.is_null() && len > 0 { + return None; + } + let bytes = if len == 0 { + &[][..] + } else { + unsafe { std::slice::from_raw_parts(data, len) } + }; + std::str::from_utf8(bytes).ok().map(ToOwned::to_owned) +} + +fn required_host_string( + host: &NemoRelayNativeHostApiV1, + value: *const NemoRelayNativeString, +) -> std::result::Result { + if value.is_null() { + return Err(NemoRelayStatus::NullPointer); + } + read_host_string(host, value).ok_or(NemoRelayStatus::InvalidArg) +} + +fn optional_host_string( + host: &NemoRelayNativeHostApiV1, + value: *const NemoRelayNativeString, +) -> std::result::Result { + if value.is_null() { + return Ok(String::new()); + } + read_host_string(host, value).ok_or(NemoRelayStatus::InvalidArg) +} + +fn host_string(host: &NemoRelayNativeHostApiV1, value: &str) -> *mut NemoRelayNativeString { + let mut out = ptr::null_mut(); + let status = unsafe { (host.string_new)(value.as_ptr(), value.len(), &mut out) }; + assert_eq!(status, NemoRelayStatus::Ok); + out +} + +fn bytes_host_string(host: &NemoRelayNativeHostApiV1, value: &[u8]) -> *mut NemoRelayNativeString { + let mut out = ptr::null_mut(); + let status = unsafe { (host.string_new)(value.as_ptr(), value.len(), &mut out) }; + assert_eq!(status, NemoRelayStatus::Ok); + out +} + +fn json_host_string(host: &NemoRelayNativeHostApiV1, value: Json) -> *mut NemoRelayNativeString { + host_string(host, &serde_json::to_string(&value).unwrap()) +} + +fn read_json_and_free(host: &NemoRelayNativeHostApiV1, value: *mut NemoRelayNativeString) -> Json { + let result: Json = serde_json::from_str(&read_host_string(host, value).unwrap()).unwrap(); + unsafe { (host.string_free)(value) }; + result +} + +fn read_string_and_free( + host: &NemoRelayNativeHostApiV1, + value: *mut NemoRelayNativeString, +) -> String { + let result = read_host_string(host, value).unwrap(); + unsafe { (host.string_free)(value) }; + result +} + +fn live_host_strings() -> usize { + STRING_LIVE_COUNT.load(Ordering::SeqCst) +} + +fn expect_string_err(result: std::result::Result) -> String { + match result { + Ok(_) => panic!("operation should have failed"), + Err(error) => error, + } +} + +fn poll_stream_chunk( + host: &NemoRelayNativeHostApiV1, + stream: &NemoRelayNativeLlmStreamV1, +) -> (NemoRelayStatus, Option) { + let mut out = ptr::null_mut(); + let status = unsafe { stream.next.unwrap()(stream.user_data, &mut out) }; + let chunk = if out.is_null() { + None + } else { + Some(read_json_and_free(host, out)) + }; + (status, chunk) +} + +unsafe fn drop_stream(stream: &mut NemoRelayNativeLlmStreamV1) { + if let Some(drop_fn) = stream.drop.take() { + unsafe { drop_fn(stream.user_data) }; + } + stream.user_data = ptr::null_mut(); +} + +unsafe extern "C" fn count_stream_drop(user_data: *mut c_void) { + if !user_data.is_null() { + unsafe { (&*(user_data as *const AtomicUsize)).fetch_add(1, Ordering::SeqCst) }; + } +} + +fn write_json( + host: &NemoRelayNativeHostApiV1, + value: &Json, + out: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + if out.is_null() { + return NemoRelayStatus::NullPointer; + } + let encoded = serde_json::to_string(value).unwrap(); + let mut string = ptr::null_mut(); + let status = unsafe { (host.string_new)(encoded.as_ptr(), encoded.len(), &mut string) }; + if status == NemoRelayStatus::Ok { + unsafe { *out = string }; + } + status +} + +fn take_tool_json_registration() -> RegisteredToolJson { + TOOL_JSON_REGISTRATION + .lock() + .unwrap() + .take() + .expect("tool JSON callback should be registered") +} + +fn take_subscriber_registration() -> RegisteredSubscriber { + SUBSCRIBER_REGISTRATION + .lock() + .unwrap() + .take() + .expect("subscriber callback should be registered") +} + +fn take_tool_conditional_registration() -> RegisteredToolConditional { + TOOL_CONDITIONAL_REGISTRATION + .lock() + .unwrap() + .take() + .expect("tool conditional callback should be registered") +} + +fn take_tool_execution_registration() -> RegisteredToolExecution { + TOOL_EXECUTION_REGISTRATION + .lock() + .unwrap() + .take() + .expect("tool execution callback should be registered") +} + +fn take_llm_request_registration() -> RegisteredLlmRequest { + LLM_REQUEST_REGISTRATION + .lock() + .unwrap() + .take() + .expect("LLM request callback should be registered") +} + +fn take_llm_json_registration() -> RegisteredLlmJson { + LLM_JSON_REGISTRATION + .lock() + .unwrap() + .take() + .expect("LLM JSON callback should be registered") +} + +fn take_llm_conditional_registration() -> RegisteredLlmConditional { + LLM_CONDITIONAL_REGISTRATION + .lock() + .unwrap() + .take() + .expect("LLM conditional callback should be registered") +} + +fn take_llm_execution_registration() -> RegisteredLlmExecution { + LLM_EXECUTION_REGISTRATION + .lock() + .unwrap() + .take() + .expect("LLM execution callback should be registered") +} + +fn take_llm_request_intercept_registration() -> RegisteredLlmRequestIntercept { + LLM_REQUEST_INTERCEPT_REGISTRATION + .lock() + .unwrap() + .take() + .expect("LLM request intercept callback should be registered") +} + +fn take_llm_stream_execution_registration() -> RegisteredLlmStreamExecution { + LLM_STREAM_EXECUTION_REGISTRATION + .lock() + .unwrap() + .take() + .expect("LLM stream execution callback should be registered") +} + +struct PanicOnDrop(&'static str); + +impl Drop for PanicOnDrop { + fn drop(&mut self) { + panic!("{}", self.0); + } +} + +struct PanicIterator { + _panic_on_drop: PanicOnDrop, +} + +impl Iterator for PanicIterator { + type Item = std::result::Result; + + fn next(&mut self) -> Option { + None + } +} + +#[test] +fn llm_stream_from_raw_drops_rejected_streams() { + let _guard = begin_test(); + let host = test_host(); + + let undersized_drop_calls = AtomicUsize::new(0); + let wrong_size = NemoRelayNativeLlmStreamV1 { + struct_size: 0, + user_data: (&undersized_drop_calls as *const AtomicUsize) + .cast_mut() + .cast(), + next: None, + cancel: None, + drop: Some(count_stream_drop), + }; + let err = match unsafe { LlmStream::from_raw(&host, wrong_size) } { + Ok(_) => panic!("undersized stream should be rejected"), + Err(err) => err, + }; + assert!(err.contains("unsupported LLM stream struct size")); + assert_eq!(undersized_drop_calls.load(Ordering::SeqCst), 0); + + let dropped = Arc::new(AtomicUsize::new(0)); + let mut wrong_size = test_llm_stream( + &host, + vec![], + Arc::new(AtomicUsize::new(0)), + dropped.clone(), + ); + wrong_size.struct_size = size_of::() + 8; + let err = match unsafe { LlmStream::from_raw(&host, wrong_size) } { + Ok(_) => panic!("oversized stream should be rejected"), + Err(err) => err, + }; + assert!(err.contains("unsupported LLM stream struct size")); + assert_eq!(dropped.load(Ordering::SeqCst), 1); + + let dropped = Arc::new(AtomicUsize::new(0)); + let mut null_next = test_llm_stream( + &host, + vec![], + Arc::new(AtomicUsize::new(0)), + dropped.clone(), + ); + null_next.next = None; + let err = match unsafe { LlmStream::from_raw(&host, null_next) } { + Ok(_) => panic!("null-next stream should be rejected"), + Err(err) => err, + }; + assert!(err.contains("LLM stream next callback was null")); + assert_eq!(dropped.load(Ordering::SeqCst), 1); +} + +#[test] +fn llm_stream_from_raw_polls_iterates_cancels_and_drops() { + let _guard = begin_test(); + let host = test_host(); + let cancelled = Arc::new(AtomicUsize::new(0)); + let dropped = Arc::new(AtomicUsize::new(0)); + let raw = manual_llm_stream( + &host, + vec![ + ManualStreamPoll::Json(json!({ "chunk": 1 })), + ManualStreamPoll::Json(json!({ "chunk": 2 })), + ManualStreamPoll::EndWithJson(json!({ "ignored": true })), + ], + NemoRelayStatus::Ok, + cancelled.clone(), + dropped.clone(), + ); + let mut stream = unsafe { LlmStream::from_raw(&host, raw) }.unwrap(); + + assert_eq!(stream.next_chunk().unwrap().unwrap()["chunk"], json!(1)); + assert_eq!(stream.next().unwrap().unwrap()["chunk"], json!(2)); + assert!(stream.next().is_none()); + assert!(stream.next_chunk().unwrap().is_none()); + assert!(stream.cancel().is_ok()); + drop(stream); + + assert_eq!(cancelled.load(Ordering::SeqCst), 0); + assert_eq!(dropped.load(Ordering::SeqCst), 1); +} + +#[test] +fn llm_stream_from_raw_reports_chunk_and_status_failures() { + let _guard = begin_test(); + let host = test_host(); + let cancelled = Arc::new(AtomicUsize::new(0)); + let dropped = Arc::new(AtomicUsize::new(0)); + + let raw = manual_llm_stream( + &host, + vec![ManualStreamPoll::NullOk], + NemoRelayStatus::Ok, + cancelled.clone(), + dropped.clone(), + ); + let mut stream = unsafe { LlmStream::from_raw(&host, raw) }.unwrap(); + assert_eq!( + stream.next_chunk().unwrap_err(), + "LLM stream returned null chunk" + ); + assert!(stream.next_chunk().unwrap().is_none()); + drop(stream); + assert_eq!(cancelled.load(Ordering::SeqCst), 0); + assert_eq!(dropped.load(Ordering::SeqCst), 1); + + let raw = manual_llm_stream( + &host, + vec![ManualStreamPoll::InvalidJson], + NemoRelayStatus::Ok, + cancelled.clone(), + dropped.clone(), + ); + let mut stream = unsafe { LlmStream::from_raw(&host, raw) }.unwrap(); + assert_eq!( + stream.next().unwrap().unwrap_err(), + "LLM stream returned invalid JSON: InvalidJson" + ); + assert!(stream.next().is_none()); + drop(stream); + assert_eq!(dropped.load(Ordering::SeqCst), 2); + + let raw = manual_llm_stream( + &host, + vec![ManualStreamPoll::StatusWithJson( + NemoRelayStatus::GuardrailRejected, + json!({ "discarded": true }), + )], + NemoRelayStatus::Ok, + cancelled.clone(), + dropped.clone(), + ); + let mut stream = unsafe { LlmStream::from_raw(&host, raw) }.unwrap(); + let live_before = live_host_strings(); + assert_eq!( + stream.next_chunk().unwrap_err(), + "LLM stream failed: GuardrailRejected" + ); + assert_eq!(live_host_strings(), live_before); + drop(stream); + assert_eq!(dropped.load(Ordering::SeqCst), 3); + + let raw = manual_llm_stream( + &host, + vec![ManualStreamPoll::Status(NemoRelayStatus::NotFound)], + NemoRelayStatus::Ok, + cancelled, + dropped.clone(), + ); + let mut stream = unsafe { LlmStream::from_raw(&host, raw) }.unwrap(); + assert_eq!( + stream.next().unwrap().unwrap_err(), + "LLM stream failed: NotFound" + ); + drop(stream); + assert_eq!(dropped.load(Ordering::SeqCst), 4); +} + +#[test] +fn llm_stream_cancel_handles_finished_missing_and_failing_callbacks() { + let _guard = begin_test(); + let host = test_host(); + + let cancelled = Arc::new(AtomicUsize::new(0)); + let dropped = Arc::new(AtomicUsize::new(0)); + let raw = manual_llm_stream( + &host, + vec![ManualStreamPoll::Json(json!({ "chunk": true }))], + NemoRelayStatus::Ok, + cancelled.clone(), + dropped.clone(), + ); + let mut stream = unsafe { LlmStream::from_raw(&host, raw) }.unwrap(); + stream.cancel().unwrap(); + stream.cancel().unwrap(); + drop(stream); + assert_eq!(cancelled.load(Ordering::SeqCst), 1); + assert_eq!(dropped.load(Ordering::SeqCst), 1); + + let mut raw = manual_llm_stream( + &host, + vec![ManualStreamPoll::Json(json!({ "chunk": true }))], + NemoRelayStatus::Ok, + cancelled.clone(), + dropped.clone(), + ); + raw.cancel = None; + let mut stream = unsafe { LlmStream::from_raw(&host, raw) }.unwrap(); + stream.cancel().unwrap(); + drop(stream); + assert_eq!(cancelled.load(Ordering::SeqCst), 1); + assert_eq!(dropped.load(Ordering::SeqCst), 2); + + let raw = manual_llm_stream( + &host, + vec![ManualStreamPoll::Json(json!({ "chunk": true }))], + NemoRelayStatus::Internal, + cancelled.clone(), + dropped.clone(), + ); + let mut stream = unsafe { LlmStream::from_raw(&host, raw) }.unwrap(); + assert_eq!( + stream.cancel().unwrap_err(), + "LLM stream cancel failed: Internal" + ); + drop(stream); + assert_eq!(cancelled.load(Ordering::SeqCst), 3); + assert_eq!(dropped.load(Ordering::SeqCst), 3); +} + +#[test] +fn plugin_runtime_scope_mark_and_stack_helpers_call_host() { + let _guard = begin_test(); + let host = test_host(); + let runtime = PluginRuntime::new(&host); + assert_eq!( + runtime.host_api().abi_version, + NEMO_RELAY_NATIVE_ABI_VERSION + ); + + let current = runtime.current_scope().unwrap(); + assert!(!current.as_ptr().is_null()); + drop(current); + + let mut scope = runtime + .scope( + "work", + ScopeType::Tool, + Some(&json!({ "data": true })), + Some(&json!({ "metadata": true })), + Some(&json!({ "input": true })), + ) + .unwrap(); + assert!(scope.handle().is_some()); + runtime + .emit_mark( + "checkpoint", + Some(&json!({ "mark": true })), + Some(&json!({ "meta": true })), + ) + .unwrap(); + scope + .close( + Some(&json!({ "output": true })), + Some(&json!({ "closed": true })), + ) + .unwrap(); + assert!(scope.handle().is_none()); + scope.close(None, None).unwrap(); + + let stack = runtime.create_scope_stack().unwrap(); + assert!(runtime.scope_stack_active()); + let with_current_calls = Arc::new(AtomicUsize::new(0)); + stack + .with_current({ + let with_current_calls = with_current_calls.clone(); + move || { + with_current_calls.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + }) + .unwrap(); + assert_eq!(with_current_calls.load(Ordering::SeqCst), 1); + runtime + .bind_scope_stack_thread(&stack) + .unwrap() + .restore() + .unwrap(); + drop(stack); + + let calls = RUNTIME_CALLS.lock().unwrap().clone(); + assert!(calls.iter().any(|call| call == "current_scope")); + assert!(calls.iter().any(|call| { + call.starts_with("push:work:Tool:0:parent=false") + && call.contains(r#""data":true"#) + && call.contains(r#""metadata":true"#) + && call.contains(r#""input":true"#) + })); + assert!(calls.iter().any(|call| { + call.starts_with("mark:checkpoint:parent=false") + && call.contains(r#""mark":true"#) + && call.contains(r#""meta":true"#) + })); + assert!(calls.iter().any(|call| { + call.starts_with("pop:") + && call.contains(r#""output":true"#) + && call.contains(r#""closed":true"#) + })); + assert!(calls.iter().any(|call| call == "stack_create")); + assert!(calls.iter().any(|call| call == "stack_with_current")); + assert!(calls.iter().any(|call| call == "stack_capture")); + assert!(calls.iter().any(|call| call == "stack_set_thread")); + assert!(calls.iter().any(|call| call == "stack_restore")); + assert_eq!(SCOPE_HANDLE_FREES.load(Ordering::SeqCst), 2); + assert_eq!(SCOPE_STACK_FREES.load(Ordering::SeqCst), 1); + assert_eq!(SCOPE_STACK_BINDING_RESTORES.load(Ordering::SeqCst), 1); + assert_eq!(SCOPE_STACK_BINDING_FREES.load(Ordering::SeqCst), 0); +} + +#[test] +fn scope_guard_drops_unclosed_scope_and_maps_scope_types() { + let _guard = begin_test(); + let host = test_host(); + let runtime = PluginRuntime::new(&host); + + assert_eq!( + [ + NemoRelayNativeScopeType::from(ScopeType::Agent), + NemoRelayNativeScopeType::from(ScopeType::Function), + NemoRelayNativeScopeType::from(ScopeType::Tool), + NemoRelayNativeScopeType::from(ScopeType::Llm), + NemoRelayNativeScopeType::from(ScopeType::Retriever), + NemoRelayNativeScopeType::from(ScopeType::Embedder), + NemoRelayNativeScopeType::from(ScopeType::Reranker), + NemoRelayNativeScopeType::from(ScopeType::Guardrail), + NemoRelayNativeScopeType::from(ScopeType::Evaluator), + NemoRelayNativeScopeType::from(ScopeType::Custom), + NemoRelayNativeScopeType::from(ScopeType::Unknown), + ], + [ + NemoRelayNativeScopeType::Agent, + NemoRelayNativeScopeType::Function, + NemoRelayNativeScopeType::Tool, + NemoRelayNativeScopeType::Llm, + NemoRelayNativeScopeType::Retriever, + NemoRelayNativeScopeType::Embedder, + NemoRelayNativeScopeType::Reranker, + NemoRelayNativeScopeType::Guardrail, + NemoRelayNativeScopeType::Evaluator, + NemoRelayNativeScopeType::Custom, + NemoRelayNativeScopeType::Unknown, + ] + ); + + { + let scope = runtime + .scope("auto", ScopeType::Agent, None, None, None) + .unwrap(); + assert!(scope.handle().is_some()); + } + + let calls = RUNTIME_CALLS.lock().unwrap().clone(); + assert!(calls.iter().any(|call| call.starts_with("push:auto:Agent"))); + assert!(calls.iter().any(|call| call == "pop:output=:metadata=")); + assert_eq!(SCOPE_HANDLE_FREES.load(Ordering::SeqCst), 1); +} + +#[test] +fn plugin_runtime_reports_scope_host_failures_and_allocation_failures() { + let _guard = begin_test(); + let host = test_host(); + let runtime = PluginRuntime::new(&host); + + *SCOPE_GET_CURRENT_STATUS.lock().unwrap() = NemoRelayStatus::NotFound; + assert_eq!( + expect_string_err(runtime.current_scope()), + "scope_get_current failed: NotFound" + ); + *SCOPE_GET_CURRENT_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + + *SCOPE_GET_CURRENT_RETURNS_NULL.lock().unwrap() = true; + assert_eq!( + expect_string_err(runtime.current_scope()), + "scope_get_current failed: Ok" + ); + *SCOPE_GET_CURRENT_RETURNS_NULL.lock().unwrap() = false; + + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(0); + assert_eq!( + expect_string_err(runtime.push_scope("scope", ScopeType::Tool, None, None, None)), + "failed to allocate scope name" + ); + assert_eq!( + runtime.emit_mark("mark", None, None).unwrap_err(), + "failed to allocate mark name" + ); + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = None; + + *SCOPE_PUSH_STATUS.lock().unwrap() = NemoRelayStatus::InvalidArg; + assert_eq!( + expect_string_err(runtime.push_scope("scope", ScopeType::Tool, None, None, None)), + "scope_push failed: InvalidArg" + ); + *SCOPE_PUSH_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + + *SCOPE_PUSH_RETURNS_NULL.lock().unwrap() = true; + assert_eq!( + expect_string_err(runtime.push_scope("scope", ScopeType::Tool, None, None, None)), + "scope_push failed: Ok" + ); + *SCOPE_PUSH_RETURNS_NULL.lock().unwrap() = false; + + *EMIT_MARK_STATUS.lock().unwrap() = NemoRelayStatus::Internal; + assert_eq!( + runtime.emit_mark("mark", None, None).unwrap_err(), + "emit_mark failed: Internal" + ); + *EMIT_MARK_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + + let handle = runtime + .push_scope("scope", ScopeType::Tool, None, None, None) + .unwrap(); + *SCOPE_POP_STATUS.lock().unwrap() = NemoRelayStatus::ScopeStackEmpty; + assert_eq!( + runtime.pop_scope(&handle, None, None).unwrap_err(), + "scope_pop failed: ScopeStackEmpty" + ); + *SCOPE_POP_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + drop(handle); + + *SCOPE_STACK_CREATE_STATUS.lock().unwrap() = NemoRelayStatus::Internal; + assert_eq!( + expect_string_err(runtime.create_scope_stack()), + "scope_stack_create failed: Internal" + ); + *SCOPE_STACK_CREATE_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + + *SCOPE_STACK_CREATE_RETURNS_NULL.lock().unwrap() = true; + assert_eq!( + expect_string_err(runtime.create_scope_stack()), + "scope_stack_create failed: Ok" + ); + *SCOPE_STACK_CREATE_RETURNS_NULL.lock().unwrap() = false; + + *SCOPE_STACK_CAPTURE_THREAD_STATUS.lock().unwrap() = NemoRelayStatus::NotFound; + assert_eq!( + expect_string_err(runtime.capture_scope_stack_thread()), + "scope_stack_capture_thread failed: NotFound" + ); + *SCOPE_STACK_CAPTURE_THREAD_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + + *SCOPE_STACK_CAPTURE_THREAD_RETURNS_NULL.lock().unwrap() = true; + assert_eq!( + expect_string_err(runtime.capture_scope_stack_thread()), + "scope_stack_capture_thread failed: Ok" + ); + *SCOPE_STACK_CAPTURE_THREAD_RETURNS_NULL.lock().unwrap() = false; + + *STRING_NEW_RETURNS_NULL.lock().unwrap() = true; + assert_eq!( + runtime.emit_mark("mark", None, None).unwrap_err(), + "failed to allocate mark name" + ); + *STRING_NEW_RETURNS_NULL.lock().unwrap() = false; +} + +#[test] +fn scope_stack_with_current_reports_callback_and_host_failures() { + let _guard = begin_test(); + let host = test_host(); + let runtime = PluginRuntime::new(&host); + let stack = runtime.create_scope_stack().unwrap(); + assert!(!stack.as_ptr().is_null()); + + assert_eq!( + stack + .with_current(|| Err("scope stack callback failed".into())) + .unwrap_err(), + "scope stack callback failed" + ); + assert_eq!( + stack + .with_current(|| panic!("scope stack panic")) + .unwrap_err(), + "scope-stack callback panicked" + ); + + *SCOPE_STACK_WITH_CURRENT_STATUS.lock().unwrap() = NemoRelayStatus::NotFound; + assert_eq!( + stack.with_current(|| Ok(())).unwrap_err(), + "scope_stack_with_current failed: NotFound" + ); +} + +#[test] +fn scope_stack_thread_binding_restores_on_set_failure_and_reports_restore_failure() { + let _guard = begin_test(); + let host = test_host(); + let runtime = PluginRuntime::new(&host); + let stack = runtime.create_scope_stack().unwrap(); + + *SCOPE_STACK_SET_THREAD_STATUS.lock().unwrap() = NemoRelayStatus::InvalidArg; + assert_eq!( + expect_string_err(runtime.bind_scope_stack_thread(&stack)), + "scope_stack_set_thread failed: InvalidArg" + ); + assert_eq!(SCOPE_STACK_BINDING_RESTORES.load(Ordering::SeqCst), 1); + *SCOPE_STACK_SET_THREAD_STATUS.lock().unwrap() = NemoRelayStatus::Ok; + + *SCOPE_STACK_RESTORE_THREAD_STATUS.lock().unwrap() = NemoRelayStatus::Internal; + let guard = runtime.bind_scope_stack_thread(&stack).unwrap(); + assert_eq!( + guard.restore().unwrap_err(), + "scope_stack_restore_thread failed: Internal" + ); + assert_eq!(SCOPE_STACK_BINDING_RESTORES.load(Ordering::SeqCst), 2); + assert_eq!(SCOPE_STACK_BINDING_FREES.load(Ordering::SeqCst), 0); +} + +#[test] +fn scope_stack_bindings_restore_or_free_on_drop() { + let _guard = begin_test(); + let host = test_host(); + let runtime = PluginRuntime::new(&host); + let stack = runtime.create_scope_stack().unwrap(); + + { + let _guard = runtime.bind_scope_stack_thread(&stack).unwrap(); + } + assert_eq!(SCOPE_STACK_BINDING_RESTORES.load(Ordering::SeqCst), 1); + assert_eq!(SCOPE_STACK_BINDING_FREES.load(Ordering::SeqCst), 0); + + let binding = runtime.capture_scope_stack_thread().unwrap(); + drop(binding); + assert_eq!(SCOPE_STACK_BINDING_FREES.load(Ordering::SeqCst), 1); +} + +#[test] +fn typed_subscriber_registration_decodes_events() { + let _guard = begin_test(); + let host = test_host(); + let called = Arc::new(AtomicUsize::new(0)); + let mut ctx = test_context(&host); + ctx.register_subscriber("events", { + let called = called.clone(); + move |event: &Event| { + assert_eq!(event.kind(), "mark"); + called.fetch_add(1, Ordering::SeqCst); + } + }) + .unwrap(); + + let registration = take_subscriber_registration(); + assert_eq!(registration.name, "events"); + let event = json_host_string( + &host, + json!({ + "kind": "mark", + "atof_version": "0.1", + "uuid": "00000000-0000-0000-0000-000000000000", + "timestamp": "2026-01-01T00:00:00Z", + "name": "checkpoint" + }), + ); + let status = unsafe { (registration.cb)(registration.user_data as *mut c_void, event) }; + assert_eq!(status, NemoRelayStatus::Ok); + assert_eq!(called.load(Ordering::SeqCst), 1); + + unsafe { + (host.string_free)(event); + registration.free(); + } +} + +#[test] +fn repeated_captured_registration_frees_previous_callback_state() { + struct DropCounter(Arc); + + impl Drop for DropCounter { + fn drop(&mut self) { + self.0.fetch_add(1, Ordering::SeqCst); + } + } + + let _guard = begin_test(); + let host = test_host(); + let drops = Arc::new(AtomicUsize::new(0)); + let mut ctx = test_context(&host); + + ctx.register_subscriber("first", { + let counter = DropCounter(drops.clone()); + move |_event: &Event| { + let _ = &counter; + } + }) + .unwrap(); + ctx.register_subscriber("second", { + let counter = DropCounter(drops.clone()); + move |_event: &Event| { + let _ = &counter; + } + }) + .unwrap(); + + assert_eq!(drops.load(Ordering::SeqCst), 1); + let registration = take_subscriber_registration(); + assert_eq!(registration.name, "second"); + unsafe { registration.free() }; + assert_eq!(drops.load(Ordering::SeqCst), 2); +} + +#[test] +fn typed_tool_sanitize_guardrails_transform_payloads() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_tool_sanitize_request_guardrail("tool-sanitize-request", 4, |name, mut args| { + assert_eq!(name, "tool"); + args["surface"] = json!("request"); + args + }) + .unwrap(); + + let registration = take_tool_json_registration(); + assert_eq!(registration.name, "tool-sanitize-request"); + assert_eq!(registration.priority, 4); + assert!(!registration.break_chain); + let name = host_string(&host, "tool"); + let payload = json_host_string(&host, json!({ "input": true })); + let mut out = ptr::null_mut(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + payload, + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::Ok); + assert_eq!(read_json_and_free(&host, out)["surface"], json!("request")); + unsafe { + (host.string_free)(name); + (host.string_free)(payload); + registration.free(); + } + + let mut ctx = test_context(&host); + ctx.register_tool_sanitize_response_guardrail( + "tool-sanitize-response", + 5, + |name, mut value| { + assert_eq!(name, "tool"); + value["surface"] = json!("response"); + value + }, + ) + .unwrap(); + + let registration = take_tool_json_registration(); + assert_eq!(registration.name, "tool-sanitize-response"); + assert_eq!(registration.priority, 5); + let name = host_string(&host, "tool"); + let payload = json_host_string(&host, json!({ "output": true })); + let mut out = ptr::null_mut(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + payload, + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::Ok); + assert_eq!(read_json_and_free(&host, out)["surface"], json!("response")); + unsafe { + (host.string_free)(name); + (host.string_free)(payload); + registration.free(); + } +} + +#[test] +fn typed_json_callbacks_report_output_allocation_failures() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_tool_sanitize_request_guardrail("tool-sanitize", 0, |_name, value| value) + .unwrap(); + + let registration = take_tool_json_registration(); + let name = host_string(&host, "tool"); + let payload = json_host_string(&host, json!({ "input": true })); + let mut out = ptr::null_mut(); + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(0); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + payload, + &mut out, + ) + }; + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = None; + assert_eq!(status, NemoRelayStatus::Internal); + assert!(out.is_null()); + + unsafe { + (host.string_free)(name); + (host.string_free)(payload); + registration.free(); + } +} + +#[test] +fn typed_tool_conditional_guardrail_returns_optional_reason() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_tool_conditional_execution_guardrail("tool-conditional", 8, |name, args| { + assert_eq!(name, "tool"); + if args["block"].as_bool().unwrap_or(false) { + Ok(Some("blocked by policy".into())) + } else { + Ok(None) + } + }) + .unwrap(); + + let registration = take_tool_conditional_registration(); + assert_eq!(registration.name, "tool-conditional"); + assert_eq!(registration.priority, 8); + let name = host_string(&host, "tool"); + let args = json_host_string(&host, json!({ "block": false })); + let sentinel = host_string(&host, "sentinel"); + let mut reason = sentinel; + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + args, + &mut reason, + ) + }; + assert_eq!(status, NemoRelayStatus::Ok); + assert!(reason.is_null()); + unsafe { + (host.string_free)(sentinel); + (host.string_free)(args); + } + + let args = json_host_string(&host, json!({ "block": true })); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + args, + &mut reason, + ) + }; + assert_eq!(status, NemoRelayStatus::Ok); + assert_eq!(read_string_and_free(&host, reason), "blocked by policy"); + unsafe { + (host.string_free)(name); + (host.string_free)(args); + registration.free(); + } +} + +#[test] +fn typed_tool_intercept_registration_rewrites_json() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_tool_request_intercept("tool", 17, true, |_name, mut value| { + value["typed"] = json!(true); + Ok(value) + }) + .unwrap(); + + let registration = take_tool_json_registration(); + assert_eq!(registration.name, "tool"); + assert_eq!(registration.priority, 17); + assert!(registration.break_chain); + let name = host_string(&host, "tool"); + let payload = json_host_string(&host, json!({ "input": "value" })); + let mut out = ptr::null_mut(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + payload, + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::Ok); + assert_eq!(read_json_and_free(&host, out)["typed"], json!(true)); + unsafe { + (host.string_free)(name); + (host.string_free)(payload); + registration.free(); + } +} + +#[test] +fn typed_tool_intercept_registration_reports_invalid_json() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_tool_request_intercept("tool", 0, false, |_name, value| Ok(value)) + .unwrap(); + + let registration = take_tool_json_registration(); + let name = host_string(&host, "tool"); + let payload = host_string(&host, "{not json"); + let stale_out = host_string(&host, r#"{"stale":true}"#); + let mut out = stale_out; + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + payload, + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::InvalidJson); + assert!(out.is_null()); + unsafe { + (host.string_free)(stale_out); + (host.string_free)(name); + (host.string_free)(payload); + registration.free(); + } +} + +#[test] +fn typed_tool_intercept_reports_null_inputs_separately_from_invalid_utf8() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_tool_request_intercept("tool", 0, false, |_name, value| Ok(value)) + .unwrap(); + + let registration = take_tool_json_registration(); + let name = host_string(&host, "tool"); + let payload = json_host_string(&host, json!({})); + let stale_out = host_string(&host, r#"{"stale":true}"#); + let mut out = stale_out; + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + ptr::null(), + payload, + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::NullPointer); + assert!(out.is_null()); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("tool name was null") + ); + unsafe { (host.string_free)(stale_out) }; + + let stale_out = host_string(&host, r#"{"stale":true}"#); + let mut out = stale_out; + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + ptr::null(), + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::NullPointer); + assert!(out.is_null()); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("tool payload was null") + ); + unsafe { (host.string_free)(stale_out) }; + + let invalid_name = bytes_host_string(&host, b"\xff"); + let stale_out = host_string(&host, r#"{"stale":true}"#); + let mut out = stale_out; + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + invalid_name, + payload, + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::InvalidUtf8); + assert!(out.is_null()); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("tool name contained invalid UTF-8") + ); + + unsafe { + (host.string_free)(stale_out); + (host.string_free)(invalid_name); + (host.string_free)(name); + (host.string_free)(payload); + registration.free(); + } +} + +#[test] +fn typed_tool_intercept_registration_maps_callback_errors_and_panics() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_tool_request_intercept("tool", 0, false, |_name, _value| { + Err("callback failed".into()) + }) + .unwrap(); + + let registration = take_tool_json_registration(); + let name = host_string(&host, "tool"); + let payload = json_host_string(&host, json!({})); + let stale_out = host_string(&host, r#"{"stale":true}"#); + let mut out = stale_out; + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + payload, + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::Internal); + assert!(out.is_null()); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("callback failed") + ); + unsafe { + (host.string_free)(stale_out); + (host.string_free)(name); + (host.string_free)(payload); + registration.free(); + } + + let mut ctx = test_context(&host); + ctx.register_tool_request_intercept( + "tool", + 0, + false, + |_name, _value| -> Result { panic!("boom") }, + ) + .unwrap(); + let registration = take_tool_json_registration(); + let name = host_string(&host, "tool"); + let payload = json_host_string(&host, json!({})); + let stale_out = host_string(&host, r#"{"stale":true}"#); + let mut out = stale_out; + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + payload, + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::Internal); + assert!(out.is_null()); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("tool intercept callback panicked") + ); + unsafe { + (host.string_free)(stale_out); + (host.string_free)(name); + (host.string_free)(payload); + registration.free(); + } +} + +#[test] +fn typed_callback_free_catches_drop_panics() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + let panic_on_drop = PanicOnDrop("typed callback drop panic"); + ctx.register_tool_request_intercept("tool", 0, false, move |_name, value| { + let _ = &panic_on_drop; + Ok(value) + }) + .unwrap(); + + let registration = take_tool_json_registration(); + *LAST_ERROR.lock().unwrap() = None; + unsafe { registration.free() }; + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("native plugin typed callback state drop panicked") + ); +} + +#[test] +fn typed_callbacks_reject_null_abi_pointers_before_decoding_inputs() { + let _guard = begin_test(); + let host = test_host(); + + let mut ctx = test_context(&host); + ctx.register_subscriber("events", |_event: &Event| {}) + .unwrap(); + let registration = take_subscriber_registration(); + let event = json_host_string( + &host, + json!({ + "kind": "mark", + "atof_version": "0.1", + "uuid": "00000000-0000-0000-0000-000000000000", + "timestamp": "2026-01-01T00:00:00Z", + "name": "checkpoint" + }), + ); + assert_eq!( + unsafe { (registration.cb)(ptr::null_mut(), event) }, + NemoRelayStatus::NullPointer + ); + unsafe { + (host.string_free)(event); + registration.free(); + } + + let mut ctx = test_context(&host); + ctx.register_tool_sanitize_request_guardrail("tool-sanitize", 0, |_name, value| value) + .unwrap(); + let registration = take_tool_json_registration(); + let name = host_string(&host, "tool"); + let payload = json_host_string(&host, json!({})); + let mut out = ptr::null_mut(); + assert_eq!( + unsafe { (registration.cb)(ptr::null_mut(), name, payload, &mut out) }, + NemoRelayStatus::NullPointer + ); + assert_eq!( + unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + payload, + ptr::null_mut(), + ) + }, + NemoRelayStatus::NullPointer + ); + unsafe { + (host.string_free)(name); + (host.string_free)(payload); + registration.free(); + } + + let mut ctx = test_context(&host); + ctx.register_tool_conditional_execution_guardrail("tool-conditional", 0, |_name, _value| { + Ok(None) + }) + .unwrap(); + let registration = take_tool_conditional_registration(); + let name = host_string(&host, "tool"); + let payload = json_host_string(&host, json!({})); + let mut reason = ptr::null_mut(); + assert_eq!( + unsafe { (registration.cb)(ptr::null_mut(), name, payload, &mut reason) }, + NemoRelayStatus::NullPointer + ); + assert_eq!( + unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + payload, + ptr::null_mut(), + ) + }, + NemoRelayStatus::NullPointer + ); + unsafe { + (host.string_free)(name); + (host.string_free)(payload); + registration.free(); + } + + let mut ctx = test_context(&host); + ctx.register_tool_execution_intercept("tool-exec", 0, |_name, value, _next| Ok(value)) + .unwrap(); + let registration = take_tool_execution_registration(); + let name = host_string(&host, "tool"); + let payload = json_host_string(&host, json!({})); + let next_state = Box::into_raw(Box::new(NextState { + host, + called: Arc::new(AtomicUsize::new(0)), + })); + assert_eq!( + unsafe { + (registration.cb)( + ptr::null_mut(), + name, + payload, + fake_tool_next, + next_state.cast(), + &mut out, + ) + }, + NemoRelayStatus::NullPointer + ); + assert_eq!( + unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + payload, + fake_tool_next, + next_state.cast(), + ptr::null_mut(), + ) + }, + NemoRelayStatus::NullPointer + ); + unsafe { + (host.string_free)(name); + (host.string_free)(payload); + drop(Box::from_raw(next_state)); + registration.free(); + } + + let mut ctx = test_context(&host); + ctx.register_llm_sanitize_request_guardrail("llm-request", 0, |request| request) + .unwrap(); + let registration = take_llm_request_registration(); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + assert_eq!( + unsafe { (registration.cb)(ptr::null_mut(), request, &mut out) }, + NemoRelayStatus::NullPointer + ); + assert_eq!( + unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + request, + ptr::null_mut(), + ) + }, + NemoRelayStatus::NullPointer + ); + unsafe { + (host.string_free)(request); + registration.free(); + } + + let mut ctx = test_context(&host); + ctx.register_llm_sanitize_response_guardrail("llm-response", 0, |value| value) + .unwrap(); + let registration = take_llm_json_registration(); + let response = json_host_string(&host, json!({})); + assert_eq!( + unsafe { (registration.cb)(ptr::null_mut(), response, &mut out) }, + NemoRelayStatus::NullPointer + ); + assert_eq!( + unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + response, + ptr::null_mut(), + ) + }, + NemoRelayStatus::NullPointer + ); + unsafe { + (host.string_free)(response); + registration.free(); + } + + let mut ctx = test_context(&host); + ctx.register_llm_conditional_execution_guardrail("llm-conditional", 0, |_request| Ok(None)) + .unwrap(); + let registration = take_llm_conditional_registration(); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + assert_eq!( + unsafe { (registration.cb)(ptr::null_mut(), request, &mut reason) }, + NemoRelayStatus::NullPointer + ); + assert_eq!( + unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + request, + ptr::null_mut(), + ) + }, + NemoRelayStatus::NullPointer + ); + unsafe { + (host.string_free)(request); + registration.free(); + } + + let mut ctx = test_context(&host); + ctx.register_llm_request_intercept("llm-request-intercept", 0, false, |_name, request, ann| { + Ok((request, ann)) + }) + .unwrap(); + let registration = take_llm_request_intercept_registration(); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let mut out_request = ptr::null_mut(); + let mut out_annotated = ptr::null_mut(); + assert_eq!( + unsafe { + (registration.cb)( + ptr::null_mut(), + name, + request, + ptr::null(), + &mut out_request, + &mut out_annotated, + ) + }, + NemoRelayStatus::NullPointer + ); + assert_eq!( + unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + ptr::null(), + ptr::null_mut(), + &mut out_annotated, + ) + }, + NemoRelayStatus::NullPointer + ); + assert_eq!( + unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + ptr::null(), + &mut out_request, + ptr::null_mut(), + ) + }, + NemoRelayStatus::NullPointer + ); + unsafe { + (host.string_free)(name); + (host.string_free)(request); + registration.free(); + } + + let mut ctx = test_context(&host); + ctx.register_llm_execution_intercept( + "llm-exec", + 0, + |_name, request, _next| Ok(request.content), + ) + .unwrap(); + let registration = take_llm_execution_registration(); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let next_state = Box::into_raw(Box::new(NextState { + host, + called: Arc::new(AtomicUsize::new(0)), + })); + assert_eq!( + unsafe { + (registration.cb)( + ptr::null_mut(), + name, + request, + failing_llm_next, + next_state.cast(), + &mut out, + ) + }, + NemoRelayStatus::NullPointer + ); + assert_eq!( + unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + failing_llm_next, + next_state.cast(), + ptr::null_mut(), + ) + }, + NemoRelayStatus::NullPointer + ); + unsafe { + (host.string_free)(name); + (host.string_free)(request); + drop(Box::from_raw(next_state)); + registration.free(); + } + + let mut ctx = test_context(&host); + ctx.register_llm_stream_execution_intercept("llm-stream", 0, |_name, _request, _next| { + Ok(Box::new(std::iter::empty())) + }) + .unwrap(); + let registration = take_llm_stream_execution_registration(); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let next_state = Box::into_raw(Box::new(StreamNextState { + host, + called: Arc::new(AtomicUsize::new(0)), + cancelled: Arc::new(AtomicUsize::new(0)), + dropped: Arc::new(AtomicUsize::new(0)), + })); + let mut stream = NemoRelayNativeLlmStreamV1::default(); + assert_eq!( + unsafe { + (registration.cb)( + ptr::null_mut(), + name, + request, + fake_llm_stream_next, + next_state.cast(), + &mut stream, + ) + }, + NemoRelayStatus::NullPointer + ); + assert_eq!( + unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + fake_llm_stream_next, + next_state.cast(), + ptr::null_mut(), + ) + }, + NemoRelayStatus::NullPointer + ); + unsafe { + (host.string_free)(name); + (host.string_free)(request); + drop(Box::from_raw(next_state)); + registration.free(); + } +} + +struct NextState { + host: NemoRelayNativeHostApiV1, + called: Arc, +} + +unsafe extern "C" fn fake_tool_next( + args_json: *const NemoRelayNativeString, + next_ctx: *mut c_void, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + let state = unsafe { &*(next_ctx as *const NextState) }; + state.called.fetch_add(1, Ordering::SeqCst); + let mut args: Json = + serde_json::from_str(&read_host_string(&state.host, args_json).unwrap()).unwrap(); + args["next_called"] = json!(true); + write_json(&state.host, &args, out_json) +} + +unsafe extern "C" fn failing_tool_next( + _args_json: *const NemoRelayNativeString, + next_ctx: *mut c_void, + _out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + let state = unsafe { &*(next_ctx as *const NextState) }; + state.called.fetch_add(1, Ordering::SeqCst); + NemoRelayStatus::GuardrailRejected +} + +unsafe extern "C" fn invalid_json_tool_next( + _args_json: *const NemoRelayNativeString, + next_ctx: *mut c_void, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + let state = unsafe { &*(next_ctx as *const NextState) }; + state.called.fetch_add(1, Ordering::SeqCst); + let invalid = b"{not json"; + unsafe { (state.host.string_new)(invalid.as_ptr(), invalid.len(), out_json) } +} + +unsafe extern "C" fn null_tool_next( + _args_json: *const NemoRelayNativeString, + next_ctx: *mut c_void, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + let state = unsafe { &*(next_ctx as *const NextState) }; + state.called.fetch_add(1, Ordering::SeqCst); + unsafe { *out_json = ptr::null_mut() }; + NemoRelayStatus::Ok +} + +unsafe extern "C" fn failing_llm_next( + _request_json: *const NemoRelayNativeString, + next_ctx: *mut c_void, + _out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + let state = unsafe { &*(next_ctx as *const NextState) }; + state.called.fetch_add(1, Ordering::SeqCst); + NemoRelayStatus::GuardrailRejected +} + +unsafe extern "C" fn invalid_json_llm_next( + _request_json: *const NemoRelayNativeString, + next_ctx: *mut c_void, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + let state = unsafe { &*(next_ctx as *const NextState) }; + state.called.fetch_add(1, Ordering::SeqCst); + let invalid = b"{not json"; + unsafe { (state.host.string_new)(invalid.as_ptr(), invalid.len(), out_json) } +} + +unsafe extern "C" fn null_llm_next( + _request_json: *const NemoRelayNativeString, + next_ctx: *mut c_void, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + let state = unsafe { &*(next_ctx as *const NextState) }; + state.called.fetch_add(1, Ordering::SeqCst); + unsafe { *out_json = ptr::null_mut() }; + NemoRelayStatus::Ok +} + +struct StreamNextState { + host: NemoRelayNativeHostApiV1, + called: Arc, + cancelled: Arc, + dropped: Arc, +} + +struct TestLlmStreamState { + host: NemoRelayNativeHostApiV1, + chunks: Mutex>>, + cancelled: Arc, + dropped: Arc, +} + +fn test_llm_stream( + host: &NemoRelayNativeHostApiV1, + chunks: Vec>, + cancelled: Arc, + dropped: Arc, +) -> NemoRelayNativeLlmStreamV1 { + let state = Box::new(TestLlmStreamState { + host: *host, + chunks: Mutex::new(VecDeque::from(chunks)), + cancelled, + dropped, + }); + NemoRelayNativeLlmStreamV1 { + struct_size: size_of::(), + user_data: Box::into_raw(state).cast(), + next: Some(poll_test_llm_stream), + cancel: Some(cancel_test_llm_stream), + drop: Some(drop_test_llm_stream), + } +} + +unsafe extern "C" fn poll_test_llm_stream( + user_data: *mut c_void, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + if user_data.is_null() || out_json.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out_json = ptr::null_mut() }; + let state = unsafe { &*(user_data as *const TestLlmStreamState) }; + let mut chunks = state.chunks.lock().unwrap(); + match chunks.pop_front() { + Some(Ok(chunk)) => write_json(&state.host, &chunk, out_json), + Some(Err(message)) => { + let message = host_string(&state.host, &message); + unsafe { + (state.host.last_error_set)(message); + (state.host.string_free)(message); + } + NemoRelayStatus::Internal + } + None => NemoRelayStatus::StreamEnd, + } +} + +unsafe extern "C" fn cancel_test_llm_stream(user_data: *mut c_void) -> NemoRelayStatus { + if user_data.is_null() { + return NemoRelayStatus::NullPointer; + } + let state = unsafe { &*(user_data as *const TestLlmStreamState) }; + state.cancelled.fetch_add(1, Ordering::SeqCst); + NemoRelayStatus::Ok +} + +unsafe extern "C" fn drop_test_llm_stream(user_data: *mut c_void) { + if !user_data.is_null() { + let state = unsafe { Box::from_raw(user_data as *mut TestLlmStreamState) }; + state.dropped.fetch_add(1, Ordering::SeqCst); + } +} + +unsafe extern "C" fn fake_llm_stream_next( + _request_json: *const NemoRelayNativeString, + next_ctx: *mut c_void, + out_stream: *mut NemoRelayNativeLlmStreamV1, +) -> NemoRelayStatus { + if out_stream.is_null() { + return NemoRelayStatus::NullPointer; + } + let state = unsafe { &*(next_ctx as *const StreamNextState) }; + state.called.fetch_add(1, Ordering::SeqCst); + unsafe { + *out_stream = test_llm_stream( + &state.host, + vec![Ok(json!({ "chunk": 1 })), Ok(json!({ "chunk": 2 }))], + state.cancelled.clone(), + state.dropped.clone(), + ) + }; + NemoRelayStatus::Ok +} + +unsafe extern "C" fn failing_llm_stream_next( + _request_json: *const NemoRelayNativeString, + next_ctx: *mut c_void, + _out_stream: *mut NemoRelayNativeLlmStreamV1, +) -> NemoRelayStatus { + let state = unsafe { &*(next_ctx as *const StreamNextState) }; + state.called.fetch_add(1, Ordering::SeqCst); + NemoRelayStatus::GuardrailRejected +} + +enum ManualStreamPoll { + Json(Json), + InvalidJson, + NullOk, + Status(NemoRelayStatus), + StatusWithJson(NemoRelayStatus, Json), + End, + EndWithJson(Json), +} + +struct ManualStreamState { + host: NemoRelayNativeHostApiV1, + polls: Mutex>, + cancel_status: NemoRelayStatus, + cancelled: Arc, + dropped: Arc, +} + +fn manual_llm_stream( + host: &NemoRelayNativeHostApiV1, + polls: Vec, + cancel_status: NemoRelayStatus, + cancelled: Arc, + dropped: Arc, +) -> NemoRelayNativeLlmStreamV1 { + let state = Box::new(ManualStreamState { + host: *host, + polls: Mutex::new(VecDeque::from(polls)), + cancel_status, + cancelled, + dropped, + }); + NemoRelayNativeLlmStreamV1 { + struct_size: size_of::(), + user_data: Box::into_raw(state).cast(), + next: Some(poll_manual_llm_stream), + cancel: Some(cancel_manual_llm_stream), + drop: Some(drop_manual_llm_stream), + } +} + +unsafe extern "C" fn poll_manual_llm_stream( + user_data: *mut c_void, + out_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + if user_data.is_null() || out_json.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out_json = ptr::null_mut() }; + let state = unsafe { &*(user_data as *const ManualStreamState) }; + match state + .polls + .lock() + .unwrap() + .pop_front() + .unwrap_or(ManualStreamPoll::End) + { + ManualStreamPoll::Json(value) => write_json(&state.host, &value, out_json), + ManualStreamPoll::InvalidJson => { + let invalid = b"{not json"; + unsafe { (state.host.string_new)(invalid.as_ptr(), invalid.len(), out_json) } + } + ManualStreamPoll::NullOk => NemoRelayStatus::Ok, + ManualStreamPoll::Status(status) => status, + ManualStreamPoll::StatusWithJson(status, value) => { + let write_status = write_json(&state.host, &value, out_json); + if write_status == NemoRelayStatus::Ok { + status + } else { + write_status + } + } + ManualStreamPoll::End => NemoRelayStatus::StreamEnd, + ManualStreamPoll::EndWithJson(value) => { + let write_status = write_json(&state.host, &value, out_json); + if write_status == NemoRelayStatus::Ok { + NemoRelayStatus::StreamEnd + } else { + write_status + } + } + } +} + +unsafe extern "C" fn cancel_manual_llm_stream(user_data: *mut c_void) -> NemoRelayStatus { + if user_data.is_null() { + return NemoRelayStatus::NullPointer; + } + let state = unsafe { &*(user_data as *const ManualStreamState) }; + state.cancelled.fetch_add(1, Ordering::SeqCst); + state.cancel_status +} + +unsafe extern "C" fn drop_manual_llm_stream(user_data: *mut c_void) { + if !user_data.is_null() { + let state = unsafe { Box::from_raw(user_data as *mut ManualStreamState) }; + state.dropped.fetch_add(1, Ordering::SeqCst); + } +} + +#[test] +fn typed_tool_execution_registration_calls_next() { + let _guard = begin_test(); + let host = test_host(); + let called = Arc::new(AtomicUsize::new(0)); + let mut ctx = test_context(&host); + ctx.register_tool_execution_intercept("tool", 23, |_name, args, next: ToolNext<'_>| { + next.call(args) + }) + .unwrap(); + + let registration = take_tool_execution_registration(); + assert_eq!(registration.name, "tool"); + assert_eq!(registration.priority, 23); + let next_state = Box::into_raw(Box::new(NextState { + host, + called: called.clone(), + })); + let name = host_string(&host, "tool"); + let args = json_host_string(&host, json!({ "input": true })); + let mut out = ptr::null_mut(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + args, + fake_tool_next, + next_state.cast(), + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::Ok); + assert_eq!(called.load(Ordering::SeqCst), 1); + assert_eq!(read_json_and_free(&host, out)["next_called"], json!(true)); + unsafe { + (host.string_free)(name); + (host.string_free)(args); + drop(Box::from_raw(next_state)); + registration.free(); + } +} + +#[test] +fn typed_tool_execution_surfaces_next_status_failures() { + let _guard = begin_test(); + let host = test_host(); + let called = Arc::new(AtomicUsize::new(0)); + let mut ctx = test_context(&host); + ctx.register_tool_execution_intercept("tool", 0, |_name, args, next: ToolNext<'_>| { + next.call(args) + }) + .unwrap(); + + let registration = take_tool_execution_registration(); + let next_state = Box::into_raw(Box::new(NextState { + host, + called: called.clone(), + })); + let name = host_string(&host, "tool"); + let args = json_host_string(&host, json!({ "input": true })); + let mut out = ptr::null_mut(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + args, + failing_tool_next, + next_state.cast(), + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::Internal); + assert!(out.is_null()); + assert_eq!(called.load(Ordering::SeqCst), 1); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("tool next failed: GuardrailRejected") + ); + unsafe { + (host.string_free)(name); + (host.string_free)(args); + drop(Box::from_raw(next_state)); + registration.free(); + } +} + +#[test] +fn typed_tool_execution_surfaces_invalid_next_json() { + let _guard = begin_test(); + let host = test_host(); + let called = Arc::new(AtomicUsize::new(0)); + let mut ctx = test_context(&host); + ctx.register_tool_execution_intercept("tool", 0, |_name, args, next: ToolNext<'_>| { + next.call(args) + }) + .unwrap(); + + let registration = take_tool_execution_registration(); + let next_state = Box::into_raw(Box::new(NextState { + host, + called: called.clone(), + })); + let name = host_string(&host, "tool"); + let args = json_host_string(&host, json!({ "input": true })); + let mut out = ptr::null_mut(); + let live_before = live_host_strings(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + args, + invalid_json_tool_next, + next_state.cast(), + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::Internal); + assert!(out.is_null()); + assert_eq!(called.load(Ordering::SeqCst), 1); + assert_eq!(live_host_strings(), live_before); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("tool next returned invalid JSON: InvalidJson") + ); + unsafe { + (host.string_free)(name); + (host.string_free)(args); + drop(Box::from_raw(next_state)); + registration.free(); + } +} + +#[test] +fn typed_tool_execution_surfaces_null_next_output() { + let _guard = begin_test(); + let host = test_host(); + let called = Arc::new(AtomicUsize::new(0)); + let mut ctx = test_context(&host); + ctx.register_tool_execution_intercept("tool", 0, |_name, args, next: ToolNext<'_>| { + next.call(args) + }) + .unwrap(); + + let registration = take_tool_execution_registration(); + let next_state = Box::into_raw(Box::new(NextState { + host, + called: called.clone(), + })); + let name = host_string(&host, "tool"); + let args = json_host_string(&host, json!({ "input": true })); + let mut out = ptr::null_mut(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + args, + null_tool_next, + next_state.cast(), + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::Internal); + assert!(out.is_null()); + assert_eq!(called.load(Ordering::SeqCst), 1); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("tool next returned null output") + ); + unsafe { + (host.string_free)(name); + (host.string_free)(args); + drop(Box::from_raw(next_state)); + registration.free(); + } +} + +#[test] +fn typed_llm_sanitize_guardrails_transform_request_and_response() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_llm_sanitize_request_guardrail("llm-sanitize-request", 12, |mut request| { + request.headers.insert("x-policy".into(), json!("sdk")); + request.content["sanitized"] = json!(true); + request + }) + .unwrap(); + + let registration = take_llm_request_registration(); + assert_eq!(registration.name, "llm-sanitize-request"); + assert_eq!(registration.priority, 12); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let mut out = ptr::null_mut(); + let status = + unsafe { (registration.cb)(registration.user_data as *mut c_void, request, &mut out) }; + assert_eq!(status, NemoRelayStatus::Ok); + let output = read_json_and_free(&host, out); + assert_eq!(output["headers"]["x-policy"], json!("sdk")); + assert_eq!(output["content"]["sanitized"], json!(true)); + unsafe { + (host.string_free)(request); + registration.free(); + } + + let mut ctx = test_context(&host); + ctx.register_llm_sanitize_response_guardrail("llm-sanitize-response", 13, |mut payload| { + payload["sanitized"] = json!(true); + payload + }) + .unwrap(); + + let registration = take_llm_json_registration(); + assert_eq!(registration.name, "llm-sanitize-response"); + assert_eq!(registration.priority, 13); + let response = json_host_string(&host, json!({ "output": true })); + let mut out = ptr::null_mut(); + let status = + unsafe { (registration.cb)(registration.user_data as *mut c_void, response, &mut out) }; + assert_eq!(status, NemoRelayStatus::Ok); + assert_eq!(read_json_and_free(&host, out)["sanitized"], json!(true)); + unsafe { + (host.string_free)(response); + registration.free(); + } +} + +#[test] +fn typed_llm_conditional_guardrail_returns_optional_reason() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_llm_conditional_execution_guardrail("llm-conditional", 14, |request| { + if request.content["block"].as_bool().unwrap_or(false) { + Ok(Some("LLM blocked".into())) + } else { + Ok(None) + } + }) + .unwrap(); + + let registration = take_llm_conditional_registration(); + assert_eq!(registration.name, "llm-conditional"); + assert_eq!(registration.priority, 14); + let request = json_host_string( + &host, + serde_json::to_value(LlmRequest { + headers: Map::new(), + content: json!({ "block": false }), + }) + .unwrap(), + ); + let sentinel = host_string(&host, "sentinel"); + let mut reason = sentinel; + let status = + unsafe { (registration.cb)(registration.user_data as *mut c_void, request, &mut reason) }; + assert_eq!(status, NemoRelayStatus::Ok); + assert!(reason.is_null()); + unsafe { + (host.string_free)(sentinel); + (host.string_free)(request); + } + + let request = json_host_string( + &host, + serde_json::to_value(LlmRequest { + headers: Map::new(), + content: json!({ "block": true }), + }) + .unwrap(), + ); + let status = + unsafe { (registration.cb)(registration.user_data as *mut c_void, request, &mut reason) }; + assert_eq!(status, NemoRelayStatus::Ok); + assert_eq!(read_string_and_free(&host, reason), "LLM blocked"); + unsafe { + (host.string_free)(request); + registration.free(); + } +} + +#[test] +fn typed_llm_execution_surfaces_next_status_failures() { + let _guard = begin_test(); + let host = test_host(); + let called = Arc::new(AtomicUsize::new(0)); + let mut ctx = test_context(&host); + ctx.register_llm_execution_intercept("llm", 0, |_name, request, next: LlmNext<'_>| { + next.call(request) + }) + .unwrap(); + + let registration = take_llm_execution_registration(); + let next_state = Box::into_raw(Box::new(NextState { + host, + called: called.clone(), + })); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let mut out = ptr::null_mut(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + failing_llm_next, + next_state.cast(), + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::Internal); + assert!(out.is_null()); + assert_eq!(called.load(Ordering::SeqCst), 1); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("llm next failed: GuardrailRejected") + ); + unsafe { + (host.string_free)(name); + (host.string_free)(request); + drop(Box::from_raw(next_state)); + registration.free(); + } +} + +#[test] +fn typed_llm_execution_surfaces_invalid_next_json() { + let _guard = begin_test(); + let host = test_host(); + let called = Arc::new(AtomicUsize::new(0)); + let mut ctx = test_context(&host); + ctx.register_llm_execution_intercept("llm", 0, |_name, request, next: LlmNext<'_>| { + next.call(request) + }) + .unwrap(); + + let registration = take_llm_execution_registration(); + let next_state = Box::into_raw(Box::new(NextState { + host, + called: called.clone(), + })); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let mut out = ptr::null_mut(); + let live_before = live_host_strings(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + invalid_json_llm_next, + next_state.cast(), + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::Internal); + assert!(out.is_null()); + assert_eq!(called.load(Ordering::SeqCst), 1); + assert_eq!(live_host_strings(), live_before); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("llm next returned invalid JSON: InvalidJson") + ); + unsafe { + (host.string_free)(name); + (host.string_free)(request); + drop(Box::from_raw(next_state)); + registration.free(); + } +} + +#[test] +fn typed_llm_execution_surfaces_null_next_output() { + let _guard = begin_test(); + let host = test_host(); + let called = Arc::new(AtomicUsize::new(0)); + let mut ctx = test_context(&host); + ctx.register_llm_execution_intercept("llm", 31, |_name, request, next: LlmNext<'_>| { + next.call(request) + }) + .unwrap(); + + let registration = take_llm_execution_registration(); + assert_eq!(registration.name, "llm"); + assert_eq!(registration.priority, 31); + let next_state = Box::into_raw(Box::new(NextState { + host, + called: called.clone(), + })); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let mut out = ptr::null_mut(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + null_llm_next, + next_state.cast(), + &mut out, + ) + }; + assert_eq!(status, NemoRelayStatus::Internal); + assert!(out.is_null()); + assert_eq!(called.load(Ordering::SeqCst), 1); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("llm next returned null output") + ); + unsafe { + (host.string_free)(name); + (host.string_free)(request); + drop(Box::from_raw(next_state)); + registration.free(); + } +} + +#[test] +fn typed_llm_stream_execution_wraps_next_chunks() { + let _guard = begin_test(); + let host = test_host(); + let called = Arc::new(AtomicUsize::new(0)); + let cancelled = Arc::new(AtomicUsize::new(0)); + let dropped = Arc::new(AtomicUsize::new(0)); + let mut ctx = test_context(&host); + ctx.register_llm_stream_execution_intercept( + "llm-stream", + 31, + |_name, request, next: LlmStreamNext<'_>| { + let stream = next.call(request)?; + let stream: LlmJsonStream = Box::new(stream.map(|chunk| { + chunk.map(|mut chunk| { + chunk["wrapped"] = json!(true); + chunk + }) + })); + Ok(stream) + }, + ) + .unwrap(); + + let registration = take_llm_stream_execution_registration(); + assert_eq!(registration.name, "llm-stream"); + assert_eq!(registration.priority, 31); + let next_state = Box::into_raw(Box::new(StreamNextState { + host, + called: called.clone(), + cancelled: cancelled.clone(), + dropped: dropped.clone(), + })); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let mut stream = NemoRelayNativeLlmStreamV1::default(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + fake_llm_stream_next, + next_state.cast(), + &mut stream, + ) + }; + assert_eq!(status, NemoRelayStatus::Ok); + assert_eq!(called.load(Ordering::SeqCst), 1); + + let (status, chunk) = poll_stream_chunk(&host, &stream); + assert_eq!(status, NemoRelayStatus::Ok); + assert_eq!(chunk.unwrap()["wrapped"], json!(true)); + let (status, chunk) = poll_stream_chunk(&host, &stream); + assert_eq!(status, NemoRelayStatus::Ok); + let chunk = chunk.unwrap(); + assert_eq!(chunk["chunk"], json!(2)); + assert_eq!(chunk["wrapped"], json!(true)); + let (status, chunk) = poll_stream_chunk(&host, &stream); + assert_eq!(status, NemoRelayStatus::StreamEnd); + assert!(chunk.is_none()); + + unsafe { + drop_stream(&mut stream); + (host.string_free)(name); + (host.string_free)(request); + drop(Box::from_raw(next_state)); + registration.free(); + } + assert_eq!(cancelled.load(Ordering::SeqCst), 0); + assert_eq!(dropped.load(Ordering::SeqCst), 1); +} + +#[test] +fn typed_llm_stream_drop_catches_stream_state_panics() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_llm_stream_execution_intercept("llm-stream", 0, |_name, _request, _next| { + let stream: LlmJsonStream = Box::new(PanicIterator { + _panic_on_drop: PanicOnDrop("LLM stream state drop panic"), + }); + Ok(stream) + }) + .unwrap(); + + let registration = take_llm_stream_execution_registration(); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let mut stream = NemoRelayNativeLlmStreamV1::default(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + fake_llm_stream_next, + ptr::null_mut(), + &mut stream, + ) + }; + assert_eq!(status, NemoRelayStatus::Ok); + + *LAST_ERROR.lock().unwrap() = None; + unsafe { + drop_stream(&mut stream); + (host.string_free)(name); + (host.string_free)(request); + registration.free(); + } + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("native plugin LLM stream state drop panicked") + ); +} + +#[test] +fn typed_llm_stream_execution_surfaces_next_failures() { + let _guard = begin_test(); + let host = test_host(); + let called = Arc::new(AtomicUsize::new(0)); + let cancelled = Arc::new(AtomicUsize::new(0)); + let dropped = Arc::new(AtomicUsize::new(0)); + let mut ctx = test_context(&host); + ctx.register_llm_stream_execution_intercept( + "llm-stream", + 0, + |_name, request, next: LlmStreamNext<'_>| { + let stream = next.call(request)?; + Ok(Box::new(stream)) + }, + ) + .unwrap(); + + let registration = take_llm_stream_execution_registration(); + let next_state = Box::into_raw(Box::new(StreamNextState { + host, + called: called.clone(), + cancelled, + dropped, + })); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let mut stream = NemoRelayNativeLlmStreamV1::default(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + failing_llm_stream_next, + next_state.cast(), + &mut stream, + ) + }; + assert_eq!(status, NemoRelayStatus::Internal); + assert_eq!( + stream.struct_size, + NemoRelayNativeLlmStreamV1::default().struct_size + ); + assert!(stream.user_data.is_null()); + assert!(stream.next.is_none()); + assert!(stream.cancel.is_none()); + assert!(stream.drop.is_none()); + assert_eq!(called.load(Ordering::SeqCst), 1); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("llm stream next failed: GuardrailRejected") + ); + unsafe { + (host.string_free)(name); + (host.string_free)(request); + drop(Box::from_raw(next_state)); + registration.free(); + } +} + +#[test] +fn typed_llm_stream_execution_surfaces_chunk_errors() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_llm_stream_execution_intercept("llm-stream", 0, |_name, _request, _next| { + let stream: LlmJsonStream = Box::new(std::iter::once(Err("chunk failed".into()))); + Ok(stream) + }) + .unwrap(); + + let registration = take_llm_stream_execution_registration(); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let next_state = Box::into_raw(Box::new(StreamNextState { + host, + called: Arc::new(AtomicUsize::new(0)), + cancelled: Arc::new(AtomicUsize::new(0)), + dropped: Arc::new(AtomicUsize::new(0)), + })); + let mut stream = NemoRelayNativeLlmStreamV1::default(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + fake_llm_stream_next, + next_state.cast(), + &mut stream, + ) + }; + assert_eq!(status, NemoRelayStatus::Ok); + let (status, chunk) = poll_stream_chunk(&host, &stream); + assert_eq!(status, NemoRelayStatus::Internal); + assert!(chunk.is_none()); + assert_eq!(LAST_ERROR.lock().unwrap().as_deref(), Some("chunk failed")); + + unsafe { + drop_stream(&mut stream); + (host.string_free)(name); + (host.string_free)(request); + drop(Box::from_raw(next_state)); + registration.free(); + } +} + +#[test] +fn typed_llm_stream_execution_cancels_unconsumed_next_stream() { + let _guard = begin_test(); + let host = test_host(); + let called = Arc::new(AtomicUsize::new(0)); + let cancelled = Arc::new(AtomicUsize::new(0)); + let dropped = Arc::new(AtomicUsize::new(0)); + let mut ctx = test_context(&host); + ctx.register_llm_stream_execution_intercept( + "llm-stream", + 0, + |_name, request, next: LlmStreamNext<'_>| { + let stream = next.call(request)?; + drop(stream); + let stream: LlmJsonStream = Box::new(std::iter::empty()); + Ok(stream) + }, + ) + .unwrap(); + + let registration = take_llm_stream_execution_registration(); + let next_state = Box::into_raw(Box::new(StreamNextState { + host, + called: called.clone(), + cancelled: cancelled.clone(), + dropped: dropped.clone(), + })); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let mut stream = NemoRelayNativeLlmStreamV1::default(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + fake_llm_stream_next, + next_state.cast(), + &mut stream, + ) + }; + assert_eq!(status, NemoRelayStatus::Ok); + assert_eq!(called.load(Ordering::SeqCst), 1); + assert_eq!(cancelled.load(Ordering::SeqCst), 1); + assert_eq!(dropped.load(Ordering::SeqCst), 1); + let (status, chunk) = poll_stream_chunk(&host, &stream); + assert_eq!(status, NemoRelayStatus::StreamEnd); + assert!(chunk.is_none()); + + unsafe { + drop_stream(&mut stream); + (host.string_free)(name); + (host.string_free)(request); + drop(Box::from_raw(next_state)); + registration.free(); + } +} + +fn test_llm_request() -> LlmRequest { + LlmRequest { + headers: Map::new(), + content: json!({ "input": true }), + } +} + +fn test_annotated_llm_request() -> AnnotatedLlmRequest { + serde_json::from_value(json!({ "messages": [] })).unwrap() +} + +#[test] +fn typed_llm_request_intercept_does_not_publish_partial_outputs() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_llm_request_intercept("llm", 0, false, |_name, request, _annotated| { + Ok((request, Some(test_annotated_llm_request()))) + }) + .unwrap(); + + let registration = take_llm_request_intercept_registration(); + assert_eq!(registration.name, "llm"); + assert_eq!(registration.priority, 0); + assert!(!registration.break_chain); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let stale_request = host_string(&host, r#"{"stale":"request"}"#); + let stale_annotated = host_string(&host, r#"{"stale":"annotated"}"#); + let mut out_request = stale_request; + let mut out_annotated = stale_annotated; + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(1); + let live_before = live_host_strings(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + ptr::null(), + &mut out_request, + &mut out_annotated, + ) + }; + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = None; + assert_eq!(status, NemoRelayStatus::Internal); + assert!(out_request.is_null()); + assert!(out_annotated.is_null()); + assert_eq!(live_host_strings(), live_before); + unsafe { + (host.string_free)(stale_request); + (host.string_free)(stale_annotated); + (host.string_free)(name); + (host.string_free)(request); + registration.free(); + } + + let mut ctx = test_context(&host); + ctx.register_llm_request_intercept("llm", 0, false, |_name, request, _annotated| { + Ok((request, None)) + }) + .unwrap(); + + let registration = take_llm_request_intercept_registration(); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let mut out_request = ptr::null_mut(); + let mut out_annotated = ptr::null_mut(); + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(0); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + ptr::null(), + &mut out_request, + &mut out_annotated, + ) + }; + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = None; + assert_eq!(status, NemoRelayStatus::Internal); + assert!(out_request.is_null()); + assert!(out_annotated.is_null()); + unsafe { + (host.string_free)(name); + (host.string_free)(request); + registration.free(); + } +} + +#[test] +fn typed_llm_request_intercept_round_trips_request_and_annotations() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_llm_request_intercept("llm", 19, true, |name, mut request, annotated| { + assert_eq!(name, "llm"); + assert!(annotated.is_some()); + request.headers.insert("x-mutated".into(), json!(true)); + request.content["rewritten"] = json!(true); + Ok((request, Some(test_annotated_llm_request()))) + }) + .unwrap(); + + let registration = take_llm_request_intercept_registration(); + assert_eq!(registration.name, "llm"); + assert_eq!(registration.priority, 19); + assert!(registration.break_chain); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let annotated = json_host_string( + &host, + serde_json::to_value(test_annotated_llm_request()).unwrap(), + ); + let mut out_request = ptr::null_mut(); + let mut out_annotated = ptr::null_mut(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + annotated, + &mut out_request, + &mut out_annotated, + ) + }; + assert_eq!(status, NemoRelayStatus::Ok); + let out_request = read_json_and_free(&host, out_request); + assert_eq!(out_request["headers"]["x-mutated"], json!(true)); + assert_eq!(out_request["content"]["rewritten"], json!(true)); + let out_annotated = read_json_and_free(&host, out_annotated); + assert_eq!(out_annotated["messages"], json!([])); + + unsafe { + (host.string_free)(name); + (host.string_free)(request); + (host.string_free)(annotated); + registration.free(); + } + + let mut ctx = test_context(&host); + ctx.register_llm_request_intercept("llm", 0, false, |_name, request, _annotated| { + Ok((request, None)) + }) + .unwrap(); + let registration = take_llm_request_intercept_registration(); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let mut out_request = ptr::null_mut(); + let mut out_annotated = host_string(&host, r#"{"stale":true}"#); + let stale_annotated = out_annotated; + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + ptr::null(), + &mut out_request, + &mut out_annotated, + ) + }; + assert_eq!(status, NemoRelayStatus::Ok); + assert!(out_annotated.is_null()); + assert_eq!( + read_json_and_free(&host, out_request)["content"]["input"], + json!(true) + ); + unsafe { + (host.string_free)(stale_annotated); + (host.string_free)(name); + (host.string_free)(request); + registration.free(); + } +} + +struct DropCounter(Arc); + +impl Drop for DropCounter { + fn drop(&mut self) { + self.0.fetch_add(1, Ordering::SeqCst); + } +} + +#[test] +fn failed_typed_registration_drops_callback_state() { + let _guard = begin_test(); + let host = test_host(); + *REGISTRATION_STATUS.lock().unwrap() = NemoRelayStatus::AlreadyExists; + let drops = Arc::new(AtomicUsize::new(0)); + let drop_counter = DropCounter(drops.clone()); + let mut ctx = test_context(&host); + let result = ctx.register_tool_request_intercept("duplicate", 0, false, move |_name, value| { + let _keep_alive = &drop_counter; + Ok(value) + }); + + assert!(result.is_err()); + assert_eq!(drops.load(Ordering::SeqCst), 1); + assert!(TOOL_JSON_REGISTRATION.lock().unwrap().is_none()); +} + +#[test] +fn raw_registration_propagates_name_allocation_status() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(0); + let status = unsafe { + ctx.register_tool_request_intercept_raw( + "tool", + 0, + false, + passthrough_tool_json_cb, + ptr::null_mut(), + None, + ) + }; + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = None; + + assert_eq!(status, NemoRelayStatus::Internal); + assert!(TOOL_JSON_REGISTRATION.lock().unwrap().is_none()); +} + +#[test] +fn typed_registration_name_allocation_failure_drops_callback_state() { + let _guard = begin_test(); + let host = test_host(); + let drops = Arc::new(AtomicUsize::new(0)); + let drop_counter = DropCounter(drops.clone()); + let mut ctx = test_context(&host); + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(0); + let result = ctx.register_tool_request_intercept("tool", 0, false, move |_name, value| { + let _keep_alive = &drop_counter; + Ok(value) + }); + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = None; + + assert!(result.is_err()); + assert_eq!(drops.load(Ordering::SeqCst), 1); + assert!(TOOL_JSON_REGISTRATION.lock().unwrap().is_none()); +} + +struct ConstructorPanicPlugin; + +impl NativePlugin for ConstructorPanicPlugin { + fn plugin_kind(&self) -> &str { + "test.constructor_panic" + } + + fn register( + &mut self, + _plugin_config: &Map, + _ctx: &mut PluginContext<'_>, + ) -> nemo_relay_plugin::Result<()> { + Ok(()) + } +} + +static CONSTRUCTOR_CALLS: AtomicUsize = AtomicUsize::new(0); + +struct CountingPlugin; + +impl NativePlugin for CountingPlugin { + fn plugin_kind(&self) -> &str { + "test.counting" + } + + fn register( + &mut self, + _plugin_config: &Map, + _ctx: &mut PluginContext<'_>, + ) -> nemo_relay_plugin::Result<()> { + Ok(()) + } +} + +struct DiagnosticsPlugin; + +impl NativePlugin for DiagnosticsPlugin { + fn plugin_kind(&self) -> &str { + "test.diagnostics" + } + + fn allows_multiple_components(&self) -> bool { + false + } + + fn validate(&self, plugin_config: &Map) -> Vec { + vec![ConfigDiagnostic { + level: DiagnosticLevel::Warning, + code: "test.warning".into(), + component: plugin_config + .get("component") + .and_then(Json::as_str) + .map(ToOwned::to_owned), + field: Some("component".into()), + message: "diagnostic from plugin".into(), + }] + } + + fn register( + &mut self, + _plugin_config: &Map, + _ctx: &mut PluginContext<'_>, + ) -> nemo_relay_plugin::Result<()> { + Ok(()) + } +} + +struct RegisteringPlugin; + +impl NativePlugin for RegisteringPlugin { + fn plugin_kind(&self) -> &str { + "test.registering" + } + + fn register( + &mut self, + plugin_config: &Map, + ctx: &mut PluginContext<'_>, + ) -> nemo_relay_plugin::Result<()> { + assert_eq!(plugin_config.get("enabled"), Some(&json!(true))); + assert_eq!(ctx.host_api().abi_version, NEMO_RELAY_NATIVE_ABI_VERSION); + assert!(ctx.runtime().scope_stack_active()); + ctx.register_subscriber("registered", |_event: &Event| {})?; + Ok(()) + } +} + +struct RegisterErrorPlugin; + +impl NativePlugin for RegisterErrorPlugin { + fn plugin_kind(&self) -> &str { + "test.register_error" + } + + fn register( + &mut self, + _plugin_config: &Map, + _ctx: &mut PluginContext<'_>, + ) -> nemo_relay_plugin::Result<()> { + Err("register rejected config".into()) + } +} + +struct PluginKindPanicPlugin; + +impl NativePlugin for PluginKindPanicPlugin { + fn plugin_kind(&self) -> &str { + panic!("plugin kind panic") + } + + fn register( + &mut self, + _plugin_config: &Map, + _ctx: &mut PluginContext<'_>, + ) -> nemo_relay_plugin::Result<()> { + Ok(()) + } +} + +struct AllowsMultiplePanicPlugin; + +impl NativePlugin for AllowsMultiplePanicPlugin { + fn plugin_kind(&self) -> &str { + "test.allows_multiple_panic" + } + + fn allows_multiple_components(&self) -> bool { + panic!("allows multiple panic") + } + + fn register( + &mut self, + _plugin_config: &Map, + _ctx: &mut PluginContext<'_>, + ) -> nemo_relay_plugin::Result<()> { + Ok(()) + } +} + +struct ValidatePanicPlugin; + +impl NativePlugin for ValidatePanicPlugin { + fn plugin_kind(&self) -> &str { + "test.validate_panic" + } + + fn validate( + &self, + _plugin_config: &Map, + ) -> Vec { + panic!("validate panic") + } + + fn register( + &mut self, + _plugin_config: &Map, + _ctx: &mut PluginContext<'_>, + ) -> nemo_relay_plugin::Result<()> { + Ok(()) + } +} + +struct RegisterPanicPlugin; + +impl NativePlugin for RegisterPanicPlugin { + fn plugin_kind(&self) -> &str { + "test.register_panic" + } + + fn register( + &mut self, + _plugin_config: &Map, + _ctx: &mut PluginContext<'_>, + ) -> nemo_relay_plugin::Result<()> { + panic!("register panic") + } +} + +struct DropPanicPlugin; + +impl Drop for DropPanicPlugin { + fn drop(&mut self) { + panic!("plugin state drop panic") + } +} + +impl NativePlugin for DropPanicPlugin { + fn plugin_kind(&self) -> &str { + "test.drop_panic" + } + + fn register( + &mut self, + _plugin_config: &Map, + _ctx: &mut PluginContext<'_>, + ) -> nemo_relay_plugin::Result<()> { + Ok(()) + } +} + +nemo_relay_plugin::nemo_relay_plugin!(constructor_counting_entry, || { + CONSTRUCTOR_CALLS.fetch_add(1, Ordering::SeqCst); + CountingPlugin +}); +nemo_relay_plugin::nemo_relay_plugin!(constructor_panic_entry, || -> ConstructorPanicPlugin { + panic!("constructor panic") +}); +nemo_relay_plugin::nemo_relay_plugin!(plugin_kind_panic_entry, || PluginKindPanicPlugin); +nemo_relay_plugin::nemo_relay_plugin!(allows_multiple_panic_entry, || AllowsMultiplePanicPlugin); + +unsafe fn drop_exported_plugin(host: &NemoRelayNativeHostApiV1, plugin: NemoRelayNativePluginV1) { + unsafe { (host.string_free)(plugin.plugin_kind) }; + if let Some(drop_fn) = plugin.drop { + unsafe { drop_fn(plugin.user_data) }; + } +} + +#[test] +fn direct_export_plugin_validates_host_table_and_kind_allocation() { + let _guard = begin_test(); + let host = test_host(); + + let mut plugin = NemoRelayNativePluginV1::default(); + assert_eq!( + unsafe { nemo_relay_plugin::export_plugin(ptr::null(), &mut plugin, CountingPlugin) }, + NemoRelayStatus::NullPointer + ); + assert_eq!( + unsafe { nemo_relay_plugin::export_plugin(&host, ptr::null_mut(), CountingPlugin) }, + NemoRelayStatus::NullPointer + ); + + let mut bad_host = host; + bad_host.abi_version = NEMO_RELAY_NATIVE_ABI_VERSION + 1; + let stale_kind = host_string(&host, "stale"); + let mut plugin = NemoRelayNativePluginV1 { + struct_size: 123, + plugin_kind: stale_kind, + allows_multiple_components: false, + user_data: NonNull::::dangling().as_ptr().cast(), + validate: None, + register: None, + drop: None, + }; + assert_eq!( + unsafe { nemo_relay_plugin::export_plugin(&bad_host, &mut plugin, CountingPlugin) }, + NemoRelayStatus::InvalidArg + ); + unsafe { (host.string_free)(stale_kind) }; + assert!(plugin.plugin_kind.is_null()); + assert!(plugin.user_data.is_null()); + + let mut short_host = host; + short_host.struct_size = size_of::() - 1; + assert_eq!( + unsafe { nemo_relay_plugin::export_plugin(&short_host, &mut plugin, CountingPlugin) }, + NemoRelayStatus::InvalidArg + ); + + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(0); + assert_eq!( + unsafe { nemo_relay_plugin::export_plugin(&host, &mut plugin, CountingPlugin) }, + NemoRelayStatus::Internal + ); + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = None; + assert!(plugin.plugin_kind.is_null()); + assert!(plugin.user_data.is_null()); +} + +#[test] +fn exported_plugin_validate_serializes_diagnostics_and_rejects_invalid_config() { + let _guard = begin_test(); + let host = test_host(); + let mut plugin = NemoRelayNativePluginV1::default(); + assert_eq!( + unsafe { nemo_relay_plugin::export_plugin(&host, &mut plugin, DiagnosticsPlugin) }, + NemoRelayStatus::Ok + ); + assert!(!plugin.allows_multiple_components); + assert_eq!( + read_host_string(&host, plugin.plugin_kind).as_deref(), + Some("test.diagnostics") + ); + + let config = json_host_string(&host, json!({ "component": "policy" })); + let mut diagnostics = ptr::null_mut(); + assert_eq!( + unsafe { plugin.validate.unwrap()(plugin.user_data, config, &mut diagnostics) }, + NemoRelayStatus::Ok + ); + let diagnostics: Vec = + serde_json::from_value(read_json_and_free(&host, diagnostics)).unwrap(); + assert_eq!(diagnostics.len(), 1); + assert_eq!(diagnostics[0].level, DiagnosticLevel::Warning); + assert_eq!(diagnostics[0].component.as_deref(), Some("policy")); + unsafe { (host.string_free)(config) }; + + let config = json_host_string(&host, json!(["not", "object"])); + let stale = host_string(&host, r#"[{"stale":true}]"#); + let mut diagnostics = stale; + assert_eq!( + unsafe { plugin.validate.unwrap()(plugin.user_data, config, &mut diagnostics) }, + NemoRelayStatus::InvalidJson + ); + assert!(diagnostics.is_null()); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("plugin config must be a JSON object") + ); + unsafe { + (host.string_free)(stale); + (host.string_free)(config); + } + + let config = host_string(&host, "{not json"); + assert_eq!( + unsafe { plugin.validate.unwrap()(plugin.user_data, config, ptr::null_mut()) }, + NemoRelayStatus::NullPointer + ); + let mut diagnostics = ptr::null_mut(); + assert_eq!( + unsafe { plugin.validate.unwrap()(ptr::null_mut(), config, &mut diagnostics) }, + NemoRelayStatus::NullPointer + ); + assert_eq!( + unsafe { plugin.validate.unwrap()(plugin.user_data, config, &mut diagnostics) }, + NemoRelayStatus::InvalidJson + ); + let last_error = LAST_ERROR.lock().unwrap().clone().unwrap(); + assert!(last_error.starts_with("plugin config was invalid JSON:")); + unsafe { + (host.string_free)(config); + drop_exported_plugin(&host, plugin); + } +} + +#[test] +fn exported_plugin_default_validate_returns_empty_diagnostics() { + let _guard = begin_test(); + let host = test_host(); + let mut plugin = NemoRelayNativePluginV1::default(); + assert_eq!( + unsafe { nemo_relay_plugin::export_plugin(&host, &mut plugin, CountingPlugin) }, + NemoRelayStatus::Ok + ); + + let config = json_host_string(&host, json!({})); + let mut diagnostics = ptr::null_mut(); + assert_eq!( + unsafe { plugin.validate.unwrap()(plugin.user_data, config, &mut diagnostics) }, + NemoRelayStatus::Ok + ); + let diagnostics: Vec = + serde_json::from_value(read_json_and_free(&host, diagnostics)).unwrap(); + assert!(diagnostics.is_empty()); + unsafe { + (host.string_free)(config); + drop_exported_plugin(&host, plugin); + } +} + +#[test] +fn exported_plugin_register_installs_callbacks_and_propagates_errors() { + let _guard = begin_test(); + let host = test_host(); + + let mut plugin = NemoRelayNativePluginV1::default(); + assert_eq!( + unsafe { nemo_relay_plugin::export_plugin(&host, &mut plugin, RegisteringPlugin) }, + NemoRelayStatus::Ok + ); + let config = json_host_string(&host, json!({ "enabled": true })); + assert_eq!( + unsafe { + plugin.register.unwrap()( + plugin.user_data, + config, + NonNull::::dangling().as_ptr(), + ) + }, + NemoRelayStatus::Ok + ); + let registration = take_subscriber_registration(); + assert_eq!(registration.name, "registered"); + unsafe { + registration.free(); + (host.string_free)(config); + } + + let config = json_host_string(&host, json!({ "enabled": true })); + assert_eq!( + unsafe { plugin.register.unwrap()(plugin.user_data, config, ptr::null_mut()) }, + NemoRelayStatus::NullPointer + ); + unsafe { (host.string_free)(config) }; + unsafe { drop_exported_plugin(&host, plugin) }; + + let mut plugin = NemoRelayNativePluginV1::default(); + assert_eq!( + unsafe { nemo_relay_plugin::export_plugin(&host, &mut plugin, RegisterErrorPlugin) }, + NemoRelayStatus::Ok + ); + let config = json_host_string(&host, json!({})); + assert_eq!( + unsafe { + plugin.register.unwrap()( + plugin.user_data, + config, + NonNull::::dangling().as_ptr(), + ) + }, + NemoRelayStatus::Internal + ); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("register rejected config") + ); + unsafe { + (host.string_free)(config); + drop_exported_plugin(&host, plugin); + } +} + +#[test] +fn exported_entry_symbol_validates_args_before_constructor() { + let _guard = begin_test(); + let host = test_host(); + CONSTRUCTOR_CALLS.store(0, Ordering::SeqCst); + + let mut plugin = NemoRelayNativePluginV1::default(); + assert_eq!( + unsafe { constructor_counting_entry(ptr::null(), &mut plugin) }, + NemoRelayStatus::NullPointer + ); + assert_eq!(CONSTRUCTOR_CALLS.load(Ordering::SeqCst), 0); + + assert_eq!( + unsafe { constructor_counting_entry(&host, ptr::null_mut()) }, + NemoRelayStatus::NullPointer + ); + assert_eq!(CONSTRUCTOR_CALLS.load(Ordering::SeqCst), 0); + + let mut bad_host = host; + bad_host.abi_version = NEMO_RELAY_NATIVE_ABI_VERSION + 1; + let stale_kind = host_string(&host, "stale"); + let mut plugin = NemoRelayNativePluginV1 { + struct_size: 123, + plugin_kind: stale_kind, + allows_multiple_components: true, + user_data: NonNull::::dangling().as_ptr().cast(), + validate: None, + register: None, + drop: None, + }; + assert_eq!( + unsafe { constructor_counting_entry(&bad_host, &mut plugin) }, + NemoRelayStatus::InvalidArg + ); + unsafe { (host.string_free)(stale_kind) }; + assert_eq!(CONSTRUCTOR_CALLS.load(Ordering::SeqCst), 0); + let default_plugin = NemoRelayNativePluginV1::default(); + assert_eq!(plugin.struct_size, default_plugin.struct_size); + assert!(plugin.plugin_kind.is_null()); + assert_eq!( + plugin.allows_multiple_components, + default_plugin.allows_multiple_components + ); + assert!(plugin.user_data.is_null()); + assert!(plugin.validate.is_none()); + assert!(plugin.register.is_none()); + assert!(plugin.drop.is_none()); +} + +#[test] +fn exported_entry_symbol_catches_panics() { + let _guard = begin_test(); + let host = test_host(); + + for entry in [ + constructor_panic_entry, + plugin_kind_panic_entry, + allows_multiple_panic_entry, + ] { + *LAST_ERROR.lock().unwrap() = Some("stale error".into()); + let mut plugin = NemoRelayNativePluginV1::default(); + assert_eq!( + unsafe { entry(&host, &mut plugin) }, + NemoRelayStatus::Internal + ); + assert!(plugin.plugin_kind.is_null()); + assert!(plugin.user_data.is_null()); + assert!(plugin.validate.is_none()); + assert!(plugin.register.is_none()); + assert!(plugin.drop.is_none()); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("native plugin entry callback panicked") + ); + } +} + +#[test] +fn plugin_drop_callback_catches_state_drop_panics() { + let _guard = begin_test(); + let host = test_host(); + let mut plugin = NemoRelayNativePluginV1::default(); + assert_eq!( + unsafe { nemo_relay_plugin::export_plugin(&host, &mut plugin, DropPanicPlugin) }, + NemoRelayStatus::Ok + ); + + *LAST_ERROR.lock().unwrap() = None; + unsafe { drop_exported_plugin(&host, plugin) }; + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("native plugin state drop panicked") + ); +} + +#[test] +fn plugin_validate_and_register_panics_replace_last_error() { + let _guard = begin_test(); + let host = test_host(); + + let mut validate_plugin = NemoRelayNativePluginV1::default(); + assert_eq!( + unsafe { + nemo_relay_plugin::export_plugin(&host, &mut validate_plugin, ValidatePanicPlugin) + }, + NemoRelayStatus::Ok + ); + *LAST_ERROR.lock().unwrap() = Some("stale error".into()); + let config = json_host_string(&host, json!({})); + let stale_diagnostics = host_string(&host, r#"[{"stale":true}]"#); + let mut diagnostics = stale_diagnostics; + assert_eq!( + unsafe { + validate_plugin.validate.unwrap()(validate_plugin.user_data, config, &mut diagnostics) + }, + NemoRelayStatus::Internal + ); + assert!(diagnostics.is_null()); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("native plugin validate callback panicked") + ); + unsafe { + (host.string_free)(stale_diagnostics); + (host.string_free)(config); + drop_exported_plugin(&host, validate_plugin); + } + + let mut register_plugin = NemoRelayNativePluginV1::default(); + assert_eq!( + unsafe { + nemo_relay_plugin::export_plugin(&host, &mut register_plugin, RegisterPanicPlugin) + }, + NemoRelayStatus::Ok + ); + *LAST_ERROR.lock().unwrap() = Some("stale error".into()); + let config = json_host_string(&host, json!({})); + assert_eq!( + unsafe { + register_plugin.register.unwrap()( + register_plugin.user_data, + config, + NonNull::::dangling().as_ptr(), + ) + }, + NemoRelayStatus::Internal + ); + assert_eq!( + LAST_ERROR.lock().unwrap().as_deref(), + Some("native plugin register callback panicked") + ); + unsafe { + (host.string_free)(config); + drop_exported_plugin(&host, register_plugin); + } +} diff --git a/examples/rust-native-plugin/.gitignore b/examples/rust-native-plugin/.gitignore new file mode 100644 index 00000000..d2577c43 --- /dev/null +++ b/examples/rust-native-plugin/.gitignore @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +/target/ +/Cargo.lock diff --git a/examples/rust-native-plugin/Cargo.toml b/examples/rust-native-plugin/Cargo.toml new file mode 100644 index 00000000..2a47f953 --- /dev/null +++ b/examples/rust-native-plugin/Cargo.toml @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "nemo-relay-rust-native-plugin-example" +version = "0.1.0" +edition = "2024" +publish = false +description = "Example Rust native dynamic plugin for NeMo Relay" + +[workspace] + +[lib] +name = "nemo_relay_rust_native_plugin_example" +crate-type = ["cdylib"] + +[dependencies] +nemo-relay-plugin = { path = "../../crates/plugin" } +serde_json = "1" diff --git a/examples/rust-native-plugin/README.md b/examples/rust-native-plugin/README.md new file mode 100644 index 00000000..69e77d0e --- /dev/null +++ b/examples/rust-native-plugin/README.md @@ -0,0 +1,78 @@ + + +# Rust Native Dynamic Plugin + +This example shows a trusted in-process Rust dynamic plugin using the +high-level `nemo-relay-plugin` SDK. It builds as a `cdylib`, exports a stable +native ABI entry symbol, validates JSON config, registers middleware and +subscribers, emits runtime marks/scopes, and creates an isolated scope stack. + +The example intentionally depends on `nemo-relay-plugin`, not on the host +`nemo-relay` runtime crate. Rust DTOs stay inside the plugin crate; the +dynamic-library boundary remains the stable C ABI. + +## Build + +Run this command from the example directory: + +```bash +cargo build +``` + +Before you register the plugin, replace `` in +`relay-plugin.toml` with the file name that `cargo build` creates for your +platform: + +| Platform | Library path | +|---|---| +| macOS | `target/debug/libnemo_relay_rust_native_plugin_example.dylib` | +| Linux | `target/debug/libnemo_relay_rust_native_plugin_example.so` | +| Windows | `target/debug/nemo_relay_rust_native_plugin_example.dll` | + +## Register With Relay + +After updating `load.library`, run these commands from the repository root: + +```bash +nemo-relay plugins add ./examples/rust-native-plugin/relay-plugin.toml +nemo-relay plugins enable examples.rust_native_policy +``` + +You can also reference the manifest manually from `plugins.toml`: + +```toml +[[plugins.dynamic]] +manifest = "./examples/rust-native-plugin/relay-plugin.toml" + +[plugins.dynamic.config] +tag = "demo" +block_tools = false +block_llms = false +emit_isolated_scope = true +``` + +Start the gateway normally after the dynamic record is enabled: + +```bash +nemo-relay gateway +``` + +## What the Example Registers + +The example registers the following runtime behavior: + +- A subscriber that emits a mark when it sees non-plugin scope starts. +- Tool sanitize request/response guardrails for observability payload tagging. +- Conditional execution guardrails for tools and LLMs controlled by config. +- Request and execution intercepts for tools that mutate JSON payloads and call + continuations. +- LLM sanitize request/response guardrails. +- LLM request, execution, and stream execution intercepts. +- Runtime mark and scope events. +- A plugin-owned isolated scope stack for non-correlated visibility. + +Native plugins are not sandboxed. They run in the Relay process and must not +unwind across ABI callbacks. diff --git a/examples/rust-native-plugin/relay-plugin.toml b/examples/rust-native-plugin/relay-plugin.toml new file mode 100644 index 00000000..77644210 --- /dev/null +++ b/examples/rust-native-plugin/relay-plugin.toml @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +manifest_version = 1 + +[plugin] +id = "examples.rust_native_policy" +kind = "rust_dynamic" + +[compat] +relay = ">=0.5,<1.0" +native_api = "1" + +[defaults] +enabled = false + +[capabilities] +items = ["plugin_native"] + +[load] +# Replace `` with the file built by `cargo build` for +# your platform. Refer to README.md for the expected debug artifact names. +library = "target/debug/" +symbol = "nemo_relay_register_plugin" diff --git a/examples/rust-native-plugin/src/lib.rs b/examples/rust-native-plugin/src/lib.rs new file mode 100644 index 00000000..f18296f5 --- /dev/null +++ b/examples/rust-native-plugin/src/lib.rs @@ -0,0 +1,351 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use nemo_relay_plugin::{ + ConfigDiagnostic, DiagnosticLevel, Event, Json, LlmJsonStream, LlmRequest, NativePlugin, + PluginContext, PluginRuntime, ScopeCategory, ScopeType, +}; +use serde_json::{Map, json}; + +struct ExampleNativePlugin; + +#[derive(Clone, Debug)] +struct ExampleConfig { + tag: String, + block_tools: bool, + block_llms: bool, + emit_isolated_scope: bool, +} + +#[derive(Clone, Copy)] +enum ConfigField { + Tag, + BlockTools, + BlockLlms, + EmitIsolatedScope, +} + +impl ConfigField { + const ALL: [Self; 4] = [ + Self::Tag, + Self::BlockTools, + Self::BlockLlms, + Self::EmitIsolatedScope, + ]; + + const fn name(self) -> &'static str { + match self { + Self::Tag => "tag", + Self::BlockTools => "block_tools", + Self::BlockLlms => "block_llms", + Self::EmitIsolatedScope => "emit_isolated_scope", + } + } + + const fn expected_type(self) -> &'static str { + match self { + Self::Tag => "string", + Self::BlockTools | Self::BlockLlms | Self::EmitIsolatedScope => "boolean", + } + } + + const fn invalid_code(self) -> &'static str { + match self { + Self::Tag => "examples.rust_native_policy.invalid_tag", + Self::BlockTools | Self::BlockLlms | Self::EmitIsolatedScope => { + "examples.rust_native_policy.invalid_boolean" + } + } + } + + fn accepts(self, value: &Json) -> bool { + match self { + Self::Tag => value.is_string(), + Self::BlockTools | Self::BlockLlms | Self::EmitIsolatedScope => value.is_boolean(), + } + } + + fn parse_into( + self, + value: &Json, + config: &mut ExampleConfig, + ) -> nemo_relay_plugin::Result<()> { + if !self.accepts(value) { + return Err(format!( + "{} must be a {}", + self.name(), + self.expected_type() + )); + } + match self { + Self::Tag => { + config.tag = value.as_str().expect("checked config field type").to_owned(); + } + Self::BlockTools => { + config.block_tools = value.as_bool().expect("checked config field type"); + } + Self::BlockLlms => { + config.block_llms = value.as_bool().expect("checked config field type"); + } + Self::EmitIsolatedScope => { + config.emit_isolated_scope = value + .as_bool() + .expect("checked config field type"); + } + } + Ok(()) + } +} + +impl Default for ExampleConfig { + fn default() -> Self { + Self { + tag: "rust-native-example".into(), + block_tools: false, + block_llms: false, + emit_isolated_scope: true, + } + } +} + +impl ExampleConfig { + fn parse(plugin_config: &Map) -> nemo_relay_plugin::Result { + let mut config = Self::default(); + for field in ConfigField::ALL { + if let Some(value) = plugin_config.get(field.name()) { + field.parse_into(value, &mut config)?; + } + } + + Ok(config) + } +} + +impl NativePlugin for ExampleNativePlugin { + fn plugin_kind(&self) -> &str { + "examples.rust_native_policy" + } + + fn allows_multiple_components(&self) -> bool { + false + } + + fn validate(&self, plugin_config: &Map) -> Vec { + let mut diagnostics = Vec::new(); + + for key in plugin_config.keys() { + if !ConfigField::ALL + .iter() + .any(|field| field.name() == key.as_str()) + { + diagnostics.push(diagnostic( + DiagnosticLevel::Warning, + "examples.rust_native_policy.unknown_field", + Some(key), + format!("unknown config field '{key}' will be ignored"), + )); + } + } + + for field in ConfigField::ALL { + if let Some(value) = plugin_config.get(field.name()) { + if !field.accepts(value) { + diagnostics.push(diagnostic( + DiagnosticLevel::Error, + field.invalid_code(), + Some(field.name()), + format!("{} must be a {}", field.name(), field.expected_type()), + )); + } + } + } + + diagnostics + } + + fn register( + &mut self, + plugin_config: &Map, + ctx: &mut PluginContext<'_>, + ) -> nemo_relay_plugin::Result<()> { + let config = ExampleConfig::parse(plugin_config)?; + let runtime = ctx.runtime(); + + ctx.register_subscriber("example_native_subscriber", { + let runtime = runtime.clone(); + let tag = config.tag.clone(); + move |event| subscriber_mark(&runtime, &tag, event) + })?; + + ctx.register_tool_sanitize_request_guardrail("example_tool_sanitize_request", 10, { + let tag = config.tag.clone(); + move |_name, args| tag_json(args, "native_tool_sanitize_request", &tag) + })?; + ctx.register_tool_sanitize_response_guardrail("example_tool_sanitize_response", 10, { + let tag = config.tag.clone(); + move |_name, result| tag_json(result, "native_tool_sanitize_response", &tag) + })?; + ctx.register_tool_conditional_execution_guardrail("example_tool_conditional", 10, { + let block_tools = config.block_tools; + move |name, _args| { + Ok(block_tools.then(|| format!("tool '{name}' blocked by Rust native plugin"))) + } + })?; + ctx.register_tool_request_intercept("example_tool_request", 20, false, { + let runtime = runtime.clone(); + let tag = config.tag.clone(); + let emit_isolated_scope = config.emit_isolated_scope; + move |name, args| { + emit_runtime_events(&runtime, &tag, emit_isolated_scope)?; + let mut scope = runtime.scope( + "example.native.tool_request", + ScopeType::Tool, + Some(&json!({ "tool": name, "tag": tag })), + None, + Some(&args), + )?; + let tagged = tag_json(args, "native_tool_request_intercept", &tag); + scope.close(Some(&tagged), None)?; + Ok(tagged) + } + })?; + ctx.register_tool_execution_intercept("example_tool_execution", 30, { + let tag = config.tag.clone(); + move |_name, args, next| { + let request = tag_json(args, "native_tool_execution_request", &tag); + let result = next.call(request)?; + Ok(tag_json(result, "native_tool_execution_response", &tag)) + } + })?; + + ctx.register_llm_sanitize_request_guardrail("example_llm_sanitize_request", 10, { + let tag = config.tag.clone(); + move |request| tag_llm_request(request, "native_llm_sanitize_request", &tag) + })?; + ctx.register_llm_sanitize_response_guardrail("example_llm_sanitize_response", 10, { + let tag = config.tag.clone(); + move |response| tag_json(response, "native_llm_sanitize_response", &tag) + })?; + ctx.register_llm_conditional_execution_guardrail("example_llm_conditional", 10, { + let block_llms = config.block_llms; + move |_request| { + Ok(block_llms.then(|| "LLM call blocked by Rust native plugin".to_string())) + } + })?; + ctx.register_llm_request_intercept("example_llm_request", 20, false, { + let tag = config.tag.clone(); + move |_name, request, annotated| { + Ok(( + tag_llm_request(request, "native_llm_request_intercept", &tag), + annotated, + )) + } + })?; + ctx.register_llm_execution_intercept("example_llm_execution", 30, { + let tag = config.tag.clone(); + move |_name, request, next| { + let request = tag_llm_request(request, "native_llm_execution_request", &tag); + let response = next.call(request)?; + Ok(tag_json(response, "native_llm_execution_response", &tag)) + } + })?; + ctx.register_llm_stream_execution_intercept("example_llm_stream_execution", 30, { + let tag = config.tag; + move |_name, request, next| { + let request = tag_llm_request(request, "native_llm_stream_execution_request", &tag); + let stream = next.call(request)?; + let tag = tag.clone(); + let stream: LlmJsonStream = Box::new(stream.map(move |chunk| { + chunk.map(|chunk| { + tag_json(chunk, "native_llm_stream_execution_response", &tag) + }) + })); + Ok(stream) + } + })?; + + Ok(()) + } +} + +fn diagnostic( + level: DiagnosticLevel, + code: &str, + field: Option<&str>, + message: impl Into, +) -> ConfigDiagnostic { + ConfigDiagnostic { + level, + code: code.into(), + component: Some("examples.rust_native_policy".into()), + field: field.map(str::to_owned), + message: message.into(), + } +} + +fn subscriber_mark(runtime: &PluginRuntime, tag: &str, event: &Event) { + if event.scope_category() == Some(ScopeCategory::Start) + && !event.name().starts_with("example.native") + { + let _ = runtime.emit_mark( + "example.native.subscriber.seen", + Some(&json!({ "event": event.name(), "tag": tag })), + None, + ); + } +} + +fn emit_runtime_events( + runtime: &PluginRuntime, + tag: &str, + emit_isolated_scope: bool, +) -> nemo_relay_plugin::Result<()> { + runtime.emit_mark( + "example.native.tool_request.seen", + Some(&json!({ "tag": tag })), + None, + )?; + + if !emit_isolated_scope { + return Ok(()); + } + + let isolated = runtime.create_scope_stack()?; + isolated.with_current(|| { + runtime.emit_mark( + "example.native.isolated.mark", + Some(&json!({ "tag": tag })), + None, + )?; + let mut scope = runtime.scope( + "example.native.isolated.scope", + ScopeType::Custom, + None, + Some(&json!({ "visibility": "isolated" })), + Some(&json!({ "tag": tag })), + )?; + scope.close(Some(&json!({ "done": true })), None) + }) +} + +fn tag_llm_request(mut request: LlmRequest, key: &str, tag: &str) -> LlmRequest { + request.headers.insert( + "x-nemo-relay-native-plugin".into(), + Json::String(tag.into()), + ); + request.content = tag_json(request.content, key, tag); + request +} + +fn tag_json(value: Json, key: &str, tag: &str) -> Json { + match value { + Json::Object(mut object) => { + object.insert(key.into(), Json::Bool(true)); + object.insert("native_plugin_tag".into(), Json::String(tag.into())); + Json::Object(object) + } + other => other, + } +} + +nemo_relay_plugin::nemo_relay_plugin!(nemo_relay_register_plugin, || ExampleNativePlugin); diff --git a/justfile b/justfile index f51252f1..20fa8e36 100644 --- a/justfile +++ b/justfile @@ -453,7 +453,9 @@ output = [] changed = [] found_workspace_version = False local_dependencies = ( + "nemo-relay-types", "nemo-relay", + "nemo-relay-plugin", "nemo-relay-adaptive", "nemo-relay-pii-redaction", "nemo-relay-ffi", @@ -649,6 +651,17 @@ print(f"crates/python/Cargo.toml version updated to {cargo_version}") PY } +published_cargo_packages() { + printf '%s\n' \ + nemo-relay-types \ + nemo-relay \ + nemo-relay-adaptive \ + nemo-relay-plugin \ + nemo-relay-pii-redaction \ + nemo-relay-ffi \ + nemo-relay-cli +} + # Keep local wheel packaging aligned with the CI matrix without requiring raw # maturin flags to be passed through `just --set`. linux_manylinux_compatibility() { @@ -844,6 +857,8 @@ clean: python/nemo_relay/*.so \ python/nemo_relay/__pycache__ \ python/nemo_relay/_native*.pyd \ + examples/rust-native-plugin/Cargo.lock \ + examples/rust-native-plugin/target \ python/tests/__pycache__ \ target @@ -1117,6 +1132,65 @@ set-version version="": cd "$NEMO_RELAY_REPO_ROOT" set_project_version "$version" +# --set [output_dir=] [ref_name=] +package-rust: + #!/usr/bin/env bash + {{ bash_helpers }} + output_dir="{{ output_dir }}" + cd "$NEMO_RELAY_REPO_ROOT" + package_dir="$(prepare_package_dir crates)" + package_target_dir="$(mktemp -d)" + cleanup_package_target_dir() { + rm -rf "$package_target_dir" + } + trap cleanup_package_target_dir EXIT + ref_name={{ quote(ref_name) }} + if [[ -n "$ref_name" ]]; then + echo "Using explicit version $ref_name" + set_cargo_workspace_version "$ref_name" + fi + cargo_dirty_args=() + if [[ -n "$ref_name" ]]; then + cargo_dirty_args+=(--allow-dirty) + fi + while IFS= read -r package; do + cargo_package_config=() + case "$package" in + nemo-relay) + cargo_package_config+=(--config 'patch.crates-io.nemo-relay-types.path="crates/types"') + ;; + nemo-relay-adaptive) + cargo_package_config+=(--config 'patch.crates-io.nemo-relay-types.path="crates/types"') + cargo_package_config+=(--config 'patch.crates-io.nemo-relay.path="crates/core"') + ;; + nemo-relay-plugin) + cargo_package_config+=(--config 'patch.crates-io.nemo-relay-types.path="crates/types"') + ;; + nemo-relay-pii-redaction) + cargo_package_config+=(--config 'patch.crates-io.nemo-relay-types.path="crates/types"') + cargo_package_config+=(--config 'patch.crates-io.nemo-relay.path="crates/core"') + ;; + nemo-relay-ffi|nemo-relay-cli) + cargo_package_config+=(--config 'patch.crates-io.nemo-relay-types.path="crates/types"') + cargo_package_config+=(--config 'patch.crates-io.nemo-relay.path="crates/core"') + cargo_package_config+=(--config 'patch.crates-io.nemo-relay-adaptive.path="crates/adaptive"') + cargo_package_config+=(--config 'patch.crates-io.nemo-relay-pii-redaction.path="crates/pii-redaction"') + ;; + esac + if ((${#cargo_package_config[@]} == 0)); then + cargo package --locked --package "$package" "${cargo_dirty_args[@]}" --target-dir "$package_target_dir" + else + cargo package --locked --package "$package" "${cargo_dirty_args[@]}" --target-dir "$package_target_dir" "${cargo_package_config[@]}" + fi + done < <(published_cargo_packages) + find "$package_target_dir/package" -maxdepth 1 -type f -name '*.crate' -exec cp {} "$package_dir"/ \; + shopt -s nullglob + packages=("$package_dir"/*.crate) + if ((${#packages[@]} == 0)); then + echo "Error: No Cargo package artifacts found in $package_dir" + exit 1 + fi + # --set [output_dir=] [ref_name=] package-node: #!/usr/bin/env bash