-
Notifications
You must be signed in to change notification settings - Fork 31
[WS1][kernels] Batch-invariant attention (standard softmax) #147
Copy link
Copy link
Open
Labels
component: kernelsTasks involving the development of CUDA and Triton underlying operatorsTasks involving the development of CUDA and Triton underlying operatorsfeatureplatform: cudaSpecific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations)Specific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations)priority: highSevere congestion issues require the highest priority for resolution.Severe congestion issues require the highest priority for resolution.sprint-0615
Metadata
Metadata
Assignees
Labels
component: kernelsTasks involving the development of CUDA and Triton underlying operatorsTasks involving the development of CUDA and Triton underlying operatorsfeatureplatform: cudaSpecific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations)Specific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations)priority: highSevere congestion issues require the highest priority for resolution.Severe congestion issues require the highest priority for resolution.sprint-0615
Type
Fields
Give feedbackNo fields configured for issues without a type.
Part of WS1 — Full Batch-Invariant Forward Chain (epic: #)
Why
Flash-style attention parallelizes over the KV dimension and merges partial results, and the number of KV splits depends on sequence length and batch — so the online-softmax accumulation order changes with batch configuration, breaking invariance. Because attention mixes information across positions, even tiny drift here spreads to every later token.
Scope
Provide a batch-invariant standard-softmax attention for the forward chain.
Out of scope
Acceptance criteria
Notes
Planned PRs