Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ on:
pull_request:
branches: [ main ]

permissions:
contents: read

jobs:
linting:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -63,6 +66,7 @@ jobs:
- name: Run Mocked Hardware Discovery Tests
run: |
python -m pytest rl_engine/tests/test_dispatch.py -v
PYTEST_DISABLE_PLUGIN_AUTOLOAD=1 python -m pytest tests/test_attention_correctness.py -q -rs

docs:
runs-on: ubuntu-latest
Expand Down
27 changes: 27 additions & 0 deletions docs/getting_started/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,33 @@ vLLM runtime. Core CI and mocked integration tests do not require it.
For common CUDA, ROCm, vLLM, fallback, and CI questions, see the
[FAQ](faq.md).

### ROCm Backend

Use a ROCm PyTorch build that matches the installed ROCm toolchain. Then install
FlashAttention with an AMD backend:

```bash
python -m pip install ninja packaging wheel psutil einops
git clone --recurse-submodules https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
python -m pip install --no-build-isolation --no-deps .
cd ..
```

Verify the environment from the RL-Kernel checkout:

```bash
python scripts/check_rocm_env.py
```

RL-Kernel uses external FlashAttention as the default ROCm attention path. To
fall back to PyTorch SDPA for ROCm attention dispatch, set:

```bash
export RL_KERNEL_ROCM_ATTN_BACKEND=sdpa
```

## Development Dependencies

```bash
Expand Down
49 changes: 49 additions & 0 deletions rl_engine/kernels/ops/pytorch/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors

import torch
import torch.nn.functional as F


class NativeAttentionOp:
"""PyTorch SDPA fallback for FlashAttention-layout tensors."""

def __call__(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float = 0.0,
softmax_scale: float | None = None,
causal: bool = False,
) -> torch.Tensor:
# Convert FlashAttention layout to PyTorch SDPA layout:
# (batch, seqlen, nheads, headdim) -> (batch, nheads, seqlen, headdim)
q_ref = q.transpose(1, 2)
k_ref = k.transpose(1, 2)
v_ref = v.transpose(1, 2)

q_head_num = q_ref.shape[1]
k_head_num = k_ref.shape[1]
if k_head_num != v_ref.shape[1]:
raise ValueError("k and v must have the same number of heads")

if q_head_num != k_head_num:
if q_head_num % k_head_num != 0:
raise ValueError("q heads must be divisible by k/v heads for GQA/MQA")
repeat = q_head_num // k_head_num
k_ref = k_ref.repeat_interleave(repeat, dim=1)
v_ref = v_ref.repeat_interleave(repeat, dim=1)

out = F.scaled_dot_product_attention(
Comment thread
FED4 marked this conversation as resolved.
q_ref,
k_ref,
v_ref,
dropout_p=dropout_p,
is_causal=causal,
scale=softmax_scale,
)
return out.transpose(1, 2)


__all__ = ["NativeAttentionOp"]
8 changes: 8 additions & 0 deletions rl_engine/kernels/ops/rocm/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors

from .flash_attn import RocmFlashAttentionOp

__all__ = [
"RocmFlashAttentionOp",
]
92 changes: 92 additions & 0 deletions rl_engine/kernels/ops/rocm/attention/flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors

import os

import torch

from rl_engine.utils.logger import logger

_MAX_TESTED_ROCM_TRITON_HEAD_DIM = 512


def _select_flash_attn_backend() -> str:
Comment thread
FED4 marked this conversation as resolved.
"""Select the installed FlashAttention ROCm backend."""
return "triton"


class RocmFlashAttentionOp:
"""
Standard FlashAttention wrapper for ROCm.
Demonstrates the reference structure for adding new operator families.
"""

def __init__(self):
if torch.version.hip is None:
raise RuntimeError("RocmFlashAttentionOp requires a ROCm PyTorch build.")

backend = _select_flash_attn_backend()
if backend == "triton":
# flash-attn selects the ROCm CK/Triton backend at import time.
os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE"
try:
from flash_attn import flash_attn_func

self.op = flash_attn_func
logger.info("Successfully linked to external flash_attn library (%s backend).", backend)
except (ImportError, OSError, RuntimeError) as exc:
raise RuntimeError(
"ROCm FlashAttention requires a ROCm-compatible flash-attn installation. "
"See docs/getting_started/installation.md#rocm-backend."
) from exc

def __call__(
Comment thread
FED4 marked this conversation as resolved.
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float = 0.0,
softmax_scale: float | None = None,
Comment thread
FED4 marked this conversation as resolved.
causal: bool = False,
) -> torch.Tensor:
"""
Standard attention forward pass.
Args:
q: (batch, seqlen, nheads, headdim)
k: (batch, seqlen, nheads_k, headdim)
v: (batch, seqlen, nheads_k, headdim)
"""
valid_dtypes = (torch.float16, torch.bfloat16)
if (
q.dtype not in valid_dtypes
or k.dtype not in valid_dtypes
or v.dtype not in valid_dtypes
):
raise TypeError("FlashAttention requires FP16 or BF16 for q/k/v")
# PyTorch uses the CUDA device API for both CUDA and ROCm tensors.
if not (q.is_cuda and k.is_cuda and v.is_cuda):
raise ValueError("Inputs must be on a CUDA/ROCm GPU device")
if not (q.device == k.device == v.device):
raise ValueError("q, k, and v must be on the same device")
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
raise ValueError(
"q, k, and v must be rank-4 tensors: (batch, seqlen, nheads, head_dim)"
)

head_dim = q.shape[-1]
if head_dim == 0:
raise ValueError("head_dim must be positive")
if k.shape[-1] != head_dim or v.shape[-1] != head_dim:
raise ValueError("q, k, and v must have the same head_dim")
if head_dim > _MAX_TESTED_ROCM_TRITON_HEAD_DIM:
raise NotImplementedError(
"RL-Kernel's ROCm FlashAttention wrapper currently supports "
f"head_dim <= {_MAX_TESTED_ROCM_TRITON_HEAD_DIM}; got {head_dim}"
)

if softmax_scale is None:
softmax_scale = q.shape[-1] ** -0.5

q, k, v = q.contiguous(), k.contiguous(), v.contiguous()

return self.op(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal)
34 changes: 31 additions & 3 deletions rl_engine/kernels/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2026 RL-Kernel Contributors

import importlib
import os
from enum import Enum, EnumMeta
from typing import Any, Dict, Optional, Set, Type

Expand Down Expand Up @@ -32,6 +33,7 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta):
# AMD ROCm optimized stack
ROCM_AITER = "rl_engine.kernels.ops.rocm.aiter.AiterOp"
ROCM_CK = "rl_engine.kernels.ops.rocm.composable_kernel.CKOp"
ROCM_FLASH_ATTN = "rl_engine.kernels.ops.rocm.attention.flash_attn.RocmFlashAttentionOp"

# GRPO loss (group reward normalization + clipped surrogate + KL)
TRITON_GRPO_LOSS = "rl_engine.kernels.ops.triton.loss.grpo_loss.TritonGRPOLossOp"
Expand All @@ -43,6 +45,7 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta):

# Generic fallback
TRITON_GENERIC = "rl_engine.kernels.ops.triton.generic.TritonOp"
PYTORCH_ATTN = "rl_engine.kernels.ops.pytorch.attention.NativeAttentionOp"
PYTORCH_NATIVE = "rl_engine.kernels.ops.pytorch.loss.logp.NativeLogpOp"


Expand Down Expand Up @@ -76,26 +79,51 @@ def __init__(self):
OpBackend.CUDA_FUSED_LOGP_GENERIC,
OpBackend.PYTORCH_NATIVE,
],
"attn": [OpBackend.FLASH_ATTN, OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE],
"attn": [OpBackend.FLASH_ATTN, OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_ATTN],
"grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS],
"ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL],
# Default dispatch logic for new operators
},
"rocm": {
"logp": [OpBackend.ROCM_AITER, OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE],
"attn": [OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE],
"attn": [
Comment thread
FED4 marked this conversation as resolved.
OpBackend.ROCM_FLASH_ATTN,
OpBackend.PYTORCH_ATTN,
OpBackend.TRITON_GENERIC,
],
"grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS],
"ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL],
},
"cpu": {
"logp": [OpBackend.PYTORCH_NATIVE],
"attn": [OpBackend.PYTORCH_NATIVE],
"attn": [OpBackend.PYTORCH_ATTN],
"grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS],
"ratio_kl": [OpBackend.PYTORCH_RATIO_KL],
},
}
logger.info(f"KernelRegistry initialized for {device_ctx.device_type}")
self._adjust_priority_for_hardware()
self._adjust_priority_from_env()

def _adjust_priority_from_env(self):
Comment thread
FED4 marked this conversation as resolved.
rocm_attn_backend = os.getenv("RL_KERNEL_ROCM_ATTN_BACKEND", "").strip().lower()
if rocm_attn_backend in {"flash_attn", "flash-attn", "flash_attention"}:
self._priority_map["rocm"]["attn"] = [
OpBackend.ROCM_FLASH_ATTN,
OpBackend.PYTORCH_ATTN,
OpBackend.TRITON_GENERIC,
]
elif rocm_attn_backend in {"native", "pytorch", "sdpa"}:
self._priority_map["rocm"]["attn"] = [
OpBackend.PYTORCH_ATTN,
OpBackend.ROCM_FLASH_ATTN,
OpBackend.TRITON_GENERIC,
]
elif rocm_attn_backend and rocm_attn_backend not in {"native", "pytorch", "sdpa"}:
logger.warning(
"Unknown RL_KERNEL_ROCM_ATTN_BACKEND=%s; using default ROCm attention priority.",
rocm_attn_backend,
)

def _adjust_priority_for_hardware(self):
"""Prioritize the fused TMA LogP kernel only when it is compiled into the
Expand Down
19 changes: 18 additions & 1 deletion rl_engine/tests/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from rl_engine.executors.rollout import RolloutExecutor
from rl_engine.kernels.registry import kernel_registry
from rl_engine.kernels.registry import KernelRegistry, OpBackend, kernel_registry
from rl_engine.platforms.device import device_ctx
from rl_engine.utils.logger import logger

Expand All @@ -25,6 +25,23 @@ def test_device_and_registry():
logger.info(f"Retrieved Attention Operator: {attn_op}")


def test_rocm_attention_uses_flash_attention_by_default(monkeypatch):
monkeypatch.delenv("RL_KERNEL_ROCM_ATTN_BACKEND", raising=False)

registry = KernelRegistry()

assert registry._priority_map["rocm"]["attn"][0] == OpBackend.ROCM_FLASH_ATTN


def test_rocm_attention_native_sdpa_opt_out(monkeypatch):
monkeypatch.setenv("RL_KERNEL_ROCM_ATTN_BACKEND", " sdpa ")

registry = KernelRegistry()

assert registry._priority_map["rocm"]["attn"][0] == OpBackend.PYTORCH_ATTN
assert registry._priority_map["rocm"]["attn"][1] == OpBackend.ROCM_FLASH_ATTN


def test_executor_flow():
executor = RolloutExecutor()
mock_input_ids = torch.ones((1, 16), dtype=torch.long)
Expand Down
54 changes: 54 additions & 0 deletions scripts/check_rocm_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors

from __future__ import annotations

import importlib.util
import os


def _fail(message: str) -> None:
raise SystemExit(f"ERROR: {message}")


def main() -> None:
try:
import torch
except ImportError as exc:
_fail(f"PyTorch is not installed: {exc}")

if torch.version.hip is None:
_fail(f"PyTorch is not a ROCm build: torch={torch.__version__}")

if not torch.cuda.is_available():
_fail("ROCm GPU is not available to PyTorch")

device_name = torch.cuda.get_device_name(0)
triton_available = importlib.util.find_spec("triton") is not None
flash_attn_func_available = False
# flash-attn selects the ROCm CK/Triton backend at import time.
os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE"
try:
from flash_attn import flash_attn_func
except (ImportError, OSError, RuntimeError) as exc:
flash_attn_status = f"not available ({exc})"
else:
flash_attn_func_available = flash_attn_func is not None
flash_attn_status = "available" if flash_attn_func_available else "not available"

print("backend availability:")
print(
" ROCm PyTorch runtime: "
f"available (torch={torch.__version__}, hip={torch.version.hip}, GPU={device_name})"
)
print(" PyTorch SDPA fallback: available")
print(f" Triton package: {'available' if triton_available else 'not available'}")
print(f" flash-attn AMD Triton: {flash_attn_status}")
print(" ROCm CK: not selected by this checker")

if not flash_attn_func_available:
_fail("flash_attn AMD Triton backend is required but could not be imported")


if __name__ == "__main__":
main()
Loading
Loading