Skip to content

[FEAT][kernels] Fused linear log-prob without materializing logits — SM90 TMA + tensor-core forward, chunked backward#122

Open
KJLdefeated wants to merge 7 commits into
RL-Align:mainfrom
KJLdefeated:feat/fused-linear-logp
Open

[FEAT][kernels] Fused linear log-prob without materializing logits — SM90 TMA + tensor-core forward, chunked backward#122
KJLdefeated wants to merge 7 commits into
RL-Align:mainfrom
KJLdefeated:feat/fused-linear-logp

Conversation

@KJLdefeated

@KJLdefeated KJLdefeated commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Summary

Fused linear log-prob op (CUDA(SM90) + Triton + native reference + chunked backward landed earlier on the branch).

Backend Status
CUDA SM90 (Hopper) TMA-streamed, double-buffered, register-blocked tensor-core forward (ldmatrix + mma.sync.m16n8k16, fp32 accum) + split-V; online softmax.
CUDA / ROCm (Triton) Portable online-softmax forward + Liger-style chunked backward — the semantic baseline and tolerance target.
PyTorch native F.linear + log_softmax + gather reference / CPU fallback.

The SM90 op requires bf16 hidden/weight with D % 32 == 0; for any other input(fp32/fp16, awkward D, CPU) it transparently falls back to Triton (else native), so it stays a drop-in.

Implementation (SM90 forward)

  • TMA (cp.async.bulk.tensor, mbarrier-completed) streams H/W tiles into shared memory, double-buffered so the next hidden slice loads while the current one feeds the MMAs.
  • Register-blocked M-tiling (BM = 256 rows/CTA) so the weight matrix is re-read only N/BM times from HBM.
  • Split-V: the vocab loop is partitioned across blockIdx.y so the grid fills all SMs even when N/BM alone is too few CTAs; each split emits a partial online-softmax state and a tiny combine kernel merges them (standard log-sum-exp merge). Adds only 3·n_split·N floats of scratch — independent of V.

Performance (H100, bf16, N=4096, D=2048)

Forward latency (ms) and peak forward VRAM:

shape (N×D×V) native Triton SM90 SM90 vs Triton peak fwd VRAM (native → fused)
4096×2048×32768 1.79 6.42 3.41 1.88× 1280 MB → ~0 MB
4096×2048×50257 9.96 9.82 4.88 2.01× 1965 MB → ~0 MB
4096×2048×131072 7.28 25.56 12.88 1.98× 5120 MB → ~0 MB

Forward + backward latency (ms):

shape (N×D×V) native Triton SM90
4096×2048×32768 4.25 15.86 12.69
4096×2048×50257 23.29 47.20 42.20
4096×2048×131072 17.05 117.62 104.97

python benchmarks/benchmark_linear_logp.py. Forward peak memory is the per-CTA shared-memory tiles — independent of V.

Correctness / tests

python -m pytest tests/test_linear_logp.py — 17/17 pass.

Bug fix: fused_logp_sm90.cu

While enabling the SM90 build I found the existing TMA online-softmax kernel (csrc/cuda/fused_logp_sm90.cu, the materialized-logits fused_logp) fails to compile whenever the SM90 path is actually built — two latent, version-independent errors that were dormant because the prebuilt extension had no SM90 symbols (the path
had never been exercised).

  1. CUDART_INF_F undefined — the kernel uses -CUDART_INF_F but never includes <math_constants.h> (not pulled in transitively).

    fused_logp_sm90.cu(62): error: identifier "CUDART_INF_F" is undefined
    

    Fix: #include <math_constants.h>.

  2. TmaTypeTraits<c10::BFloat16> is an incomplete type — the host wrapper passes logits.data_ptr<at::BFloat16>() (i.e. c10::BFloat16*) straight into init_tensor_map, but TmaTypeTraits is only specialized for nv_bfloat16/float,
    so template instantiation fails.

    tma_utils.cuh: error: incomplete type "TmaTypeTraits<c10::BFloat16>" is not allowed
    

    Fix: reinterpret_cast<const nv_bfloat16*>(logits.data_ptr<at::BFloat16>())(matching how the linear kernel and the rest of the codebase bridge at::BFloat16nv_bfloat16).

Summary by CodeRabbit

Release Notes

  • New Features

    • Added a CUDA SM90 fused Linear LogP backend (with optional bias) and automatic dispatch across fused, Triton, and native implementations.
    • Added a CUDA-only benchmark script to compare latency, forward+backward performance, and peak VRAM across configurable shapes (optionally including fused SM90 results when available).
  • Documentation

    • Added Linear LogP operator documentation, including semantics, constraints, supported backends, and benchmark guidance.
  • Tests

    • Added comprehensive correctness, gradient, fallback, error-handling, and dispatch tests covering native, Triton, and SM90 fused backends (with hardware-dependent skips).

KJLdefeated and others added 3 commits June 14, 2026 13:20
… backward + SM90 TMA/WGMMA)

Implements `linear_logp`: log_softmax(hidden @ W^T + b)[target] without
materializing the [N, V] logits, differentiable w.r.t. hidden/weight/bias.

- PyTorch `NativeLinearLogpOp`: naive F.linear + log_softmax + gather reference.
- Triton `TritonLinearLogpOp`: online-softmax forward (zero [N,V] materialization,
  tensor cores via native-dtype tl.dot), Liger-style chunked backward (cuBLAS
  matmuls, sequential grad_weight accumulation -> deterministic, peak mem chunk*V).
- CUDA SM90 `FusedLinearLogpSM90Op`: TMA + WGMMA streaming forward + smem online
  softmax (csrc/cuda/fused_linear_logp_sm90.cu); build-guarded behind
  KERNEL_ALIGN_FORCE_SM90, registry-gated to cc_major==9. Assembles for sm_90a
  under CUDA 13.1; numerics pending validation on Hopper.
- Registry dispatch, tests, benchmark, operator + design docs.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@coderabbitai

coderabbitai Bot commented Jun 15, 2026

Copy link
Copy Markdown

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ebcffb91-012c-4dfa-ac90-8cf76ce0c04c

📥 Commits

Reviewing files that changed from the base of the PR and between bc93e36 and dc3cd73.

📒 Files selected for processing (3)
  • csrc/cuda/fused_linear_logp_sm90.cu
  • rl_engine/kernels/ops/cuda/loss/linear_logp.py
  • tests/test_linear_logp.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • csrc/cuda/fused_linear_logp_sm90.cu
  • rl_engine/kernels/ops/cuda/loss/linear_logp.py
  • tests/test_linear_logp.py

📝 Walkthrough

Walkthrough

Adds a new linear_logp operator that computes per-token selected log-probabilities from hidden states and LM-head weights without materializing full [N, V] logit tensors. Three backends are introduced: a PyTorch native reference, a Triton fused kernel with online-softmax streaming, and an SM90 Hopper CUDA kernel using TMA-based tiling and tensor-core MMA. All backends are wired into the KernelRegistry hardware-priority dispatch, with a comprehensive test suite, benchmark script, and documentation.

Changes

Fused Linear LogP operator

Layer / File(s) Summary
Native PyTorch reference op
rl_engine/kernels/ops/pytorch/loss/linear_logp.py
Adds NativeLinearLogpOp: validates tensor shape compatibility, flattens leading batch dimensions, computes F.linear logits (with optional bias), applies log_softmax, gathers selected log-probabilities at target indices, and reshapes to original batch shape. Serves as correctness oracle for backend validation.
Triton fused kernel and autograd wrapper
rl_engine/kernels/ops/triton/loss/__init__.py, rl_engine/kernels/ops/triton/loss/linear_logp.py
Adds _linear_logp_fwd_kernel Triton kernel streaming vocab tiles with online log-softmax, _LinearLogpFunction autograd wrapper with chunked recompute backward, and TritonLinearLogpOp callable interface with CUDA validation and shape checks.
SM90 Hopper CUDA fused kernel
csrc/cuda/fused_linear_logp_sm90.cu
Adds fused_linear_logp_sm90_kernel using TMA double-buffered shared-memory tiling, bf16 tensor-core MMA, and per-row online softmax tracking partials across vocab splits. fused_linear_logp_sm90_combine_kernel reduces partials to final logp and lse. Helper init_tensor_map_noswizzle constructs TMA maps without swizzle. Exported fused_linear_logp_sm90_forward validates inputs, manages shared-memory, launches both kernels.
SM90 Python wrapper, C++ binding, and build integration
rl_engine/kernels/ops/cuda/loss/linear_logp.py, csrc/ops.cpp, csrc/cuda/fused_logp_sm90.cu, setup.py
Adds _FusedLinearLogpSM90Function (fused bf16 forward via _C.fused_linear_logp_sm90, chunked recompute backward) and FusedLinearLogpSM90Op with _sm90_supported gating and fallback to Triton/Native. Registers fused_linear_logp_sm90 PyBind11 binding. Fixes missing <math_constants.h> include and adds explicit reinterpret_cast in fused_logp_sm90.cu. Refactors setup.py SM90 source discovery.
KernelRegistry hardware-aware dispatch
rl_engine/kernels/registry.py
Extends OpBackend with CUDA_FUSED_LINEAR_LOGP_SM90, TRITON_LINEAR_LOGP, and PYTORCH_LINEAR_LOGP. Adds linear_logp to dispatch tables for CUDA, ROCm, and CPU. Probes for SM90 extension symbol and prepends SM90 backend for Hopper devices (cc_major == 9).
Comprehensive test suite
tests/test_linear_logp.py
Tests native correctness vs manual oracle, Triton forward/backward parity with bias variants, SM90 forward/backward parity vs Triton, fallback behavior for unsupported inputs, leading-shape preservation, large-vocab smoke tests, error handling for mismatched shapes/devices, and registry dispatch.
Benchmark script and operator documentation
benchmarks/benchmark_linear_logp.py, docs/operators/linear-logp.md, docs/operators/README.md, docs/.nav.yml
Adds CLI benchmark comparing Native/Triton/SM90 for forward latency, forward+backward latency, and peak VRAM. Adds comprehensive operator documentation covering API, tensor contract, reference semantics, backend availability, performance tables, and test coverage. Updates doc navigation.

Sequence Diagram(s)

sequenceDiagram
  participant App
  participant KernelRegistry
  participant BackendSelector
  participant FusedSM90Op
  participant TritonOp
  participant NativeOp

  App->>KernelRegistry: get_op("linear_logp")
  KernelRegistry->>BackendSelector: probe device and extension symbols
  
  alt CUDA device with cc_major==9 and SM90 extension
    BackendSelector-->>KernelRegistry: FusedLinearLogpSM90Op has priority
    KernelRegistry-->>App: FusedLinearLogpSM90Op instance
    App->>FusedSM90Op: apply(hidden_bf16, weight, target, bias)
    FusedSM90Op-->>App: logp tensor
  else CUDA available, Triton compiled
    BackendSelector-->>KernelRegistry: TritonLinearLogpOp fallback
    KernelRegistry-->>App: TritonLinearLogpOp instance
    App->>TritonOp: apply(hidden, weight, target, bias)
    TritonOp-->>App: logp tensor
  else CPU or others
    BackendSelector-->>KernelRegistry: NativeLinearLogpOp fallback
    KernelRegistry-->>App: NativeLinearLogpOp instance
    App->>NativeOp: apply(hidden, weight, target, bias)
    NativeOp-->>App: logp tensor
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related issues

  • Implements the feature described in [Feat][Kernel] Fused CE LogProb without materializing logits + fused backward #155: provides a fused CE LogProb operator that avoids materializing large logits tensors by using SM90 TMA/WGMMA streaming with online softmax in the forward path, and a fused backward that recomputes logits in chunks to trade compute for memory, along with a Triton portable baseline implementation.

Suggested reviewers

  • inaniloquentee
  • Flink-ddd

Poem

🐇 Hop, hop through the vocab tiles so wide,
No giant logit matrix left to hide!
TMA streams bf16 at Hopper speed,
Online softmax — just the log-prob we need.
Three backends bound, the registry knows best,
This rabbit ships fused kernels with panache! 🚀

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely describes the main addition: a fused linear log-probability operator optimized for SM90 using TMA and tensor cores with chunked backward computation.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
rl_engine/kernels/ops/triton/loss/linear_logp.py (1)

174-174: 💤 Low value

Optional: Use tuple unpacking for clarity.

Per static analysis (RUF005), prefer (*ctx.lead_shape, d) over ctx.lead_shape + (d,).

♻️ Suggested fix
-        grad_hidden = grad_h.to(ctx.hidden_dtype).reshape(ctx.lead_shape + (d,))
+        grad_hidden = grad_h.to(ctx.hidden_dtype).reshape((*ctx.lead_shape, d))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/triton/loss/linear_logp.py` at line 174, In the
grad_hidden assignment on the line with the reshape call, replace the tuple
concatenation syntax ctx.lead_shape + (d,) with tuple unpacking syntax
(*ctx.lead_shape, d) to improve code clarity and comply with the RUF005 style
recommendation.

Source: Linters/SAST tools

tests/test_linear_logp.py (1)

26-33: ⚡ Quick win

Narrow SM90 availability exception handling to avoid silently skipping real failures.

Catching Exception here can hide real extension/runtime breakages as test skips. Catch only expected “unavailable” failures so unexpected errors still fail fast.

Suggested patch
 def _sm90_available():
@@
-    except Exception:  # pragma: no cover
+    except (ImportError, AttributeError, OSError):  # pragma: no cover
         return False
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_linear_logp.py` around lines 26 - 33, The broad Exception catch in
the function (around the import and availability check) masks real failures like
import errors or runtime issues. Replace the generic Exception handler with
specific exception types that indicate the extension is genuinely unavailable,
such as ImportError or ModuleNotFoundError, while allowing unexpected errors to
propagate and fail the test setup appropriately.

Source: Linters/SAST tools

benchmarks/benchmark_linear_logp.py (1)

107-115: ⚡ Quick win

Bind loop-scoped tensors directly in benchmark closures.

The current closure pattern relies on loop-scope capture (target, hidden, weight) and is flagged by B023. Binding them as default args removes late-binding risk and future-proofs the timing helpers.

Suggested patch
-        def fwd(op, h=hidden, w=weight):
+        def fwd(op, h=hidden, w=weight, t=target):
             with torch.no_grad():
-                op(h, w, target)
+                op(h, w, t)
 
-        def fwd_bwd(op):
-            h = hidden.clone().requires_grad_(True)
-            w = weight.clone().requires_grad_(True)
-            op(h, w, target).sum().backward()
+        def fwd_bwd(op, h_src=hidden, w_src=weight, t=target):
+            h = h_src.clone().requires_grad_(True)
+            w = w_src.clone().requires_grad_(True)
+            op(h, w, t).sum().backward()
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/benchmark_linear_logp.py` around lines 107 - 115, The closures fwd
and fwd_bwd are relying on late-binding capture of loop-scoped tensors,
particularly target, which creates a B023 linting issue. Fix this by adding
target as a default argument to both function definitions: add target=target to
the parameter list of the fwd function (alongside the existing h=hidden,
w=weight defaults), and add target=target as a default argument to the fwd_bwd
function definition. This binds the loop-scoped variable directly at closure
definition time rather than relying on late binding.

Source: Linters/SAST tools

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/operators/linear-logp.md`:
- Around line 68-72: In the code snippet showing the linear operation, the bias
parameter passed to torch.nn.functional.linear is not being cast to float, which
is inconsistent with the fp32-upcast reference implementation. Update the bias
argument in the torch.nn.functional.linear call to conditionally cast it to
float using the pattern: None if bias is None else bias.float(). This ensures
the documentation snippet matches the actual reference implementation behavior
for bf16/fp16 inputs.

In `@rl_engine/kernels/ops/cuda/loss/linear_logp.py`:
- Around line 22-29: The _sm90_supported function is too permissive and can
route unsupported inputs to the fused kernel, causing runtime errors. Add
stricter eligibility checks: verify that the GPU is actually SM90-capable
(Hopper or newer) using torch.cuda.get_device_capability() or similar device
properties, and ensure that both the hidden and lm_head_weight tensors are on
the same CUDA device by checking their device attribute matches. These checks
should be added to the existing dtype and tensor dimension validations in the
_sm90_supported predicate.

In `@setup.py`:
- Around line 118-124: The current condition `if enable_sm90 and present_sm90:`
enables the KERNEL_ALIGN_WITH_SM90 macro whenever any SM90 source exists, but
since csrc/ops.cpp registers SM90 entry points under this macro, enabling it
with only a partial set of SM90 sources will cause unresolved symbols at link
time. Modify the condition to only enable KERNEL_ALIGN_WITH_SM90 and append the
related link flags when the complete set of SM90 sources is present (not just a
subset), by verifying that present_sm90 contains all the expected sources from
sm90_srcs.

---

Nitpick comments:
In `@benchmarks/benchmark_linear_logp.py`:
- Around line 107-115: The closures fwd and fwd_bwd are relying on late-binding
capture of loop-scoped tensors, particularly target, which creates a B023
linting issue. Fix this by adding target as a default argument to both function
definitions: add target=target to the parameter list of the fwd function
(alongside the existing h=hidden, w=weight defaults), and add target=target as a
default argument to the fwd_bwd function definition. This binds the loop-scoped
variable directly at closure definition time rather than relying on late
binding.

In `@rl_engine/kernels/ops/triton/loss/linear_logp.py`:
- Line 174: In the grad_hidden assignment on the line with the reshape call,
replace the tuple concatenation syntax ctx.lead_shape + (d,) with tuple
unpacking syntax (*ctx.lead_shape, d) to improve code clarity and comply with
the RUF005 style recommendation.

In `@tests/test_linear_logp.py`:
- Around line 26-33: The broad Exception catch in the function (around the
import and availability check) masks real failures like import errors or runtime
issues. Replace the generic Exception handler with specific exception types that
indicate the extension is genuinely unavailable, such as ImportError or
ModuleNotFoundError, while allowing unexpected errors to propagate and fail the
test setup appropriately.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 256fc756-1f5a-457f-ab20-239d81d61a10

📥 Commits

Reviewing files that changed from the base of the PR and between e3446bc and 6c366c6.

📒 Files selected for processing (14)
  • benchmarks/benchmark_linear_logp.py
  • csrc/cuda/fused_linear_logp_sm90.cu
  • csrc/cuda/fused_logp_sm90.cu
  • csrc/ops.cpp
  • docs/.nav.yml
  • docs/operators/README.md
  • docs/operators/linear-logp.md
  • rl_engine/kernels/ops/cuda/loss/linear_logp.py
  • rl_engine/kernels/ops/pytorch/loss/linear_logp.py
  • rl_engine/kernels/ops/triton/loss/__init__.py
  • rl_engine/kernels/ops/triton/loss/linear_logp.py
  • rl_engine/kernels/registry.py
  • setup.py
  • tests/test_linear_logp.py

Comment on lines +68 to +72
```python
logits = torch.nn.functional.linear(hidden.float(), weight.float(), bias) # [N, V]
logp = torch.log_softmax(logits, dim=-1)
out = logp.gather(-1, target_ids.long().unsqueeze(-1)).squeeze(-1)
```

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, locate and read the file to verify the current state
fd -type f -name "linear-logp.md" | head -5

Repository: RL-Align/RL-Kernel

Length of output: 233


🏁 Script executed:

# Check the structure to understand the repository
git ls-files | grep -E "(docs/operators|test)" | head -20

Repository: RL-Align/RL-Kernel

Length of output: 631


🏁 Script executed:

# Search for bias casting patterns in the codebase
rg "bias\.float\(\)" --type py -B 2 -A 2

Repository: RL-Align/RL-Kernel

Length of output: 1680


🏁 Script executed:

# Read the specific file at the mentioned lines
sed -n '60,80p' docs/operators/linear-logp.md

Repository: RL-Align/RL-Kernel

Length of output: 818


🏁 Script executed:

# Get more context around those lines
cat -n docs/operators/linear-logp.md | sed -n '50,90p'

Repository: RL-Align/RL-Kernel

Length of output: 1969


🏁 Script executed:

# Check the complete test file to understand the reference implementation
cat -n tests/test_linear_logp.py | head -50

Repository: RL-Align/RL-Kernel

Length of output: 1893


🏁 Script executed:

# Check if bias can be None and understand the full reference implementation pattern
grep -B 5 -A 10 "def native" tests/test_linear_logp.py

Repository: RL-Align/RL-Kernel

Length of output: 44


🏁 Script executed:

# Look for the full test function to understand the bf16/fp16 test context
grep -B 3 -A 8 "bfloat16\|fp16" tests/test_linear_logp.py

Repository: RL-Align/RL-Kernel

Length of output: 2378


🏁 Script executed:

# Get the complete test function to see all the patterns
cat -n tests/test_linear_logp.py | sed -n '70,140p'

Repository: RL-Align/RL-Kernel

Length of output: 3427


🏁 Script executed:

# Check if there are any other reference implementations or examples
rg "NativeLinearLogpOp\|F\.linear" docs/ --context 3

Repository: RL-Align/RL-Kernel

Length of output: 44


🏁 Script executed:

# Verify the tensor contract to confirm bias is optional
sed -n '55,65p' docs/operators/linear-logp.md

Repository: RL-Align/RL-Kernel

Length of output: 632


🏁 Script executed:

# Check if the _inputs function returns None for bias when bias=False
sed -n '55,75p' tests/test_linear_logp.py

Repository: RL-Align/RL-Kernel

Length of output: 923


🏁 Script executed:

# One more check - see the _inputs function to confirm bias parameter behavior
sed -n '55,75p' tests/test_linear_logp.py

Repository: RL-Align/RL-Kernel

Length of output: 923


Cast bias to float for consistency with the fp32-upcast reference path.

The reference snippet claims to match the "fp32-upcast reference" for bf16/fp16 inputs (line 76), but doesn't cast bias. The actual reference implementation in tests (line 77 of test_linear_logp.py) conditionally casts bias: None if bias is None else bias.float(). Update the snippet to match.

Suggested patch
-logits = torch.nn.functional.linear(hidden.float(), weight.float(), bias)  # [N, V]
+logits = torch.nn.functional.linear(
+    hidden.float(),
+    weight.float(),
+    None if bias is None else bias.float(),
+)  # [N, V]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/operators/linear-logp.md` around lines 68 - 72, In the code snippet
showing the linear operation, the bias parameter passed to
torch.nn.functional.linear is not being cast to float, which is inconsistent
with the fp32-upcast reference implementation. Update the bias argument in the
torch.nn.functional.linear call to conditionally cast it to float using the
pattern: None if bias is None else bias.float(). This ensures the documentation
snippet matches the actual reference implementation behavior for bf16/fp16
inputs.

Comment thread rl_engine/kernels/ops/cuda/loss/linear_logp.py
Comment thread setup.py

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/test_linear_logp.py (1)

287-296: 💤 Low value

ROCm device handling gap in registry dispatch test.

The device selection logic at line 292 treats ROCm as CPU since "rocm" != "cuda". On a ROCm system, kernel_registry.get_op("linear_logp") returns a ROCm backend (Triton or Native per the registry's "rocm" priority list), but the test would create CPU tensors, causing either a device mismatch or testing the wrong code path.

Suggested fix to handle ROCm
 def test_registry_dispatch_matches_native():
     from rl_engine.kernels.registry import kernel_registry
     from rl_engine.platforms.device import device_ctx

     op = kernel_registry.get_op("linear_logp")
-    device = device_ctx.device if device_ctx.device_type == "cuda" else "cpu"
+    device = device_ctx.device if device_ctx.device_type in ("cuda", "rocm") else "cpu"
     hidden, weight, target, bias = _inputs(6, device=device)
     out = op(hidden, weight, target, bias)
     ref = NativeLinearLogpOp()(hidden, weight, target, bias)
     assert torch.allclose(out.cpu(), ref.cpu(), atol=1e-3)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_linear_logp.py` around lines 287 - 296, The device selection logic
in test_registry_dispatch_matches_native does not account for ROCm devices. The
current condition at line 292 treats ROCm as CPU since the check only looks for
"cuda" device type, causing a mismatch when
kernel_registry.get_op("linear_logp") returns a ROCm backend but the test
creates CPU tensors. Update the device selection logic to check for both "cuda"
and "rocm" device types, assigning device_ctx.device to the device variable when
either is present, and only default to "cpu" for other cases.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/test_linear_logp.py`:
- Around line 287-296: The device selection logic in
test_registry_dispatch_matches_native does not account for ROCm devices. The
current condition at line 292 treats ROCm as CPU since the check only looks for
"cuda" device type, causing a mismatch when
kernel_registry.get_op("linear_logp") returns a ROCm backend but the test
creates CPU tensors. Update the device selection logic to check for both "cuda"
and "rocm" device types, assigning device_ctx.device to the device variable when
either is present, and only default to "cpu" for other cases.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ccf841fc-c977-45c1-9176-9c62f141fcbf

📥 Commits

Reviewing files that changed from the base of the PR and between 6c366c6 and 3138dbb.

📒 Files selected for processing (2)
  • benchmarks/benchmark_linear_logp.py
  • tests/test_linear_logp.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • benchmarks/benchmark_linear_logp.py

@Flink-ddd Flink-ddd requested a review from maxiaosong1124 June 15, 2026 09:19
@Flink-ddd

Copy link
Copy Markdown
Collaborator

cc @inaniloquentee @maxiaosong1124 PTAL

# to a portable backend so the op stays a drop-in for any input.
if not _sm90_supported(hidden, lm_head_weight):
return _fallback_op()(hidden, lm_head_weight, target_ids, bias)
return _FusedLinearLogpSM90Function.apply(hidden, lm_head_weight, bias, target_ids)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add the same target/bias validation here before launching the SM90 kernel? Right now this fast path only checks the hidden dim; if target_ids has the wrong shape/length, the CUDA kernel can read target[row_base + r] out of bounds, and a bad bias shape/device can similarly become an invalid device access. Catching this as a Python/C++ validation error would keep bad inputs from turning into CUDA illegal memory access.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for suggestion! I will add the target/bias validation before kernel launch.

@KJLdefeated

Copy link
Copy Markdown
Collaborator Author

Hi @inaniloquentee, sorry for late reply, I have resolved the requests.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/test_linear_logp.py (1)

295-300: ⚡ Quick win

Assert the validation errors, not any RuntimeError.

RuntimeError is broad enough to also catch CUDA error: illegal memory access, so this test can pass while missing the clean pre-launch validation it is meant to guarantee. Since FusedLinearLogpSM90Op.apply() now raises ValueError for these three cases, match those errors directly.

🧪 Proposed test tightening
-    with pytest.raises((ValueError, RuntimeError)):  # wrong target length
+    with pytest.raises(ValueError, match="target_ids must have one id per token"):
         sm90(hidden, weight, target[:-1], bias)
-    with pytest.raises((ValueError, RuntimeError)):  # wrong bias length
+    with pytest.raises(ValueError, match="bias must have V="):
         sm90(hidden, weight, target, bias[:-1])
-    with pytest.raises((ValueError, RuntimeError)):  # bias on the wrong device
+    with pytest.raises(ValueError, match=r"bias device .* must match hidden device"):
         sm90(hidden, weight, target, bias.cpu())
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_linear_logp.py` around lines 295 - 300, The test is catching both
ValueError and RuntimeError, which is too broad since RuntimeError could match
unintended errors like CUDA illegal memory access errors, allowing the test to
pass without actually validating the expected behavior. Replace the three
pytest.raises calls in the test (for wrong target length, wrong bias length, and
bias on wrong device) to expect only ValueError instead of the tuple
(ValueError, RuntimeError), since FusedLinearLogpSM90Op.apply() now raises
ValueError for these specific validation cases in the sm90 function calls.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/test_linear_logp.py`:
- Around line 295-300: The test is catching both ValueError and RuntimeError,
which is too broad since RuntimeError could match unintended errors like CUDA
illegal memory access errors, allowing the test to pass without actually
validating the expected behavior. Replace the three pytest.raises calls in the
test (for wrong target length, wrong bias length, and bias on wrong device) to
expect only ValueError instead of the tuple (ValueError, RuntimeError), since
FusedLinearLogpSM90Op.apply() now raises ValueError for these specific
validation cases in the sm90 function calls.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: fbd3e088-76b8-4ea2-bfa4-ba50d8f4bf69

📥 Commits

Reviewing files that changed from the base of the PR and between 01c53b2 and bc93e36.

📒 Files selected for processing (3)
  • csrc/cuda/fused_linear_logp_sm90.cu
  • rl_engine/kernels/ops/cuda/loss/linear_logp.py
  • tests/test_linear_logp.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/cuda/fused_linear_logp_sm90.cu

@maxiaosong1124

Copy link
Copy Markdown
Collaborator

Sorry for the late reply. I fully agree with the suggestion provided by @inaniloquentee. Since you've already addressed the issue on your side, this PR is now quite solid.

@Flink-ddd

Copy link
Copy Markdown
Collaborator

Good work! Thank you for your contribution, Could you add an issue to this PR?

@inaniloquentee

Copy link
Copy Markdown
Collaborator

Hi @inaniloquentee, sorry for late reply, I have resolved the requests.

Nice work! This covers the original target length and bias cases. One small remaining guard: can we also check lm_head_weight.device == hidden.device before the SM90 launch?

@KJLdefeated

Copy link
Copy Markdown
Collaborator Author

Hi @Flink-ddd, @inaniloquentee and @maxiaosong1124, thx for review, the requested changes are resolved. The issue for this PR is here #155 .

@inaniloquentee

Copy link
Copy Markdown
Collaborator

Hi @Flink-ddd, @inaniloquentee and @maxiaosong1124, thx for review, the requested changes are resolved. The issue for this PR is here #155 .

LGTM now, thx for your contribution!

@Flink-ddd Flink-ddd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the excellent PR! It's a highly optimized piece of work that brings significant memory savings for our RL workflows. Once the boundary validation and the DRY refactoring are addressed, this will be good to merge.

if bias.device != hidden.device:
raise ValueError(
f"bias device {bias.device} must match hidden device {hidden.device}"
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a value-bounds check for target_ids here before launching the kernel. The current fast path only checks the shape. If target_ids contains padding tokens (-100) or out-of-bounds values, it will cause silent errors in the forward pass and wrap-around memory corruption in the backward pass.

Could you add something like this?

        min_target = target_ids.min().item()
        max_target = target_ids.max().item()
        if min_target < 0 or max_target >= lm_head_weight.size(0):
            raise ValueError(
                f"target_ids contains out-of-bounds values. Expected range [0, {lm_head_weight.size(0)}-1], "
                f"got [{min_target}, {max_target}]."
            )

raise ValueError(
f"hidden dim {hidden.size(-1)} must match lm_head_weight dim "
f"{lm_head_weight.size(-1)}"
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same with rl_engine/kernels/ops/cuda/loss/linear_logp.py FusedLinearLogpSM90Op.apply

grad_w = torch.zeros(v, d, device=weight.device, dtype=torch.float32)
grad_b = torch.zeros(v, device=weight.device, dtype=torch.float32) if ctx.has_bias else None

chunk = max(1, min(n, _BWD_CHUNK_ELEMS // v))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire chunked backward logic is duplicated identically in triton/loss/linear_logp.py. To keep the codebase maintainable and maintain a Single Source of Truth, could we extract this for loop into a shared utility function?

For example, creating a helper in a shared utils.py or base class:

def compute_chunked_linear_logp_backward(hidden_2d, weight, bias_t, target_1d, g, chunk_elems):
    # Move the Liger-style chunked loop here
    # ...
    return grad_h, grad_w, grad_b



def run_benchmark(args):
if device_ctx.device_type != "cuda":

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the Triton implementation is designed to be portable (and ROCm compatible as mentioned in the PR), this strict "cuda" check will prevent benchmarking on AMD devices.

Let's relax this check. You can replace it with:

    if device_ctx.device_type not in ["cuda", "xpu", "hip"]:
        raise RuntimeError("linear_logp benchmark requires a compatible GPU device.")

dz[rows, target_1d[i0:i1].long()] += 1.0
dz *= g[i0:i1].unsqueeze(1)

dz_dt = dz.to(dt)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a quick confirmation on this explicit downcast: dz is computed in fp32, and here we cast it back to dt (bf16/fp16) before the matmul.

While this perfectly matches the Triton test precision and maximizes Tensor Core throughput, it might introduce slight quantization errors. If TF32 is enabled, doing the matmul directly in fp32 is also very fast and keeps higher gradient fidelity.

Is this explicit downcast intentionally kept to strictly pass the cross-architecture precision tests? If yes, it's totally fine to leave it as is!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants