Skip to content

[CPU] Enable pre-packed weights sharing for MatMulNBits#29163

Merged
derdeljan-msft merged 13 commits into
mainfrom
derdeljan/matmulnbits_cpu_weight_share
Jun 26, 2026
Merged

[CPU] Enable pre-packed weights sharing for MatMulNBits#29163
derdeljan-msft merged 13 commits into
mainfrom
derdeljan/matmulnbits_cpu_weight_share

Conversation

@derdeljan-msft

@derdeljan-msft derdeljan-msft commented Jun 19, 2026

Copy link
Copy Markdown
Contributor

Description

Enable pre-packed weights sharing for MatMulNBits operator on CPU. When performing DQ + MatMul -> MatMulNBits fusion, the original weight names are lost, so the standard AddInitializer approach does not work. To overcome this, introduced the option for graph optimization pass to tag weights which are sharable across sessions (hashing the content and matching it across the sessions).

Motivation and Context

For executing ASG SLMs on CPU - there are two sessions, one for prefill stage and for decode stage (due to different shapes and session options). With this change, storing the weights in memory twice is avoided. The first sessions pre-packs the weights which the second session can reuse.

Confirmed memory reduction through the WPA memory traces.

@derdeljan-msft derdeljan-msft marked this pull request as ready for review June 19, 2026 19:04
@derdeljan-msft derdeljan-msft self-assigned this Jun 19, 2026

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review summary

Thanks for enabling cross-session pre-packed weight sharing for MatMulNBits — the motivation (prefill + decode sessions sharing one in-memory copy of the packed weights) is clear and the share_all_prepacked_cpu_initializers opt-in plus content-addressing is a reasonable approach. The test matrix (4-bit/8-bit, fp32/fp16, symmetric/asymmetric, +/- bias, multiple block sizes and accuracy levels, plus an AddInitializer path and a negative control) is thorough and the new shared test helper keeps the two test files DRY.

Main concern — the cache key is computed from a partially-packed buffer:

The per-B cache key is produced by GenerateKeyForPrepackedWeightsMap() immediately after PrePack(input_idx == B) returns, but the scales and zero_points PrePack calls subsequently mutate that same buffer in place (the MlasQNBitGemmPackQuantBData(..., packed_b_.get(), ...) calls in the scales/zero_points branches). At hash time the zero-point region is still the zeroed placeholder and blksum was computed with no zero point (zp passed as nullptr during B packing). So the hash does not reflect the final packed bytes — specifically it does not capture zero_points.

Consequence under share_all: two CPU MatMulNBits initializers with byte-identical quantized B and identical scales but different zero_points would collide on the same key; the second one adopts the first's already-finalized buffer and, because packed_b_is_shared_ becomes true, skips packing its own zero points — silently producing a wrong result. For the intended prefill/decode same-model case the weights are identical so this never triggers, but share_all makes every initializer eligible and widens the collision surface beyond the old AddInitializer-only path. Could you confirm the B-only hash uniquely determines the fully-finalized buffer (i.e. that no post-hash packing step can differ between two initializers that hash equal), or compute the key after all packing for the node completes / fold zero_points into the hashed content? Inline detail below.

Minor notes:

  • The std::memset(packed_b_.get(), 0, packed_b_size_) runs for every CompInt8 PrePack even when no OrtPrepackedWeightsContainer is configured (sharing disabled). The zero-fill is only needed for hash stability, so it's wasted (one-time, session-init) work on the non-sharing path; consider gating it on prepacked_weights != nullptr.
  • On the ARM64/HQNBIT paths a nullptr placeholder is pushed for scales, and PrePackedWeights::GetHash() skips null buffers — so the scales container key is op_type + hash-of-nothing, identical for every MatMulNBits node. It's benign (no real buffer is shared), but it does increment used_shared_pre_packed_weights_counter_ for unrelated nodes, which could be mistaken for real scale sharing.

Verdict: COMMENT — primarily to confirm the hashing assumption above before merge.

Comment thread onnxruntime/core/framework/session_state.cc Outdated
Comment thread onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Outdated
MatMulNBits' CompInt8 (accuracy_level 4) packing is staged and stateful:
PrePack(B) packs the quantized weights and accumulates a partial block sum
into the buffer, then PrePack(scales)/PrePack(zero_points) consume that
state to finalize it. MLAS requires each step to run exactly once per
buffer (see SQ8BitGemmPackQuantBDataAndBlkSum).

Cross-session pre-packed weight sharing broke this contract: the second
session adopts the buffer the first session already finalized and then
re-runs PrePack(scales)/PrePack(zero_points) on it, finalizing a second
time over already-folded data. That corrupts the block-sum correction and
produces wrong results. It reproduces on Linux ARM64, where
ArmNeonIsQuantActivationsUnsigned selects the stateful correction path, and
is latent in the AVX2/AVX512 packers that use the same design.

Track the buffer each instance packs, and in UseSharedPrePackedBuffers
detect when the buffer handed back came from another session (it differs
from the one this instance packed) and skip the staged scale/zero-point
re-pack. The first session and the non-sharing path adopt their own buffer
and are unchanged; only the redundant re-pack in later sessions is removed.
All changes are in PrePack/UseSharedPrePackedBuffers, so inference and the
single-session path are unaffected.
@derdeljan-msft derdeljan-msft force-pushed the derdeljan/matmulnbits_cpu_weight_share branch from 6da9ed9 to afb93d0 Compare June 22, 2026 20:57
@derdeljan-msft derdeljan-msft requested a review from tianleiwu June 22, 2026 20:59

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-reviewed at afb93d0; my two earlier threads are resolved. The scales/zero-point folding into packed_b_ during PrePack(B) plus the packed_b_finalized_ guard is consistent across the x64 / ARM64-KleidiAI / ARM64-non-KleidiAI-4bit / 8-bit / fp16-fallback branches, and the buffer ownership flow matches the established CPU-kernel pattern (pushed into prepacked_weights, restored as a non-owning reference via UseSharedPrePackedBuffers) — no leak or double-free. The DifferentZeroPoints regression test guards the central correctness risk and the helper-based tests give good coverage.

No blocking issues. Two low-priority notes inline (one dead-condition nit, one latent determinism assumption for non-CompInt8 shared buffers).

Comment thread onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Comment thread onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Outdated

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the cross-session pre-packed weight sharing for MatMulNBits on the CPU EP. The core correctness work — folding scales + zero points into packed_b_ during PrePack(B) so the content hash reflects them, plus the packed_b_finalized_ guard preventing later staged packs from writing into a buffer that may now be shared from another session — is well reasoned across the x64 / ARM64-KleidiAI / non-KleidiAI-4bit / fp16-fallback / 8-bit branches, and the buffer ownership flow matches the established CPU-kernel pattern.

The DifferentZeroPointsDoNotCollide regression test directly guards the central risk (identical B+scales but different zero points must not collide), and the helper-based tests cover symmetric/asymmetric, +/- bias, several block sizes, accuracy levels, fp32/fp16, and a no-opt-in negative control. Nice coverage.

No blocking issues. A few low-priority clarity / robustness notes left inline:

  • The prepacked_weights != nullptr guard is effectively always true, so the deterministic memset and buffer push now run on every prepack at load — worth confirming the cost and rewording the comment.
  • The memset is CompInt8-only while the buffer is pushed for every compute type; non-CompInt8 dedup then relies on MLAS leaving no uninitialized padding (safe, but undocumented).
  • The dead || HQNBIT_CompInt8 term in the fp16-fallback memset.
  • share_all_prepacked_cpu_initializers broadens sharing to every CPU prepacking kernel, not just MatMulNBits — worth documenting the blast radius and the content-complete invariant it assumes.

Comment thread onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Outdated
Comment thread onnxruntime/core/framework/session_state.cc Outdated
@derdeljan-msft derdeljan-msft requested a review from tianleiwu June 23, 2026 20:01

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requesting changes for one correctness regression in the MatMulNBits prepack path; details inline.

Design note: I agree the root cause here is that some fusions synthesize new MatMulNBits initializer names after SessionOptions::AddInitializer has already been populated. The public AddInitializer path is exact-name/user-OrtValue based, so the fusion cannot simply call it as-is, but a cleaner internal design would be for the transformer to tag the generated B/scales/ZP group with a stable sharing identity and have SessionState enroll only those tagged generated weights. That would avoid exposing a broad session option whose correctness depends on the kernel's packed bytes fully capturing compute semantics.

Comment thread onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
@tianleiwu

Copy link
Copy Markdown
Contributor

a cleaner internal design would be for the transformer to tag the generated B/scales/ZP group with a stable sharing identity and have SessionState enroll only those tagged generated weights. That would avoid exposing a broad session option whose correctness depends on the kernel's packed bytes fully capturing compute semantics.

@derdeljan-msft, could you take a look at this design option? That could be a general solution.
Current walkaround (adding option for a single op) is not a long-term solution for similar issue.

DQ->MatMulNBits fusions synthesize new B/scales/zero-point initializers whose
names are generated per-graph and are therefore not stable across sessions, so
the prior content-hash-of-packed-bytes sharing could not safely dedup them.
Replace the broad `session.share_matmulnbits_prepacked_weights` option with a
tagging mechanism: each fusion computes a stable, content-derived identity over
the generated weight/scale/zero-point bytes plus the quant params (N, K,
block_size, bits, accuracy_level) and tags the generated B initializer with it.
SessionState enrolls tagged initializers into the shared pre-packed-weights
container keyed by that identity, so cross-session sharing needs no session
option and cannot false-share across models differing in any semantic input
(e.g. zero points, or accuracy_level which changes the packed layout).

Mechanism:
- graph.h: add Graph::Set/GetSharedPrepackInitializerId and the backing
  name->identity side-map.
- matmul_nbits_sharing_identity.h (new): shared ComputeMatMulNBitsSharingId
  helper (MurmurHash3 over the generated tensors + quant params).
- dq_matmulnbits_fusion.cc: tag the generated B in both fusion patterns via the
  shared helper.
- qdq_actions.cc: tag the generated B in the default DQMatMulToMatMulNBitsAction
  (QDQ selector/action) path, which runs without the fusion flag — closing the
  gap where typical QDQ models never shared.
- session_state.cc: enroll tagged initializers by identity; drop the option gate.
- onnxruntime_session_options_config_keys.h: remove the now-unused
  kOrtSessionOptionsShareMatMulNBitsPrepackedWeights config key.

Tests:
- dq_matmulnbits_fusion_test.cc: fusion-path tag stability/collision-safety +
  end-to-end cross-session sharing tests.
- qdq_matmulnbits_transformer_test.cc: default-path equivalents, plus a test that
  a different accuracy_level yields a different identity (no cross-compute-type
  sharing).
- Remove the option-based sharing tests from matmul_4bits/8bits_test.cc and the
  shared test util; add an opt-in (disabled) pre-pack memset benchmark.
@derdeljan-msft

Copy link
Copy Markdown
Contributor Author

a cleaner internal design would be for the transformer to tag the generated B/scales/ZP group with a stable sharing identity and have SessionState enroll only those tagged generated weights. That would avoid exposing a broad session option whose correctness depends on the kernel's packed bytes fully capturing compute semantics.

@derdeljan-msft, could you take a look at this design option? That could be a general solution. Current walkaround (adding option for a single op) is not a long-term solution for similar issue.

Implemented suggested approach - graph transformation passes can now tag weights (newly added initializers) which can be shared across sessions.

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-reviewed at bc2a0e8. The PR has moved to the tagged-identity design — the transformer tags the generated MatMulNBits weight group with a stable content id and SessionState enrolls only those tagged generated weights into the cross-session container. This is exactly the cleaner internal approach I was hoping for and resolves my earlier concern about the broad share_all session option whose correctness depended on the packed bytes fully capturing compute semantics. The scales/zero-point folding plus the packed_b_finalized_ guard from the prior revision is intact, and the new tag tests (stable identity, cross-session share, different-accuracy-no-share) plus the shared-weight-not-fused negative controls are solid.

Two non-blocking findings on the new identity path, inline:

  1. The 128-bit content hash effectively narrows to ~32-bit collision resistance because the four MurmurHash calls are chained through only hash[0]. Low probability for realistic model sizes, but the failure mode is a silent wrong result and it is cheap to harden.
  2. The name-keyed tag is safe only under the fusion's single-consumer guarantee for generated weights; worth recording that invariant explicitly.

Verdict: COMMENT.

Comment thread onnxruntime/core/optimizer/matmul_nbits_sharing_identity.h Outdated
Comment thread include/onnxruntime/core/graph/graph.h
tianleiwu

This comment was marked as duplicate.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables cross-session sharing of CPU MatMulNBits pre-packed weights when the MatMulNBits weights are synthesized by graph optimizations (DQ+MatMul → MatMulNBits and QDQ selector/action conversion), where the generated initializer names are not stable across sessions. It does so by tagging those generated initializers with a stable, content-derived identity and enrolling them into the shared prepacked-weights container during session initialization.

Changes:

  • Add a stable MatMulNBits sharing identity (hash over weight/scale/(zp) + shape/quant params + accuracy_level) and tag fusion-generated B initializers in both the DQMatMulNBitsFusion pass and the QDQ selector/action conversion.
  • Extend Graph + SessionState to carry and consume per-initializer sharing IDs for enrolling synthesized initializers into the shared prepacked-weights container.
  • Fix/adjust MatMulNBits CPU prepack recording/dedup (incl. LUT path double-append) and add/extend tests for tag stability and cross-session sharing behavior.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc Tags the generated MatMulNBits B initializer with a computed stable sharing ID in the QDQ action path.
onnxruntime/core/optimizer/matmul_nbits_sharing_identity.h Adds helper to compute a stable, content-derived sharing ID for MatMulNBits fusion-generated weights.
onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc Tags fusion-generated MatMulNBits B initializer with a stable sharing ID in DQMatMulNBitsFusion paths.
include/onnxruntime/core/graph/graph.h Adds Graph storage + accessors for generated-initializer sharing IDs.
onnxruntime/core/framework/session_state.cc Enrolls either AddInitializer-registered or tagged initializers into the shared prepacked-weights container (CPU EP).
onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Fixes LUT prepack double-buffer record; makes packed buffer hash stable (zero-fill padding); folds scale/zp into packed B earlier and prevents re-folding when sharing.
onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc Adds tests for default-path QDQ action conversion: tag stability, cross-session reuse, and accuracy-level separation.
onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc Adds tests for fusion path: tag stability and cross-session reuse through shared container.
onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.h New shared test utility API for validating cross-session prepack sharing via OpTester.
onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.cc Implements 2-session sharing checks and negative control for OpTester-based MatMulNBits tests.
onnxruntime/test/contrib_ops/matmul_8bits_test.cc Adds OpTester-based cross-session sharing tests (legacy AddInitializer + negative control) via the new helper.
onnxruntime/test/contrib_ops/matmul_4bits_test.cc Adds OpTester-based cross-session sharing tests (legacy AddInitializer + negative control) via the new helper.
onnxruntime/test/contrib_ops/matmul_2bits_test.cc Adds regression test ensuring LUT GEMM prepack + save-external-prepacked-initializers path doesn’t crash.

Comment thread onnxruntime/core/framework/session_state.cc Outdated
Comment thread onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc
@tianleiwu tianleiwu dismissed their stale review June 25, 2026 06:16

Superseded by re-review at bc2a0e8: the PR now uses the tagged-identity design and my earlier blocking concern is resolved. Remaining notes are non-blocking (left as COMMENT).

Addresses review feedback on the MatMulNBits cross-session
prepacked-weight
sharing feature. Four related changes:

session_state: key the shared-prepack container by the packed-bytes hash
(GenerateKeyForPrepackedWeightsMap) for tagged initializers too, exactly
as
the AddInitializer path already does; the fusion tag is now only the
enrollment signal. The tag is derived from the *unpacked* initializer
content,
so using it as the key let two sessions that differ in any option
affecting
the packed layout (mlas.use_lut_gemm, a CPU backend-selector difference,
or
the compute type) reuse an incompatible packed buffer -- wrong
results/crash.
Keying by the packed bytes only ever shares byte-identical buffers.

graph: enforce the single-consumer invariant in
SetSharedPrepackInitializerId.
A MatMulNBits packed buffer folds in the consuming node's
scales/zero-points/
attributes, so a sharing id is valid only for a B initializer with
exactly one
consumer (guaranteed today by the DQ->MatMulNBits producers).
ORT_ENFORCE that
a name is never re-tagged with a conflicting id so the guarantee
survives later
refactors.

matmul_nbits_sharing_identity: fold each segment's full 128-bit
MurmurHash3
output into a 64-bit accumulator instead of forwarding only hash[0] (a
32-bit
seed bottleneck). Every input bit now reaches the id, raising collision
resistance from ~2^32 to ~2^64; a collision would silently adopt another
weight
group's already-packed buffer.

tests: make the negative "must not share" cases differ in the weight,
which
changes the packed bytes on every compute type, instead of the zero
points or
accuracy level (those only change the bytes under CompInt8 -- on
CompFp32 they
are applied at compute time and left out of the packed B, so such models
correctly share a byte-identical buffer). Rename
DefaultPath_DifferentAccuracyLevelDoesNotShare to
...GetsDistinctIdentity and
assert the identity is distinct rather than a platform-dependent sharing
count.
Update comments to reflect packed-bytes keying.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc Outdated
derdeljan-msft and others added 6 commits June 25, 2026 10:07
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…2 for ASan CI

Reverts the debug DISABLED_ and test-count reductions (commits 926a6de..a4b9a46) back to the pre-hack state at 1bb7bd1, so CI runs the full test set with reduced ctest parallelism only.
@derdeljan-msft derdeljan-msft requested a review from tianleiwu June 25, 2026 21:21
@derdeljan-msft derdeljan-msft enabled auto-merge (squash) June 26, 2026 21:18

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review summary — re-review at a074c775

Re-reviewed cross-session pre-packed weight sharing for MatMulNBits on the CPU EP. The commits since my last pass (2f1e6ed8 "Key by packed bytes; harden id and invariant", plus the test-infra commits) address every concern from my earlier rounds, and all of my previously opened threads are resolved.

What I verified is now correct:

  • Container keying is by packed bytes, not the tag. SessionState::PrepackConstantInitializedTensors now keys the shared container with GenerateKeyForPrepackedWeightsMap(op_type, weights_to_be_filled_in) for both the AddInitializer and the tagged-enrollment paths. The tag is used only as the enrollment signal (its presence routes the fusion-generated initializer into the container branch). This closes the earlier correctness hole where two sessions sharing a container but differing in an option that changes the packed layout (e.g. mlas.use_lut_gemm, a CPU backend-selector difference) could have collided on the tag and reused an incompatible buffer.
  • Sharing-id hardening. ComputeMatMulNBitsSharingId now folds the full 128-bit MurmurHash output of every segment through a 64-bit accumulator (with fmix64 avalanche) instead of forwarding only hash[0], so collision resistance tracks the 64-bit id width. Since the id is now only an enrollment signal (never the key), even a hypothetical collision is non-fatal.
  • Single-consumer invariant enforced. Graph::SetSharedPrepackInitializerId rejects re-tagging a name with a conflicting id via ORT_ENFORCE, with a comment documenting why the id is meaningful only for a single-consumer B initializer.
  • LUT double-append fix + regression test. The single shared append of packed_b_ (with the Float32_2Bits_PrepackSaveDoesNotCrash save-path test) correctly prevents the moved-from/null buffer that the prepacked-save path would have dereferenced.
  • Finalization gating. packed_b_finalized_ correctly suppresses the staged scales/zero-point packing on the later PrePack calls (and in UseSharedPrePackedBuffers), so an adopted shared buffer is never re-folded.

Coverage looks good: positive/negative AddInitializer sharing tests for 4-bit and 8-bit, the default-path and fusion-path tag/share end-to-end tests, the distinct-identity-per-accuracy-level test, and the LUT save-path regression test.

LGTM. Approving.

@derdeljan-msft derdeljan-msft merged commit a5d2663 into main Jun 26, 2026
97 of 115 checks passed
@derdeljan-msft derdeljan-msft deleted the derdeljan/matmulnbits_cpu_weight_share branch June 26, 2026 22:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants