Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b508d65
feat: add --precision fp16 to optimize, build, and export commands
DingmaomaoBJTU Jun 11, 2026
e82a099
refactor: integrate FP16 into quantize stage as post-processing
github-actions[bot] Jun 23, 2026
ef1926a
chore: remove spurious .data files
github-actions[bot] Jun 23, 2026
6ee9e66
refactor: remove --precision from export/optimize, add fp16 to quantize
github-actions[bot] Jun 23, 2026
e974a75
feat(build): extend --precision to accept all quantization values
github-actions[bot] Jun 23, 2026
118187f
fix: resolve CodeQL import warnings in fp16 module
github-actions[bot] Jun 23, 2026
c1aab3b
fix: resolve rebase conflicts with main
github-actions[bot] Jun 23, 2026
ef8c779
feat: warn when calibration options are ignored in FP16 mode
github-actions[bot] Jun 23, 2026
ae2521e
fix: skip task/model_name validation for fp16_only quant configs
github-actions[bot] Jun 23, 2026
ed2c591
fix: skip calibration validation for rtn and dynamic algorithms
github-actions[bot] Jun 23, 2026
fce06ff
feat: merge --rtn-bits into --precision (int4/w4a16 auto-selects RTN)
github-actions[bot] Jun 23, 2026
9b0e8cc
fix: build pipeline RTN routing and MatMulNBitsQuantizer model extrac…
github-actions[bot] Jun 23, 2026
32211ac
fix: resolve lint warnings (raw regex strings, unused variable)
github-actions[bot] Jun 23, 2026
85a774f
fix: resolve mypy type errors and remove duplicate imports
github-actions[bot] Jun 23, 2026
8c86403
fix: address code review findings
github-actions[bot] Jun 23, 2026
e14cd13
fix: address deep code review findings
github-actions[bot] Jun 23, 2026
8b4dcb0
feat: support w4a32 precision (equivalent to int4) and w4a16 FP16 pos…
github-actions[bot] Jun 24, 2026
011ad21
refactor: unify fp16/fp16_only into algorithm='fp16' + fp16_postprocess
github-actions[bot] Jun 24, 2026
a79c67c
refactor: replace fp16_postprocess with multi-pass pipeline
github-actions[bot] Jun 24, 2026
ddeb3be
fix: clean up intermediate pass files in multi-pass quantize stage
github-actions[bot] Jun 24, 2026
7e17a71
refactor: move multi-pass precision logic into quantize_onnx
github-actions[bot] Jun 24, 2026
c4cb818
chore: remove duplicate is_submodule assignment in build config valid…
github-actions[bot] Jun 24, 2026
7f93d93
refactor: extract warn_ignored_calibration_options to shared cli utils
github-actions[bot] Jun 24, 2026
6bb1e64
refactor: move convert_to_fp16 from optim to quant module
github-actions[bot] Jun 24, 2026
8ae24dc
chore: mark legacy mode field as deprecated in quant config
github-actions[bot] Jun 24, 2026
bb51338
refactor: unify mode and algorithm fields in WinMLQuantizationConfig
github-actions[bot] Jun 24, 2026
0125ec6
refactor: split quantizer into dispatch pattern and consolidate quant…
github-actions[bot] Jun 24, 2026
df3d0a5
fix: type dispatch dict properly to satisfy mypy no-any-return
github-actions[bot] Jun 24, 2026
cd9fcd3
refactor: remove multi-pass w4a16 from quantize_onnx, simplify to sin…
github-actions[bot] Jun 25, 2026
b261de1
cleanup: remove remaining multi-pass references from build.py
github-actions[bot] Jun 25, 2026
eca92a6
cleanup: remove thin wrapper and adopt main's add_pre_process_metadat…
github-actions[bot] Jun 25, 2026
15c151f
test: update e2e test — fp16 is now a valid precision for winml quantize
github-actions[bot] Jun 25, 2026
16bf584
Add explanatory comment to empty except clause (CodeQL fix)
github-actions[bot] Jun 25, 2026
b303514
Remove dead 'algorithm' key compat from from_dict()
github-actions[bot] Jun 25, 2026
f98a24f
Fix FP16 detection to use config.quant.mode; remove _is_weight_only w…
github-actions[bot] Jun 25, 2026
96cc6af
Fix test: w4a16 now raises ValueError (dead guard removed)
github-actions[bot] Jun 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 95 additions & 44 deletions src/winml/modelkit/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,12 @@ def build(

# Force rebuild
winml build -c config.json -m microsoft/resnet-50 -o output/ --rebuild

# Build with INT8 quantization
winml build -m microsoft/resnet-50 -o output/ --precision int8

# Build with mixed precision (INT8 weights, INT8 activations)
winml build -m microsoft/resnet-50 -o output/ --precision w8a8
"""
# Merge top-level -v/-q with subcommand-level flags so either position works.
verbose, quiet = cli_utils.resolve_verbosity(ctx, verbose, quiet)
Expand All @@ -540,6 +546,16 @@ def build(
if not output_dir and not use_cache:
raise click.UsageError("One of --output-dir or --use-cache is required.")

# Validate precision value early for better error messages.
if precision is not None:
from ..config.precision import _is_valid_precision

if not _is_valid_precision(precision.lower()):
raise click.UsageError(
f"Invalid precision '{precision}'. "
"Expected: auto, fp32, fp16, int8, int16, or w{{x}}a{{y}} (e.g., w8a8, w8a16)."
)

# If ep unspecified, resolve the target device and pick the highest-priority
# EP compatible with it. Avoids selecting an EP that does not match the host
# hardware -- analyzing for the wrong EP leaves black nodes that block a
Expand Down Expand Up @@ -615,6 +631,15 @@ def _patch_device(cfg: WinMLBuildConfig) -> None:
# and other calibration settings from the existing config.
cfg.quant.weight_type = resolved_quant.weight_type
cfg.quant.activation_type = resolved_quant.activation_type
cfg.quant.mode = resolved_quant.mode
if resolved_quant.mode == "rtn":
cfg.quant.rtn_bits = resolved_quant.rtn_bits
cfg.quant.rtn_block_size = resolved_quant.rtn_block_size
cfg.quant.rtn_symmetric = resolved_quant.rtn_symmetric
cfg.quant.rtn_accuracy_level = resolved_quant.rtn_accuracy_level
# Store the original precision string for stage display
if precision:
cfg.precision = precision.lower()
if cfg.compile is not None and cfg.compile.ep_config is not None:
provider = cfg.compile.ep_config.provider
patched = WinMLCompileConfig.for_provider(provider, device=device)
Expand Down Expand Up @@ -1116,10 +1141,10 @@ def _run_quantize_stage(
quantized_path: Path,
stage_timings: list[tuple[str, float | None]],
) -> Path:
"""Run the quantize stage inside a StageLive context (if quant is configured).
"""Run the quantize stage (if quant is configured).

Handles QDQ skip detection, shows dataset/calibration/precision details,
and appends timing to stage_timings.
Delegates single-pass quantization to ``quantize_onnx(config=...)``.
The cmd layer only handles UI display and the QDQ skip check.

Args:
config: Build configuration.
Expand All @@ -1137,35 +1162,46 @@ def _run_quantize_stage(
if config.quant is None:
return current_path

if is_quantized_onnx(current_path):
# QDQ skip check: if model already has QDQ nodes and we're doing static/dynamic
if config.quant.mode in ("static", "dynamic") and is_quantized_onnx(current_path):
print_stage_skip(console, "quantize", "(QDQ nodes already present)")
stage_timings.append(("Quantize", None))
return current_path

with StageLive("quantize", console) as sl:
wt = config.quant.weight_type
sl.set_status(f"Quantizing ({wt})...")
# Calibration info before blocking call
ds = config.quant.dataset_name or "default"
sl.kv(
"Dataset:",
f"[cyan]{ds}[/cyan] [dim]({config.quant.task or 'unknown'})[/dim]",
)
sl.kv(
"Calibration:",
f"[cyan]{config.quant.samples}[/cyan] samples"
f" [dim]({config.quant.calibration_method})[/dim]",
)
# Suppress tqdm/datasets progress bars during quantize
# to keep Live display clean
# Determine stage label from quant mode
is_fp16_only = config.quant.mode == "fp16"
stage_label = "fp16" if is_fp16_only else "quantize"
stage_name = "FP16" if is_fp16_only else "Quantize"

with StageLive(stage_label, console) as sl:
# Show status based on what we're about to do
if is_fp16_only:
sl.set_status("Converting to FP16...")
elif config.quant.mode == "rtn":
sl.set_status(f"Quantizing (RTN {config.quant.rtn_bits}-bit)...")
else:
sl.set_status(f"Quantizing ({config.quant.weight_type})...")
ds = config.quant.dataset_name or "default"
sl.kv(
"Dataset:",
f"[cyan]{ds}[/cyan] [dim]({config.quant.task or 'unknown'})[/dim]",
)
sl.kv(
"Calibration:",
f"[cyan]{config.quant.samples}[/cyan] samples"
f" [dim]({config.quant.calibration_method})[/dim]",
)

# Suppress tqdm/datasets progress bars for QDQ calibration
_datasets_available = False
try:
import datasets
if config.quant.mode in ("static", "dynamic"):
try:
import datasets

datasets.disable_progress_bars()
_datasets_available = True
except ImportError:
pass # datasets package not installed; progress bar suppression not needed
datasets.disable_progress_bars()
_datasets_available = True
except ImportError:
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
pass # datasets package is optional; calibration falls back to random data

t0 = time.monotonic()
try:
Expand All @@ -1177,27 +1213,42 @@ def _run_quantize_stage(
)
finally:
if _datasets_available:
import datasets

datasets.enable_progress_bars()

if not quant_result.success:
errors = ", ".join(quant_result.errors) if quant_result.errors else "Unknown"
sl.set_error(errors)
raise RuntimeError(f"Quantization failed: {errors}")
current_path = quantized_path
_quant_elapsed = time.monotonic() - t0
sl.set_done(_quant_elapsed)
sl.kv(
"Precision:",
f"[cyan]{config.quant.weight_type}/"
f"{config.quant.activation_type}[/cyan]"
f" [dim](weight/activation)[/dim]",
)
sl.artifact(
str(quantized_path),
_safe_size(quantized_path),
)
raise RuntimeError(f"{stage_name} failed: {errors}")

elapsed = time.monotonic() - t0
sl.set_done(elapsed)

# Show algorithm-specific result details
if is_fp16_only:
sl.detail("[dim]I/O types preserved as FP32[/dim]")
elif config.quant.mode == "rtn":
sl.kv(
"Algorithm:",
f"[cyan]RTN[/cyan] [dim](weight-only {config.quant.rtn_bits}-bit)[/dim]",
)
sl.kv(
"Config:",
f"block_size={config.quant.rtn_block_size}, symmetric={config.quant.rtn_symmetric}",
)
else:
sl.kv(
"Precision:",
f"[cyan]{config.quant.weight_type}/{config.quant.activation_type}[/cyan]"
f" [dim](weight/activation)[/dim]",
)

sl.artifact(str(quantized_path), _safe_size(quantized_path))
sl.blank()
stage_timings.append(("Quantize", _quant_elapsed))
return current_path

stage_timings.append((stage_name, elapsed))
return quantized_path


def _run_compile_stage(
Expand Down Expand Up @@ -1387,7 +1438,7 @@ def _name(base: str) -> str:
# Persist config after autoconf
config_path.write_text(json.dumps(config.to_dict(), indent=2))

# ── Quantize stage ───────────────────────────────────────────
# ── Quantize stage ──────
current_path = _run_quantize_stage(
config=config,
current_path=current_path,
Expand Down Expand Up @@ -1482,7 +1533,7 @@ def _build_onnx_pipeline(

config_path.write_text(json.dumps(config.to_dict(), indent=2))

# ── Quantize stage ───────────────────────────────────────────
# ── Quantize stage ──────
current_path = _run_quantize_stage(
config=config,
current_path=current_path,
Expand Down
134 changes: 84 additions & 50 deletions src/winml/modelkit/commands/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import click
from rich.console import Console

from ..config.precision import is_weight_only_precision
from ..utils import cli as cli_utils
from ..utils.logging import configure_logging

Expand All @@ -49,8 +50,10 @@
@cli_utils.output_option("Output path (default: {input}_qdq.onnx)")
@cli_utils.precision_option(
default=None,
help_text="Quantization precision: auto, int8, int16, or w{x}a{y} where "
"x,y in {8,16} (e.g., w8a8, w8a16, w16a16)",
help_text="Quantization precision: auto, fp16, int4, int8, int16, or w{x}a{y} where "
"x in {4,8,16}, y in {8,16} (e.g., w4a16, w8a8, w8a16). "
"int4/w4a16 uses RTN weight-only quantization; "
"fp16 converts all FP32 tensors to FP16 (no QDQ)",
optional_message="Overridden by explicit --weight-type/--activation-type",
)
@click.option(
Expand Down Expand Up @@ -122,11 +125,11 @@ def quantize(
quiet: bool,
config_file: Path | None,
) -> None:
r"""Quantize ONNX model by inserting QDQ nodes.
r"""Quantize ONNX model by inserting QDQ nodes, RTN weight-only, or convert to FP16.

This command applies static quantization to an ONNX model using calibration
data to determine quantization parameters. The output model contains
QuantizeLinear and DequantizeLinear nodes for quantization-aware inference.
This command applies quantization to an ONNX model. The algorithm is
auto-selected from the precision: int4/w4a16 → RTN weight-only,
int8/int16/w8a8 → static QDQ, fp16 → FP16 conversion.

\b
Examples:
Expand All @@ -136,9 +139,15 @@ def quantize(
# Use precision shorthand (same as --weight-type uint8 --activation-type uint8)
winml quantize -m model.onnx --precision int8

# RTN 4-bit weight-only quantization (no calibration data needed)
winml quantize -m model.onnx --precision int4

# Int16 quantization
winml quantize -m model.onnx --precision int16

# Convert model to FP16 (no QDQ, full-model conversion)
winml quantize -m model.onnx --precision fp16

# Custom output path and more samples
winml quantize -m model.onnx -o quantized.onnx --samples 100

Expand Down Expand Up @@ -174,69 +183,94 @@ def quantize(
# Import quantizer (late import to speed up CLI)
from ..quant import WinMLQuantizationConfig, quantize_onnx

# Resolve weight/activation types from --precision or explicit flags
resolved_weight, resolved_activation = _resolve_quant_types(
precision, weight_type, activation_type
)
# ── Build config based on precision ──────────────────────────
precision_lower = precision.lower() if precision else None

# Determine output path
if output is None:
output = model.parent / f"{model.stem}_qdq.onnx"
output.parent.mkdir(parents=True, exist_ok=True)
if precision_lower == "fp16":
# FP16 conversion
cli_utils.warn_ignored_calibration_options(
ctx, "FP16 conversion does not use calibration data.", console=console
)
if output is None:
output = model.parent / f"{model.stem}_fp16.onnx"
config = WinMLQuantizationConfig(mode="fp16")
label = "FP16 conversion"

elif precision_lower and is_weight_only_precision(precision_lower):
# RTN weight-only
from ..config.precision import extract_weight_bits

cli_utils.warn_ignored_calibration_options(
ctx, "RTN weight-only quantization does not use calibration data.", console=console
)
rtn_bits = extract_weight_bits(precision_lower)
if output is None:
output = model.parent / f"{model.stem}_int{rtn_bits}.onnx"
config = WinMLQuantizationConfig(mode="rtn", rtn_bits=rtn_bits)
label = f"RTN {rtn_bits}-bit quantization"

# Show info
else:
# QDQ calibrated quantization
resolved_weight, resolved_activation = _resolve_quant_types(
precision, weight_type, activation_type
)
if output is None:
output = model.parent / f"{model.stem}_qdq.onnx"
config = WinMLQuantizationConfig(
samples=samples,
calibration_method=cast('Literal["minmax", "entropy", "percentile"]', method),
weight_type=cast('Literal["uint8", "int8", "uint16", "int16"]', resolved_weight),
activation_type=cast(
'Literal["uint8", "int8", "uint16", "int16"]', resolved_activation
),
per_channel=per_channel,
symmetric=symmetric,
task=task,
model_name=model_name,
)
label = "Quantization"

# Display QDQ-specific info
console.print(f"[bold blue]Weight type:[/bold blue] {resolved_weight}")
console.print(f"[bold blue]Activation type:[/bold blue] {resolved_activation}")
console.print(f"[bold blue]Samples:[/bold blue] {samples}")
console.print(f"[bold blue]Method:[/bold blue] {method}")
if config.dataset_name:
_dataset_display = config.dataset_name
elif config.task and config.task != "random":
_dataset_display = f"Default for task '{config.task}'"
else:
_dataset_display = "Random data (synthetic from ONNX I/O specs)"
console.print(f"[bold blue]Dataset:[/bold blue] {_dataset_display}")

# ── Shared execution: print header, run, report ──────────────
output.parent.mkdir(parents=True, exist_ok=True)
console.print(f"[bold blue]Input:[/bold blue] {model}")
console.print(f"[bold blue]Output:[/bold blue] {output}")
console.print(f"[bold blue]Precision:[/bold blue] {precision or 'auto'}")
console.print(f"[bold blue]Weight type:[/bold blue] {resolved_weight}")
console.print(f"[bold blue]Activation type:[/bold blue] {resolved_activation}")
console.print(f"[bold blue]Samples:[/bold blue] {samples}")
console.print(f"[bold blue]Method:[/bold blue] {method}")

# Create config (output_path is passed separately to API).
# Click's Choice validates these strings at parse time, so cast acknowledges
# the Literal[] contract that mypy can't see through the str return type.
config = WinMLQuantizationConfig(
samples=samples,
calibration_method=cast('Literal["minmax", "entropy", "percentile"]', method),
weight_type=cast('Literal["uint8", "int8", "uint16", "int16"]', resolved_weight),
activation_type=cast('Literal["uint8", "int8", "uint16", "int16"]', resolved_activation),
per_channel=per_channel,
symmetric=symmetric,
task=task,
model_name=model_name,
)

# Display dataset info from config
if config.dataset_name:
_dataset_display = config.dataset_name
elif config.task and config.task != "random":
_dataset_display = f"Default for task '{config.task}'"
else:
_dataset_display = "Random data (synthetic from ONNX I/O specs)"
console.print(f"[bold blue]Dataset:[/bold blue] {_dataset_display}")

try:
console.print("\n[bold]Running quantization...[/bold]")
console.print(f"\n[bold]Running {label.lower()}...[/bold]")
result = quantize_onnx(model, output_path=output, config=config)

if result.success:
console.print("\n[bold green]Success![/bold green] Model quantized")
console.print(f"\n[bold green]Success![/bold green] {label} complete")
console.print(f"[dim]Output: {result.output_path}[/dim]")
console.print(f"[dim]QDQ nodes inserted: {result.nodes_quantized}[/dim]")
if result.nodes_quantized:
console.print(f"[dim]QDQ nodes inserted: {result.nodes_quantized}[/dim]")
console.print(f"[dim]Total time: {result.total_time_seconds:.2f}s[/dim]")
else:
console.print("\n[bold red]Quantization failed:[/bold red]")
console.print(f"\n[bold red]{label} failed:[/bold red]")
for error in result.errors:
console.print(f" {error}")
raise click.ClickException("Quantization failed")
raise click.ClickException(f"{label} failed")

except click.ClickException:
raise
except Exception as e:
console.print(f"\n[bold red]Quantization failed:[/bold red] {e}")
logger.exception("Quantization failed")
raise click.ClickException(f"Quantization failed: {e}") from e
console.print(f"\n[bold red]{label} failed:[/bold red] {e}")
logger.exception("%s failed", label)
raise click.ClickException(f"{label} failed: {e}") from e


def _resolve_quant_types(
Expand Down
Loading
Loading