Skip to content

[WS1] KV-cache path consistency (prefill & decode) #152

@Flink-ddd

Description

@Flink-ddd

Part of WS1 — Full Batch-Invariant Forward Chain (epic: #)

Why

Rollout generates token-by-token through the decode path; training re-runs the same sequence through prefill. If the two paths reduce in different orders, the same token gets different logprobs in rollout vs training — a classic and high-impact rollout-vs-training drift source. WS1 invariance work tends to focus on chunked-prefill; the decode stage must be covered explicitly.

Scope

Ensure the prefill and decode paths produce the same reductions for the same effective context.

  • Verify that attention over a cached context (decode: one query against N cached KV) reduces in the same fixed order as the equivalent prefill over the full sequence.
  • Cover the decode-stage path explicitly in tests, not only chunked-prefill.
  • Confirm cache writes/reads (layout, dtype of stored KV) do not introduce a precision difference between the path that wrote the cache and the path that consumes it.
  • Validate "generate then re-score" equivalence against the [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 harness.

Out of scope

  • The attention kernel's internal accumulation design (covered by the attention issue; this issue checks prefill/decode parity on top of it).
  • Paged-attention / cache-eviction policy beyond reduction-order correctness.
  • Multi-GPU KV sharding (WS2).
  • FP8 KV cache.

Acceptance criteria

Notes

Planned PRs

  • Full-prefill reference vs chunked-prefill consistency test
  • Decode-stage path test (one query vs N cached KV) with reduction order matching prefill
  • Stored-KV layout/dtype: show no writer-vs-reader precision drift
  • "Generate then re-score" equivalence vs the [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 harness
  • CI-friendly decode-path smoke test (short / long / varlen / padded)

Metadata

Metadata

Assignees

Labels

component: testingAdd test cases and benchmark-related tasksfeatureplatform: cudaSpecific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations)priority: highSevere congestion issues require the highest priority for resolution.sprint-0615

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions