[FEAT][kernels]: add ROCm FlashAttention backend#104
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (8)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds ROCm FlashAttention and PyTorch native SDPA fallback attention operations, updates kernel registry dispatch with environment-driven ROCm backend override, introduces comprehensive correctness tests comparing backends, provides ROCm environment verification, documents ROCm setup, and integrates tests into CI. ChangesAttention Backend Implementations and Registry Dispatch
Attention Correctness and Integration
Sequence Diagram(s)sequenceDiagram
participant Dev as Developer/Test Runner
participant KernelRegistry
participant RocmOp as RocmFlashAttentionOp
participant PyTorchOp as NativeAttentionOp
Dev->>KernelRegistry: request "attn" op (platform, env)
KernelRegistry->>RocmOp: dispatch (if ROCm + env favors flash_attn)
KernelRegistry->>PyTorchOp: dispatch (fallback / native)
Dev->>RocmOp: q,k,v + dropout,scale,causal
RocmOp->>Dev: attention output
Dev->>PyTorchOp: q,k,v + dropout,scale,causal
PyTorchOp->>Dev: attention output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related issues
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
.github/workflows/ci.yml (2)
12-13:⚠️ Potential issue | 🟠 Major | ⚡ Quick winSet explicit least-privilege GitHub token permissions.
This workflow relies on default permissions, which is broader than necessary for these jobs.
Suggested patch
name: CI-Pipeline +permissions: + contents: readAlso applies to: 34-66
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In @.github/workflows/ci.yml around lines 12 - 13, The workflow relies on default (broad) permissions; add an explicit least-privilege permissions block at the top of the workflow and restrict each job to only what it needs—e.g., add a top-level permissions: block with minimal scopes such as "contents: read" (and "checks: write" only if you create check runs), and for the "linting" job (and the other CI jobs referenced later) add or override job-level permissions only when a job requires extra scopes; ensure any job that needs write access (if any) explicitly requests it, otherwise keep read-only permissions to limit token privileges.Source: Linters/SAST tools
60-67:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftCurrent CI wiring does not exercise the new GPU attention correctness paths.
This job installs CPU-only PyTorch and runs on
ubuntu-latestCPU runners, while the new correctness tests are GPU/ROCm-gated. The step will mostly skip instead of validating the new backend behavior.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In @.github/workflows/ci.yml around lines 60 - 67, The CI step named "Run Mocked Hardware Discovery Tests" currently installs CPU-only PyTorch (the pip install torch --index-url ... CPU wheel) and runs the attention correctness test, so it mostly skips GPU/ROCm paths; fix this by creating a separate CI job (e.g., "Run GPU Attention Correctness Tests") that runs on a GPU-enabled runner and installs the appropriate GPU/ROCm PyTorch wheel (replace the CPU pip install torch line with the matching GPU/ROCm index or wheel), then run the attention-correctness test there; alternatively, gate the existing step so it only runs the CPU tests and move the GPU-gated tests into the new GPU job to ensure the new backend paths are exercised.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/getting_started/installation.md`:
- Around line 39-47: After the step that does "cd flash-attention" ensure you
change back to the RL-Kernel root before running "python
scripts/check_rocm_env.py --require-flash-attn"; update the docs to insert a
directory-return step (e.g., cd back to the RL-Kernel root) between the
flash-attention install block and the verifier command so the verifier runs from
the correct repository.
In `@tests/test_attention_correctness.py`:
- Around line 115-116: Replace the broad "except Exception as exc" that converts
any failure into a skip with a narrow catch for expected availability/import
errors (e.g., "except (ImportError, ModuleNotFoundError, OSError) as exc")
around the FlashAttention availability check that returns "False, f'CUDA
FlashAttentionOp is unavailable: {exc}'"; let all other exceptions propagate (do
not catch/return False) so unexpected errors fail the test. Apply the same
change to the two other identical catch sites flagged by the review so only
import/availability-related exceptions produce the skip message and all other
exceptions are re-raised.
---
Outside diff comments:
In @.github/workflows/ci.yml:
- Around line 12-13: The workflow relies on default (broad) permissions; add an
explicit least-privilege permissions block at the top of the workflow and
restrict each job to only what it needs—e.g., add a top-level permissions: block
with minimal scopes such as "contents: read" (and "checks: write" only if you
create check runs), and for the "linting" job (and the other CI jobs referenced
later) add or override job-level permissions only when a job requires extra
scopes; ensure any job that needs write access (if any) explicitly requests it,
otherwise keep read-only permissions to limit token privileges.
- Around line 60-67: The CI step named "Run Mocked Hardware Discovery Tests"
currently installs CPU-only PyTorch (the pip install torch --index-url ... CPU
wheel) and runs the attention correctness test, so it mostly skips GPU/ROCm
paths; fix this by creating a separate CI job (e.g., "Run GPU Attention
Correctness Tests") that runs on a GPU-enabled runner and installs the
appropriate GPU/ROCm PyTorch wheel (replace the CPU pip install torch line with
the matching GPU/ROCm index or wheel), then run the attention-correctness test
there; alternatively, gate the existing step so it only runs the CPU tests and
move the GPU-gated tests into the new GPU job to ensure the new backend paths
are exercised.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 897751d9-4ddd-4872-9b11-3d00a88b8be5
📒 Files selected for processing (9)
.github/workflows/ci.ymldocs/getting_started/installation.mdrl_engine/kernels/ops/pytorch/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/flash_attn.pyrl_engine/kernels/registry.pyrl_engine/tests/test_dispatch.pyscripts/check_rocm_env.pytests/test_attention_correctness.py
46b19b2 to
0ddfb91
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
tests/test_attention_correctness.py (1)
104-149: ⚡ Quick winCache backend availability probes to reduce repeated operator initialization.
Each parametrized case re-runs availability checks and re-instantiates ops, which adds avoidable overhead across the full matrix.
♻️ Suggested refactor
import os +from functools import lru_cache @@ +@lru_cache(maxsize=1) def cuda_flash_attention_availability(): @@ +@lru_cache(maxsize=1) def rocm_flash_attention_availability(): @@ +@lru_cache(maxsize=1) def native_attention_availability():🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_attention_correctness.py` around lines 104 - 149, The three availability probe functions (cuda_flash_attention_availability, rocm_flash_attention_availability, native_attention_availability) currently re-instantiate operator classes (FlashAttentionOp, RocmFlashAttentionOp, NativeAttentionOp) on every call; cache their results to avoid repeated initialization overhead by decorating each probe with a memoization strategy (e.g., functools.lru_cache(maxsize=1)) or by using a module-level cached tuple variable populated on first call, and return the cached (bool, message) thereafter; ensure the caching covers both success and failure results so subsequent parametrized tests reuse the single probe outcome instead of recreating the ops.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/kernels/registry.py`:
- Around line 102-103: The env var normalization for RL_KERNEL_ROCM_ATTN_BACKEND
is only lowercasing and can miss values with surrounding whitespace; update the
assignment to strip whitespace before lowercasing (e.g., set rocm_attn_backend =
os.getenv("RL_KERNEL_ROCM_ATTN_BACKEND", "").strip().lower()) so values like
"flash_attn " will match the alias set used in the subsequent if-check.
In `@scripts/check_rocm_env.py`:
- Around line 48-55: The function _flash_attn_backend currently checks for the
"flash_attn_2_cuda" package before honoring the
FLASH_ATTENTION_TRITON_AMD_ENABLE env override, causing a mismatch with runtime
selection; change the check order in _flash_attn_backend so that it first
returns "triton" if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE",
"").upper() == "TRUE", then checks for
importlib.util.find_spec("flash_attn_2_cuda") to return "ck", preserving the
initial existence check for "flash_attn" and the final fallback
"triton-available-if-enabled".
In `@tests/test_attention_correctness.py`:
- Line 9: The import hard-codes torch.nn.attention symbols and can fail on
PyTorch <2.6; change the top-level import so you try to import SDPBackend and
sdpa_kernel from torch.nn.attention, and if that ImportError/AttributeError
occurs fall back to assigning sdpa_kernel = torch.backends.cuda.sdp_kernel (and
set SDPBackend to None or omit it); then update any backend-selection logic to
use sdpa_kernel unconditionally and to only reference SDPBackend when it is not
None (i.e., gate selections on the presence of SDPBackend). Ensure you touch the
symbols SDPBackend and sdpa_kernel in the test so imports won't raise during
collection and backend selection is guarded.
---
Nitpick comments:
In `@tests/test_attention_correctness.py`:
- Around line 104-149: The three availability probe functions
(cuda_flash_attention_availability, rocm_flash_attention_availability,
native_attention_availability) currently re-instantiate operator classes
(FlashAttentionOp, RocmFlashAttentionOp, NativeAttentionOp) on every call; cache
their results to avoid repeated initialization overhead by decorating each probe
with a memoization strategy (e.g., functools.lru_cache(maxsize=1)) or by using a
module-level cached tuple variable populated on first call, and return the
cached (bool, message) thereafter; ensure the caching covers both success and
failure results so subsequent parametrized tests reuse the single probe outcome
instead of recreating the ops.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 32feadf9-79d3-4436-9f99-03fbaa796f7a
📒 Files selected for processing (9)
.github/workflows/ci.ymldocs/getting_started/installation.mdrl_engine/kernels/ops/pytorch/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/flash_attn.pyrl_engine/kernels/registry.pyrl_engine/tests/test_dispatch.pyscripts/check_rocm_env.pytests/test_attention_correctness.py
🚧 Files skipped from review as they are similar to previous changes (4)
- rl_engine/kernels/ops/rocm/attention/init.py
- rl_engine/tests/test_dispatch.py
- rl_engine/kernels/ops/pytorch/attention/init.py
- rl_engine/kernels/ops/rocm/attention/flash_attn.py
0ddfb91 to
7819654
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
scripts/check_rocm_env.py (1)
51-54:⚠️ Potential issue | 🟠 Major | ⚡ Quick winAlign backend precedence with runtime selector to avoid false diagnostics.
At Line 51,
_flash_attn_backend()checksflash_attn_2_cudabeforeFLASH_ATTENTION_TRITON_AMD_ENABLE. Runtime selection does the opposite, so this script can reportckwhile runtime actually usestriton.Suggested patch
def _flash_attn_backend() -> str | None: if importlib.util.find_spec("flash_attn") is None: return None - if importlib.util.find_spec("flash_attn_2_cuda") is not None: - return "ck" if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE", "").upper() == "TRUE": return "triton" + if importlib.util.find_spec("flash_attn_2_cuda") is not None: + return "ck" return "triton-available-if-enabled"As per coding guidelines, keep script behavior consistent with the runtime backend contract in
rl_engine/kernels/ops/rocm/attention/flash_attn.pyto avoid contradictory state reporting.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@scripts/check_rocm_env.py` around lines 51 - 54, The backend precedence is reversed: update the _flash_attn_backend() logic to match the runtime selector by checking the FLASH_ATTENTION_TRITON_AMD_ENABLE environment variable first (if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE", "").upper() == "TRUE" return "triton"), and only if that is not set/true then check importlib.util.find_spec("flash_attn_2_cuda") to return "ck"; this ensures the script's reported backend follows the runtime contract in _flash_attn_backend.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@scripts/check_rocm_env.py`:
- Around line 86-89: The try/except around subprocess.check_output([hipcc,
"--version"]) only catches subprocess.CalledProcessError and therefore OSError
(e.g., executable not found or permissions) will raise an unhandled traceback;
update the except to catch both subprocess.CalledProcessError and OSError (e.g.
except (subprocess.CalledProcessError, OSError) as exc:) and call _fail(f"Could
not run {hipcc} --version: {exc}") so all failures are handled gracefully,
keeping the existing hipcc_output call and message format.
---
Duplicate comments:
In `@scripts/check_rocm_env.py`:
- Around line 51-54: The backend precedence is reversed: update the
_flash_attn_backend() logic to match the runtime selector by checking the
FLASH_ATTENTION_TRITON_AMD_ENABLE environment variable first (if
os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE", "").upper() == "TRUE" return
"triton"), and only if that is not set/true then check
importlib.util.find_spec("flash_attn_2_cuda") to return "ck"; this ensures the
script's reported backend follows the runtime contract in _flash_attn_backend.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7a935686-d6f0-42c9-a81e-55ef9fa51641
📒 Files selected for processing (9)
.github/workflows/ci.ymldocs/getting_started/installation.mdrl_engine/kernels/ops/pytorch/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/flash_attn.pyrl_engine/kernels/registry.pyrl_engine/tests/test_dispatch.pyscripts/check_rocm_env.pytests/test_attention_correctness.py
✅ Files skipped from review due to trivial changes (2)
- docs/getting_started/installation.md
- rl_engine/kernels/ops/rocm/attention/init.py
🚧 Files skipped from review as they are similar to previous changes (6)
- .github/workflows/ci.yml
- rl_engine/tests/test_dispatch.py
- rl_engine/kernels/ops/pytorch/attention/init.py
- rl_engine/kernels/ops/rocm/attention/flash_attn.py
- tests/test_attention_correctness.py
- rl_engine/kernels/registry.py
7819654 to
c19cfe0
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (1)
scripts/check_rocm_env.py (1)
86-89:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winMissing
OSErrorhandling (previously flagged).Line 87 catches only
CalledProcessError, butsubprocess.check_outputcan raiseOSErrorif the executable is not found, lacks permissions, or encounters other OS-level failures, bypassing_fail(...)with an unhandled traceback.This was previously raised but not yet addressed.
🛡️ Proposed fix
- except subprocess.CalledProcessError as exc: + except (subprocess.CalledProcessError, OSError) as exc: _fail(f"Could not run {hipcc} --version: {exc}")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@scripts/check_rocm_env.py` around lines 86 - 89, The try/except around subprocess.check_output([hipcc, "--version"], text=True) only catches subprocess.CalledProcessError and misses OSError (e.g., executable not found), so update the except to catch both exceptions (e.g., except (subprocess.CalledProcessError, OSError) as exc:) and call _fail(f"Could not run {hipcc} --version: {exc}") so OS-level errors are handled the same way as subprocess errors; keep the call to subprocess.check_output and the _fail message unchanged otherwise.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@scripts/check_rocm_env.py`:
- Around line 86-89: The try/except around subprocess.check_output([hipcc,
"--version"], text=True) only catches subprocess.CalledProcessError and misses
OSError (e.g., executable not found), so update the except to catch both
exceptions (e.g., except (subprocess.CalledProcessError, OSError) as exc:) and
call _fail(f"Could not run {hipcc} --version: {exc}") so OS-level errors are
handled the same way as subprocess errors; keep the call to
subprocess.check_output and the _fail message unchanged otherwise.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b5ebb6d1-924e-4280-b780-5c9dab6b99f5
📒 Files selected for processing (9)
.github/workflows/ci.ymldocs/getting_started/installation.mdrl_engine/kernels/ops/pytorch/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/flash_attn.pyrl_engine/kernels/registry.pyrl_engine/tests/test_dispatch.pyscripts/check_rocm_env.pytests/test_attention_correctness.py
✅ Files skipped from review due to trivial changes (1)
- docs/getting_started/installation.md
🚧 Files skipped from review as they are similar to previous changes (7)
- rl_engine/tests/test_dispatch.py
- .github/workflows/ci.yml
- rl_engine/kernels/ops/pytorch/attention/init.py
- rl_engine/kernels/ops/rocm/attention/init.py
- rl_engine/kernels/ops/rocm/attention/flash_attn.py
- rl_engine/kernels/registry.py
- tests/test_attention_correctness.py
c19cfe0 to
84b8625
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (1)
scripts/check_rocm_env.py (1)
48-55:⚠️ Potential issue | 🟠 Major | ⚡ Quick winBackend precedence still misaligned with runtime selection.
The function checks
flash_attn_2_cudabefore honoringFLASH_ATTENTION_TRITON_AMD_ENABLE, which can reportckwhile runtime actually usestriton(when the env var is set). The past review comment on these lines suggested checking the env var first but appears not to have been applied.🔧 Recommended fix (from past review)
def _flash_attn_backend() -> str | None: if importlib.util.find_spec("flash_attn") is None: return None - if importlib.util.find_spec("flash_attn_2_cuda") is not None: - return "ck" if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE", "").upper() == "TRUE": return "triton" + if importlib.util.find_spec("flash_attn_2_cuda") is not None: + return "ck" return "triton-available-if-enabled"🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@scripts/check_rocm_env.py` around lines 48 - 55, The _flash_attn_backend function reports the wrong precedence: check the environment flag FLASH_ATTENTION_TRITON_AMD_ENABLE first and honor it before inspecting modules; change the logic in _flash_attn_backend so it returns "triton" when os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE", "").upper() == "TRUE" regardless of whether importlib.util.find_spec("flash_attn_2_cuda") exists, otherwise fall back to checking flash_attn_2_cuda (return "ck"), then check for flash_attn presence and finally return the default "triton-available-if-enabled" or None as currently implemented.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@scripts/check_rocm_env.py`:
- Around line 48-55: The _flash_attn_backend function reports the wrong
precedence: check the environment flag FLASH_ATTENTION_TRITON_AMD_ENABLE first
and honor it before inspecting modules; change the logic in _flash_attn_backend
so it returns "triton" when os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE",
"").upper() == "TRUE" regardless of whether
importlib.util.find_spec("flash_attn_2_cuda") exists, otherwise fall back to
checking flash_attn_2_cuda (return "ck"), then check for flash_attn presence and
finally return the default "triton-available-if-enabled" or None as currently
implemented.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a0063888-6a0e-43ed-837f-97c9f61c84b8
📒 Files selected for processing (9)
.github/workflows/ci.ymldocs/getting_started/installation.mdrl_engine/kernels/ops/pytorch/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/flash_attn.pyrl_engine/kernels/registry.pyrl_engine/tests/test_dispatch.pyscripts/check_rocm_env.pytests/test_attention_correctness.py
✅ Files skipped from review due to trivial changes (3)
- .github/workflows/ci.yml
- rl_engine/kernels/ops/rocm/attention/init.py
- docs/getting_started/installation.md
🚧 Files skipped from review as they are similar to previous changes (5)
- rl_engine/tests/test_dispatch.py
- rl_engine/kernels/ops/rocm/attention/flash_attn.py
- rl_engine/kernels/registry.py
- rl_engine/kernels/ops/pytorch/attention/init.py
- tests/test_attention_correctness.py
|
@coderabbitai resume |
✅ Action performedReviews resumed. |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@scripts/check_rocm_env.py`:
- Around line 110-111: Replace the os.environ.setdefault call with an explicit
assignment so the env var is forced to TRUE when flash_attn_backend ==
"triton-available-if-enabled": change the code that checks flash_attn_backend
and currently calls os.environ.setdefault("FLASH_ATTENTION_TRITON_AMD_ENABLE",
"TRUE") to instead assign os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] =
"TRUE" (ensuring the FLASH_ATTENTION_TRITON_AMD_ENABLE variable is
unconditionally set to "TRUE" in that branch).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7366e8c7-0a91-4c85-99ea-9e8c08350ecc
📒 Files selected for processing (9)
.github/workflows/ci.ymldocs/getting_started/installation.mdrl_engine/kernels/ops/pytorch/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/flash_attn.pyrl_engine/kernels/registry.pyrl_engine/tests/test_dispatch.pyscripts/check_rocm_env.pytests/test_attention_correctness.py
✅ Files skipped from review due to trivial changes (2)
- .github/workflows/ci.yml
- docs/getting_started/installation.md
🚧 Files skipped from review as they are similar to previous changes (6)
- rl_engine/kernels/ops/pytorch/attention/init.py
- rl_engine/kernels/ops/rocm/attention/init.py
- rl_engine/kernels/ops/rocm/attention/flash_attn.py
- rl_engine/tests/test_dispatch.py
- tests/test_attention_correctness.py
- rl_engine/kernels/registry.py
84b8625 to
789fe1d
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (1)
scripts/check_rocm_env.py (1)
48-55:⚠️ Potential issue | 🟠 Major | ⚡ Quick winHonor Triton env override before CK detection in backend probe.
Line 51 currently returns
"ck"before Line 53 checksFLASH_ATTENTION_TRITON_AMD_ENABLE, which conflicts with runtime selection order inrl_engine/kernels/ops/rocm/attention/flash_attn.pyand can report the wrong backend.Suggested patch
def _flash_attn_backend() -> str | None: if importlib.util.find_spec("flash_attn") is None: return None - if importlib.util.find_spec("flash_attn_2_cuda") is not None: - return "ck" if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE", "").upper() == "TRUE": return "triton" + if importlib.util.find_spec("flash_attn_2_cuda") is not None: + return "ck" return "triton-available-if-enabled"🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@scripts/check_rocm_env.py` around lines 48 - 55, The probe function _flash_attn_backend currently prefers the CK backend before honoring the Triton override; update _flash_attn_backend so it checks the FLASH_ATTENTION_TRITON_AMD_ENABLE environment variable (os.environ.get(...).upper() == "TRUE") immediately after confirming "flash_attn" is present and before testing for the "flash_attn_2_cuda" spec, returning "triton" when the env var is set; leave the other return values ("ck" if flash_attn_2_cuda exists, otherwise "triton-available-if-enabled" or None) unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@scripts/check_rocm_env.py`:
- Around line 48-55: The probe function _flash_attn_backend currently prefers
the CK backend before honoring the Triton override; update _flash_attn_backend
so it checks the FLASH_ATTENTION_TRITON_AMD_ENABLE environment variable
(os.environ.get(...).upper() == "TRUE") immediately after confirming
"flash_attn" is present and before testing for the "flash_attn_2_cuda" spec,
returning "triton" when the env var is set; leave the other return values ("ck"
if flash_attn_2_cuda exists, otherwise "triton-available-if-enabled" or None)
unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 479c04eb-c895-4cb6-bdd4-28b7ceb5698b
📒 Files selected for processing (9)
.github/workflows/ci.ymldocs/getting_started/installation.mdrl_engine/kernels/ops/pytorch/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/flash_attn.pyrl_engine/kernels/registry.pyrl_engine/tests/test_dispatch.pyscripts/check_rocm_env.pytests/test_attention_correctness.py
✅ Files skipped from review due to trivial changes (2)
- docs/getting_started/installation.md
- rl_engine/kernels/ops/rocm/attention/init.py
🚧 Files skipped from review as they are similar to previous changes (4)
- rl_engine/kernels/ops/rocm/attention/flash_attn.py
- rl_engine/kernels/ops/pytorch/attention/init.py
- rl_engine/kernels/registry.py
- tests/test_attention_correctness.py
789fe1d to
46c6c61
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@scripts/check_rocm_env.py`:
- Around line 112-117: The try/except around "from flash_attn import
flash_attn_func" only catches ImportError but compiled-extension import failures
may raise OSError or RuntimeError; update the except to catch those as well
(e.g., except (ImportError, OSError, RuntimeError) as exc or a broad except
Exception as exc), set flash_attn_error = str(exc) and ensure
flash_attn_func_available is left False in the error case so the checker reports
a controlled failure instead of crashing.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 454c6d3f-b8de-4b1a-895e-2efddd8cbb62
📒 Files selected for processing (9)
.github/workflows/ci.ymldocs/getting_started/installation.mdrl_engine/kernels/ops/pytorch/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/__init__.pyrl_engine/kernels/ops/rocm/attention/flash_attn.pyrl_engine/kernels/registry.pyrl_engine/tests/test_dispatch.pyscripts/check_rocm_env.pytests/test_attention_correctness.py
✅ Files skipped from review due to trivial changes (1)
- docs/getting_started/installation.md
🚧 Files skipped from review as they are similar to previous changes (6)
- .github/workflows/ci.yml
- rl_engine/kernels/ops/rocm/attention/init.py
- rl_engine/kernels/ops/rocm/attention/flash_attn.py
- rl_engine/tests/test_dispatch.py
- rl_engine/kernels/ops/pytorch/attention/init.py
- tests/test_attention_correctness.py
46c6c61 to
2dd883d
Compare
Flink-ddd
left a comment
There was a problem hiding this comment.
Thanks for the thorough PR, Happy to re-review once the blocking items are addressed. The core wrapper and test scaffolding are in good shape
|
@FED4 please resolve CI error first, then We can merge this PR first. |
Summary
Related to #39.
This PR adds ROCm attention support and makes external ROCm FlashAttention the default ROCm attention backend to match the issue contract. PyTorch SDPA remains available as fallback / explicit opt-out:
export RL_KERNEL_ROCM_ATTN_BACKEND=sdpaThe SDPA/AOTriton vs external ROCm FlashAttention performance gap is real, but I am treating that as follow-up work rather than changing the default in this PR.
Changes
RocmFlashAttentionOpfor external FlashAttention 2 Triton AMD.ROCM_FLASH_ATTNas the default ROCmattnbackend.NativeAttentionOpas the PyTorch SDPA fallback for FlashAttention-layout tensors.RL_KERNEL_ROCM_ATTN_BACKEND=sdpaopt-out for PyTorch SDPA dispatch.NativeAttentionOpforq_heads > kv_heads.ROCm backend note
This PR does not implement RL-Kernel's generic
ROCM_CKbackend.For FlashAttention itself, this PR uses FlashAttention 2 Triton AMD. I also tried FlashAttention 2's CK backend, but its standalone build failed on the MI300/gfx942 machine before RL-Kernel was involved, so this PR does not depend on that path.
Test environment
Correctness
Command:
Result:
Observed external ROCm FlashAttention diff vs PyTorch SDPA math backend:
Other checks:
One-off benchmark
This is a quick sanity benchmark, not a formal performance claim. It compares PyTorch SDPA/AOTriton with external
RocmFlashAttentionOp.With
FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=TRUE:External FlashAttention wins on some cases, but not broadly enough to claim a general performance win. Further performance work can investigate CK or shape-specific routing in a separate PR.
Summary by CodeRabbit
Release Notes
New Features
RL_KERNEL_ROCM_ATTN_BACKENDenvironment variable.Documentation
Tests