diff --git a/src/winml/modelkit/commands/debug.py b/src/winml/modelkit/commands/debug.py new file mode 100644 index 000000000..6940aab0a --- /dev/null +++ b/src/winml/modelkit/commands/debug.py @@ -0,0 +1,211 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +r"""Measure per-op quantization error (local and cumulative SQNR). + +Usage: + winml debug --float-model float.onnx --quant-model quantized.onnx + +Examples: + # Random inputs (self-contained, no downloads) + winml debug --float-model model_optimized.onnx --quant-model model_quantized.onnx + + # Real, task-aware calibration inputs + winml debug --float-model float.onnx --quant-model qdq.onnx \\ + --model-id microsoft/swinv2-tiny-patch4-window16-256 \\ + --task image-classification --samples 16 +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path + +import click +from rich.console import Console +from rich.table import Table + +from ..utils import cli as cli_utils +from ..utils.logging import configure_logging + + +logger = logging.getLogger(__name__) +console = Console() + + +@click.command("debug") +@click.option( + "--float-model", + "float_model", + required=True, + type=click.Path(exists=True, path_type=Path), + help="Float (pre-quantization) ONNX model.", +) +@click.option( + "--quant-model", + "quant_model", + required=True, + type=click.Path(exists=True, path_type=Path), + help="Quantized (QDQ) ONNX model — the build artifact to debug.", +) +@click.option( + "--samples", + type=int, + default=2, + show_default=True, + help="Number of input samples to average over.", +) +@click.option( + "--model-id", + type=str, + default=None, + help="HuggingFace model id for real, task-aware calibration inputs.", +) +@click.option( + "--task", + type=str, + default=None, + help="Task for task-aware calibration (e.g. 'image-classification'). " + "Falls back to random inputs when omitted.", +) +@cli_utils.output_option("Write the full per-tensor results to this JSON file.") +@cli_utils.verbosity_options() +@click.pass_context +def debug( + ctx: click.Context, + float_model: Path, + quant_model: Path, + samples: int, + model_id: str | None, + task: str | None, + output: Path | None, + verbose: int, + quiet: bool, +) -> None: + """Measure per-op quantization error, op by op. + + Runs the float and quantized models over the same inputs and reports, per + activation, the local SQNR and the cumulative SQNR. Lower dB == more damage. + + Local SQNR is the error from quantizing this tensor alone, excluding + upstream. Cumulative SQNR is the error at this tensor, including error + inherited from upstream. + """ + verbose, quiet = cli_utils.resolve_verbosity(ctx, verbose, quiet) + configure_logging(verbosity=verbose, quiet=quiet) + + from ..debug import debug_quantization + + console.print(f"[bold blue]Float model:[/bold blue] {float_model}") + console.print(f"[bold blue]Quant model:[/bold blue] {quant_model}") + console.print(f"[bold blue]Samples:[/bold blue] {samples}\n") + + result = debug_quantization( + float_model, + quant_model, + samples=samples, + model_id=model_id, + task=task, + ) + + print_result(result) + + if output is not None: + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text(json.dumps(result, indent=2), encoding="utf-8") + console.print(f"\n[dim]Full per-tensor results written to {output}[/dim]") + +# Number of worst-ranked rows shown per table. +TOP_N = 10 + + +def print_result(result: dict) -> None: + """Render the debug result dict as console tables.""" + activations = result["activations"] + weights = result["weights"] + model_outputs = result["model_outputs"] + summary = result["summary"] + + console.print( + "Local SQNR = error from quantizing this tensor alone, excluding upstream." + ) + console.print( + "Cumulative SQNR = error at this tensor, including error inherited from upstream." + ) + + # Model outputs: cumulative SQNR at every graph output (shown in full). + _render_table( + "Outputs cumulative SQNR", + "Output", + [(o["output_name"], o["cumulative_sqnr_db"]) for o in model_outputs], + ) + + local_sorted = sorted(activations, key=lambda a: a["local_sqnr_db"]) + _render_table( + f"Top {TOP_N} worst local SQNR", + "Tensor", + [(a["tensor_name"], a["local_sqnr_db"]) for a in local_sorted], + top=TOP_N, + ) + _print_stats(summary["local"]) + + cumulative_sorted = sorted( + activations, + key=lambda a: (a["cumulative_sqnr_db"] is None, a["cumulative_sqnr_db"] or 0.0), + ) + _render_table( + f"Top {TOP_N} worst cumulative SQNR", + "Tensor", + [(a["tensor_name"], a["cumulative_sqnr_db"]) for a in cumulative_sorted], + top=TOP_N, + ) + _print_stats(summary["cumulative"]) + + weights_sorted = sorted(weights, key=lambda w: w["weight_sqnr_db"]) + _render_table( + f"Top {TOP_N} worst weight SQNR", + "Weight", + [(w["weight_name"], w["weight_sqnr_db"]) for w in weights_sorted], + top=TOP_N, + ) + _print_stats(summary["weight"]) + + +def _print_stats(stats: dict) -> None: + # One-line SQNR summary printed below a table. + def _fmt(value: float | None) -> str: + return f"{value:.2f}" if value is not None else "n/a" + + console.print( + f"(count = {stats['count']}, mean = {_fmt(stats['mean'])}, " + f"std = {_fmt(stats['std'])}, min = {_fmt(stats['min'])}, " + f"max = {_fmt(stats['max'])})\n" + ) + + +def _render_table( + title: str, + name_header: str, + rows: list[tuple[str, float | None]], + *, + top: int | None = None, +) -> None: + table = Table(title=title, title_style="bold", title_justify="left", header_style="bold cyan") + table.add_column("#", justify="right", style="dim") + table.add_column("SQNR (dB)", justify="right") + table.add_column(name_header, overflow="fold") + + shown = rows if top is None else rows[:top] + for i, (name, sqnr) in enumerate(shown, 1): + table.add_row(str(i), _fmt_sqnr(sqnr), name) + console.print(table) + + + +def _fmt_sqnr(value: float | None) -> str: + if value is None: + return "[dim]n/a[/dim]" + color = "red" if value < 20 else "yellow" if value < 40 else "green" + return f"[{color}]{value:7.2f}[/{color}]" diff --git a/src/winml/modelkit/debug/__init__.py b/src/winml/modelkit/debug/__init__.py new file mode 100644 index 000000000..50ca1b7db --- /dev/null +++ b/src/winml/modelkit/debug/__init__.py @@ -0,0 +1,18 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Per-op quantization error measurement for ONNX models. + +Usage: + from winml.modelkit.debug import debug_quantization + + errors = debug_quantization("float.onnx", "quantized.onnx") +""" + +from .debugger import debug_quantization + + +__all__ = [ + "debug_quantization", +] diff --git a/src/winml/modelkit/debug/debugger.py b/src/winml/modelkit/debug/debugger.py new file mode 100644 index 000000000..2b60f6907 --- /dev/null +++ b/src/winml/modelkit/debug/debugger.py @@ -0,0 +1,183 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Per-op quantization error measurement. + +Runs a float ONNX model and its quantized counterpart over the same inputs and +reports, per intermediate activation and per weight, the local and cumulative +SQNR (dB) using ``onnxruntime.quantization.qdq_loss_debug``. +""" + +from __future__ import annotations + +import logging +import tempfile +from pathlib import Path +from typing import Any + +from rich.console import Console + + +logger = logging.getLogger(__name__) + + +def _graph_output_names(model_path: str | Path) -> list[str]: + """Return the model's graph output tensor names, in graph order.""" + import onnx + + model = onnx.load(str(model_path), load_external_data=False) + return [o.name for o in model.graph.output] + + +def _sqnr_db(x: Any, y: Any) -> float: + """SQNR (dB) wrapper that tolerates scalar (0-d) tensors. + + ORT's ``compute_signal_to_quantization_noice_ratio`` calls ``len()`` on its + inputs, which fails for a weight that dequantizes to a numpy scalar. Coercing + with ``atleast_1d`` keeps such tensors as length-1 arrays. + """ + import numpy as np + from onnxruntime.quantization.qdq_loss_debug import ( + compute_signal_to_quantization_noice_ratio, + ) + + return compute_signal_to_quantization_noice_ratio(np.atleast_1d(x), np.atleast_1d(y)) + + +def _summarize(values: Any) -> dict: + """Return count/mean/std/min/max for a sequence of SQNR values. + + ``None`` and non-finite (``nan``/``inf``) entries are skipped — the latter + arise when ORT's SQNR hits an overflow or a zero-difference tensor. + ``mean``/``std``/``min``/``max`` are ``None`` when no finite values remain. + """ + import math + import statistics + + finite = [float(v) for v in values if v is not None and math.isfinite(v)] + if not finite: + return {"count": 0, "mean": None, "std": None, "min": None, "max": None} + return { + "count": len(finite), + "mean": statistics.fmean(finite), + "std": statistics.pstdev(finite), + "min": min(finite), + "max": max(finite), + } + + +def debug_quantization( + float_model_path: str | Path, + quant_model_path: str | Path, + *, + samples: int = 8, + model_id: str | None = None, + task: str | None = None, +) -> dict: + """Measure per-activation and per-weight SQNR between two ONNX models. + + Returns a dict with three lists: + + - ``activations``: ``{tensor_name, local_sqnr_db, cumulative_sqnr_db}`` per + intermediate tensor. ``cumulative_sqnr_db`` is ``None`` when the float + reference is unavailable. + - ``weights``: ``{weight_name, weight_sqnr_db}`` per quantized weight. + - ``model_outputs``: ``{output_name, cumulative_sqnr_db}`` per graph output. + - ``summary``: per-category ``{count, mean, std, min, max}`` over the + ``local``, ``cumulative``, and ``weight`` SQNR values. + + Calibration inputs come from ``DatasetCalibrationReader`` (task-aware when + ``model_id``/``task`` are given, random otherwise). Both models run on the + CPU execution provider, matching ORT's quantization debugging guidance. + """ + from onnxruntime.quantization.qdq_loss_debug import ( + collect_activations, + compute_activation_error, + compute_weight_error, + create_activation_matching, + create_weight_matching, + modify_model_output_intermediate_tensors, + ) + + from ..datasets import DatasetCalibrationReader + + console = Console() + + float_model_path = Path(float_model_path) + quant_model_path = Path(quant_model_path) + + reader = DatasetCalibrationReader( + model_name=model_id or "random", + task=task or "random", + max_samples=samples, + model_path=float_model_path, + ) + + with tempfile.TemporaryDirectory() as work_dir: + work_path = Path(work_dir) + aug_float = work_path / "augmented_float.onnx" + aug_quant = work_path / "augmented_quant.onnx" + + console.print("[bold]Augmenting models...[/bold]") + modify_model_output_intermediate_tensors( + str(float_model_path), str(aug_float), save_as_external_data=True + ) + modify_model_output_intermediate_tensors( + str(quant_model_path), str(aug_quant), save_as_external_data=True + ) + + # Both passes must replay the same samples, so rewind between them. + console.print("[bold]Collecting activations...[/bold]") + float_acts = collect_activations(str(aug_float), reader) + reader.rewind() + qdq_acts = collect_activations(str(aug_quant), reader) + + console.print("[bold]Matching activations and weights...[/bold]") + matched = create_activation_matching(qdq_acts, float_acts) + act_err = compute_activation_error(matched) + + weight_match = create_weight_matching(str(float_model_path), str(quant_model_path)) + weight_err = compute_weight_error(weight_match, err_func=_sqnr_db) + + activations = [ + { + "tensor_name": name, + "local_sqnr_db": float(err["qdq_err"]), + "cumulative_sqnr_db": ( + float(err["xmodel_err"]) if "xmodel_err" in err else None + ), + } + for name, err in act_err.items() + ] + + weights = [ + {"weight_name": name, "weight_sqnr_db": float(sqnr)} + for name, sqnr in weight_err.items() + ] + + cumulative_output = { + name: err["xmodel_err"] for name, err in act_err.items() if "xmodel_err" in err + } + model_outputs = [ + { + "output_name": name, + "cumulative_sqnr_db": ( + float(cumulative_output[name]) if name in cumulative_output else None + ), + } + for name in _graph_output_names(float_model_path) + ] + + summary = { + "local": _summarize(a["local_sqnr_db"] for a in activations), + "cumulative": _summarize(a["cumulative_sqnr_db"] for a in activations), + "weight": _summarize(w["weight_sqnr_db"] for w in weights), + } + + return { + "activations": activations, + "weights": weights, + "model_outputs": model_outputs, + "summary": summary, + } diff --git a/tests/unit/test_debug.py b/tests/unit/test_debug.py new file mode 100644 index 000000000..50904b089 --- /dev/null +++ b/tests/unit/test_debug.py @@ -0,0 +1,162 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for the quantization debug engine. + +Covers the WinML wrapper around ORT's ``qdq_loss_debug``: activation SQNR, +weight SQNR, and graph-output cumulative SQNR. ORT's measurement functions are +faked so the tests run without an inference session. +""" + +from __future__ import annotations + +import sys +from types import ModuleType +from typing import TYPE_CHECKING + +from winml.modelkit.debug import debug_quantization +from winml.modelkit.debug.debugger import _graph_output_names + + +if TYPE_CHECKING: + from pathlib import Path + + import pytest + + +def _build_tiny_model(path: Path) -> None: + import onnx + from onnx import TensorProto, helper + + matmul = helper.make_node("MatMul", ["X", "W"], ["Y"], name="matmul0") + relu = helper.make_node("Relu", ["Y"], ["Z"], name="relu0") + graph = helper.make_graph( + [matmul, relu], + "tiny", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2])], + outputs=[helper.make_tensor_value_info("Z", TensorProto.FLOAT, [1, 2])], + initializer=[helper.make_tensor("W", TensorProto.FLOAT, [2, 2], [1, 0, 0, 1])], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + onnx.save(model, str(path)) + + +def test_graph_output_names(tmp_path: Path) -> None: + model_path = tmp_path / "tiny.onnx" + _build_tiny_model(model_path) + + assert _graph_output_names(model_path) == ["Z"] + + +class _FakeReader: + def get_next(self) -> None: + return None + + def rewind(self) -> None: + return None + + +def _install_fake_qdq_loss_debug( + monkeypatch: pytest.MonkeyPatch, + act_err: dict[str, dict[str, float]], + weight_err: dict[str, float], +) -> None: + mod = ModuleType("onnxruntime.quantization.qdq_loss_debug") + mod.modify_model_output_intermediate_tensors = ( # type: ignore[attr-defined] + lambda _in, out, **_kw: __import__("pathlib").Path(out).write_text("x") + ) + mod.create_activation_matching = lambda *_a, **_k: {} # type: ignore[attr-defined] + mod.compute_activation_error = lambda _m: act_err # type: ignore[attr-defined] + mod.create_weight_matching = lambda *_a, **_k: {} # type: ignore[attr-defined] + mod.compute_weight_error = lambda _m, **_k: weight_err # type: ignore[attr-defined] + mod.collect_activations = lambda *_a, **_k: {} # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, "onnxruntime.quantization.qdq_loss_debug", mod) + + # Avoid real dataset construction and inference. + import winml.modelkit.datasets as datasets_mod + + monkeypatch.setattr( + datasets_mod, "DatasetCalibrationReader", lambda **_kw: _FakeReader() + ) + + +def test_debug_quantization_returns_activations_weights_and_outputs( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + float_model = tmp_path / "tiny.onnx" + _build_tiny_model(float_model) + quant_model = tmp_path / "tiny_quant.onnx" + quant_model.write_text("ignored") + + act_err = { + "Y": {"qdq_err": 80.0, "xmodel_err": 40.0}, + "Z": {"qdq_err": 20.0, "xmodel_err": 10.0}, + } + weight_err = {"W": 12.5} + _install_fake_qdq_loss_debug(monkeypatch, act_err, weight_err) + + result = debug_quantization(float_model, quant_model) + + activations = {a["tensor_name"]: a for a in result["activations"]} + assert activations["Y"]["local_sqnr_db"] == 80.0 + assert activations["Y"]["cumulative_sqnr_db"] == 40.0 + assert activations["Z"]["cumulative_sqnr_db"] == 10.0 + + assert result["weights"] == [{"weight_name": "W", "weight_sqnr_db": 12.5}] + + # The single graph output Z carries its cumulative SQNR. + assert result["model_outputs"] == [ + {"output_name": "Z", "cumulative_sqnr_db": 10.0} + ] + + assert result["summary"]["local"] == { + "count": 2, + "mean": 50.0, + "std": 30.0, + "min": 20.0, + "max": 80.0, + } + assert result["summary"]["cumulative"]["count"] == 2 + assert result["summary"]["weight"] == { + "count": 1, + "mean": 12.5, + "std": 0.0, + "min": 12.5, + "max": 12.5, + } + + +def test_debug_quantization_handles_missing_cumulative( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + float_model = tmp_path / "tiny.onnx" + _build_tiny_model(float_model) + quant_model = tmp_path / "tiny_quant.onnx" + quant_model.write_text("ignored") + + # No xmodel_err -> cumulative stays None (no float reference). + act_err = {"Z": {"qdq_err": 15.0}} + _install_fake_qdq_loss_debug(monkeypatch, act_err, {}) + + result = debug_quantization(float_model, quant_model) + + (z,) = result["activations"] + assert z["cumulative_sqnr_db"] is None + assert z["local_sqnr_db"] == 15.0 + assert result["weights"] == [] + # Output Z has no measured cumulative SQNR. + assert result["model_outputs"] == [ + {"output_name": "Z", "cumulative_sqnr_db": None} + ] + + # No cumulative or weight values remain, so those summaries are empty. + assert result["summary"]["local"]["count"] == 1 + assert result["summary"]["cumulative"] == { + "count": 0, + "mean": None, + "std": None, + "min": None, + "max": None, + } + assert result["summary"]["weight"]["count"] == 0