implement quantization debug#962
Draft
zhenchaoni wants to merge 1 commit into
Draft
Conversation
|
|
||
|
|
||
| def _build_tiny_model(path: Path) -> None: | ||
| import onnx |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Design:
winml debug— Per-Op Quantization Error Debuggingdebug,commandswinml debugCLI ·winml.modelkit.debug.debug_quantizationAPI1. The problem
A user quantizes a model with the WinML CLI and the quantized artifact shows an accuracy drop in
winml eval. The existing signals only say that accuracy dropped:winml eval— a single end-to-end metric.winml eval --mode compare— per-output parity (ONNX vs PyTorch).Neither says which tensors or weights the quantization damaged. Without per-tensor attribution the user cannot tell whether the loss is concentrated in a few ops or spread across the whole graph, and has no starting point for mixed-precision tuning.
winml debugfills that gap: it compares the float ONNX model against its quantized counterpart tensor by tensor and reports where the quantized model diverges.2. ONNX Runtime provides the measurement
We do not invent the augment / collect / match / compare mechanism. It is ONNX Runtime's officially documented quantization-debugging recipe, shipped as
onnxruntime.quantization.qdq_loss_debug.The idea: two graphs should compute the same function — the float reference and the quantized model (the float graph with QuantizeLinear/DequantizeLinear "QDQ" pairs inserted). Feed both the same input, capture every intermediate tensor from each, and compare them.
flowchart LR X[same input] --> F[Float model] X --> Q[Quantized model] F -->|t1_f, t2_f, ...| C[compare each tensor] Q -->|t1_q, t2_q, ...| C C --> R[per-tensor SQNR]ORT does this in a handful of calls:
modify_model_output_intermediate_tensorscollect_activations{tensor: [values]}.create_activation_matchingcompute_activation_errorcreate_weight_matching+compute_weight_errorThe metric: SQNR (dB)
Each comparison reduces two tensors to a single quality number — the signal-to-quantization-noise ratio, in decibels:
where
xis the float (reference) tensor andx - yis the quantization error. Higher dB means the error is small relative to the signal.Because it is
20 * log10of an amplitude ratio, every 20 dB is a 10x change in signal-to-error:SQNR is scale-invariant, so it is comparable across tensors of very different magnitude.
Three signals
Reading local and cumulative together separates originators from inheritors: a tensor with high local but low cumulative quantizes fine and merely inherited upstream damage, whereas low local and low cumulative means the tensor's own rounding is lossy.
3. Integration:
debug_quantizationThe engine (
src/winml/modelkit/debug/debugger.py) wraps the ORT calls with WinML's calibration-data plumbing and result shaping. It is import-light and Click/Rich-free, mirroring thequant/quantizer.py↔commands/quantize.pysplit.What WinML adds on top of ORT:
DatasetCalibrationReader— task-aware and correctly preprocessed whenmodel_id/taskare given, otherwise random inputs synthesized from the model's I/O spec (self-contained, no downloads). The reader isrewind()-ed between the float and quant collection passes so both models see the identical sample sequence._sqnr_db) passed aserr_functocompute_weight_error: ORT's SQNR callslen()on its inputs, which fails for a weight that dequantizes to a numpy scalar;numpy.atleast_1dkeeps such weights as length-1 arrays.mean/stdstay well-defined.Return shape
{ "activations": [ # cumulative_sqnr_db is None when there is no float reference {"tensor_name", "local_sqnr_db", "cumulative_sqnr_db"}, ... ], "weights": [ {"weight_name", "weight_sqnr_db"}, ... ], "model_outputs": [ # cumulative SQNR at each graph output {"output_name", "cumulative_sqnr_db"}, ... ], "summary": { "local": {"count", "mean", "std", "min", "max"}, "cumulative": {"count", "mean", "std", "min", "max"}, "weight": {"count", "mean", "std", "min", "max"}, }, }4. CLI surface and output
--float-model--quant-model--samples8--model-id--task--outputThe command (
src/winml/modelkit/commands/debug.py) prints four tables — the cumulative SQNR at every graph output (in full), and the top-10 worst local, cumulative, and weight SQNR — each followed by a one-line summary(count, mean, std, min, max). SQNR colouring: > 40 dB green · 20–40 dB yellow · < 20 dB red.