Skip to content

[rocm7.2_internal_testing] Enable variable-length attention unit tests (#170969) and Bump AOTriton to 0.11.2b (#174105)#2964

Open
xinyazhang wants to merge 2 commits intorocm7.2_internal_testingfrom
xinyazhang/aotriton-0.11.2-rocm
Open

[rocm7.2_internal_testing] Enable variable-length attention unit tests (#170969) and Bump AOTriton to 0.11.2b (#174105)#2964
xinyazhang wants to merge 2 commits intorocm7.2_internal_testingfrom
xinyazhang/aotriton-0.11.2-rocm

Conversation

@xinyazhang
Copy link

Cherry-pick upstream PR

We need both to enable all tests in test/test_varlen_attention.py

chinmaydk99 and others added 2 commits February 5, 2026 10:43
## Summary
Enable variable-length attention unit tests for ROCm which were previously skipped. ROCm now supports varlen attention via Composable Kernel (CK) and AOTriton backends.

## Changes

### 1. Remove `@skipIfRocm` decorators
- `test_basic_functionality`
- `test_custom_op_compliance`
- `test_custom_op_registration`
- `test_varlen_vs_sdpa`
- `test_batch_invariance`

### 2. Fix platform-aware LSE shape in `_varlen_attn_fake` (`torch/nn/attention/varlen.py`)
ROCm and NVIDIA use different logsumexp tensor shapes:
- **NVIDIA**: `[num_heads, total_q]` (packed format)
- **ROCm**: `[batch_size, num_heads, max_seqlen_q]` (batched format)

Updated the fake tensor implementation to return the correct shape based on `torch.version.hip`.

### 3. Fix causal mask handling in test's `forward_sdpa` (`test/test_varlen_attention.py`)
The test's SDPA reference implementation passed both `attn_mask` and `is_causal=True` to `F.scaled_dot_product_attention`. Per documentation, this should error, but fused backends silently ignore `is_causal` when `attn_mask` is provided.

Fixed by explicitly combining the padding mask with a causal (lower triangular) mask and passing `is_causal=False`.

Fixes pytorch#168881
Fixes pytorch#168882
Fixes pytorch#168883
Fixes pytorch#168884
Fixes pytorch#168885

Pull Request resolved: pytorch#170969
Approved by: https://github.com/liangel-02, https://github.com/jeffdaily
Notable new features:

* AOTriton 0.11.2b adds gfx1151/1152/1153 support.
* Add precompiled AOTriton runtime for ROCM 7.2
* Match the sliding window attention behavior of `_flash_attention_forward/backward` with CUTLASS backend.

Bug fixes:

* Fixes pytorch#173204. Now all tests in `test/test_varlen_attention.py` are enabled on ROCm

Notes:

This replaces PR pytorch#173820 and pytorch#173469

Pull Request resolved: pytorch#174105
Approved by: https://github.com/jeffdaily
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Feb 5, 2026

Jenkins build for 707c9c52222b92a513f3fe3118ab3d7a20860bc2 commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants