Skip to content

Add batch-invariant deterministic CUDA logp#98

Open
inaniloquentee wants to merge 7 commits into
mainfrom
feat/deterministic-logp
Open

Add batch-invariant deterministic CUDA logp#98
inaniloquentee wants to merge 7 commits into
mainfrom
feat/deterministic-logp

Conversation

@inaniloquentee

@inaniloquentee inaniloquentee commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Summary

This PR adds an opt-in batch-invariant deterministic CUDA selected-token logprob backend for RL-Kernel. It is intended for RL train/eval paths where the same flattened row must produce the same selected logprob regardless of batch size, batch position, sparse indexing pattern, or unrelated rows in the launch.

Why

GRPO/PPO-style training is sensitive to tiny logprob drift: policy ratios and KL penalties can amplify rollout-time vs train-time differences. Existing high-throughput reductions may change floating-point accumulation order with launch shape or sparse indexing. This PR adds a CUDA path whose reduction topology is locked by vocab bucket, so a row's result is independent of the surrounding batch while avoiding the worst cost of one universal block size.

Implementation

  • Adds csrc/deterministic_logp_kernel.cu.
  • Uses one CUDA block per flattened logit row.
  • Locks deterministic reduction topology by vocab_size only:
    • vocab_size <= 128 uses BlockSize=128
    • 129 <= vocab_size <= 4096 uses BlockSize=256
    • vocab_size > 4096 uses BlockSize=512
  • Keeps bucket selection independent of batch size, sequence length, launch row count, sparse index count, and row position.
  • Uses explicit two-pass log-sum-exp:
    • fixed tree reduction for row max
    • fixed tree reduction for exp-sum
  • Avoids atomicAdd, CUB block reductions, autotuned block sizes, and dynamic reduction topology.
  • Exposes dense, out, fp32, indexed out, and indexed fp32 bindings through _C.
  • Adds DeterministicLogpCUDAOp plus registry op types:
    • logp_deterministic
    • logp_deterministic_indexed
  • Adds user-facing aliases via resolve_logp_op_type, including deterministic, deterministic_cuda, and batch_invariant.
  • Wires rollout/training config validation and examples/grpo_single_gpu.py through:
    • --logp-backend
    • --require-batch-invariant-logp
  • Adds benchmark candidates:
    • deterministic_fp32
    • deterministic_indexed_fp32
  • Removes default CUDA --use_fast_math; it is now opt-in with KERNEL_ALIGN_USE_FAST_MATH=1.
  • Adds Windows CUDA build escape hatch for unsupported MSVC combinations via KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC=1.

Test Matrix

CUDA Shape / Dtype / Bucket Matrix

Axis Values Covered Assertion Result
Vocab size 1, 2, 31, 32, 33, 127, 128, 129, 255, 256, 257, 1024, 4095, 4096, 4097, 4099, 8192 matches PyTorch log_softmax(...).gather(...) within dtype tolerance Pass
Bucket boundaries 128, 129, 4096, 4097 dense and full-row indexed results are byte-for-byte invariant across batch sizes and positions Pass
Block topology contract source checks for 128/256/512 block constants and 128/4096 vocab limits topology is fixed by vocab bucket and not autotuned Pass
Batch/sequence shapes (1,1), (1,3), (2,5), (2,7), (3,4), (4,3), (4,5), (2,6), (2,3) paired with vocab matrix output shape equals token id shape Pass
Input dtype fp16, bf16 when supported, fp32 fp32 deterministic output matches reference Pass
out dtype fp16, bf16 when supported, fp32, fp64 reuses caller storage and matches reference after cast Pass
Non-contiguous inputs strided logits and strided token ids wrapper handles non-contiguous tensors via reshape/contiguous prep Pass

Determinism Matrix

Case Coverage Assertion Result
Repeatability 20 repeated launches on same tensor byte-for-byte identical output Pass
Batch-size invariance same target row packed into batch sizes 1, 2, 4, 8, 16 target row byte-for-byte identical Pass
Batch-position invariance same target row placed at positions 0..7 target row byte-for-byte identical Pass
Batch-noise invariance target row embedded in batch size 32, unrelated rows randomized for seeds 20..29 target row byte-for-byte identical Pass
Dense/indexed parity sparse row sets including single row, subset, duplicate rows, and all rows indexed selected rows byte-match dense rows Pass

Indexed / Edge-Case Matrix

Case Coverage Expected Behavior Result
Empty indices indexed_out with zero row indices output remains unchanged Pass
Empty indices helper indexed_fp32 with zero row indices returns zero-filled fp32 output Pass
Duplicate indices duplicate sparse row list selected row value remains deterministic Pass
Unordered indices unordered sparse row list selected rows match dense results Pass
Inactive rows sentinel-filled output inactive rows preserve sentinel Pass
Invalid token ids negative and >= vocab_size ids invalid positions zero-fill, valid positions match reference Pass
Bad token shape token ids shape does not match logits leading shape raises Pass
Bad output shape output shape mismatch raises Pass
Unsupported output dtype integer output tensor raises Pass
Bad row index dtype at C++ boundary int32 row indices passed directly to _C raises row_indices must be int64 Pass
Out-of-range row indices negative and too-large row ids mixed with valid id valid row is written, invalid rows do not overwrite output Pass

Numerical-Stability Matrix

Logit Pattern Vocab Token Targets Assertion Result
all 0.0 4099 first token finite and matches reference Pass
all 80.0 4099 last token finite and matches reference Pass
all -80.0 4099 middle token finite and matches reference Pass
linearly increasing [-80, 80] 4099 last token finite and matches reference Pass
linearly decreasing [80, -80] 4099 first token finite and matches reference Pass

Validation Results

Command / Check Result
CUDA extension rebuild with CUDA 11.8 / VS BuildTools on Windows Pass
Direct cp39 CUDA production matrix bucketed_deterministic_logp_cuda_production_matrix_passed
Direct cp39 CUDA error contract deterministic_logp_cuda_error_contract_passed
py -3.13 -m pytest rl_engine/tests/test_dispatch.py tests\test_grpo_single_gpu_example.py tests\test_vllm_rollout_sampler.py tests\test_deepspeed_training_worker.py tests\test_deterministic_logp.py tests\test_op_accuracy.py -q 51 passed, 73 skipped
py -3.13 -m pytest tests\test_deterministic_logp.py -q 1 passed, 73 skipped on py3.13, because CUDA extension is cp39 in this local env
py -3.13 -m pre_commit run --all-files Pass
py -3.13 -m mypy --ignore-missing-imports rl_engine/ Pass
py -3.13 -m compileall rl_engine tests\test_deterministic_logp.py Pass
py -3.13 -m mkdocs build --strict -f mkdocs.yaml Pass
git diff --check Pass, with CRLF conversion warnings only
python benchmarks\benchmark_rl_kernels.py --candidate deterministic_fp32 --smoke --warmup 1 --repeat 2 Pass
python benchmarks\benchmark_rl_kernels.py --candidate deterministic_indexed_fp32 --smoke --warmup 1 --repeat 2 Pass
python benchmarks\benchmark_rl_kernels.py --candidate deterministic_fp32 --vocab-sizes 4097 --completion-lens 4 --g-sizes 2 --num-prompts 1 --mask-densities 1.0 --warmup 1 --repeat 2 Pass
python benchmarks\benchmark_rl_kernels.py --candidate deterministic_indexed_fp32 --vocab-sizes 4097 --completion-lens 4 --g-sizes 2 --num-prompts 1 --mask-densities 1.0 --warmup 1 --repeat 2 Pass

Summary by CodeRabbit

Release Notes

  • New Features

    • Added a deterministic log-probability backend with selectable logp_backend and optional batch-invariant enforcement (require_batch_invariant_logp).
    • Introduced deterministic candidates for benchmarking, plus backend selection support in the GRPO single-GPU example and rollout executor.
  • Bug Fixes

    • Fixed DeepSpeed training config initialization to run parent validation logic.
  • Tests

    • Added a CUDA-gated deterministic LogP test suite (correctness, bitwise repeatability, indexing semantics, and edge cases).
    • Added unit tests for configuration validation and backend alias handling.

Signed-off-by: inaniloquentee <3051000145@qq.com>
@coderabbitai

coderabbitai Bot commented Jun 11, 2026

Copy link
Copy Markdown

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a5743eeb-ef67-442f-bc3b-85dc4fd6cc97

📥 Commits

Reviewing files that changed from the base of the PR and between 7db5247 and f266afd.

📒 Files selected for processing (8)
  • csrc/deterministic_logp_kernel.cu
  • envs.py
  • examples/grpo_single_gpu.py
  • rl_engine/kernels/ops/cuda/loss/logp.py
  • setup.py
  • tests/test_deterministic_logp.py
  • tests/test_envs.py
  • tests/test_grpo_single_gpu_example.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • setup.py
  • examples/grpo_single_gpu.py
  • csrc/deterministic_logp_kernel.cu
  • tests/test_deterministic_logp.py

📝 Walkthrough

Walkthrough

Adds a deterministic CUDA log-probability kernel with FP32/indexed variants, exposes PyBind11 bindings and type stubs, extends the kernel registry and operator wrapper, wires config/CLI and benchmarks to select deterministic backends, updates build sources/flags using environment helpers, and adds comprehensive deterministic tests.

Changes

Deterministic Log-Probability Feature

Layer / File(s) Summary
CUDA Kernel Implementation
csrc/deterministic_logp_kernel.cu
Deterministic CUDA kernel with fixed block-size reduction topology (128/256/512 threads), warp-shuffle + shared-memory reductions, host input validation, launch wrapper dispatching over floating types, and public entry points supporting non-indexed/indexed execution with FP32 variants.
Environment flag utilities
envs.py, tests/test_envs.py
Helper module for parsing boolean environment variables with explicit string normalization, default handling, and ValueError for invalid values; environment variable name constants for build flags; parametrized tests validating flag parsing and error behavior.
C++ exports and type stubs
csrc/ops.cpp, rl_engine/_C.pyi
PyBind11 forward declarations and module exports for deterministic_logp* entry points (base, *_out, *_fp32, indexed variants); type stubs for compiled extension with fully-typed tensor signatures.
Registry and op-type resolution
rl_engine/kernels/registry.py
Adds OpBackend.CUDA_DETERMINISTIC_LOGP enum member; implements resolve_logp_op_type() for user-provided backend alias normalization with optional batch-invariant enforcement; extends priority-map entries for deterministic dispatch across CUDA/ROCm/CPU platforms.
Python operator wrapper and fused markers
rl_engine/kernels/ops/cuda/loss/logp.py
DeterministicLogpCUDAOp operator class binding to deterministic float32 kernels with call/apply/apply_fp32/out/indexed_out/indexed_fp32/online variants; adds is_fused_logp = True marker to FusedLogpSM90Op and FusedLogpGenericOp; replaces view with reshape in _prepare_inputs/_prepare_indices for proper tensor flattening.
Config and executor wiring
rl_engine/executors/training_contract.py, rl_engine/executors/rollout.py, rl_engine/executors/deepspeed_trainer.py
TorchRLTrainingConfig adds logp_backend and require_batch_invariant_logp fields with post_init validation; RolloutExecutor imports resolve_logp_op_type, computes and caches logp_op_type during init, and retrieves resolved operator from kernel_registry; DeepSpeedTrainingConfig calls parent post_init to ensure inherited validation.
Example CLI and operator resolution
examples/grpo_single_gpu.py
Adds --logp-backend and --require-batch-invariant-logp command-line arguments; updates resolve_logp_op() to accept backend parameters and delegate to registry-based resolution; rewrites is_fused_logp_backend() to check operator instance capability flag instead of string-name prefix.
Build configuration
setup.py
Dynamically loads envs.py helper module; adds csrc/deterministic_logp_kernel.cu to CUDA extension sources; makes NVCC flags (fast-math, lineinfo) and Windows MSVC allowances conditional via env_flag().
Benchmark candidate support
benchmarks/benchmark_rl_kernels.py
Adds deterministic_fp32 and deterministic_indexed_fp32 benchmark candidates with kernel op lookups, backend name mappings, required-backend validation, execution branches for both variants (indexed variant uses batch.valid_indices), and updated CLI argument choices.
Comprehensive tests
tests/test_deterministic_logp.py, tests/test_deepspeed_training_worker.py, tests/test_vllm_rollout_sampler.py, tests/test_grpo_single_gpu_example.py
Deterministic CUDA kernel tests covering shape/dtype matrix, bitwise repeatability, out-buffer storage/dtype behavior, non-contiguous inputs, batch-size/position invariance, boundary-vocab invariance, noise robustness, indexed equivalence/stability, indexed_out semantics, empty indices, invalid token ids, extreme logits, error handling, out-of-range index safety, source-level reduction contract, and reference tolerance. Config validation tests for require_batch_invariant_logp with deterministic/non-deterministic backends. Executor and CLI tests for backend alias resolution, batch-invariant validation, and capability flag detection.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

Suggested reviewers

  • Flink-ddd
  • EthanZero2Hero

Poem

🐰 In shuffled warps my kernels hum,
Each row's logp, deterministic, done,
No races here, just stable art,
From CUDA heat to Python's heart,
I hop, I bind, and pass the test with fun!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 4.26% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add batch-invariant deterministic CUDA logp' clearly and concisely describes the main feature addition—a new deterministic CUDA log-probability backend that is batch-invariant, which aligns with the extensive changes across the codebase.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/deterministic-logp

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/executors/rollout.py`:
- Around line 43-46: The call to resolve_logp_op_type is being passed
require_batch_invariant=bool(self.config.get("require_batch_invariant_logp",
False)), which incorrectly coerces string values like "false" or "0" to True;
update the logic that reads self.config.get("require_batch_invariant_logp") in
rollout.py so it treats explicit booleans correctly and only returns True for
genuine truthy indicators (e.g., if the value is a bool use it directly,
otherwise normalize strings by lowercasing and compare to "true" or "1"), then
pass that normalized boolean to resolve_logp_op_type when assigning
self.logp_op_type.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4cff2109-e66b-4284-bf90-fda84b84ff29

📥 Commits

Reviewing files that changed from the base of the PR and between 4a9ca42 and 3f44526.

📒 Files selected for processing (14)
  • benchmarks/benchmark_rl_kernels.py
  • csrc/deterministic_logp_kernel.cu
  • csrc/ops.cpp
  • examples/grpo_single_gpu.py
  • rl_engine/_C.pyi
  • rl_engine/executors/deepspeed_trainer.py
  • rl_engine/executors/rollout.py
  • rl_engine/executors/training_contract.py
  • rl_engine/kernels/ops/cuda/loss/logp.py
  • rl_engine/kernels/registry.py
  • setup.py
  • tests/test_deepspeed_training_worker.py
  • tests/test_deterministic_logp.py
  • tests/test_vllm_rollout_sampler.py

Comment thread rl_engine/executors/rollout.py
Signed-off-by: inaniloquentee <3051000145@qq.com>

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/test_deterministic_logp.py`:
- Around line 65-79: The tolerance currently derives only from output_dtype in
_assert_close_to_reference, which forces fp16/bf16-input tests to use fp32-level
tolerance; update the function to consider the input dtype (logits.dtype) as
well by computing input_tol = _dtype_tolerance(logits.dtype) and output_tol =
_dtype_tolerance(output_dtype) and set tolerance = max(input_tol, output_tol)
before the torch.allclose check; make the same change in the other analogous
assertion block(s) referenced (lines ~150-184) so input-aware tolerances are
used consistently; use the existing helpers _dtype_tolerance and keep the rest
of _assert_close_to_reference (and its counterparts) unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2c68736f-e4f2-4223-9b39-5fb0ca3961ce

📥 Commits

Reviewing files that changed from the base of the PR and between 3f44526 and 302386d.

📒 Files selected for processing (2)
  • rl_engine/kernels/ops/cuda/loss/logp.py
  • tests/test_deterministic_logp.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • rl_engine/kernels/ops/cuda/loss/logp.py

Comment thread tests/test_deterministic_logp.py
Signed-off-by: inaniloquentee <3051000145@qq.com>

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/test_deterministic_logp.py (1)

449-449: 💤 Low value

Consider explicit dtype for consistency.

Line 449 creates token_ids without specifying dtype=torch.long. While this works (integer lists default to Long), other tests in this file consistently specify dtype explicitly (lines 96, 176, 213). Adding dtype=torch.long would improve consistency and make the intent clearer.

✨ Consistency improvement
-    token_ids = torch.tensor([0, vocab_size - 1, vocab_size // 2, vocab_size - 1, 0], device=device)
+    token_ids = torch.tensor([0, vocab_size - 1, vocab_size // 2, vocab_size - 1, 0], device=device, dtype=torch.long)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_deterministic_logp.py` at line 449, The test creates token_ids
without an explicit dtype; update the token_ids tensor creation (the token_ids
variable in tests/test_deterministic_logp.py) to include dtype=torch.long (e.g.,
torch.tensor([...], device=device, dtype=torch.long)) to match other tests and
make the integer intent explicit and consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/test_deterministic_logp.py`:
- Line 449: The test creates token_ids without an explicit dtype; update the
token_ids tensor creation (the token_ids variable in
tests/test_deterministic_logp.py) to include dtype=torch.long (e.g.,
torch.tensor([...], device=device, dtype=torch.long)) to match other tests and
make the integer intent explicit and consistent.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 0eec0468-69e7-4287-87e1-6606cabea35c

📥 Commits

Reviewing files that changed from the base of the PR and between 302386d and 08c7369.

📒 Files selected for processing (1)
  • tests/test_deterministic_logp.py

Signed-off-by: inaniloquentee <3051000145@qq.com>
@Flink-ddd

Copy link
Copy Markdown
Collaborator

Great work! Could you please add an issue to this PR?

@Flink-ddd Flink-ddd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid, well-tested PR, Happy to re-review once the shared-memory fix lands.

Comment thread csrc/deterministic_logp_kernel.cu Outdated
}
__syncthreads();

val = threadIdx.x < WarpCount ? shared[lane] : kDeterministicLogpNegInf;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the per-warp shuffle, the broadcast-back line reads shared[lane]:

val = threadIdx.x < WarpCount ? shared[lane] : kDeterministicLogpNegInf;

shared is sized WarpCount. With BlockSize=512, WarpCount=16 but lane ranges 0–31, so threads with lane >= 16 evaluate shared[16..31], which is out of bounds. The guarding condition is in the ternary, but both branches are evaluated, so the OOB shared-memory read still happens. This triggers on any vocab_size > 4096. Please index by the guarded value:

val = (threadIdx.x < WarpCount) ? shared[threadIdx.x] : kDeterministicLogpNegInf;

Same fix for the sum reducer (else branch 0.0f).

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

if (vocab_size <= kDeterministicLogpSmallVocabLimit) {
deterministic_logp_forward_kernel<

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In indexed mode multiple blocks can map to the same output[row] when row_indices contains duplicates. It's currently correct only because identical inputs produce bit-identical results, i.e. the write is idempotent. Please add a short comment documenting that duplicate/overlapping row indices are well-defined by idempotent writes, so a future change to a non-idempotent write path doesn't silently introduce a race.

self, logits: torch.Tensor, token_ids: torch.Tensor, output: torch.Tensor
) -> torch.Tensor:
logits_2d, token_ids_1d, orig_shape = self._prepare_inputs(logits, token_ids)
output_1d = self._prepare_output(output, orig_shape)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_..._rejects_bad_shapes_and_output_dtype_cuda expects ValueError(match="output shape") on shape mismatch:

with pytest.raises(ValueError, match="output shape"):
    op.out(logits, token_ids, torch.empty(2, 2, ...))

That error must be raised in _prepare_output. Please confirm _prepare_output validates output.shape == orig_shape and raises a ValueError containing "output shape" — otherwise this test will fail. (Not visible in the diff.)

Comment thread examples/grpo_single_gpu.py Outdated

def is_fused_logp_backend(backend_name: str) -> bool:
return backend_name.startswith("FusedLogp")
return backend_name.startswith("FusedLogp") or backend_name == "DeterministicLogpCUDAOp"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String-name matching is brittle — renaming the class or adding a new deterministic variant (e.g. an SM90 deterministic op) will silently break this. Prefer a capability flag on the op:

class DeterministicLogpCUDAOp(FusedLogpGenericOp):
    is_batch_invariant = True

def is_fused_logp_backend(op) -> bool:
    return getattr(op, "is_batch_invariant", False) or isinstance(op, FusedLogpGenericOp)

Comment thread setup.py Outdated
cc_major, cc_minor = torch.cuda.get_device_capability()
nvcc_flags = ["-O3", "--use_fast_math", "-Xfatbin", "-compress-all"]
nvcc_flags = ["-O3", "-Xfatbin", "-compress-all"]
if os.environ.get("KERNEL_ALIGN_USE_FAST_MATH") == "1":

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This scatters raw env-var string parsing across the codebase (same pattern repeats for KERNEL_ALIGN_NCU_LINEINFO, KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC). I'd suggest centralizing these into an envs.py module like vLLM does, so every flag has one documented definition, consistent parsing, and a single import surface for call sites.

Create rl_engine/envs.py (or top-level envs.py next to setup.py):

# rl_engine/envs.py
import os
from typing import TYPE_CHECKING, Any, Callable

if TYPE_CHECKING:
    KERNEL_ALIGN_USE_FAST_MATH: bool = False
    KERNEL_ALIGN_NCU_LINEINFO: bool = False
    KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC: bool = False

environment_variables: dict[str, Callable[[], Any]] = {
    # Opt-in --use_fast_math for CUDA kernels. Off by default because it
    # breaks bit-for-bit reproducibility required by the deterministic logp path.
    "KERNEL_ALIGN_USE_FAST_MATH": lambda: bool(
        int(os.getenv("KERNEL_ALIGN_USE_FAST_MATH", "0"))
    ),
    "KERNEL_ALIGN_NCU_LINEINFO": lambda: bool(
        int(os.getenv("KERNEL_ALIGN_NCU_LINEINFO", "0"))
    ),
    # Windows-only escape hatch for unsupported MSVC/CUDA combinations.
    "KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC": lambda: bool(
        int(os.getenv("KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC", "0"))
    ),
}

def __getattr__(name: str):
    if name in environment_variables:
        return environment_variables[name]()
    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

def __dir__():
    return list(environment_variables.keys())

Then call sites read it the same way vLLM does:

from rl_engine import envs

nvcc_flags = ["-O3", "-Xfatbin", "-compress-all"]
if envs.KERNEL_ALIGN_USE_FAST_MATH:
    nvcc_flags.append("--use_fast_math")

Comment thread setup.py Outdated
)
if os.environ.get("KERNEL_ALIGN_NCU_LINEINFO") == "1":
nvcc_flags.append("-lineinfo")
if os.name == "nt" and os.environ.get("KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC") == "1":

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

if os.name == "nt" and envs.KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC:
    nvcc_flags.append("-allow-unsupported-compiler")
    nvcc_flags.append("-D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH")

One caveat: setup.py runs before the package is importable, so either keep envs.py import-safe with zero heavy deps (as above), or place it at repo root and import it directly in setup.py.

@inaniloquentee

Copy link
Copy Markdown
Collaborator Author

Great work! Could you please add an issue to this PR?

Sure! The issue is here: #96

Signed-off-by: inaniloquentee <3051000145@qq.com>
Signed-off-by: inaniloquentee <3051000145@qq.com>
@inaniloquentee

Copy link
Copy Markdown
Collaborator Author

Solid, well-tested PR, Happy to re-review once the shared-memory fix lands.

Good catch. The previous reducer used shared[lane], which is unsafe when WarpCount < 32.

I fixed both deterministic reducers to read from shared memory only when threadIdx.x < WarpCount; otherwise the lane uses the reduction identity value. I also added a regression/source-contract test to prevent this pattern from coming back.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants