Skip to content

[WS1][kernels] Batch-invariant matmul / GEMM #146

@Flink-ddd

Description

@Flink-ddd

Part of WS1 — Full Batch-Invariant Forward Chain (epic: #)

Why

This is the highest-technical-risk op in WS1. cuBLAS selects kernels by heuristic based on problem shape, and split-K decompositions change the reduction order of the K dimension — both break batch-invariance the moment batch size or sequence length shifts the chosen kernel. Matmul is also the most frequent op in the network (QKV, MLP, LM head), so drift here dominates everything downstream.

Scope

Provide a deterministic, batch-invariant GEMM the forward chain can route through.

  • Either implement a deterministic GEMM (fixed tiling, no split-K, fixed K-accumulation order) or integrate a DeepGEMM-style deterministic matmul.
  • Guarantee the K-dimension reduction order is fixed and independent of M (the batch / token dimension), so a row's output does not change when other rows are added or removed.
  • FP32 accumulation for BF16 inputs; TF32 behavior must be explicitly pinned.
  • Initial target shapes (from the standard-Transformer model): QKV projection, MLP up/gate/down, and LM-head projection (or a representative reduced-vocab CI config).
  • Validate against the [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 harness across the standard batch-config sweep.

Possible implementation routes:

  • A deterministic baseline GEMM for the selected WS1 shapes.
  • CUTLASS with fixed tile shape, fixed epilogue, split-K disabled.
  • A DeepGEMM-style deterministic matmul integration.

Out of scope

  • Squeezing peak TFLOPs / full perf tuning — correctness and invariance first; a perf pass can follow.
  • FP8 GEMM (out of scope this month).
  • Full cuBLAS replacement / all possible matrix shapes.
  • Distributed / tensor-parallel GEMM (WS2).

Acceptance criteria

Notes

Planned PRs

  • Design note: deterministic GEMM approach (custom fixed-tile vs DeepGEMM integration)
  • Implement / integrate deterministic GEMM (no split-K, fixed K-accumulation)
  • Tests for QKV / MLP / LM-head projection shapes
  • Wire one real projection (e.g. LM head) through the deterministic path
  • Benchmark vs cuBLAS baseline; document overhead + supported shapes

Metadata

Metadata

Assignees

Labels

component: kernelsTasks involving the development of CUDA and Triton underlying operatorsfeatureplatform: cudaSpecific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations)priority: highSevere congestion issues require the highest priority for resolution.sprint-0615

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions