Skip to content

[WS1][kernels] Batch-invariant attention (standard softmax) #147

@Flink-ddd

Description

@Flink-ddd

Part of WS1 — Full Batch-Invariant Forward Chain (epic: #)

Why

Flash-style attention parallelizes over the KV dimension and merges partial results, and the number of KV splits depends on sequence length and batch — so the online-softmax accumulation order changes with batch configuration, breaking invariance. Because attention mixes information across positions, even tiny drift here spreads to every later token.

Scope

Provide a batch-invariant standard-softmax attention for the forward chain.

  • Start from a deterministic masked-softmax + logsumexp (LSE) kernel with a fixed max / sum-exp order, then extend to the full attention output (P x V).
  • Avoid split-KV / variable KV-partitioning; use a single-SM whole-sequence reduction, or a dual-kernel design with an identical, fixed accumulation order regardless of split count.
  • Causal mask + padding mask support; variable-length sequence coverage.
  • Keep prefill and decode on a consistent reduction path (coordinate with the KV-cache consistency issue); expose the hooks that issue needs.
  • Accumulate softmax statistics in FP32.
  • Validate against the [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 harness across batch=1/N, chunked-prefill on/off, padding layouts.

Out of scope

  • Non-standard attention variants (linear / sliding-window / MoE routing) — standard softmax only for this gate.
  • Multi-GPU / sequence-parallel attention (WS2).
  • FP8 attention.
  • Full FlashAttention replacement / maximum throughput in the first deterministic baseline.

Acceptance criteria

  • The same query / sequence produces bitwise-identical (or within [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 tolerance) attention output regardless of batch size, batch position, and chunked-prefill on/off.
  • No split-KV path whose split count varies with batch / seq config; the accumulation order is pinned and documented.
  • Softmax + LSE outputs are invariant across the sweep; tests cover both fixed-length and variable-length sequences.
  • Decode-stage output matches the prefill path for the same effective context (handoff with the KV-cache issue).
  • The attention backward path passes the shared gradient-invariance check from the WS1 backward-consistency issue.

Notes

Planned PRs

  • Deterministic masked-softmax + LSE kernel (fixed max / sum-exp order)
  • Causal + padding mask tests; variable-length coverage
  • Deterministic P x V aggregation (single-SM or fixed-order dual-kernel)
  • Prefill/decode handoff hooks (coordinate with the KV-cache issue)
  • Wire through the [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 harness sweep; benchmark + tradeoff note

Metadata

Metadata

Labels

component: kernelsTasks involving the development of CUDA and Triton underlying operatorsfeatureplatform: cudaSpecific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations)priority: highSevere congestion issues require the highest priority for resolution.sprint-0615

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions