Skip to content

[WS1] dtype coverage in the numerical contract #154

@Flink-ddd

Description

@Flink-ddd

Part of WS1 — Full Batch-Invariant Forward Chain (epic: #)

Why

"Aligned" is only meaningful relative to a pinned set of dtypes. RL training runs in BF16, so BF16 invariance is mandatory; FP32 is the reference for tolerance. Without pinning this, different ops could be validated under different dtypes and the chain-level guarantee would be meaningless. This issue locks the dtype axis of the #108 contract.

Scope

Pin the dtype set every WS1 op validates against, and block finalization of the #108 contract until this axis is resolved.

  • Declare the tested dtype set: BF16 mandatory (the RL training dtype), FP32 as the reference path for computing tolerances; FP16 optional if supported; FP8 explicitly out of scope this month.
  • Specify the accumulation policy (FP32 accumulation for BF16 inputs) as part of the contract so ops don't each choose differently; TF32 explicitly enabled or disabled.
  • Ensure the [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 threshold table has per-dtype rows, and every op runs its batch-config sweep under the same pinned set.

Initial per-op recommendations (to be ratified in the contract):

  • RMSNorm: BF16 input, FP32 accumulation, BF16/FP32 output.
  • Logprob: BF16/FP32 logits, FP32 accumulation, FP32 output.
  • Matmul/GEMM: BF16 input, FP32 accumulation, TF32 pinned.
  • Attention: BF16 Q/K/V, FP32 softmax accumulation, FP32 LSE where applicable.

Out of scope

Acceptance criteria

  • The numerical contract names the exact dtype set (BF16 mandatory, FP32 reference) and states FP8 is out of scope this month.
  • The [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 threshold table has explicit per-dtype entries used as the single source of truth.
  • Every WS1 op validates against the same pinned dtype set (no op tested under a different dtype).
  • The FP32-accumulation-for-BF16 rule (and the TF32 decision) is written down once and referenced by op issues; unsupported dtypes skip or fail clearly.

Notes

Planned PRs

Metadata

Metadata

Labels

featureplatform: cudaSpecific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations)priority: highSevere congestion issues require the highest priority for resolution.sprint-0615type: designIssues requiring in-depth discussion of architecture design

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions