Add batch-invariant deterministic CUDA logp#98
Conversation
Signed-off-by: inaniloquentee <3051000145@qq.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (8)
🚧 Files skipped from review as they are similar to previous changes (4)
📝 WalkthroughWalkthroughAdds a deterministic CUDA log-probability kernel with FP32/indexed variants, exposes PyBind11 bindings and type stubs, extends the kernel registry and operator wrapper, wires config/CLI and benchmarks to select deterministic backends, updates build sources/flags using environment helpers, and adds comprehensive deterministic tests. ChangesDeterministic Log-Probability Feature
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/executors/rollout.py`:
- Around line 43-46: The call to resolve_logp_op_type is being passed
require_batch_invariant=bool(self.config.get("require_batch_invariant_logp",
False)), which incorrectly coerces string values like "false" or "0" to True;
update the logic that reads self.config.get("require_batch_invariant_logp") in
rollout.py so it treats explicit booleans correctly and only returns True for
genuine truthy indicators (e.g., if the value is a bool use it directly,
otherwise normalize strings by lowercasing and compare to "true" or "1"), then
pass that normalized boolean to resolve_logp_op_type when assigning
self.logp_op_type.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4cff2109-e66b-4284-bf90-fda84b84ff29
📒 Files selected for processing (14)
benchmarks/benchmark_rl_kernels.pycsrc/deterministic_logp_kernel.cucsrc/ops.cppexamples/grpo_single_gpu.pyrl_engine/_C.pyirl_engine/executors/deepspeed_trainer.pyrl_engine/executors/rollout.pyrl_engine/executors/training_contract.pyrl_engine/kernels/ops/cuda/loss/logp.pyrl_engine/kernels/registry.pysetup.pytests/test_deepspeed_training_worker.pytests/test_deterministic_logp.pytests/test_vllm_rollout_sampler.py
Signed-off-by: inaniloquentee <3051000145@qq.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/test_deterministic_logp.py`:
- Around line 65-79: The tolerance currently derives only from output_dtype in
_assert_close_to_reference, which forces fp16/bf16-input tests to use fp32-level
tolerance; update the function to consider the input dtype (logits.dtype) as
well by computing input_tol = _dtype_tolerance(logits.dtype) and output_tol =
_dtype_tolerance(output_dtype) and set tolerance = max(input_tol, output_tol)
before the torch.allclose check; make the same change in the other analogous
assertion block(s) referenced (lines ~150-184) so input-aware tolerances are
used consistently; use the existing helpers _dtype_tolerance and keep the rest
of _assert_close_to_reference (and its counterparts) unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2c68736f-e4f2-4223-9b39-5fb0ca3961ce
📒 Files selected for processing (2)
rl_engine/kernels/ops/cuda/loss/logp.pytests/test_deterministic_logp.py
🚧 Files skipped from review as they are similar to previous changes (1)
- rl_engine/kernels/ops/cuda/loss/logp.py
Signed-off-by: inaniloquentee <3051000145@qq.com>
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/test_deterministic_logp.py (1)
449-449: 💤 Low valueConsider explicit dtype for consistency.
Line 449 creates
token_idswithout specifyingdtype=torch.long. While this works (integer lists default to Long), other tests in this file consistently specify dtype explicitly (lines 96, 176, 213). Addingdtype=torch.longwould improve consistency and make the intent clearer.✨ Consistency improvement
- token_ids = torch.tensor([0, vocab_size - 1, vocab_size // 2, vocab_size - 1, 0], device=device) + token_ids = torch.tensor([0, vocab_size - 1, vocab_size // 2, vocab_size - 1, 0], device=device, dtype=torch.long)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_deterministic_logp.py` at line 449, The test creates token_ids without an explicit dtype; update the token_ids tensor creation (the token_ids variable in tests/test_deterministic_logp.py) to include dtype=torch.long (e.g., torch.tensor([...], device=device, dtype=torch.long)) to match other tests and make the integer intent explicit and consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/test_deterministic_logp.py`:
- Line 449: The test creates token_ids without an explicit dtype; update the
token_ids tensor creation (the token_ids variable in
tests/test_deterministic_logp.py) to include dtype=torch.long (e.g.,
torch.tensor([...], device=device, dtype=torch.long)) to match other tests and
make the integer intent explicit and consistent.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 0eec0468-69e7-4287-87e1-6606cabea35c
📒 Files selected for processing (1)
tests/test_deterministic_logp.py
Signed-off-by: inaniloquentee <3051000145@qq.com>
|
Great work! Could you please add an issue to this PR? |
Flink-ddd
left a comment
There was a problem hiding this comment.
Solid, well-tested PR, Happy to re-review once the shared-memory fix lands.
| } | ||
| __syncthreads(); | ||
|
|
||
| val = threadIdx.x < WarpCount ? shared[lane] : kDeterministicLogpNegInf; |
There was a problem hiding this comment.
After the per-warp shuffle, the broadcast-back line reads shared[lane]:
val = threadIdx.x < WarpCount ? shared[lane] : kDeterministicLogpNegInf;
shared is sized WarpCount. With BlockSize=512, WarpCount=16 but lane ranges 0–31, so threads with lane >= 16 evaluate shared[16..31], which is out of bounds. The guarding condition is in the ternary, but both branches are evaluated, so the OOB shared-memory read still happens. This triggers on any vocab_size > 4096. Please index by the guarded value:
val = (threadIdx.x < WarpCount) ? shared[threadIdx.x] : kDeterministicLogpNegInf;
Same fix for the sum reducer (else branch 0.0f).
| cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
|
||
| if (vocab_size <= kDeterministicLogpSmallVocabLimit) { | ||
| deterministic_logp_forward_kernel< |
There was a problem hiding this comment.
In indexed mode multiple blocks can map to the same output[row] when row_indices contains duplicates. It's currently correct only because identical inputs produce bit-identical results, i.e. the write is idempotent. Please add a short comment documenting that duplicate/overlapping row indices are well-defined by idempotent writes, so a future change to a non-idempotent write path doesn't silently introduce a race.
| self, logits: torch.Tensor, token_ids: torch.Tensor, output: torch.Tensor | ||
| ) -> torch.Tensor: | ||
| logits_2d, token_ids_1d, orig_shape = self._prepare_inputs(logits, token_ids) | ||
| output_1d = self._prepare_output(output, orig_shape) |
There was a problem hiding this comment.
test_..._rejects_bad_shapes_and_output_dtype_cuda expects ValueError(match="output shape") on shape mismatch:
with pytest.raises(ValueError, match="output shape"):
op.out(logits, token_ids, torch.empty(2, 2, ...))
That error must be raised in _prepare_output. Please confirm _prepare_output validates output.shape == orig_shape and raises a ValueError containing "output shape" — otherwise this test will fail. (Not visible in the diff.)
|
|
||
| def is_fused_logp_backend(backend_name: str) -> bool: | ||
| return backend_name.startswith("FusedLogp") | ||
| return backend_name.startswith("FusedLogp") or backend_name == "DeterministicLogpCUDAOp" |
There was a problem hiding this comment.
String-name matching is brittle — renaming the class or adding a new deterministic variant (e.g. an SM90 deterministic op) will silently break this. Prefer a capability flag on the op:
class DeterministicLogpCUDAOp(FusedLogpGenericOp):
is_batch_invariant = True
def is_fused_logp_backend(op) -> bool:
return getattr(op, "is_batch_invariant", False) or isinstance(op, FusedLogpGenericOp)
| cc_major, cc_minor = torch.cuda.get_device_capability() | ||
| nvcc_flags = ["-O3", "--use_fast_math", "-Xfatbin", "-compress-all"] | ||
| nvcc_flags = ["-O3", "-Xfatbin", "-compress-all"] | ||
| if os.environ.get("KERNEL_ALIGN_USE_FAST_MATH") == "1": |
There was a problem hiding this comment.
This scatters raw env-var string parsing across the codebase (same pattern repeats for KERNEL_ALIGN_NCU_LINEINFO, KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC). I'd suggest centralizing these into an envs.py module like vLLM does, so every flag has one documented definition, consistent parsing, and a single import surface for call sites.
Create rl_engine/envs.py (or top-level envs.py next to setup.py):
# rl_engine/envs.py
import os
from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING:
KERNEL_ALIGN_USE_FAST_MATH: bool = False
KERNEL_ALIGN_NCU_LINEINFO: bool = False
KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC: bool = False
environment_variables: dict[str, Callable[[], Any]] = {
# Opt-in --use_fast_math for CUDA kernels. Off by default because it
# breaks bit-for-bit reproducibility required by the deterministic logp path.
"KERNEL_ALIGN_USE_FAST_MATH": lambda: bool(
int(os.getenv("KERNEL_ALIGN_USE_FAST_MATH", "0"))
),
"KERNEL_ALIGN_NCU_LINEINFO": lambda: bool(
int(os.getenv("KERNEL_ALIGN_NCU_LINEINFO", "0"))
),
# Windows-only escape hatch for unsupported MSVC/CUDA combinations.
"KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC": lambda: bool(
int(os.getenv("KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC", "0"))
),
}
def __getattr__(name: str):
if name in environment_variables:
return environment_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(environment_variables.keys())
Then call sites read it the same way vLLM does:
from rl_engine import envs
nvcc_flags = ["-O3", "-Xfatbin", "-compress-all"]
if envs.KERNEL_ALIGN_USE_FAST_MATH:
nvcc_flags.append("--use_fast_math")
| ) | ||
| if os.environ.get("KERNEL_ALIGN_NCU_LINEINFO") == "1": | ||
| nvcc_flags.append("-lineinfo") | ||
| if os.name == "nt" and os.environ.get("KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC") == "1": |
There was a problem hiding this comment.
same
if os.name == "nt" and envs.KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC:
nvcc_flags.append("-allow-unsupported-compiler")
nvcc_flags.append("-D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH")
One caveat: setup.py runs before the package is importable, so either keep envs.py import-safe with zero heavy deps (as above), or place it at repo root and import it directly in setup.py.
Sure! The issue is here: #96 |
Signed-off-by: inaniloquentee <3051000145@qq.com>
Signed-off-by: inaniloquentee <3051000145@qq.com>
Good catch. The previous reducer used shared[lane], which is unsafe when WarpCount < 32. I fixed both deterministic reducers to read from shared memory only when threadIdx.x < WarpCount; otherwise the lane uses the reduction identity value. I also added a regression/source-contract test to prevent this pattern from coming back. |
Summary
This PR adds an opt-in batch-invariant deterministic CUDA selected-token logprob backend for RL-Kernel. It is intended for RL train/eval paths where the same flattened row must produce the same selected logprob regardless of batch size, batch position, sparse indexing pattern, or unrelated rows in the launch.
Why
GRPO/PPO-style training is sensitive to tiny logprob drift: policy ratios and KL penalties can amplify rollout-time vs train-time differences. Existing high-throughput reductions may change floating-point accumulation order with launch shape or sparse indexing. This PR adds a CUDA path whose reduction topology is locked by vocab bucket, so a row's result is independent of the surrounding batch while avoiding the worst cost of one universal block size.
Implementation
csrc/deterministic_logp_kernel.cu.vocab_sizeonly:vocab_size <= 128usesBlockSize=128129 <= vocab_size <= 4096usesBlockSize=256vocab_size > 4096usesBlockSize=512atomicAdd, CUB block reductions, autotuned block sizes, and dynamic reduction topology.out, fp32, indexedout, and indexed fp32 bindings through_C.DeterministicLogpCUDAOpplus registry op types:logp_deterministiclogp_deterministic_indexedresolve_logp_op_type, includingdeterministic,deterministic_cuda, andbatch_invariant.examples/grpo_single_gpu.pythrough:--logp-backend--require-batch-invariant-logpdeterministic_fp32deterministic_indexed_fp32--use_fast_math; it is now opt-in withKERNEL_ALIGN_USE_FAST_MATH=1.KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC=1.Test Matrix
CUDA Shape / Dtype / Bucket Matrix
1, 2, 31, 32, 33, 127, 128, 129, 255, 256, 257, 1024, 4095, 4096, 4097, 4099, 8192log_softmax(...).gather(...)within dtype tolerance128, 129, 4096, 4097128/256/512block constants and128/4096vocab limits(1,1), (1,3), (2,5), (2,7), (3,4), (4,3), (4,5), (2,6), (2,3)paired with vocab matrixfp16,bf16when supported,fp32outdtypefp16,bf16when supported,fp32,fp64Determinism Matrix
1, 2, 4, 8, 160..732, unrelated rows randomized for seeds20..29Indexed / Edge-Case Matrix
indexed_outwith zero row indicesindexed_fp32with zero row indices>= vocab_sizeidsint32row indices passed directly to_Crow_indices must be int64Numerical-Stability Matrix
0.0409980.04099-80.04099[-80, 80]4099[80, -80]4099Validation Results
bucketed_deterministic_logp_cuda_production_matrix_passeddeterministic_logp_cuda_error_contract_passedpy -3.13 -m pytest rl_engine/tests/test_dispatch.py tests\test_grpo_single_gpu_example.py tests\test_vllm_rollout_sampler.py tests\test_deepspeed_training_worker.py tests\test_deterministic_logp.py tests\test_op_accuracy.py -q51 passed, 73 skippedpy -3.13 -m pytest tests\test_deterministic_logp.py -q1 passed, 73 skippedon py3.13, because CUDA extension is cp39 in this local envpy -3.13 -m pre_commit run --all-filespy -3.13 -m mypy --ignore-missing-imports rl_engine/py -3.13 -m compileall rl_engine tests\test_deterministic_logp.pypy -3.13 -m mkdocs build --strict -f mkdocs.yamlgit diff --checkpython benchmarks\benchmark_rl_kernels.py --candidate deterministic_fp32 --smoke --warmup 1 --repeat 2python benchmarks\benchmark_rl_kernels.py --candidate deterministic_indexed_fp32 --smoke --warmup 1 --repeat 2python benchmarks\benchmark_rl_kernels.py --candidate deterministic_fp32 --vocab-sizes 4097 --completion-lens 4 --g-sizes 2 --num-prompts 1 --mask-densities 1.0 --warmup 1 --repeat 2python benchmarks\benchmark_rl_kernels.py --candidate deterministic_indexed_fp32 --vocab-sizes 4097 --completion-lens 4 --g-sizes 2 --num-prompts 1 --mask-densities 1.0 --warmup 1 --repeat 2Summary by CodeRabbit
Release Notes
New Features
logp_backendand optional batch-invariant enforcement (require_batch_invariant_logp).Bug Fixes
Tests