From 84cd08d484075c99a88b51153bd3d8d5069417c1 Mon Sep 17 00:00:00 2001 From: inaniloquentee <3051000145@qq.com> Date: Tue, 16 Jun 2026 21:55:19 +0800 Subject: [PATCH] feat: add vime rollout logp probe Signed-off-by: inaniloquentee <3051000145@qq.com> --- docs/usage/README.md | 1 + docs/usage/vime-rollout-logp-probe.md | 143 +++++++++++++++ rl_engine/integrations/__init__.py | 4 + rl_engine/integrations/vime/__init__.py | 22 +++ .../integrations/vime/rollout_logp_probe.py | 149 +++++++++++++++ tests/test_vime_rollout_logp_probe.py | 171 ++++++++++++++++++ 6 files changed, 490 insertions(+) create mode 100644 docs/usage/vime-rollout-logp-probe.md create mode 100644 rl_engine/integrations/__init__.py create mode 100644 rl_engine/integrations/vime/__init__.py create mode 100644 rl_engine/integrations/vime/rollout_logp_probe.py create mode 100644 tests/test_vime_rollout_logp_probe.py diff --git a/docs/usage/README.md b/docs/usage/README.md index 095bf40..54b7793 100644 --- a/docs/usage/README.md +++ b/docs/usage/README.md @@ -10,3 +10,4 @@ Start with: - [Single-GPU GRPO Example](grpo-single-gpu-example.md) - [Operators](../operators/README.md) - [Weight Sync Bridge](weight-sync-bridge.md) +- [Vime Rollout LogP Probe](vime-rollout-logp-probe.md) diff --git a/docs/usage/vime-rollout-logp-probe.md b/docs/usage/vime-rollout-logp-probe.md new file mode 100644 index 0000000..4ed8a7e --- /dev/null +++ b/docs/usage/vime-rollout-logp-probe.md @@ -0,0 +1,143 @@ +# Vime Rollout LogP Probe + +This page documents the minimal WS5 proof-of-concept for issue #120. It wires +one existing RL-Kernel operator into a single Vime rollout path using Vime's +public `--custom-generate-function-path` hook. + +## Issue Checklist Mapping + +- Smallest adapter/shim: `custom_generate` wraps one Vime generate call and one + RL-Kernel `logp` probe. +- Explicit opt-in: `RL_KERNEL_VIME_LOGP_PROBE=1` is required before RL-Kernel is + invoked. +- Instrumentation: `Sample.metadata["rl_kernel"]["vime_logp_probe"]` records + structured evidence, including whether the operator was invoked and the + process-local `call_count`. +- Run the minimal vime example: start from the #117 fully-async Qwen2.5-0.5B + baseline and add the exact command/config below. +- Smoke test with mocks: `tests/test_vime_rollout_logp_probe.py` installs a fake + Vime module when full Vime dependencies are unavailable. +- Fallback behavior: when RL-Kernel or its CUDA extension is unavailable, + non-strict mode keeps the native generated sample unchanged; native Vime/vLLM + generation failures still surface normally. + +## What It Proves + +The probe proves that a Vime rollout can invoke RL-Kernel from an opt-in custom +generate shim. It does not replace vLLM sampling or rollout-side logprob +computation. Vime's HTTP rollout path returns selected logprobs, not logits, so +the shim runs a small deterministic synthetic tensor through RL-Kernel's `logp` +operator and records structured evidence in `Sample.metadata`. + +## Entry Point + +Use this Vime custom generate path: + +```text +rl_engine.integrations.vime.rollout_logp_probe.custom_generate +``` + +Enable the probe with: + +```bash +export RL_KERNEL_VIME_LOGP_PROBE=1 +``` + +Optional strict mode: + +```bash +export RL_KERNEL_VIME_LOGP_STRICT=1 +``` + +Strict mode raises if RL-Kernel import or backend dispatch fails. Without strict +mode, the shim records fallback metadata and returns Vime's native generated +sample unchanged. + +## Minimal Vime Command + +Starting from the #117 baseline +`vime/examples/fully_async/run-qwen2.5-0.5B-fully_async.sh`, add the custom +generate function to `ROLLOUT_ARGS`: + +Add this line inside the script's `ROLLOUT_ARGS` array: + +```bash +--custom-generate-function-path rl_engine.integrations.vime.rollout_logp_probe.custom_generate +``` + +Make RL-Kernel importable inside the Ray job runtime environment. Either install +RL-Kernel in the image, or include the checkout path and opt-in variable in the +script's `RUNTIME_ENV_JSON`: + +```json +{ + "env_vars": { + "PYTHONPATH": "/path/to/RL-Kernel:/root/Megatron-LM/:${SCRIPT_DIR}", + "RL_KERNEL_VIME_LOGP_PROBE": "1", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "NCCL_NVLS_ENABLE": "${HAS_NVLINK}" + } +} +``` + +Then run the baseline script normally: + +```bash +bash examples/fully_async/run-qwen2.5-0.5B-fully_async.sh +``` + +For a direct `train_async.py` or `ray job submit` command, the exact Vime +argument to add is: + +```bash +--custom-generate-function-path rl_engine.integrations.vime.rollout_logp_probe.custom_generate +``` + +The run should produce samples whose metadata contains: + +```python +sample.metadata["rl_kernel"]["vime_logp_probe"] +``` + +Expected fields include: + +```text +enabled +invoked +call_count +op +backend +fallback +fallback_reason +output_shape +output_sum +``` + +The `invoked` field proves the shim reached `kernel_registry.get_op("logp")` +and executed the returned operator for that sample. `call_count` is a +process-local successful invocation counter. + +## Fallback Behavior + +The shim always calls Vime's native `vime.rollout.vllm_rollout.generate` first. +If native Vime/vLLM generation is unavailable, the run fails the same way the +native Vime path would fail; the shim does not hide that failure. If the probe is +disabled, it records `enabled=False` and returns the sample. If RL-Kernel is +unavailable, a backend is unavailable, or the CUDA extension is not built, +non-strict mode records `fallback=True` and returns the native generated sample +unchanged. + +This keeps pure Vime inference and native RL paths unaffected when the probe is +disabled or when RL-Kernel cannot run. + +## Local Smoke Test + +The mock smoke test does not require a full Vime installation: + +```bash +python -m pytest tests/test_vime_rollout_logp_probe.py +``` + +The test installs a fake `vime.rollout.vllm_rollout.generate`, exercises the +custom generate shim, verifies that RL-Kernel `logp` dispatch was invoked, and +checks non-strict fallback behavior. diff --git a/rl_engine/integrations/__init__.py b/rl_engine/integrations/__init__.py new file mode 100644 index 0000000..414e09f --- /dev/null +++ b/rl_engine/integrations/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +"""Framework integration helpers for RL-Kernel.""" diff --git a/rl_engine/integrations/vime/__init__.py b/rl_engine/integrations/vime/__init__.py new file mode 100644 index 0000000..89baa7e --- /dev/null +++ b/rl_engine/integrations/vime/__init__.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +"""Opt-in Vime integration helpers.""" + +from rl_engine.integrations.vime.rollout_logp_probe import ( + ENV_ENABLED, + ENV_STRICT, + METADATA_KEY, + RLKernelProbeResult, + custom_generate, + run_logp_probe, +) + +__all__ = [ + "ENV_ENABLED", + "ENV_STRICT", + "METADATA_KEY", + "RLKernelProbeResult", + "custom_generate", + "run_logp_probe", +] diff --git a/rl_engine/integrations/vime/rollout_logp_probe.py b/rl_engine/integrations/vime/rollout_logp_probe.py new file mode 100644 index 0000000..c8f7d2a --- /dev/null +++ b/rl_engine/integrations/vime/rollout_logp_probe.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import inspect +import os +from dataclasses import asdict, dataclass +from typing import Any + +ENV_ENABLED = "RL_KERNEL_VIME_LOGP_PROBE" +ENV_STRICT = "RL_KERNEL_VIME_LOGP_STRICT" +METADATA_KEY = "rl_kernel" + +_TRUE_VALUES = {"1", "true", "yes", "on"} +_CALL_COUNT = 0 + + +@dataclass(frozen=True) +class RLKernelProbeResult: + """Structured evidence that the Vime shim reached RL-Kernel.""" + + enabled: bool + invoked: bool + call_count: int = 0 + op: str = "logp" + backend: str | None = None + fallback: bool = False + fallback_reason: str | None = None + output_shape: tuple[int, ...] | None = None + output_sum: float | None = None + + +def _env_enabled(name: str) -> bool: + return os.environ.get(name, "").strip().lower() in _TRUE_VALUES + + +def _probe_tensors(): + import torch + + from rl_engine.platforms.device import device_ctx + + logits = torch.tensor( + [ + [0.25, 1.5, -0.5, 0.0], + [2.0, -1.0, 0.5, 0.25], + ], + device=device_ctx.device, + dtype=torch.float32, + ) + token_ids = torch.tensor([1, 0], device=device_ctx.device, dtype=torch.long) + return logits, token_ids + + +def _fallback(reason: str) -> RLKernelProbeResult: + return RLKernelProbeResult( + enabled=True, + invoked=False, + call_count=_CALL_COUNT, + fallback=True, + fallback_reason=reason, + ) + + +def _supports_evaluation_arg(fn: Any) -> bool: + try: + parameters = inspect.signature(fn).parameters.values() + except (TypeError, ValueError): + return False + return any( + param.name == "evaluation" or param.kind == inspect.Parameter.VAR_KEYWORD + for param in parameters + ) + + +def run_logp_probe(*, strict: bool | None = None) -> RLKernelProbeResult: + """Invoke one RL-Kernel logp operator on a small deterministic tensor. + + The probe intentionally uses synthetic tensors instead of Vime rollout + logits. Vime's rollout HTTP path exposes selected logprobs, not logits, so + this is an invocation proof rather than a rollout-logprob replacement. + """ + + if strict is None: + strict = _env_enabled(ENV_STRICT) + + try: + from rl_engine.kernels.registry import kernel_registry + + logits, token_ids = _probe_tensors() + op = kernel_registry.get_op("logp") + output = op(logits, token_ids) + global _CALL_COUNT + _CALL_COUNT += 1 + return RLKernelProbeResult( + enabled=True, + invoked=True, + call_count=_CALL_COUNT, + backend=op.__class__.__name__, + output_shape=tuple(output.shape), + output_sum=float(output.detach().float().sum().item()), + ) + except Exception as exc: + if strict: + raise + return _fallback(f"{type(exc).__name__}: {exc}") + + +def _record_probe(sample: Any, result: RLKernelProbeResult) -> None: + metadata = getattr(sample, "metadata", None) + if not isinstance(metadata, dict): + metadata = {} + setattr(sample, "metadata", metadata) + + rl_kernel_metadata = metadata.get(METADATA_KEY) + if not isinstance(rl_kernel_metadata, dict): + rl_kernel_metadata = {} + metadata[METADATA_KEY] = rl_kernel_metadata + + rl_kernel_metadata["vime_logp_probe"] = asdict(result) + + +async def custom_generate( + args: Any, + sample: Any, + sampling_params: dict[str, Any], + evaluation: bool = False, +) -> Any: + """Vime ``--custom-generate-function-path`` entry point. + + This shim preserves Vime's native generation path and only adds opt-in + RL-Kernel invocation evidence. Enable it with + ``RL_KERNEL_VIME_LOGP_PROBE=1``. + """ + + from vime.rollout.vllm_rollout import generate + + if _supports_evaluation_arg(generate): + sample = await generate(args, sample, sampling_params, evaluation=evaluation) + else: + sample = await generate(args, sample, sampling_params) + + if not _env_enabled(ENV_ENABLED): + _record_probe(sample, RLKernelProbeResult(enabled=False, invoked=False)) + return sample + + result = run_logp_probe(strict=_env_enabled(ENV_STRICT)) + _record_probe(sample, result) + return sample diff --git a/tests/test_vime_rollout_logp_probe.py b/tests/test_vime_rollout_logp_probe.py new file mode 100644 index 0000000..55e27b6 --- /dev/null +++ b/tests/test_vime_rollout_logp_probe.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import asyncio +import sys +import types +from dataclasses import dataclass, field + +import pytest + +from rl_engine.integrations.vime.rollout_logp_probe import ( + ENV_ENABLED, + ENV_STRICT, + METADATA_KEY, + custom_generate, + run_logp_probe, +) + + +@dataclass +class FakeSample: + tokens: list[int] = field(default_factory=list) + response: str = "" + response_length: int = 0 + metadata: dict = field(default_factory=dict) + + +def _install_fake_vime(monkeypatch, generate_func): + vime_mod = types.ModuleType("vime") + rollout_pkg = types.ModuleType("vime.rollout") + rollout_mod = types.ModuleType("vime.rollout.vllm_rollout") + rollout_mod.generate = generate_func + + monkeypatch.setitem(sys.modules, "vime", vime_mod) + monkeypatch.setitem(sys.modules, "vime.rollout", rollout_pkg) + monkeypatch.setitem(sys.modules, "vime.rollout.vllm_rollout", rollout_mod) + + +async def _native_generate(args, sample, sampling_params): + sample.tokens = [11, 12, 13] + sample.response = "native" + sample.response_length = 1 + return sample + + +async def _native_generate_with_evaluation(args, sample, sampling_params, evaluation=False): + sample.tokens = [21, 22, 23] + sample.response = "eval" if evaluation else "native" + sample.response_length = 1 + sample.metadata["evaluation"] = evaluation + return sample + + +async def _native_generate_with_kwargs(args, sample, sampling_params, **kwargs): + sample.tokens = [31, 32, 33] + sample.response = "kwargs" if kwargs.get("evaluation") else "native" + sample.response_length = 1 + sample.metadata["evaluation"] = kwargs.get("evaluation") + return sample + + +def test_run_logp_probe_invokes_dispatch_backend(): + result = run_logp_probe(strict=True) + + assert result.enabled is True + assert result.invoked is True + assert result.call_count >= 1 + assert result.op == "logp" + assert result.backend + assert result.output_shape == (2,) + assert isinstance(result.output_sum, float) + assert result.fallback is False + + +def test_custom_generate_records_disabled_probe_without_changing_native_sample( + monkeypatch, +): + monkeypatch.delenv(ENV_ENABLED, raising=False) + _install_fake_vime(monkeypatch, _native_generate) + + sample = asyncio.run(custom_generate(object(), FakeSample(), {"temperature": 1.0})) + + assert sample.response == "native" + assert sample.tokens == [11, 12, 13] + probe = sample.metadata[METADATA_KEY]["vime_logp_probe"] + assert probe["enabled"] is False + assert probe["invoked"] is False + + +def test_custom_generate_invokes_probe_when_enabled(monkeypatch): + monkeypatch.setenv(ENV_ENABLED, "1") + _install_fake_vime(monkeypatch, _native_generate_with_evaluation) + + sample = asyncio.run( + custom_generate(object(), FakeSample(), {"temperature": 1.0}, evaluation=True) + ) + + assert sample.response == "eval" + assert sample.metadata["evaluation"] is True + probe = sample.metadata[METADATA_KEY]["vime_logp_probe"] + assert probe["enabled"] is True + assert probe["invoked"] is True + assert probe["call_count"] >= 1 + assert probe["op"] == "logp" + assert probe["backend"] + assert probe["output_shape"] == (2,) + + +def test_custom_generate_passes_evaluation_to_kwargs_generate(monkeypatch): + monkeypatch.delenv(ENV_ENABLED, raising=False) + _install_fake_vime(monkeypatch, _native_generate_with_kwargs) + + sample = asyncio.run( + custom_generate(object(), FakeSample(), {"temperature": 1.0}, evaluation=True) + ) + + assert sample.response == "kwargs" + assert sample.metadata["evaluation"] is True + + +def test_probe_falls_back_when_registry_is_unavailable(monkeypatch): + monkeypatch.setenv(ENV_ENABLED, "1") + monkeypatch.setattr( + "rl_engine.kernels.registry.kernel_registry.get_op", + lambda _op_type: (_ for _ in ()).throw(RuntimeError("backend unavailable")), + ) + _install_fake_vime(monkeypatch, _native_generate) + + sample = asyncio.run(custom_generate(object(), FakeSample(), {"temperature": 1.0})) + + assert sample.response == "native" + probe = sample.metadata[METADATA_KEY]["vime_logp_probe"] + assert probe["enabled"] is True + assert probe["invoked"] is False + assert probe["fallback"] is True + assert "backend unavailable" in probe["fallback_reason"] + + +def test_custom_generate_preserves_existing_rl_kernel_metadata(monkeypatch): + monkeypatch.delenv(ENV_ENABLED, raising=False) + _install_fake_vime(monkeypatch, _native_generate) + sample = FakeSample(metadata={METADATA_KEY: {"existing": "keep"}}) + + sample = asyncio.run(custom_generate(object(), sample, {"temperature": 1.0})) + + assert sample.metadata[METADATA_KEY]["existing"] == "keep" + assert "vime_logp_probe" in sample.metadata[METADATA_KEY] + + +def test_custom_generate_handles_non_dict_metadata(monkeypatch): + monkeypatch.delenv(ENV_ENABLED, raising=False) + _install_fake_vime(monkeypatch, _native_generate) + sample = FakeSample(metadata=None) + + sample = asyncio.run(custom_generate(object(), sample, {"temperature": 1.0})) + + assert isinstance(sample.metadata, dict) + assert "vime_logp_probe" in sample.metadata[METADATA_KEY] + + +def test_probe_strict_mode_raises_on_backend_failure(monkeypatch): + monkeypatch.setenv(ENV_STRICT, "1") + monkeypatch.setattr( + "rl_engine.kernels.registry.kernel_registry.get_op", + lambda _op_type: (_ for _ in ()).throw(RuntimeError("strict failure")), + ) + + with pytest.raises(RuntimeError, match="strict failure"): + run_logp_probe()