Skip to content

implement quantization debug#962

Draft
zhenchaoni wants to merge 1 commit into
mainfrom
private/zhenni/debug
Draft

implement quantization debug#962
zhenchaoni wants to merge 1 commit into
mainfrom
private/zhenni/debug

Conversation

@zhenchaoni

Copy link
Copy Markdown
Member

Design: winml debug — Per-Op Quantization Error Debugging

Status Implemented
Area debug, commands
Surface winml debug CLI · winml.modelkit.debug.debug_quantization API

1. The problem

A user quantizes a model with the WinML CLI and the quantized artifact shows an accuracy drop in winml eval. The existing signals only say that accuracy dropped:

  • winml eval — a single end-to-end metric.
  • winml eval --mode compare — per-output parity (ONNX vs PyTorch).

Neither says which tensors or weights the quantization damaged. Without per-tensor attribution the user cannot tell whether the loss is concentrated in a few ops or spread across the whole graph, and has no starting point for mixed-precision tuning.

winml debug fills that gap: it compares the float ONNX model against its quantized counterpart tensor by tensor and reports where the quantized model diverges.


2. ONNX Runtime provides the measurement

We do not invent the augment / collect / match / compare mechanism. It is ONNX Runtime's officially documented quantization-debugging recipe, shipped as onnxruntime.quantization.qdq_loss_debug.

The idea: two graphs should compute the same function — the float reference and the quantized model (the float graph with QuantizeLinear/DequantizeLinear "QDQ" pairs inserted). Feed both the same input, capture every intermediate tensor from each, and compare them.

flowchart LR
    X[same input] --> F[Float model]
    X --> Q[Quantized model]
    F -->|t1_f, t2_f, ...| C[compare each tensor]
    Q -->|t1_q, t2_q, ...| C
    C --> R[per-tensor SQNR]
Loading

ORT does this in a handful of calls:

Function Role
modify_model_output_intermediate_tensors Augment a model so every intermediate tensor becomes a graph output.
collect_activations Run an augmented model over a calibration reader → {tensor: [values]}.
create_activation_matching Align float/quant tensors and pre/post-QDQ pairs by name.
compute_activation_error Reduce each match to local and cumulative SQNR.
create_weight_matching + compute_weight_error Compare each weight float vs dequantized → weight SQNR.

The metric: SQNR (dB)

Each comparison reduces two tensors to a single quality number — the signal-to-quantization-noise ratio, in decibels:

SQNR = 20 * log10( norm(x) / norm(x - y) )   dB

where x is the float (reference) tensor and x - y is the quantization error. Higher dB means the error is small relative to the signal.

Because it is 20 * log10 of an amplitude ratio, every 20 dB is a 10x change in signal-to-error:

  • 40 dB — signal is ~100x the error (error ~1%): clean.
  • 20 dB — signal is ~10x the error (error ~10%): suspect.
  • < 20 dB — error is more than a tenth of the signal: likely culprit.

SQNR is scale-invariant, so it is comparable across tensors of very different magnitude.

Three signals

Signal Compares Question it answers
Local pre-QDQ vs post-QDQ (both in the quant model) How lossy is quantizing this tensor itself? Upstream error is present on both sides and cancels.
Cumulative float model vs quant model How much total drift has reached this tensor, including error inherited from upstream? Tracks real end-to-end accuracy loss.
Weight a weight's float value vs its dequantized value How lossy is quantizing this weight?

Reading local and cumulative together separates originators from inheritors: a tensor with high local but low cumulative quantizes fine and merely inherited upstream damage, whereas low local and low cumulative means the tensor's own rounding is lossy.


3. Integration: debug_quantization

The engine (src/winml/modelkit/debug/debugger.py) wraps the ORT calls with WinML's calibration-data plumbing and result shaping. It is import-light and Click/Rich-free, mirroring the quant/quantizer.pycommands/quantize.py split.

def debug_quantization(
    float_model_path, quant_model_path, *,
    samples=8, model_id=None, task=None,
) -> dict

What WinML adds on top of ORT:

  • Calibration inputs via DatasetCalibrationReader — task-aware and correctly preprocessed when model_id/task are given, otherwise random inputs synthesized from the model's I/O spec (self-contained, no downloads). The reader is rewind()-ed between the float and quant collection passes so both models see the identical sample sequence.
  • CPU execution, matching ORT's quantization-debugging guidance.
  • A scalar-tolerant SQNR wrapper (_sqnr_db) passed as err_func to compute_weight_error: ORT's SQNR calls len() on its inputs, which fails for a weight that dequantizes to a numpy scalar; numpy.atleast_1d keeps such weights as length-1 arrays.
  • Result shaping + summary statistics: non-finite SQNR values (from overflow or zero-difference tensors) are dropped from the summary so mean/std stay well-defined.

Return shape

{
    "activations": [
        # cumulative_sqnr_db is None when there is no float reference
        {"tensor_name", "local_sqnr_db", "cumulative_sqnr_db"},
        ...
    ],
    "weights": [
        {"weight_name", "weight_sqnr_db"},
        ...
    ],
    "model_outputs": [
        # cumulative SQNR at each graph output
        {"output_name", "cumulative_sqnr_db"},
        ...
    ],
    "summary": {
        "local":      {"count", "mean", "std", "min", "max"},
        "cumulative": {"count", "mean", "std", "min", "max"},
        "weight":     {"count", "mean", "std", "min", "max"},
    },
}

4. CLI surface and output

winml debug --float-model FLOAT.onnx --quant-model QUANT.onnx [options]
Flag Default Purpose
--float-model (required) Pre-quantization float ONNX model.
--quant-model (required) Quantized (QDQ) artifact to debug.
--samples 8 Input samples to average over.
--model-id HF id for task-aware calibration inputs.
--task Task for task-aware calibration (else random inputs).
--output Write the full per-tensor results as JSON.

The command (src/winml/modelkit/commands/debug.py) prints four tables — the cumulative SQNR at every graph output (in full), and the top-10 worst local, cumulative, and weight SQNR — each followed by a one-line summary (count, mean, std, min, max). SQNR colouring: > 40 dB green · 20–40 dB yellow · < 20 dB red.

Float model: temp\fairface\optimized.onnx
Quant model: temp\fairface\quantized.onnx
Samples: 8

      Model outputs
┏━━━┳━━━━━━━━━━━┳━━━━━━━━┓
┃ # ┃ SQNR (dB) ┃ Output ┃
┡━━━╇━━━━━━━━━━━╇━━━━━━━━┩
│ 1 │     22.71 │ logits │
└───┴───────────┴────────┘
Local      = error from quantizing this tensor alone, excluding upstream.
Cumulative = error at this tensor, including error inherited from upstream.

                               Top 10 worst local SQNR
┏━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃  # ┃ SQNR (dB) ┃ Tensor                                                            ┃
┡━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  1 │     32.20 │ gemm_output_reshape_arg_token_238                                 │
│  2 │     34.65 │ gemm_output_reshape_arg_token_202                                 │
│  3 │     36.88 │ gemm_output_reshape_arg_token_388                                 │
│  4 │     39.98 │ gemm_output_reshape_arg_token_412                                 │
│  5 │     42.03 │ gemm_output_reshape_arg_token_304                                 │
│  6 │     47.60 │ gemm_output_reshape_arg_token_292                                 │
│  7 │     52.25 │ /vit/encoder/layer.0/layernorm_before/LayerNormalization_output_0 │
│  8 │     53.10 │ gemm_output_reshape_arg_token_340                                 │
│  9 │     53.82 │ gemm_output_reshape_arg_token_64                                  │
│ 10 │     58.91 │ gemm_output_reshape_arg_token_172                                 │
└────┴───────────┴───────────────────────────────────────────────────────────────────┘
                               Top 10 worst cumulative SQNR
┏━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃  # ┃ SQNR (dB) ┃ Tensor                                                                ┃
┡━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  1 │      8.29 │ gemm_output_reshape_arg_token_388                                     │
│  2 │      8.29 │ /vit/encoder/layer.10/output/dense/Add_output_0                       │
│  3 │      9.56 │ gemm_input_reshape_arg_token_385                                      │
│  4 │      9.56 │ /vit/encoder/layer.10/intermediate/intermediate_act_fn/Mul_1_output_0 │
│  5 │     10.66 │ /vit/encoder/layer.11/attention/attention/Softmax_output_0            │
│  6 │     11.23 │ /vit/encoder/layer.11/attention/attention/Transpose_output_0          │
│  7 │     11.23 │ /vit/encoder/layer.11/attention/attention/Reshape_1_output_0          │
│  8 │     11.23 │ gemm_output_reshape_arg_token_394                                     │
│  9 │     11.34 │ /vit/encoder/layer.11/intermediate/intermediate_act_fn/Mul_1_output_0 │
│ 10 │     11.34 │ gemm_input_reshape_arg_token_421                                      │
└────┴───────────┴───────────────────────────────────────────────────────────────────────┘
       Top 10 worst weight SQNR
┏━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃  # ┃ SQNR (dB) ┃ Weight            ┃
┡━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│  1 │    -34.67 │ onnx::MatMul_1310 │
│  2 │    -34.58 │ onnx::MatMul_1426 │
│  3 │    -34.19 │ onnx::MatMul_1281 │
│  4 │    -33.88 │ onnx::MatMul_1279 │
│  5 │    -33.34 │ onnx::MatMul_1455 │
│  6 │    -33.33 │ onnx::MatMul_1397 │
│  7 │    -32.09 │ onnx::MatMul_1454 │
│  8 │    -32.03 │ onnx::MatMul_1339 │
│  9 │    -29.76 │ onnx::MatMul_1368 │
│ 10 │    -29.64 │ onnx::MatMul_1484 │
└────┴───────────┴───────────────────┘

Comment thread tests/unit/test_debug.py


def _build_tiny_model(path: Path) -> None:
import onnx
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants