Skip to content

[FEAT][kernels]: add ROCm FlashAttention backend#104

Open
FED4 wants to merge 2 commits into
RL-Align:mainfrom
FED4:issue-39-rocm-attention
Open

[FEAT][kernels]: add ROCm FlashAttention backend#104
FED4 wants to merge 2 commits into
RL-Align:mainfrom
FED4:issue-39-rocm-attention

Conversation

@FED4

@FED4 FED4 commented Jun 13, 2026

Copy link
Copy Markdown

Summary

Related to #39.

This PR adds ROCm attention support and makes external ROCm FlashAttention the default ROCm attention backend to match the issue contract. PyTorch SDPA remains available as fallback / explicit opt-out:

export RL_KERNEL_ROCM_ATTN_BACKEND=sdpa

The SDPA/AOTriton vs external ROCm FlashAttention performance gap is real, but I am treating that as follow-up work rather than changing the default in this PR.

Changes

  • Add RocmFlashAttentionOp for external FlashAttention 2 Triton AMD.
  • Use ROCM_FLASH_ATTN as the default ROCm attn backend.
  • Add NativeAttentionOp as the PyTorch SDPA fallback for FlashAttention-layout tensors.
  • Add RL_KERNEL_ROCM_ATTN_BACKEND=sdpa opt-out for PyTorch SDPA dispatch.
  • Add GQA/MQA handling to NativeAttentionOp for q_heads > kv_heads.
  • Add CUDA/ROCm/native attention correctness tests against PyTorch SDPA math backend.
  • Add ROCm FlashAttention environment checker and installation notes.

ROCm backend note

This PR does not implement RL-Kernel's generic ROCM_CK backend.

For FlashAttention itself, this PR uses FlashAttention 2 Triton AMD. I also tried FlashAttention 2's CK backend, but its standalone build failed on the MI300/gfx942 machine before RL-Kernel was involved, so this PR does not depend on that path.

Test environment

| Item | Value |
|---|---|
| GPU | AMD Radeon Graphics, gfx942 / MI300-class |
| PyTorch | 2.9.1+rocm6.3 |
| HIP | 6.3.42134-a9a80e791 |
| PyTorch ROCm FA backend | AOTriton |
| External FlashAttention backend | Triton AMD |

Correctness

Command:

FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE PYTEST_DISABLE_PLUGIN_AUTOLOAD=1 python -m pytest tests/test_attention_correctness.py -q -rs

Result:

| Test | Result |
|---|---|
| NativeAttentionOp + ROCm FlashAttention correctness | 87 passed |
| CUDA external FlashAttention cases on ROCm machine | 40 skipped |
| Skip reason | current torch build is not CUDA platform |

Observed external ROCm FlashAttention diff vs PyTorch SDPA math backend:

| dtype | max abs diff observed | tolerance |
|---|---:|---:|
| fp16 | ~1.01e-3 | atol=1e-3, rtol=1e-3 |
| bf16 | ~8.84e-3 | atol=2e-2, rtol=2e-2 |

Other checks:

| Check | Result |
|---|---|
| python scripts/check_rocm_env.py | passed |
| pytest tests/test_kernel_registry.py rl_engine/tests/test_dispatch.py tests/test_attention_correctness.py -q -rs | 102 passed, 40 skipped |
| ruff check changed files | passed |
| python -m mkdocs build --strict -f mkdocs.yaml | passed |

One-off benchmark

This is a quick sanity benchmark, not a formal performance claim. It compares PyTorch SDPA/AOTriton with external RocmFlashAttentionOp.

With FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=TRUE:

| dtype | shape (B,S,H,D) | causal | NativeAttentionOp ms | RocmFlashAttentionOp autotune ms | native/external |
|---|---:|---:|---:|---:|---:|
| fp16 | (1,128,4,64) | false | 0.020 | 0.101 | 0.19x |
| fp16 | (2,256,8,128) | false | 0.047 | 0.103 | 0.45x |
| fp16 | (1,512,8,256) | false | 0.099 | 0.116 | 0.85x |
| fp16 | (1,1024,16,64) | false | 0.118 | 0.111 | 1.07x |
| fp16 | (1,2048,16,64) | false | 0.311 | 0.344 | 0.90x |
| bf16 | (1,512,8,256) | false | 0.100 | 0.114 | 0.87x |
| bf16 | (1,1024,16,64) | false | 0.136 | 0.116 | 1.17x |
| bf16 | (1,2048,16,64) | false | 0.345 | 0.377 | 0.92x |

External FlashAttention wins on some cases, but not broadly enough to claim a general performance win. Further performance work can investigate CK or shape-specific routing in a separate PR.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added ROCm backend support with configurable attention kernel selection via RL_KERNEL_ROCM_ATTN_BACKEND environment variable.
    • Added PyTorch SDPA fallback attention implementation.
  • Documentation

    • Added ROCm backend installation and setup guide with environment configuration instructions.
  • Tests

    • Enhanced attention correctness validation and dispatch behavior verification.

@coderabbitai

coderabbitai Bot commented Jun 13, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 16cc12ef-0ce9-45ed-b665-83daeabd14c4

📥 Commits

Reviewing files that changed from the base of the PR and between 2dd883d and 2ce40fc.

📒 Files selected for processing (8)
  • docs/getting_started/installation.md
  • rl_engine/kernels/ops/pytorch/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • rl_engine/kernels/registry.py
  • rl_engine/tests/test_dispatch.py
  • scripts/check_rocm_env.py
  • tests/test_attention_correctness.py
  • tests/test_kernel_registry.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • rl_engine/kernels/registry.py

📝 Walkthrough

Walkthrough

Adds ROCm FlashAttention and PyTorch native SDPA fallback attention operations, updates kernel registry dispatch with environment-driven ROCm backend override, introduces comprehensive correctness tests comparing backends, provides ROCm environment verification, documents ROCm setup, and integrates tests into CI.

Changes

Attention Backend Implementations and Registry Dispatch

Layer / File(s) Summary
PyTorch Native Attention Fallback
rl_engine/kernels/ops/pytorch/attention/__init__.py
NativeAttentionOp wraps F.scaled_dot_product_attention, transposing between FlashAttention layout (batch, seqlen, nheads, headdim) and PyTorch SDPA layout, with GQA/MQA support via head repetition when q heads divide evenly by k/v heads.
ROCm FlashAttention Wrapper
rl_engine/kernels/ops/rocm/attention/flash_attn.py, rl_engine/kernels/ops/rocm/attention/__init__.py
RocmFlashAttentionOp selects Triton backend, requires ROCm PyTorch build, dynamically imports and binds flash_attn_func, validates FP16/BF16 dtype and GPU device placement (same device, rank-4 tensors), and forwards dropout, softmax scale, and causal flag.
Kernel Registry Dispatch and Environment Override
rl_engine/kernels/registry.py
Adds OpBackend.PYTORCH_ATTN and OpBackend.ROCM_FLASH_ATTN enum members, updates _priority_map attn ordering for CUDA/ROCm/CPU, and adds _adjust_priority_from_env() hook that reads RL_KERNEL_ROCM_ATTN_BACKEND to override ROCm attn priority (aliases for flash-attn, sdpa variants).
Registry Dispatch Tests
rl_engine/tests/test_dispatch.py
Two tests verify ROCm attn backend ordering: default behavior with ROCM_FLASH_ATTN top priority, and env opt-out via RL_KERNEL_ROCM_ATTN_BACKEND="sdpa" placing PYTORCH_ATTN first.

Attention Correctness and Integration

Layer / File(s) Summary
Attention Correctness Test Suite
tests/test_attention_correctness.py
Parameterized correctness tests comparing CUDA FlashAttention, ROCm FlashAttention, and native attention against PyTorch SDPA math backend reference across dtypes (FP32, FP16, BF16), shapes, causal/non-causal modes, and default/explicit softmax scales; includes GQA/MQA scenarios with head repetition validation, negative tests for unsupported head dimensions and invalid GQA ratios, and optional diff reporting via PRINT_ATTENTION_DIFF env var.
Kernel Registry Environment Tests
tests/test_kernel_registry.py
Validates KernelRegistry ROCm attention backend selection covering default ordering, env alias normalization (case/whitespace handling), env opt-in/opt-out variants, precedence of env override over hardware-based adjustment, and unknown value fallback with warning emission.
ROCm Environment Verification Script
scripts/check_rocm_env.py
Checks PyTorch ROCm build availability, CUDA/GPU presence, GPU device name, Triton import status, sets FLASH_ATTENTION_TRITON_AMD_ENABLE, attempts flash_attn_func import, reports backend availability, and fails if flash-attn is unavailable.
ROCm Installation Documentation and CI
docs/getting_started/installation.md, .github/workflows/ci.yml
Adds ROCm Backend section documenting PyTorch build selection, FlashAttention source install with Triton AMD backend, environment verification via check script, and RL_KERNEL_ROCM_ATTN_BACKEND override usage; updates CI workflow with restricted GitHub token permissions (contents: read) and adds attention correctness test run with PYTEST_DISABLE_PLUGIN_AUTOLOAD=1.

Sequence Diagram(s)

sequenceDiagram
  participant Dev as Developer/Test Runner
  participant KernelRegistry
  participant RocmOp as RocmFlashAttentionOp
  participant PyTorchOp as NativeAttentionOp
  Dev->>KernelRegistry: request "attn" op (platform, env)
  KernelRegistry->>RocmOp: dispatch (if ROCm + env favors flash_attn)
  KernelRegistry->>PyTorchOp: dispatch (fallback / native)
  Dev->>RocmOp: q,k,v + dropout,scale,causal
  RocmOp->>Dev: attention output
  Dev->>PyTorchOp: q,k,v + dropout,scale,causal
  PyTorchOp->>Dev: attention output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related issues

Poem

🐇 I hopped through ops and registry lanes,
ROCm and native minds align the chains,
FlashAttention chose its speedy track,
Tests compare outputs, watching each back,
Docs and CI tuck the build to bed.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 10.53% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title accurately describes the main feature being added: ROCm FlashAttention backend support, which is the primary objective of this changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
.github/workflows/ci.yml (2)

12-13: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Set explicit least-privilege GitHub token permissions.

This workflow relies on default permissions, which is broader than necessary for these jobs.

Suggested patch
 name: CI-Pipeline
+permissions:
+  contents: read

Also applies to: 34-66

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In @.github/workflows/ci.yml around lines 12 - 13, The workflow relies on
default (broad) permissions; add an explicit least-privilege permissions block
at the top of the workflow and restrict each job to only what it needs—e.g., add
a top-level permissions: block with minimal scopes such as "contents: read" (and
"checks: write" only if you create check runs), and for the "linting" job (and
the other CI jobs referenced later) add or override job-level permissions only
when a job requires extra scopes; ensure any job that needs write access (if
any) explicitly requests it, otherwise keep read-only permissions to limit token
privileges.

Source: Linters/SAST tools


60-67: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Current CI wiring does not exercise the new GPU attention correctness paths.

This job installs CPU-only PyTorch and runs on ubuntu-latest CPU runners, while the new correctness tests are GPU/ROCm-gated. The step will mostly skip instead of validating the new backend behavior.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In @.github/workflows/ci.yml around lines 60 - 67, The CI step named "Run Mocked
Hardware Discovery Tests" currently installs CPU-only PyTorch (the pip install
torch --index-url ... CPU wheel) and runs the attention correctness test, so it
mostly skips GPU/ROCm paths; fix this by creating a separate CI job (e.g., "Run
GPU Attention Correctness Tests") that runs on a GPU-enabled runner and installs
the appropriate GPU/ROCm PyTorch wheel (replace the CPU pip install torch line
with the matching GPU/ROCm index or wheel), then run the attention-correctness
test there; alternatively, gate the existing step so it only runs the CPU tests
and move the GPU-gated tests into the new GPU job to ensure the new backend
paths are exercised.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/getting_started/installation.md`:
- Around line 39-47: After the step that does "cd flash-attention" ensure you
change back to the RL-Kernel root before running "python
scripts/check_rocm_env.py --require-flash-attn"; update the docs to insert a
directory-return step (e.g., cd back to the RL-Kernel root) between the
flash-attention install block and the verifier command so the verifier runs from
the correct repository.

In `@tests/test_attention_correctness.py`:
- Around line 115-116: Replace the broad "except Exception as exc" that converts
any failure into a skip with a narrow catch for expected availability/import
errors (e.g., "except (ImportError, ModuleNotFoundError, OSError) as exc")
around the FlashAttention availability check that returns "False, f'CUDA
FlashAttentionOp is unavailable: {exc}'"; let all other exceptions propagate (do
not catch/return False) so unexpected errors fail the test. Apply the same
change to the two other identical catch sites flagged by the review so only
import/availability-related exceptions produce the skip message and all other
exceptions are re-raised.

---

Outside diff comments:
In @.github/workflows/ci.yml:
- Around line 12-13: The workflow relies on default (broad) permissions; add an
explicit least-privilege permissions block at the top of the workflow and
restrict each job to only what it needs—e.g., add a top-level permissions: block
with minimal scopes such as "contents: read" (and "checks: write" only if you
create check runs), and for the "linting" job (and the other CI jobs referenced
later) add or override job-level permissions only when a job requires extra
scopes; ensure any job that needs write access (if any) explicitly requests it,
otherwise keep read-only permissions to limit token privileges.
- Around line 60-67: The CI step named "Run Mocked Hardware Discovery Tests"
currently installs CPU-only PyTorch (the pip install torch --index-url ... CPU
wheel) and runs the attention correctness test, so it mostly skips GPU/ROCm
paths; fix this by creating a separate CI job (e.g., "Run GPU Attention
Correctness Tests") that runs on a GPU-enabled runner and installs the
appropriate GPU/ROCm PyTorch wheel (replace the CPU pip install torch line with
the matching GPU/ROCm index or wheel), then run the attention-correctness test
there; alternatively, gate the existing step so it only runs the CPU tests and
move the GPU-gated tests into the new GPU job to ensure the new backend paths
are exercised.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 897751d9-4ddd-4872-9b11-3d00a88b8be5

📥 Commits

Reviewing files that changed from the base of the PR and between 04c014d and 46b19b2.

📒 Files selected for processing (9)
  • .github/workflows/ci.yml
  • docs/getting_started/installation.md
  • rl_engine/kernels/ops/pytorch/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • rl_engine/kernels/registry.py
  • rl_engine/tests/test_dispatch.py
  • scripts/check_rocm_env.py
  • tests/test_attention_correctness.py

Comment thread docs/getting_started/installation.md Outdated
Comment thread rl_engine/kernels/ops/rocm/attention/flash_attn.py Outdated
Comment thread tests/test_attention_correctness.py Outdated
@FED4 FED4 force-pushed the issue-39-rocm-attention branch from 46b19b2 to 0ddfb91 Compare June 13, 2026 06:08

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tests/test_attention_correctness.py (1)

104-149: ⚡ Quick win

Cache backend availability probes to reduce repeated operator initialization.

Each parametrized case re-runs availability checks and re-instantiates ops, which adds avoidable overhead across the full matrix.

♻️ Suggested refactor
 import os
+from functools import lru_cache
@@
+@lru_cache(maxsize=1)
 def cuda_flash_attention_availability():
@@
+@lru_cache(maxsize=1)
 def rocm_flash_attention_availability():
@@
+@lru_cache(maxsize=1)
 def native_attention_availability():
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_attention_correctness.py` around lines 104 - 149, The three
availability probe functions (cuda_flash_attention_availability,
rocm_flash_attention_availability, native_attention_availability) currently
re-instantiate operator classes (FlashAttentionOp, RocmFlashAttentionOp,
NativeAttentionOp) on every call; cache their results to avoid repeated
initialization overhead by decorating each probe with a memoization strategy
(e.g., functools.lru_cache(maxsize=1)) or by using a module-level cached tuple
variable populated on first call, and return the cached (bool, message)
thereafter; ensure the caching covers both success and failure results so
subsequent parametrized tests reuse the single probe outcome instead of
recreating the ops.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/kernels/registry.py`:
- Around line 102-103: The env var normalization for RL_KERNEL_ROCM_ATTN_BACKEND
is only lowercasing and can miss values with surrounding whitespace; update the
assignment to strip whitespace before lowercasing (e.g., set rocm_attn_backend =
os.getenv("RL_KERNEL_ROCM_ATTN_BACKEND", "").strip().lower()) so values like
"flash_attn " will match the alias set used in the subsequent if-check.

In `@scripts/check_rocm_env.py`:
- Around line 48-55: The function _flash_attn_backend currently checks for the
"flash_attn_2_cuda" package before honoring the
FLASH_ATTENTION_TRITON_AMD_ENABLE env override, causing a mismatch with runtime
selection; change the check order in _flash_attn_backend so that it first
returns "triton" if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE",
"").upper() == "TRUE", then checks for
importlib.util.find_spec("flash_attn_2_cuda") to return "ck", preserving the
initial existence check for "flash_attn" and the final fallback
"triton-available-if-enabled".

In `@tests/test_attention_correctness.py`:
- Line 9: The import hard-codes torch.nn.attention symbols and can fail on
PyTorch <2.6; change the top-level import so you try to import SDPBackend and
sdpa_kernel from torch.nn.attention, and if that ImportError/AttributeError
occurs fall back to assigning sdpa_kernel = torch.backends.cuda.sdp_kernel (and
set SDPBackend to None or omit it); then update any backend-selection logic to
use sdpa_kernel unconditionally and to only reference SDPBackend when it is not
None (i.e., gate selections on the presence of SDPBackend). Ensure you touch the
symbols SDPBackend and sdpa_kernel in the test so imports won't raise during
collection and backend selection is guarded.

---

Nitpick comments:
In `@tests/test_attention_correctness.py`:
- Around line 104-149: The three availability probe functions
(cuda_flash_attention_availability, rocm_flash_attention_availability,
native_attention_availability) currently re-instantiate operator classes
(FlashAttentionOp, RocmFlashAttentionOp, NativeAttentionOp) on every call; cache
their results to avoid repeated initialization overhead by decorating each probe
with a memoization strategy (e.g., functools.lru_cache(maxsize=1)) or by using a
module-level cached tuple variable populated on first call, and return the
cached (bool, message) thereafter; ensure the caching covers both success and
failure results so subsequent parametrized tests reuse the single probe outcome
instead of recreating the ops.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 32feadf9-79d3-4436-9f99-03fbaa796f7a

📥 Commits

Reviewing files that changed from the base of the PR and between 46b19b2 and 0ddfb91.

📒 Files selected for processing (9)
  • .github/workflows/ci.yml
  • docs/getting_started/installation.md
  • rl_engine/kernels/ops/pytorch/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • rl_engine/kernels/registry.py
  • rl_engine/tests/test_dispatch.py
  • scripts/check_rocm_env.py
  • tests/test_attention_correctness.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • rl_engine/kernels/ops/rocm/attention/init.py
  • rl_engine/tests/test_dispatch.py
  • rl_engine/kernels/ops/pytorch/attention/init.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py

Comment thread rl_engine/kernels/registry.py Outdated
Comment thread scripts/check_rocm_env.py Outdated
Comment thread tests/test_attention_correctness.py Outdated
@FED4 FED4 force-pushed the issue-39-rocm-attention branch from 0ddfb91 to 7819654 Compare June 13, 2026 06:24

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
scripts/check_rocm_env.py (1)

51-54: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Align backend precedence with runtime selector to avoid false diagnostics.

At Line 51, _flash_attn_backend() checks flash_attn_2_cuda before FLASH_ATTENTION_TRITON_AMD_ENABLE. Runtime selection does the opposite, so this script can report ck while runtime actually uses triton.

Suggested patch
 def _flash_attn_backend() -> str | None:
     if importlib.util.find_spec("flash_attn") is None:
         return None
-    if importlib.util.find_spec("flash_attn_2_cuda") is not None:
-        return "ck"
     if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE", "").upper() == "TRUE":
         return "triton"
+    if importlib.util.find_spec("flash_attn_2_cuda") is not None:
+        return "ck"
     return "triton-available-if-enabled"

As per coding guidelines, keep script behavior consistent with the runtime backend contract in rl_engine/kernels/ops/rocm/attention/flash_attn.py to avoid contradictory state reporting.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@scripts/check_rocm_env.py` around lines 51 - 54, The backend precedence is
reversed: update the _flash_attn_backend() logic to match the runtime selector
by checking the FLASH_ATTENTION_TRITON_AMD_ENABLE environment variable first (if
os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE", "").upper() == "TRUE" return
"triton"), and only if that is not set/true then check
importlib.util.find_spec("flash_attn_2_cuda") to return "ck"; this ensures the
script's reported backend follows the runtime contract in _flash_attn_backend.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@scripts/check_rocm_env.py`:
- Around line 86-89: The try/except around subprocess.check_output([hipcc,
"--version"]) only catches subprocess.CalledProcessError and therefore OSError
(e.g., executable not found or permissions) will raise an unhandled traceback;
update the except to catch both subprocess.CalledProcessError and OSError (e.g.
except (subprocess.CalledProcessError, OSError) as exc:) and call _fail(f"Could
not run {hipcc} --version: {exc}") so all failures are handled gracefully,
keeping the existing hipcc_output call and message format.

---

Duplicate comments:
In `@scripts/check_rocm_env.py`:
- Around line 51-54: The backend precedence is reversed: update the
_flash_attn_backend() logic to match the runtime selector by checking the
FLASH_ATTENTION_TRITON_AMD_ENABLE environment variable first (if
os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE", "").upper() == "TRUE" return
"triton"), and only if that is not set/true then check
importlib.util.find_spec("flash_attn_2_cuda") to return "ck"; this ensures the
script's reported backend follows the runtime contract in _flash_attn_backend.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7a935686-d6f0-42c9-a81e-55ef9fa51641

📥 Commits

Reviewing files that changed from the base of the PR and between 0ddfb91 and 7819654.

📒 Files selected for processing (9)
  • .github/workflows/ci.yml
  • docs/getting_started/installation.md
  • rl_engine/kernels/ops/pytorch/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • rl_engine/kernels/registry.py
  • rl_engine/tests/test_dispatch.py
  • scripts/check_rocm_env.py
  • tests/test_attention_correctness.py
✅ Files skipped from review due to trivial changes (2)
  • docs/getting_started/installation.md
  • rl_engine/kernels/ops/rocm/attention/init.py
🚧 Files skipped from review as they are similar to previous changes (6)
  • .github/workflows/ci.yml
  • rl_engine/tests/test_dispatch.py
  • rl_engine/kernels/ops/pytorch/attention/init.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • tests/test_attention_correctness.py
  • rl_engine/kernels/registry.py

Comment thread scripts/check_rocm_env.py Outdated
@FED4 FED4 force-pushed the issue-39-rocm-attention branch from 7819654 to c19cfe0 Compare June 13, 2026 06:33

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
scripts/check_rocm_env.py (1)

86-89: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Missing OSError handling (previously flagged).

Line 87 catches only CalledProcessError, but subprocess.check_output can raise OSError if the executable is not found, lacks permissions, or encounters other OS-level failures, bypassing _fail(...) with an unhandled traceback.

This was previously raised but not yet addressed.

🛡️ Proposed fix
-    except subprocess.CalledProcessError as exc:
+    except (subprocess.CalledProcessError, OSError) as exc:
         _fail(f"Could not run {hipcc} --version: {exc}")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@scripts/check_rocm_env.py` around lines 86 - 89, The try/except around
subprocess.check_output([hipcc, "--version"], text=True) only catches
subprocess.CalledProcessError and misses OSError (e.g., executable not found),
so update the except to catch both exceptions (e.g., except
(subprocess.CalledProcessError, OSError) as exc:) and call _fail(f"Could not run
{hipcc} --version: {exc}") so OS-level errors are handled the same way as
subprocess errors; keep the call to subprocess.check_output and the _fail
message unchanged otherwise.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@scripts/check_rocm_env.py`:
- Around line 86-89: The try/except around subprocess.check_output([hipcc,
"--version"], text=True) only catches subprocess.CalledProcessError and misses
OSError (e.g., executable not found), so update the except to catch both
exceptions (e.g., except (subprocess.CalledProcessError, OSError) as exc:) and
call _fail(f"Could not run {hipcc} --version: {exc}") so OS-level errors are
handled the same way as subprocess errors; keep the call to
subprocess.check_output and the _fail message unchanged otherwise.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b5ebb6d1-924e-4280-b780-5c9dab6b99f5

📥 Commits

Reviewing files that changed from the base of the PR and between 7819654 and c19cfe0.

📒 Files selected for processing (9)
  • .github/workflows/ci.yml
  • docs/getting_started/installation.md
  • rl_engine/kernels/ops/pytorch/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • rl_engine/kernels/registry.py
  • rl_engine/tests/test_dispatch.py
  • scripts/check_rocm_env.py
  • tests/test_attention_correctness.py
✅ Files skipped from review due to trivial changes (1)
  • docs/getting_started/installation.md
🚧 Files skipped from review as they are similar to previous changes (7)
  • rl_engine/tests/test_dispatch.py
  • .github/workflows/ci.yml
  • rl_engine/kernels/ops/pytorch/attention/init.py
  • rl_engine/kernels/ops/rocm/attention/init.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • rl_engine/kernels/registry.py
  • tests/test_attention_correctness.py

@FED4 FED4 force-pushed the issue-39-rocm-attention branch from c19cfe0 to 84b8625 Compare June 13, 2026 06:37

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
scripts/check_rocm_env.py (1)

48-55: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Backend precedence still misaligned with runtime selection.

The function checks flash_attn_2_cuda before honoring FLASH_ATTENTION_TRITON_AMD_ENABLE, which can report ck while runtime actually uses triton (when the env var is set). The past review comment on these lines suggested checking the env var first but appears not to have been applied.

🔧 Recommended fix (from past review)
 def _flash_attn_backend() -> str | None:
     if importlib.util.find_spec("flash_attn") is None:
         return None
-    if importlib.util.find_spec("flash_attn_2_cuda") is not None:
-        return "ck"
     if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE", "").upper() == "TRUE":
         return "triton"
+    if importlib.util.find_spec("flash_attn_2_cuda") is not None:
+        return "ck"
     return "triton-available-if-enabled"
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@scripts/check_rocm_env.py` around lines 48 - 55, The _flash_attn_backend
function reports the wrong precedence: check the environment flag
FLASH_ATTENTION_TRITON_AMD_ENABLE first and honor it before inspecting modules;
change the logic in _flash_attn_backend so it returns "triton" when
os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE", "").upper() == "TRUE"
regardless of whether importlib.util.find_spec("flash_attn_2_cuda") exists,
otherwise fall back to checking flash_attn_2_cuda (return "ck"), then check for
flash_attn presence and finally return the default "triton-available-if-enabled"
or None as currently implemented.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@scripts/check_rocm_env.py`:
- Around line 48-55: The _flash_attn_backend function reports the wrong
precedence: check the environment flag FLASH_ATTENTION_TRITON_AMD_ENABLE first
and honor it before inspecting modules; change the logic in _flash_attn_backend
so it returns "triton" when os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE",
"").upper() == "TRUE" regardless of whether
importlib.util.find_spec("flash_attn_2_cuda") exists, otherwise fall back to
checking flash_attn_2_cuda (return "ck"), then check for flash_attn presence and
finally return the default "triton-available-if-enabled" or None as currently
implemented.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a0063888-6a0e-43ed-837f-97c9f61c84b8

📥 Commits

Reviewing files that changed from the base of the PR and between c19cfe0 and 84b8625.

📒 Files selected for processing (9)
  • .github/workflows/ci.yml
  • docs/getting_started/installation.md
  • rl_engine/kernels/ops/pytorch/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • rl_engine/kernels/registry.py
  • rl_engine/tests/test_dispatch.py
  • scripts/check_rocm_env.py
  • tests/test_attention_correctness.py
✅ Files skipped from review due to trivial changes (3)
  • .github/workflows/ci.yml
  • rl_engine/kernels/ops/rocm/attention/init.py
  • docs/getting_started/installation.md
🚧 Files skipped from review as they are similar to previous changes (5)
  • rl_engine/tests/test_dispatch.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • rl_engine/kernels/registry.py
  • rl_engine/kernels/ops/pytorch/attention/init.py
  • tests/test_attention_correctness.py

@FED4

FED4 commented Jun 13, 2026

Copy link
Copy Markdown
Author

@coderabbitai resume

@coderabbitai

coderabbitai Bot commented Jun 13, 2026

Copy link
Copy Markdown
✅ Action performed

Reviews resumed.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@scripts/check_rocm_env.py`:
- Around line 110-111: Replace the os.environ.setdefault call with an explicit
assignment so the env var is forced to TRUE when flash_attn_backend ==
"triton-available-if-enabled": change the code that checks flash_attn_backend
and currently calls os.environ.setdefault("FLASH_ATTENTION_TRITON_AMD_ENABLE",
"TRUE") to instead assign os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] =
"TRUE" (ensuring the FLASH_ATTENTION_TRITON_AMD_ENABLE variable is
unconditionally set to "TRUE" in that branch).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7366e8c7-0a91-4c85-99ea-9e8c08350ecc

📥 Commits

Reviewing files that changed from the base of the PR and between c19cfe0 and 84b8625.

📒 Files selected for processing (9)
  • .github/workflows/ci.yml
  • docs/getting_started/installation.md
  • rl_engine/kernels/ops/pytorch/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • rl_engine/kernels/registry.py
  • rl_engine/tests/test_dispatch.py
  • scripts/check_rocm_env.py
  • tests/test_attention_correctness.py
✅ Files skipped from review due to trivial changes (2)
  • .github/workflows/ci.yml
  • docs/getting_started/installation.md
🚧 Files skipped from review as they are similar to previous changes (6)
  • rl_engine/kernels/ops/pytorch/attention/init.py
  • rl_engine/kernels/ops/rocm/attention/init.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • rl_engine/tests/test_dispatch.py
  • tests/test_attention_correctness.py
  • rl_engine/kernels/registry.py

Comment thread scripts/check_rocm_env.py Outdated
@FED4 FED4 force-pushed the issue-39-rocm-attention branch from 84b8625 to 789fe1d Compare June 13, 2026 07:22

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
scripts/check_rocm_env.py (1)

48-55: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Honor Triton env override before CK detection in backend probe.

Line 51 currently returns "ck" before Line 53 checks FLASH_ATTENTION_TRITON_AMD_ENABLE, which conflicts with runtime selection order in rl_engine/kernels/ops/rocm/attention/flash_attn.py and can report the wrong backend.

Suggested patch
 def _flash_attn_backend() -> str | None:
     if importlib.util.find_spec("flash_attn") is None:
         return None
-    if importlib.util.find_spec("flash_attn_2_cuda") is not None:
-        return "ck"
     if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE", "").upper() == "TRUE":
         return "triton"
+    if importlib.util.find_spec("flash_attn_2_cuda") is not None:
+        return "ck"
     return "triton-available-if-enabled"
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@scripts/check_rocm_env.py` around lines 48 - 55, The probe function
_flash_attn_backend currently prefers the CK backend before honoring the Triton
override; update _flash_attn_backend so it checks the
FLASH_ATTENTION_TRITON_AMD_ENABLE environment variable
(os.environ.get(...).upper() == "TRUE") immediately after confirming
"flash_attn" is present and before testing for the "flash_attn_2_cuda" spec,
returning "triton" when the env var is set; leave the other return values ("ck"
if flash_attn_2_cuda exists, otherwise "triton-available-if-enabled" or None)
unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@scripts/check_rocm_env.py`:
- Around line 48-55: The probe function _flash_attn_backend currently prefers
the CK backend before honoring the Triton override; update _flash_attn_backend
so it checks the FLASH_ATTENTION_TRITON_AMD_ENABLE environment variable
(os.environ.get(...).upper() == "TRUE") immediately after confirming
"flash_attn" is present and before testing for the "flash_attn_2_cuda" spec,
returning "triton" when the env var is set; leave the other return values ("ck"
if flash_attn_2_cuda exists, otherwise "triton-available-if-enabled" or None)
unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 479c04eb-c895-4cb6-bdd4-28b7ceb5698b

📥 Commits

Reviewing files that changed from the base of the PR and between 84b8625 and 789fe1d.

📒 Files selected for processing (9)
  • .github/workflows/ci.yml
  • docs/getting_started/installation.md
  • rl_engine/kernels/ops/pytorch/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • rl_engine/kernels/registry.py
  • rl_engine/tests/test_dispatch.py
  • scripts/check_rocm_env.py
  • tests/test_attention_correctness.py
✅ Files skipped from review due to trivial changes (2)
  • docs/getting_started/installation.md
  • rl_engine/kernels/ops/rocm/attention/init.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • rl_engine/kernels/ops/pytorch/attention/init.py
  • rl_engine/kernels/registry.py
  • tests/test_attention_correctness.py

@FED4 FED4 force-pushed the issue-39-rocm-attention branch from 789fe1d to 46c6c61 Compare June 13, 2026 07:42

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@scripts/check_rocm_env.py`:
- Around line 112-117: The try/except around "from flash_attn import
flash_attn_func" only catches ImportError but compiled-extension import failures
may raise OSError or RuntimeError; update the except to catch those as well
(e.g., except (ImportError, OSError, RuntimeError) as exc or a broad except
Exception as exc), set flash_attn_error = str(exc) and ensure
flash_attn_func_available is left False in the error case so the checker reports
a controlled failure instead of crashing.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 454c6d3f-b8de-4b1a-895e-2efddd8cbb62

📥 Commits

Reviewing files that changed from the base of the PR and between 789fe1d and 46c6c61.

📒 Files selected for processing (9)
  • .github/workflows/ci.yml
  • docs/getting_started/installation.md
  • rl_engine/kernels/ops/pytorch/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/__init__.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • rl_engine/kernels/registry.py
  • rl_engine/tests/test_dispatch.py
  • scripts/check_rocm_env.py
  • tests/test_attention_correctness.py
✅ Files skipped from review due to trivial changes (1)
  • docs/getting_started/installation.md
🚧 Files skipped from review as they are similar to previous changes (6)
  • .github/workflows/ci.yml
  • rl_engine/kernels/ops/rocm/attention/init.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • rl_engine/tests/test_dispatch.py
  • rl_engine/kernels/ops/pytorch/attention/init.py
  • tests/test_attention_correctness.py

Comment thread scripts/check_rocm_env.py Outdated
@FED4 FED4 force-pushed the issue-39-rocm-attention branch from 46c6c61 to 2dd883d Compare June 13, 2026 09:41
@Flink-ddd Flink-ddd added component: kernels Tasks involving the development of CUDA and Triton underlying operators platform: rocm Specific tasks specific to AMD graphics cards (such as CK, bpreshuffle/FA) labels Jun 15, 2026

@Flink-ddd Flink-ddd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the thorough PR, Happy to re-review once the blocking items are addressed. The core wrapper and test scaffolding are in good shape

Comment thread scripts/check_rocm_env.py Outdated
Comment thread rl_engine/kernels/ops/rocm/attention/flash_attn.py
Comment thread rl_engine/kernels/ops/rocm/attention/flash_attn.py
Comment thread rl_engine/kernels/ops/rocm/attention/flash_attn.py Outdated
Comment thread rl_engine/kernels/ops/rocm/attention/flash_attn.py
Comment thread rl_engine/kernels/registry.py
Comment thread rl_engine/kernels/registry.py
Comment thread rl_engine/tests/test_dispatch.py Outdated
Comment thread tests/test_attention_correctness.py
Comment thread rl_engine/kernels/ops/pytorch/attention/__init__.py
@Flink-ddd

Copy link
Copy Markdown
Collaborator

@FED4 please resolve CI error first, then We can merge this PR first.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

component: kernels Tasks involving the development of CUDA and Triton underlying operators platform: rocm Specific tasks specific to AMD graphics cards (such as CK, bpreshuffle/FA)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants