[FEAT][kernels] Fused linear log-prob without materializing logits — SM90 TMA + tensor-core forward, chunked backward#122
Conversation
… backward + SM90 TMA/WGMMA) Implements `linear_logp`: log_softmax(hidden @ W^T + b)[target] without materializing the [N, V] logits, differentiable w.r.t. hidden/weight/bias. - PyTorch `NativeLinearLogpOp`: naive F.linear + log_softmax + gather reference. - Triton `TritonLinearLogpOp`: online-softmax forward (zero [N,V] materialization, tensor cores via native-dtype tl.dot), Liger-style chunked backward (cuBLAS matmuls, sequential grad_weight accumulation -> deterministic, peak mem chunk*V). - CUDA SM90 `FusedLinearLogpSM90Op`: TMA + WGMMA streaming forward + smem online softmax (csrc/cuda/fused_linear_logp_sm90.cu); build-guarded behind KERNEL_ALIGN_FORCE_SM90, registry-gated to cc_major==9. Assembles for sm_90a under CUDA 13.1; numerics pending validation on Hopper. - Registry dispatch, tests, benchmark, operator + design docs. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (3)
🚧 Files skipped from review as they are similar to previous changes (3)
📝 WalkthroughWalkthroughAdds a new ChangesFused Linear LogP operator
Sequence Diagram(s)sequenceDiagram
participant App
participant KernelRegistry
participant BackendSelector
participant FusedSM90Op
participant TritonOp
participant NativeOp
App->>KernelRegistry: get_op("linear_logp")
KernelRegistry->>BackendSelector: probe device and extension symbols
alt CUDA device with cc_major==9 and SM90 extension
BackendSelector-->>KernelRegistry: FusedLinearLogpSM90Op has priority
KernelRegistry-->>App: FusedLinearLogpSM90Op instance
App->>FusedSM90Op: apply(hidden_bf16, weight, target, bias)
FusedSM90Op-->>App: logp tensor
else CUDA available, Triton compiled
BackendSelector-->>KernelRegistry: TritonLinearLogpOp fallback
KernelRegistry-->>App: TritonLinearLogpOp instance
App->>TritonOp: apply(hidden, weight, target, bias)
TritonOp-->>App: logp tensor
else CPU or others
BackendSelector-->>KernelRegistry: NativeLinearLogpOp fallback
KernelRegistry-->>App: NativeLinearLogpOp instance
App->>NativeOp: apply(hidden, weight, target, bias)
NativeOp-->>App: logp tensor
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related issues
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (3)
rl_engine/kernels/ops/triton/loss/linear_logp.py (1)
174-174: 💤 Low valueOptional: Use tuple unpacking for clarity.
Per static analysis (RUF005), prefer
(*ctx.lead_shape, d)overctx.lead_shape + (d,).♻️ Suggested fix
- grad_hidden = grad_h.to(ctx.hidden_dtype).reshape(ctx.lead_shape + (d,)) + grad_hidden = grad_h.to(ctx.hidden_dtype).reshape((*ctx.lead_shape, d))🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@rl_engine/kernels/ops/triton/loss/linear_logp.py` at line 174, In the grad_hidden assignment on the line with the reshape call, replace the tuple concatenation syntax ctx.lead_shape + (d,) with tuple unpacking syntax (*ctx.lead_shape, d) to improve code clarity and comply with the RUF005 style recommendation.Source: Linters/SAST tools
tests/test_linear_logp.py (1)
26-33: ⚡ Quick winNarrow SM90 availability exception handling to avoid silently skipping real failures.
Catching
Exceptionhere can hide real extension/runtime breakages as test skips. Catch only expected “unavailable” failures so unexpected errors still fail fast.Suggested patch
def _sm90_available(): @@ - except Exception: # pragma: no cover + except (ImportError, AttributeError, OSError): # pragma: no cover return False🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_linear_logp.py` around lines 26 - 33, The broad Exception catch in the function (around the import and availability check) masks real failures like import errors or runtime issues. Replace the generic Exception handler with specific exception types that indicate the extension is genuinely unavailable, such as ImportError or ModuleNotFoundError, while allowing unexpected errors to propagate and fail the test setup appropriately.Source: Linters/SAST tools
benchmarks/benchmark_linear_logp.py (1)
107-115: ⚡ Quick winBind loop-scoped tensors directly in benchmark closures.
The current closure pattern relies on loop-scope capture (
target,hidden,weight) and is flagged by B023. Binding them as default args removes late-binding risk and future-proofs the timing helpers.Suggested patch
- def fwd(op, h=hidden, w=weight): + def fwd(op, h=hidden, w=weight, t=target): with torch.no_grad(): - op(h, w, target) + op(h, w, t) - def fwd_bwd(op): - h = hidden.clone().requires_grad_(True) - w = weight.clone().requires_grad_(True) - op(h, w, target).sum().backward() + def fwd_bwd(op, h_src=hidden, w_src=weight, t=target): + h = h_src.clone().requires_grad_(True) + w = w_src.clone().requires_grad_(True) + op(h, w, t).sum().backward()🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@benchmarks/benchmark_linear_logp.py` around lines 107 - 115, The closures fwd and fwd_bwd are relying on late-binding capture of loop-scoped tensors, particularly target, which creates a B023 linting issue. Fix this by adding target as a default argument to both function definitions: add target=target to the parameter list of the fwd function (alongside the existing h=hidden, w=weight defaults), and add target=target as a default argument to the fwd_bwd function definition. This binds the loop-scoped variable directly at closure definition time rather than relying on late binding.Source: Linters/SAST tools
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/operators/linear-logp.md`:
- Around line 68-72: In the code snippet showing the linear operation, the bias
parameter passed to torch.nn.functional.linear is not being cast to float, which
is inconsistent with the fp32-upcast reference implementation. Update the bias
argument in the torch.nn.functional.linear call to conditionally cast it to
float using the pattern: None if bias is None else bias.float(). This ensures
the documentation snippet matches the actual reference implementation behavior
for bf16/fp16 inputs.
In `@rl_engine/kernels/ops/cuda/loss/linear_logp.py`:
- Around line 22-29: The _sm90_supported function is too permissive and can
route unsupported inputs to the fused kernel, causing runtime errors. Add
stricter eligibility checks: verify that the GPU is actually SM90-capable
(Hopper or newer) using torch.cuda.get_device_capability() or similar device
properties, and ensure that both the hidden and lm_head_weight tensors are on
the same CUDA device by checking their device attribute matches. These checks
should be added to the existing dtype and tensor dimension validations in the
_sm90_supported predicate.
In `@setup.py`:
- Around line 118-124: The current condition `if enable_sm90 and present_sm90:`
enables the KERNEL_ALIGN_WITH_SM90 macro whenever any SM90 source exists, but
since csrc/ops.cpp registers SM90 entry points under this macro, enabling it
with only a partial set of SM90 sources will cause unresolved symbols at link
time. Modify the condition to only enable KERNEL_ALIGN_WITH_SM90 and append the
related link flags when the complete set of SM90 sources is present (not just a
subset), by verifying that present_sm90 contains all the expected sources from
sm90_srcs.
---
Nitpick comments:
In `@benchmarks/benchmark_linear_logp.py`:
- Around line 107-115: The closures fwd and fwd_bwd are relying on late-binding
capture of loop-scoped tensors, particularly target, which creates a B023
linting issue. Fix this by adding target as a default argument to both function
definitions: add target=target to the parameter list of the fwd function
(alongside the existing h=hidden, w=weight defaults), and add target=target as a
default argument to the fwd_bwd function definition. This binds the loop-scoped
variable directly at closure definition time rather than relying on late
binding.
In `@rl_engine/kernels/ops/triton/loss/linear_logp.py`:
- Line 174: In the grad_hidden assignment on the line with the reshape call,
replace the tuple concatenation syntax ctx.lead_shape + (d,) with tuple
unpacking syntax (*ctx.lead_shape, d) to improve code clarity and comply with
the RUF005 style recommendation.
In `@tests/test_linear_logp.py`:
- Around line 26-33: The broad Exception catch in the function (around the
import and availability check) masks real failures like import errors or runtime
issues. Replace the generic Exception handler with specific exception types that
indicate the extension is genuinely unavailable, such as ImportError or
ModuleNotFoundError, while allowing unexpected errors to propagate and fail the
test setup appropriately.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 256fc756-1f5a-457f-ab20-239d81d61a10
📒 Files selected for processing (14)
benchmarks/benchmark_linear_logp.pycsrc/cuda/fused_linear_logp_sm90.cucsrc/cuda/fused_logp_sm90.cucsrc/ops.cppdocs/.nav.ymldocs/operators/README.mddocs/operators/linear-logp.mdrl_engine/kernels/ops/cuda/loss/linear_logp.pyrl_engine/kernels/ops/pytorch/loss/linear_logp.pyrl_engine/kernels/ops/triton/loss/__init__.pyrl_engine/kernels/ops/triton/loss/linear_logp.pyrl_engine/kernels/registry.pysetup.pytests/test_linear_logp.py
| ```python | ||
| logits = torch.nn.functional.linear(hidden.float(), weight.float(), bias) # [N, V] | ||
| logp = torch.log_softmax(logits, dim=-1) | ||
| out = logp.gather(-1, target_ids.long().unsqueeze(-1)).squeeze(-1) | ||
| ``` |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, locate and read the file to verify the current state
fd -type f -name "linear-logp.md" | head -5Repository: RL-Align/RL-Kernel
Length of output: 233
🏁 Script executed:
# Check the structure to understand the repository
git ls-files | grep -E "(docs/operators|test)" | head -20Repository: RL-Align/RL-Kernel
Length of output: 631
🏁 Script executed:
# Search for bias casting patterns in the codebase
rg "bias\.float\(\)" --type py -B 2 -A 2Repository: RL-Align/RL-Kernel
Length of output: 1680
🏁 Script executed:
# Read the specific file at the mentioned lines
sed -n '60,80p' docs/operators/linear-logp.mdRepository: RL-Align/RL-Kernel
Length of output: 818
🏁 Script executed:
# Get more context around those lines
cat -n docs/operators/linear-logp.md | sed -n '50,90p'Repository: RL-Align/RL-Kernel
Length of output: 1969
🏁 Script executed:
# Check the complete test file to understand the reference implementation
cat -n tests/test_linear_logp.py | head -50Repository: RL-Align/RL-Kernel
Length of output: 1893
🏁 Script executed:
# Check if bias can be None and understand the full reference implementation pattern
grep -B 5 -A 10 "def native" tests/test_linear_logp.pyRepository: RL-Align/RL-Kernel
Length of output: 44
🏁 Script executed:
# Look for the full test function to understand the bf16/fp16 test context
grep -B 3 -A 8 "bfloat16\|fp16" tests/test_linear_logp.pyRepository: RL-Align/RL-Kernel
Length of output: 2378
🏁 Script executed:
# Get the complete test function to see all the patterns
cat -n tests/test_linear_logp.py | sed -n '70,140p'Repository: RL-Align/RL-Kernel
Length of output: 3427
🏁 Script executed:
# Check if there are any other reference implementations or examples
rg "NativeLinearLogpOp\|F\.linear" docs/ --context 3Repository: RL-Align/RL-Kernel
Length of output: 44
🏁 Script executed:
# Verify the tensor contract to confirm bias is optional
sed -n '55,65p' docs/operators/linear-logp.mdRepository: RL-Align/RL-Kernel
Length of output: 632
🏁 Script executed:
# Check if the _inputs function returns None for bias when bias=False
sed -n '55,75p' tests/test_linear_logp.pyRepository: RL-Align/RL-Kernel
Length of output: 923
🏁 Script executed:
# One more check - see the _inputs function to confirm bias parameter behavior
sed -n '55,75p' tests/test_linear_logp.pyRepository: RL-Align/RL-Kernel
Length of output: 923
Cast bias to float for consistency with the fp32-upcast reference path.
The reference snippet claims to match the "fp32-upcast reference" for bf16/fp16 inputs (line 76), but doesn't cast bias. The actual reference implementation in tests (line 77 of test_linear_logp.py) conditionally casts bias: None if bias is None else bias.float(). Update the snippet to match.
Suggested patch
-logits = torch.nn.functional.linear(hidden.float(), weight.float(), bias) # [N, V]
+logits = torch.nn.functional.linear(
+ hidden.float(),
+ weight.float(),
+ None if bias is None else bias.float(),
+) # [N, V]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@docs/operators/linear-logp.md` around lines 68 - 72, In the code snippet
showing the linear operation, the bias parameter passed to
torch.nn.functional.linear is not being cast to float, which is inconsistent
with the fp32-upcast reference implementation. Update the bias argument in the
torch.nn.functional.linear call to conditionally cast it to float using the
pattern: None if bias is None else bias.float(). This ensures the documentation
snippet matches the actual reference implementation behavior for bf16/fp16
inputs.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/test_linear_logp.py (1)
287-296: 💤 Low valueROCm device handling gap in registry dispatch test.
The device selection logic at line 292 treats ROCm as CPU since
"rocm" != "cuda". On a ROCm system,kernel_registry.get_op("linear_logp")returns a ROCm backend (Triton or Native per the registry's"rocm"priority list), but the test would create CPU tensors, causing either a device mismatch or testing the wrong code path.Suggested fix to handle ROCm
def test_registry_dispatch_matches_native(): from rl_engine.kernels.registry import kernel_registry from rl_engine.platforms.device import device_ctx op = kernel_registry.get_op("linear_logp") - device = device_ctx.device if device_ctx.device_type == "cuda" else "cpu" + device = device_ctx.device if device_ctx.device_type in ("cuda", "rocm") else "cpu" hidden, weight, target, bias = _inputs(6, device=device) out = op(hidden, weight, target, bias) ref = NativeLinearLogpOp()(hidden, weight, target, bias) assert torch.allclose(out.cpu(), ref.cpu(), atol=1e-3)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_linear_logp.py` around lines 287 - 296, The device selection logic in test_registry_dispatch_matches_native does not account for ROCm devices. The current condition at line 292 treats ROCm as CPU since the check only looks for "cuda" device type, causing a mismatch when kernel_registry.get_op("linear_logp") returns a ROCm backend but the test creates CPU tensors. Update the device selection logic to check for both "cuda" and "rocm" device types, assigning device_ctx.device to the device variable when either is present, and only default to "cpu" for other cases.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/test_linear_logp.py`:
- Around line 287-296: The device selection logic in
test_registry_dispatch_matches_native does not account for ROCm devices. The
current condition at line 292 treats ROCm as CPU since the check only looks for
"cuda" device type, causing a mismatch when
kernel_registry.get_op("linear_logp") returns a ROCm backend but the test
creates CPU tensors. Update the device selection logic to check for both "cuda"
and "rocm" device types, assigning device_ctx.device to the device variable when
either is present, and only default to "cpu" for other cases.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ccf841fc-c977-45c1-9176-9c62f141fcbf
📒 Files selected for processing (2)
benchmarks/benchmark_linear_logp.pytests/test_linear_logp.py
🚧 Files skipped from review as they are similar to previous changes (1)
- benchmarks/benchmark_linear_logp.py
|
cc @inaniloquentee @maxiaosong1124 PTAL |
| # to a portable backend so the op stays a drop-in for any input. | ||
| if not _sm90_supported(hidden, lm_head_weight): | ||
| return _fallback_op()(hidden, lm_head_weight, target_ids, bias) | ||
| return _FusedLinearLogpSM90Function.apply(hidden, lm_head_weight, bias, target_ids) |
There was a problem hiding this comment.
Could we add the same target/bias validation here before launching the SM90 kernel? Right now this fast path only checks the hidden dim; if target_ids has the wrong shape/length, the CUDA kernel can read target[row_base + r] out of bounds, and a bad bias shape/device can similarly become an invalid device access. Catching this as a Python/C++ validation error would keep bad inputs from turning into CUDA illegal memory access.
There was a problem hiding this comment.
Thx for suggestion! I will add the target/bias validation before kernel launch.
|
Hi @inaniloquentee, sorry for late reply, I have resolved the requests. |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/test_linear_logp.py (1)
295-300: ⚡ Quick winAssert the validation errors, not any
RuntimeError.
RuntimeErroris broad enough to also catchCUDA error: illegal memory access, so this test can pass while missing the clean pre-launch validation it is meant to guarantee. SinceFusedLinearLogpSM90Op.apply()now raisesValueErrorfor these three cases, match those errors directly.🧪 Proposed test tightening
- with pytest.raises((ValueError, RuntimeError)): # wrong target length + with pytest.raises(ValueError, match="target_ids must have one id per token"): sm90(hidden, weight, target[:-1], bias) - with pytest.raises((ValueError, RuntimeError)): # wrong bias length + with pytest.raises(ValueError, match="bias must have V="): sm90(hidden, weight, target, bias[:-1]) - with pytest.raises((ValueError, RuntimeError)): # bias on the wrong device + with pytest.raises(ValueError, match=r"bias device .* must match hidden device"): sm90(hidden, weight, target, bias.cpu())🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_linear_logp.py` around lines 295 - 300, The test is catching both ValueError and RuntimeError, which is too broad since RuntimeError could match unintended errors like CUDA illegal memory access errors, allowing the test to pass without actually validating the expected behavior. Replace the three pytest.raises calls in the test (for wrong target length, wrong bias length, and bias on wrong device) to expect only ValueError instead of the tuple (ValueError, RuntimeError), since FusedLinearLogpSM90Op.apply() now raises ValueError for these specific validation cases in the sm90 function calls.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/test_linear_logp.py`:
- Around line 295-300: The test is catching both ValueError and RuntimeError,
which is too broad since RuntimeError could match unintended errors like CUDA
illegal memory access errors, allowing the test to pass without actually
validating the expected behavior. Replace the three pytest.raises calls in the
test (for wrong target length, wrong bias length, and bias on wrong device) to
expect only ValueError instead of the tuple (ValueError, RuntimeError), since
FusedLinearLogpSM90Op.apply() now raises ValueError for these specific
validation cases in the sm90 function calls.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: fbd3e088-76b8-4ea2-bfa4-ba50d8f4bf69
📒 Files selected for processing (3)
csrc/cuda/fused_linear_logp_sm90.curl_engine/kernels/ops/cuda/loss/linear_logp.pytests/test_linear_logp.py
🚧 Files skipped from review as they are similar to previous changes (1)
- csrc/cuda/fused_linear_logp_sm90.cu
|
Sorry for the late reply. I fully agree with the suggestion provided by @inaniloquentee. Since you've already addressed the issue on your side, this PR is now quite solid. |
|
Good work! Thank you for your contribution, Could you add an issue to this PR? |
Nice work! This covers the original target length and bias cases. One small remaining guard: can we also check lm_head_weight.device == hidden.device before the SM90 launch? |
|
Hi @Flink-ddd, @inaniloquentee and @maxiaosong1124, thx for review, the requested changes are resolved. The issue for this PR is here #155 . |
LGTM now, thx for your contribution! |
Flink-ddd
left a comment
There was a problem hiding this comment.
Thanks for the excellent PR! It's a highly optimized piece of work that brings significant memory savings for our RL workflows. Once the boundary validation and the DRY refactoring are addressed, this will be good to merge.
| if bias.device != hidden.device: | ||
| raise ValueError( | ||
| f"bias device {bias.device} must match hidden device {hidden.device}" | ||
| ) |
There was a problem hiding this comment.
We need a value-bounds check for target_ids here before launching the kernel. The current fast path only checks the shape. If target_ids contains padding tokens (-100) or out-of-bounds values, it will cause silent errors in the forward pass and wrap-around memory corruption in the backward pass.
Could you add something like this?
min_target = target_ids.min().item()
max_target = target_ids.max().item()
if min_target < 0 or max_target >= lm_head_weight.size(0):
raise ValueError(
f"target_ids contains out-of-bounds values. Expected range [0, {lm_head_weight.size(0)}-1], "
f"got [{min_target}, {max_target}]."
)| raise ValueError( | ||
| f"hidden dim {hidden.size(-1)} must match lm_head_weight dim " | ||
| f"{lm_head_weight.size(-1)}" | ||
| ) |
There was a problem hiding this comment.
same with rl_engine/kernels/ops/cuda/loss/linear_logp.py FusedLinearLogpSM90Op.apply
| grad_w = torch.zeros(v, d, device=weight.device, dtype=torch.float32) | ||
| grad_b = torch.zeros(v, device=weight.device, dtype=torch.float32) if ctx.has_bias else None | ||
|
|
||
| chunk = max(1, min(n, _BWD_CHUNK_ELEMS // v)) |
There was a problem hiding this comment.
This entire chunked backward logic is duplicated identically in triton/loss/linear_logp.py. To keep the codebase maintainable and maintain a Single Source of Truth, could we extract this for loop into a shared utility function?
For example, creating a helper in a shared utils.py or base class:
def compute_chunked_linear_logp_backward(hidden_2d, weight, bias_t, target_1d, g, chunk_elems):
# Move the Liger-style chunked loop here
# ...
return grad_h, grad_w, grad_b|
|
||
|
|
||
| def run_benchmark(args): | ||
| if device_ctx.device_type != "cuda": |
There was a problem hiding this comment.
Since the Triton implementation is designed to be portable (and ROCm compatible as mentioned in the PR), this strict "cuda" check will prevent benchmarking on AMD devices.
Let's relax this check. You can replace it with:
if device_ctx.device_type not in ["cuda", "xpu", "hip"]:
raise RuntimeError("linear_logp benchmark requires a compatible GPU device.")| dz[rows, target_1d[i0:i1].long()] += 1.0 | ||
| dz *= g[i0:i1].unsqueeze(1) | ||
|
|
||
| dz_dt = dz.to(dt) |
There was a problem hiding this comment.
Just a quick confirmation on this explicit downcast: dz is computed in fp32, and here we cast it back to dt (bf16/fp16) before the matmul.
While this perfectly matches the Triton test precision and maximizes Tensor Core throughput, it might introduce slight quantization errors. If TF32 is enabled, doing the matmul directly in fp32 is also very fast and keeps higher gradient fidelity.
Is this explicit downcast intentionally kept to strictly pass the cross-architecture precision tests? If yes, it's totally fine to leave it as is!
Summary
Fused linear log-prob op (CUDA(SM90) + Triton + native reference + chunked backward landed earlier on the branch).
ldmatrix+mma.sync.m16n8k16, fp32 accum) + split-V; online softmax.F.linear+log_softmax+gatherreference / CPU fallback.The SM90 op requires bf16 hidden/weight with
D % 32 == 0; for any other input(fp32/fp16, awkwardD, CPU) it transparently falls back to Triton (else native), so it stays a drop-in.Implementation (SM90 forward)
cp.async.bulk.tensor, mbarrier-completed) streamsH/Wtiles into shared memory, double-buffered so the next hidden slice loads while the current one feeds the MMAs.BM = 256rows/CTA) so the weight matrix is re-read onlyN/BMtimes from HBM.blockIdx.yso the grid fills all SMs even whenN/BMalone is too few CTAs; each split emits a partial online-softmax state and a tiny combine kernel merges them (standard log-sum-exp merge). Adds only3·n_split·Nfloats of scratch — independent ofV.Performance (H100, bf16, N=4096, D=2048)
Forward latency (ms) and peak forward VRAM:
Forward + backward latency (ms):
python benchmarks/benchmark_linear_logp.py. Forward peak memory is the per-CTA shared-memory tiles — independent ofV.Correctness / tests
python -m pytest tests/test_linear_logp.py— 17/17 pass.Bug fix:
fused_logp_sm90.cuWhile enabling the SM90 build I found the existing TMA online-softmax kernel (
csrc/cuda/fused_logp_sm90.cu, the materialized-logitsfused_logp) fails to compile whenever the SM90 path is actually built — two latent, version-independent errors that were dormant because the prebuilt extension had no SM90 symbols (the pathhad never been exercised).
CUDART_INF_Fundefined — the kernel uses-CUDART_INF_Fbut never includes<math_constants.h>(not pulled in transitively).Fix:
#include <math_constants.h>.TmaTypeTraits<c10::BFloat16>is an incomplete type — the host wrapper passeslogits.data_ptr<at::BFloat16>()(i.e.c10::BFloat16*) straight intoinit_tensor_map, butTmaTypeTraitsis only specialized fornv_bfloat16/float,so template instantiation fails.
Fix:
reinterpret_cast<const nv_bfloat16*>(logits.data_ptr<at::BFloat16>())(matching how the linear kernel and the rest of the codebase bridgeat::BFloat16→nv_bfloat16).Summary by CodeRabbit
Release Notes
New Features
Documentation
Tests