[rocm7.2_internal_testing] Enable variable-length attention unit tests (#170969) and Bump AOTriton to 0.11.2b (#174105)#2964
Open
xinyazhang wants to merge 2 commits intorocm7.2_internal_testingfrom
Conversation
## Summary Enable variable-length attention unit tests for ROCm which were previously skipped. ROCm now supports varlen attention via Composable Kernel (CK) and AOTriton backends. ## Changes ### 1. Remove `@skipIfRocm` decorators - `test_basic_functionality` - `test_custom_op_compliance` - `test_custom_op_registration` - `test_varlen_vs_sdpa` - `test_batch_invariance` ### 2. Fix platform-aware LSE shape in `_varlen_attn_fake` (`torch/nn/attention/varlen.py`) ROCm and NVIDIA use different logsumexp tensor shapes: - **NVIDIA**: `[num_heads, total_q]` (packed format) - **ROCm**: `[batch_size, num_heads, max_seqlen_q]` (batched format) Updated the fake tensor implementation to return the correct shape based on `torch.version.hip`. ### 3. Fix causal mask handling in test's `forward_sdpa` (`test/test_varlen_attention.py`) The test's SDPA reference implementation passed both `attn_mask` and `is_causal=True` to `F.scaled_dot_product_attention`. Per documentation, this should error, but fused backends silently ignore `is_causal` when `attn_mask` is provided. Fixed by explicitly combining the padding mask with a causal (lower triangular) mask and passing `is_causal=False`. Fixes pytorch#168881 Fixes pytorch#168882 Fixes pytorch#168883 Fixes pytorch#168884 Fixes pytorch#168885 Pull Request resolved: pytorch#170969 Approved by: https://github.com/liangel-02, https://github.com/jeffdaily
Notable new features: * AOTriton 0.11.2b adds gfx1151/1152/1153 support. * Add precompiled AOTriton runtime for ROCM 7.2 * Match the sliding window attention behavior of `_flash_attention_forward/backward` with CUTLASS backend. Bug fixes: * Fixes pytorch#173204. Now all tests in `test/test_varlen_attention.py` are enabled on ROCm Notes: This replaces PR pytorch#173820 and pytorch#173469 Pull Request resolved: pytorch#174105 Approved by: https://github.com/jeffdaily
|
Jenkins build for 707c9c52222b92a513f3fe3118ab3d7a20860bc2 commit finished as FAILURE |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Cherry-pick upstream PR
We need both to enable all tests in test/test_varlen_attention.py