diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index dd3444bc9..a2e24093a 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,7 +13,7 @@ concurrency: jobs: lint: runs-on: windows-latest - # Bumped from 5: combined mypy on 23 packages cold-starts at ~3-4 min on + # Bumped from 5: combined mypy on 24 packages cold-starts at ~3-4 min on # Windows runners; the original 5-min ceiling cancelled mid-run. timeout-minutes: 10 @@ -65,6 +65,7 @@ jobs: -p winml.modelkit.onnx -p winml.modelkit.optim -p winml.modelkit.optracing + -p winml.modelkit.pattern -p winml.modelkit.quant -p winml.modelkit.serve -p winml.modelkit.session diff --git a/src/winml/modelkit/analyze/core/doc_constraint_checker.py b/src/winml/modelkit/analyze/core/doc_constraint_checker.py index 3b4757d83..5cd399c75 100644 --- a/src/winml/modelkit/analyze/core/doc_constraint_checker.py +++ b/src/winml/modelkit/analyze/core/doc_constraint_checker.py @@ -254,7 +254,7 @@ def _get_node_actual_dtype(self, node: onnx.NodeProto) -> str | None: return dtype.upper() return None - def _get_node_shape(self, tensor_name: str) -> list[int] | None: + def _get_node_shape(self, tensor_name: str) -> tuple[int | str | None, ...] | None: """Get shape of a tensor. Args: diff --git a/src/winml/modelkit/analyze/core/runtime_checker_query.py b/src/winml/modelkit/analyze/core/runtime_checker_query.py index e58928b34..18382f57d 100644 --- a/src/winml/modelkit/analyze/core/runtime_checker_query.py +++ b/src/winml/modelkit/analyze/core/runtime_checker_query.py @@ -836,7 +836,7 @@ def _tensor_to_array_with_fallback(tensor: onnx.TensorProto) -> np.ndarray: type_vars[type_annotation] = dtype else: vi = valueinfo.get(inp_name) - shape_seq: list | tuple[int, ...] | None = None + shape_seq: tuple[int | str | None, ...] | None = None dtype = None if vi is not None: shape_seq, dtype = shape_and_dtype_from_valueinfo(vi) diff --git a/src/winml/modelkit/pattern/attention_patterns.py b/src/winml/modelkit/pattern/attention_patterns.py index ab618960b..e783c47d1 100644 --- a/src/winml/modelkit/pattern/attention_patterns.py +++ b/src/winml/modelkit/pattern/attention_patterns.py @@ -48,12 +48,11 @@ from .base import ( Pattern, PatternInputGenerator, - PatternMatchResult, PatternSchema, Skeleton, - SkeletonMatchResult, register_pattern_input_generator, ) +from .match import PatternMatchResult, SkeletonMatchResult from .op_input_gen import InputShapeConstraint @@ -539,7 +538,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputShapeConstraint]]: + ) -> list[dict[str, object]]: """Returns input combinations for expanded attention with mask pattern testing. Provides various 4D input shapes for Q, K, V, and attn_mask tensors. @@ -596,7 +595,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputShapeConstraint]]: + ) -> list[dict[str, object]]: """Returns input combinations for Transpose+Attention pattern testing. Provides various 4D input shapes for Q, K, V, and attn_mask tensors. diff --git a/src/winml/modelkit/pattern/base.py b/src/winml/modelkit/pattern/base.py index 245e5d51b..3ba29fff9 100644 --- a/src/winml/modelkit/pattern/base.py +++ b/src/winml/modelkit/pattern/base.py @@ -169,7 +169,7 @@ def get_schema(self) -> PatternSchema: from collections import defaultdict from dataclasses import dataclass, field from pathlib import Path -from typing import Any +from typing import Any, cast import numpy as np import onnx @@ -599,7 +599,10 @@ def _check_skeleton_result_impl( if type_str: # Use InputShapeConstraint to create dummy value type_annotation = SupportedONNXType.from_onnx_type(type_str).annotation - inputs[name] = InputShapeConstraint(info.shape).get_value(type_annotation) + # Matched-tensor shapes are concrete here (dynamic dims resolved). + inputs[name] = InputShapeConstraint( + cast("tuple[int, ...]", info.shape) + ).get_value(type_annotation) # Build is_constant_map from input_infos is_constant_map = {name: info.is_constant for name, info in input_infos.items()} @@ -649,7 +652,7 @@ def _check_skeleton_result_impl( node_domain = skeleton.node_domains[node_idx] op_type = skeleton.node_op_types[node_idx] opset_versions = ONNXDomain.get_model_domain_opset_versions( - skeleton_match_result.model + skeleton_match_result.matcher.model ) opset_version = opset_versions[node_domain] op_schema = node_domain.get_op_schema(op_type, opset_version) @@ -845,7 +848,7 @@ def get_onnx_model( # Create nodes nodes = [] - node_output_names = {} # node_idx -> output_name + node_output_names: dict[int, str] = {} # node_idx -> output_name for node_idx in range(skeleton.n_nodes): op_type = skeleton.node_op_types[node_idx] @@ -1037,7 +1040,7 @@ def _infer_schema_attributes( class PatternInputGenerator(OpInputGenerator): """Input generator that wraps a Pattern for runtime checking.""" - pattern: Pattern = None + pattern: Pattern | None = None # subclasses set a real Pattern (asserted in __init__) registration_name: str def __init__( @@ -1056,7 +1059,9 @@ def __init__( self.domain_versions = domain_versions schema = self.pattern.get_schema() self.op_name = schema.name # compatibility with OpInputGenerator - super().__init__(schema, onnx_types_to_check) + # OpInputGenerator duck-types the schema (OpSchema or PatternSchema); it + # guards OpSchema-specific access with isinstance internally. + super().__init__(cast("OpSchema", schema), onnx_types_to_check) def _create_model( self, @@ -1086,8 +1091,8 @@ def _create_model( SupportedONNXType.from_annotation(dtype).onnx_type for dtype in output_dtypes ] - # Use the pattern's get_onnx_model method - return self.pattern.get_onnx_model( + # Use the pattern's get_onnx_model method (pattern is set by the subclass). + return cast("Pattern", self.pattern).get_onnx_model( inputs=input_kwargs, attributes=attr_kwargs, is_constant_map=is_constant_map, @@ -1538,7 +1543,7 @@ def _get_registered_edge_info(self, tensor_name: str, consumer_name: str) -> Edg def _check_constant_constraints( self, - matched_nodes: list[str], + matched_nodes: list[onnx.NodeProto], constant_constraints: list[tuple[int, int, np.ndarray]], ) -> bool: """Check constant value constraints for a skeleton match. @@ -1649,7 +1654,9 @@ def match(self) -> list[PatternMatchResult]: # Validate each result using pattern's check_skeleton_result validated_results = [] for result in skeleton_results: - pattern_match_result = result.pattern.check_skeleton_result(result) + # match_skeleton() yields results whose .pattern is a registered ABC + # Pattern (the pydantic PatternModel is only used for serialization). + pattern_match_result = cast("Pattern", result.pattern).check_skeleton_result(result) if pattern_match_result is not None: validated_results.append(pattern_match_result) @@ -1773,7 +1780,7 @@ def _match_single_skeleton( # check 3: the mappings must be compatible valid_merged_mappings = [] for mapping_combination in it.product(*dst_slot_partial_mappings): - merged_mapping = _merge_mappings(mapping_combination) + merged_mapping = _merge_mappings(list(mapping_combination)) if merged_mapping is not None: # valid mapping merged_mapping[subgraph_node] = node_name @@ -1955,7 +1962,7 @@ def _allocate_graph_node_key(node: Any) -> str: nonlocal generated_node_key_counter if node.name and node.name not in used_graph_node_keys: - key = node.name + key: str = node.name elif node.name: suffix = 1 key = f"{node.name}__{suffix}" @@ -2018,7 +2025,8 @@ def _allocate_graph_node_key(node: Any) -> str: # Create the new pattern instance new_pattern = new_pattern_class() - assert skeleton_match.pattern.get_schema() == new_pattern.get_schema(), ( + matched_pattern = cast("Pattern", skeleton_match.pattern) + assert matched_pattern.get_schema() == new_pattern.get_schema(), ( f"New pattern {new_pattern_class.__name__} schema does not match " f"the matched pattern {skeleton_match.pattern.__class__.__name__} schema." ) diff --git a/src/winml/modelkit/pattern/config.py b/src/winml/modelkit/pattern/config.py index dfebb6332..d5bcf4956 100644 --- a/src/winml/modelkit/pattern/config.py +++ b/src/winml/modelkit/pattern/config.py @@ -15,7 +15,7 @@ import logging from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast if TYPE_CHECKING: @@ -95,7 +95,7 @@ def load_pattern(self) -> Pattern: except (ImportError, AttributeError): continue # Instantiation errors should propagate, not be silently caught - return pattern_cls() + return cast("Pattern", pattern_cls()) msg = f"Failed to load pattern {self.pattern_class} from {self.module}" logger.error(msg) diff --git a/src/winml/modelkit/pattern/gelu_patterns.py b/src/winml/modelkit/pattern/gelu_patterns.py index d1e175df9..13c9586b4 100644 --- a/src/winml/modelkit/pattern/gelu_patterns.py +++ b/src/winml/modelkit/pattern/gelu_patterns.py @@ -124,7 +124,7 @@ def get_schema(self) -> PatternSchema: @register_pattern_input_generator -class Gelu1PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")): +class Gelu1PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")): # type: ignore[misc] # dynamic base class (runtime-checker op) """Input generator for GELU activation pattern variant 1.""" pattern = Gelu1Pattern() @@ -224,7 +224,7 @@ def get_schema(self) -> PatternSchema: @register_pattern_input_generator -class Gelu2PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")): +class Gelu2PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")): # type: ignore[misc] # dynamic base class (runtime-checker op) """Input generator for GELU activation pattern variant 2.""" pattern = Gelu2Pattern() @@ -331,7 +331,7 @@ def get_schema(self) -> PatternSchema: @register_pattern_input_generator -class Gelu3PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")): +class Gelu3PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")): # type: ignore[misc] # dynamic base class (runtime-checker op) """Input generator for GELU activation pattern variant 3.""" pattern = Gelu3Pattern() @@ -439,7 +439,7 @@ def get_schema(self) -> PatternSchema: @register_pattern_input_generator -class Gelu4PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")): +class Gelu4PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")): # type: ignore[misc] # dynamic base class (runtime-checker op) """Input generator for GELU activation pattern variant 4.""" pattern = Gelu4Pattern() diff --git a/src/winml/modelkit/pattern/gemm_patterns.py b/src/winml/modelkit/pattern/gemm_patterns.py index 4e5ca0894..5967c9535 100644 --- a/src/winml/modelkit/pattern/gemm_patterns.py +++ b/src/winml/modelkit/pattern/gemm_patterns.py @@ -387,7 +387,8 @@ def get_schema(self) -> PatternSchema: class GemmPatternInputGenerator(PatternInputGenerator): """PatternInputGenerator for Gemm patterns.""" - pattern = ReshapeGemmReshapePattern() + # Typed as the base Pattern so subclasses can set a different concrete pattern. + pattern: Pattern = ReshapeGemmReshapePattern() def get_finite_attribute_sets(self) -> dict[str, list[Any]]: """Return finite attribute sets for ReshapeGemmReshape (none).""" @@ -418,7 +419,7 @@ def get_input_and_infinite_attribute_combinations( for a_shape in a_shapes: for b_shape in b_shapes: for c_option in c_options: - combination = { + combination: dict[str, object] = { "A": InputShapeConstraint(a_shape), "B": InputShapeConstraint(b_shape), } diff --git a/src/winml/modelkit/pattern/layernorm_patterns.py b/src/winml/modelkit/pattern/layernorm_patterns.py index 4c03236c1..3998a5c55 100644 --- a/src/winml/modelkit/pattern/layernorm_patterns.py +++ b/src/winml/modelkit/pattern/layernorm_patterns.py @@ -14,7 +14,7 @@ """ from abc import abstractmethod -from typing import Any +from typing import Any, cast import numpy as np from onnx.defs import OpSchema @@ -23,13 +23,12 @@ from .base import ( Pattern, PatternInputGenerator, - PatternMatchResult, PatternMismatchedError, PatternSchema, Skeleton, - SkeletonMatchResult, register_pattern_input_generator, ) +from .match import PatternMatchResult, SkeletonMatchResult from .op_input_gen import get_runtime_checker_op from .utils import ( get_attribute_proto_value, @@ -253,6 +252,8 @@ def _infer_schema_attributes( if axes_value is None: raise PatternMismatchedError("ReduceMean missing axes attribute") + if axes_value is None: + raise PatternMismatchedError("ReduceMean axes tensor value is None") if len(axes_value) != 1: raise PatternMismatchedError( f"Only single-axis normalization supported, got axes={axes_value}" @@ -495,7 +496,7 @@ def _get_normalized_dim(self, inputs: dict[str, np.ndarray], attributes: dict[st axis = attributes["axis"] rank = len(x_shape) normalized_axis = axis if axis >= 0 else rank + axis - return x_shape[normalized_axis] + return int(x_shape[normalized_axis]) def get_internal_constants_and_attributes( self, @@ -534,7 +535,7 @@ def get_internal_constants_and_attributes( class LayerNormalizationPatternInputGenerator( - PatternInputGenerator, get_runtime_checker_op("LayerNormalization") + PatternInputGenerator, get_runtime_checker_op("LayerNormalization") # type: ignore[misc] # dynamic base class (runtime-checker op) ): """Base PatternInputGenerator for LayerNormalization patterns. @@ -547,14 +548,15 @@ def get_finite_attribute_sets(self) -> dict[str, list[Any]]: def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, Any]]: """Return input combinations with broadcast-compatible Scale/B shapes.""" - from .op_input_gen import InputValueConstraint + from .op_input_gen import InputShapeConstraint, InputValueConstraint - combinations = super().get_input_and_infinite_attribute_combinations() + # Dynamic base provides the real combinations method at runtime. + combinations = super().get_input_and_infinite_attribute_combinations() # type: ignore[safe-super] adapted = [] for combo in combinations: - axis = combo["axis"] - x_shape = combo["X"].shape + axis = cast("int", combo["axis"]) + x_shape = cast("InputShapeConstraint", combo["X"]).shape rank = len(x_shape) normalized_axis = axis if axis >= 0 else rank + axis normalized_dim = x_shape[normalized_axis] @@ -569,8 +571,8 @@ def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, Any]]: broadcast_shape = [1] * rank broadcast_shape[normalized_axis] = normalized_dim - scale_value = combo["Scale"].value - bias_value = combo["B"].value + scale_value = cast("InputValueConstraint", combo["Scale"]).value + bias_value = cast("InputValueConstraint", combo["B"]).value new_scale = np.ones((normalized_dim,), dtype=scale_value.dtype).reshape( broadcast_shape ) diff --git a/src/winml/modelkit/pattern/match.py b/src/winml/modelkit/pattern/match.py index 25a9909f7..ea5b2b034 100644 --- a/src/winml/modelkit/pattern/match.py +++ b/src/winml/modelkit/pattern/match.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from onnx import NodeProto + from ..analyze import ONNXOp from .base import Pattern, PatternMatcher from .models import Pattern as PatternModel @@ -141,40 +142,27 @@ def matched_node_keys(self) -> list[str]: return self.skeleton_match_result.matched_node_keys @property - def matched_node_names(self): + def matched_node_names(self) -> list[ONNXOp]: """Get matched nodes as ONNXOp objects. Note: Despite the name, this returns ONNXOp objects, not strings. This is for backward compatibility. Use matched_nodes for string names. Returns: - List of ONNXOp instances containing node metadata (when used from analyze). - Falls back to dicts when ONNXOp is not available. + List of ONNXOp instances containing node metadata. """ - try: - from ..analyze import ONNXOp - - node_keys = self.skeleton_match_result.matched_node_keys - - return [ - ONNXOp( - node_name=node_keys[idx], - op_type=node.op_type, - namespace=node.domain if node.domain else "ai.onnx", - ) - for idx, node in enumerate(self.skeleton_match_result.matched_nodes) - ] - except ImportError: - # When used outside analyze context, return node info as dicts - node_keys = self.skeleton_match_result.matched_node_keys - return [ - { - "node_name": node_keys[idx], - "op_type": node.op_type, - "namespace": node.domain if node.domain else "ai.onnx", - } - for idx, node in enumerate(self.skeleton_match_result.matched_nodes) - ] + from ..analyze import ONNXOp + + node_keys = self.skeleton_match_result.matched_node_keys + + return [ + ONNXOp( + node_name=node_keys[idx], + op_type=node.op_type, + namespace=node.domain if node.domain else "ai.onnx", + ) + for idx, node in enumerate(self.skeleton_match_result.matched_nodes) + ] @property def type_vars(self) -> dict[str, str]: diff --git a/src/winml/modelkit/pattern/op_input_gen/binary_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/binary_input_generator.py index c39efaa81..f09276c3f 100644 --- a/src/winml/modelkit/pattern/op_input_gen/binary_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/binary_input_generator.py @@ -24,7 +24,6 @@ from typing import Any from .op_input_gen import ( - InputConstraint, InputShapeConstraint, OpInputGenerator, QDQParameterConfig, @@ -70,7 +69,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Returns comprehensive input combinations for binary operators. Coverage strategy: @@ -367,7 +366,7 @@ def get_infinite_property_names(self) -> list[str]: y_name = self.op_input_names[1] return [f"{x_name}_shape", f"{y_name}_shape"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for binary operator inputs.""" return { self.op_input_names[0]: QDQParameterConfig( @@ -412,7 +411,7 @@ class MulInputGenerator(BinaryInputGenerator): op_name = "Mul" - def derive_properties(self, properties): + def derive_properties(self, properties: dict[str, Any]) -> dict[str, Any]: """Derive properties including broadcasting information.""" item = super().derive_properties(properties) return self._derive_broadcasting_properties(item) @@ -424,14 +423,14 @@ class DivInputGenerator(BinaryInputGenerator): op_name = "Div" - def derive_properties(self, properties): + def derive_properties(self, properties: dict[str, Any]) -> dict[str, Any]: """Derive properties including broadcasting information.""" item = super().derive_properties(properties) return self._derive_broadcasting_properties(item) def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Returns input combinations for Div operator with division-by-zero protection. Overrides parent method to set min_max=(2, 3) for the divisor (B parameter) @@ -445,10 +444,10 @@ def get_input_and_infinite_attribute_combinations( # Rebuild combinations with min_max set for B parameter new_combinations = [] for combo in combinations: - new_combo = {} + new_combo: dict[str, object] = {} for key, value in combo.items(): # Check if this is the second parameter (divisor) - if key == divisor_name: + if key == divisor_name and isinstance(value, InputShapeConstraint): # Create a new InputShapeConstraint with min_max set new_constraint = InputShapeConstraint(value.shape, min_max=(2, 3)) new_combo[key] = new_constraint @@ -465,7 +464,7 @@ class PowInputGenerator(BinaryInputGenerator): op_name = "Pow" - def derive_properties(self, properties): + def derive_properties(self, properties: dict[str, Any]) -> dict[str, Any]: """Derive properties including broadcasting information.""" item = super().derive_properties(properties) return self._derive_broadcasting_properties(item) @@ -523,7 +522,7 @@ class ComparisonInputGenerator(BinaryInputGenerator): here if needed in the future. """ - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for comparison operator inputs.""" return { self.op_input_names[0]: QDQParameterConfig( diff --git a/src/winml/modelkit/pattern/op_input_gen/binary_like_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/binary_like_input_generator.py index 03c450655..c546bbf95 100644 --- a/src/winml/modelkit/pattern/op_input_gen/binary_like_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/binary_like_input_generator.py @@ -13,13 +13,12 @@ with additional attributes or inputs. """ -from typing import Any +from typing import Any, cast import numpy as np from .binary_input_generator import BinaryInputGenerator from .op_input_gen import ( - InputConstraint, InputShapeConstraint, InputValueConstraint, QDQParameterConfig, @@ -93,7 +92,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for Mod operator. Strategy: @@ -128,7 +127,8 @@ def get_input_and_infinite_attribute_combinations( a_constraint = parent_combo[parent_first] # Get B shape from parent, but override with min_max to ensure non-zero divisor - b_shape = parent_combo[parent_second].shape + # Parent binary combos hold InputShapeConstraint values for these inputs. + b_shape = cast("InputShapeConstraint", parent_combo[parent_second]).shape # Create non-zero divisor: sample values from 1 to 10 # This avoids divide-by-zero while providing varied test data @@ -176,7 +176,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: # TODO: use InputValueConstraint or InputShapeConstraint for condition input? def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for Where operator. Strategy: @@ -210,12 +210,12 @@ def get_input_and_infinite_attribute_combinations( parent_first = parent_param_names[0] if parent_param_names else "A" parent_second = parent_param_names[1] if len(parent_param_names) > 1 else "B" - combinations = [] + combinations: list[dict[str, object]] = [] for parent_combo in parent_combinations: # Get X and Y shapes from parent combination - x_constraint = parent_combo[parent_first] - y_constraint = parent_combo[parent_second] + x_constraint = cast("InputShapeConstraint", parent_combo[parent_first]) + y_constraint = cast("InputShapeConstraint", parent_combo[parent_second]) x_shape = x_constraint.shape y_shape = y_constraint.shape @@ -312,7 +312,7 @@ def get_infinite_property_names(self) -> list[str]: f"{y_name}_shape", ] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Where operator inputs.""" return { self.op_input_names[0]: QDQParameterConfig(support_non_qdq=True), diff --git a/src/winml/modelkit/pattern/op_input_gen/constant_of_shape_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/constant_of_shape_input_generator.py index 7273e0794..754c4a644 100644 --- a/src/winml/modelkit/pattern/op_input_gen/constant_of_shape_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/constant_of_shape_input_generator.py @@ -10,7 +10,6 @@ from ...onnx import SupportedONNXType from .op_input_gen import ( - InputConstraint, InputValueConstraint, OpInputGenerator, register_runtime_checker_op, @@ -48,9 +47,9 @@ def get_finite_attribute_sets(self) -> dict[str, list]: ] } - def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, InputConstraint]]: + def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, object]]: """Return input combinations for ConstantOfShape.""" - combinations = [] + combinations: list[dict[str, object]] = [] # We want to test creating tensors of various shapes. # Common shapes from 1D to 5D diff --git a/src/winml/modelkit/pattern/op_input_gen/conv_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/conv_input_generator.py index 1064eac87..b141dbf18 100644 --- a/src/winml/modelkit/pattern/op_input_gen/conv_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/conv_input_generator.py @@ -18,7 +18,6 @@ from ...onnx import SupportedONNXType from .op_input_gen import ( - InputConstraint, InputShapeConstraint, OpInputGenerator, QDQParameterConfig, @@ -125,7 +124,7 @@ def get_attr_options(self, spatial_dims: int) -> dict[str, list]: "auto_pad": auto_pad_opts, } - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Conv operator inputs.""" # https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/operators/conv.py return { @@ -169,7 +168,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input and infinite attribute combinations for Conv operator.""" combinations = [] for x_shape, m, k_shape in self.get_base_conv_shapes(): @@ -259,7 +258,7 @@ def get_infinite_property_names(self) -> list[str]: "B_shape", ] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Conv operator inputs.""" # B/Y can be non-QDQ from P1 models # TODO: INT8 bias is a workaround — P1 models observed with INT8-quantized @@ -364,7 +363,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input and infinite attribute combinations for ConvTranspose operator.""" combinations = [] for x_shape, m, k_shape in self.get_base_conv_shapes(): diff --git a/src/winml/modelkit/pattern/op_input_gen/expand_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/expand_input_generator.py index 616ddb14d..6b58f787f 100644 --- a/src/winml/modelkit/pattern/op_input_gen/expand_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/expand_input_generator.py @@ -7,7 +7,6 @@ import numpy as np from .op_input_gen import ( - InputConstraint, InputShapeConstraint, InputValueConstraint, OpInputGenerator, @@ -67,7 +66,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Returns comprehensive input combinations for Expand operator. Coverage strategy: @@ -343,7 +342,7 @@ def get_infinite_property_names(self) -> list[str]: """ return ["input_shape", "input_value", "shape_shape", "shape_value"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Expand operator inputs.""" return { self.op_input_names[0]: QDQParameterConfig(support_activation=True), diff --git a/src/winml/modelkit/pattern/op_input_gen/flatten_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/flatten_input_generator.py index 1c1527328..13b3a9033 100644 --- a/src/winml/modelkit/pattern/op_input_gen/flatten_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/flatten_input_generator.py @@ -4,8 +4,9 @@ # -------------------------------------------------------------------------- """Input generator for Flatten operator.""" +from typing import Any + from .op_input_gen import ( - InputConstraint, InputShapeConstraint, OpInputGenerator, QDQParameterConfig, @@ -46,7 +47,7 @@ def get_finite_attribute_sets(self) -> dict[str, list[int]]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint | int]]: + ) -> list[dict[str, object]]: """Returns comprehensive input combinations for Flatten operator. Coverage strategy: @@ -76,7 +77,7 @@ def get_input_and_infinite_attribute_combinations( return combinations # noqa: RET504 - def derive_properties(self, properties: dict[str, any]) -> dict[str, any]: + def derive_properties(self, properties: dict[str, Any]) -> dict[str, Any]: """Derive additional properties for Flatten operator testing. Args: @@ -100,7 +101,7 @@ def get_infinite_property_names(self) -> list[str]: """Returns names of infinite properties for Flatten operator.""" return ["input_shape", "input_value", "attr_axis"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Flatten operator inputs.""" return { "input": QDQParameterConfig(support_activation=True), diff --git a/src/winml/modelkit/pattern/op_input_gen/global_pooling_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/global_pooling_input_generator.py index 99f87ab00..963a77d5a 100644 --- a/src/winml/modelkit/pattern/op_input_gen/global_pooling_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/global_pooling_input_generator.py @@ -9,7 +9,6 @@ """ from .op_input_gen import ( - InputConstraint, InputShapeConstraint, OpInputGenerator, QDQParameterConfig, @@ -46,7 +45,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Returns input combinations for global pooling operators. Coverage strategy: @@ -114,7 +113,7 @@ class GlobalAveragePoolInputGenerator(GlobalPoolingInputGenerator): op_name = "GlobalAveragePool" - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for GlobalAveragePool operator inputs.""" return { "X": QDQParameterConfig(support_activation=True), diff --git a/src/winml/modelkit/pattern/op_input_gen/indexing_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/indexing_input_generator.py index 790ec5b4f..b438365da 100644 --- a/src/winml/modelkit/pattern/op_input_gen/indexing_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/indexing_input_generator.py @@ -17,7 +17,6 @@ import numpy as np from .op_input_gen import ( - InputConstraint, InputShapeConstraint, InputValueConstraint, OpInputGenerator, @@ -72,7 +71,7 @@ def get_finite_attribute_sets(self) -> dict[str, list[int]]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for Gather operator. Strategy: @@ -135,7 +134,7 @@ def get_infinite_property_names(self) -> list[str]: """ return ["data_shape", "indices_shape", "attr_axis"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Gather operator inputs.""" return { "data": QDQParameterConfig(support_activation=True), @@ -171,7 +170,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for GatherElements operator.""" combinations = [] @@ -200,9 +199,9 @@ def get_input_and_infinite_attribute_combinations( indices_shape_2[axis_idx] = ( indices_shape_2[axis_idx] * 2 if indices_shape_2[axis_idx] > 0 else 2 ) - indices_shape_2 = tuple(indices_shape_2) + indices_shape_2_t = tuple(indices_shape_2) - indices_val_2 = InputShapeConstraint(indices_shape_2, min_max=(0, axis_size - 1)) + indices_val_2 = InputShapeConstraint(indices_shape_2_t, min_max=(0, axis_size - 1)) combinations.append( { "data": InputShapeConstraint(data_shape), @@ -227,7 +226,7 @@ def get_infinite_property_names(self) -> list[str]: """Return names of properties with infinite possible values.""" return ["data_shape", "indices_shape", "attr_axis"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for GatherElements operator inputs.""" return { "data": QDQParameterConfig(support_activation=True), @@ -265,7 +264,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for GatherND operator.""" combinations = [] @@ -409,7 +408,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for ScatterND operator. Optimized strategy for performance: @@ -424,7 +423,7 @@ def get_input_and_infinite_attribute_combinations( - Limits q to essential cases: scalar indices (q=1) and 2D indices (q=2) - Reduces total test combinations from ~500+ to ~100 cases """ - combinations = [] + combinations: list[dict[str, object]] = [] # Generate optimized test cases using nested loops # Coverage: data 1D-6D, focused k values, limited q values for performance @@ -518,7 +517,7 @@ def get_infinite_property_names(self) -> list[str]: """ return ["data_shape", "indices_value", "updates_shape"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for ScatterND operator inputs.""" return { self.op_input_names[0]: QDQParameterConfig(support_activation=True), # data @@ -561,7 +560,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for Unsqueeze operator. Strategy: @@ -570,7 +569,7 @@ def get_input_and_infinite_attribute_combinations( - Test single axis and multiple axes - Test both positive and negative axis values """ - combinations = [] + combinations: list[dict[str, object]] = [] # Define (data_shape, axes_values) test cases # Format: (data_shape, axes_array) @@ -655,7 +654,7 @@ def get_infinite_property_names(self) -> list[str]: """ return ["data_shape", "axes_value"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Unsqueeze operator inputs.""" return { "data": QDQParameterConfig(support_activation=True), @@ -701,7 +700,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for Split operator. Strategy: @@ -865,7 +864,7 @@ def infer_output_types( annotation = tags["type_vars"][f"{type_var_key}_{self.op_name}"] return [annotation] * num_output - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Split operator inputs.""" return { "input": QDQParameterConfig(support_activation=True), diff --git a/src/winml/modelkit/pattern/op_input_gen/matmul_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/matmul_input_generator.py index 72d18cdaf..a40a57a9f 100644 --- a/src/winml/modelkit/pattern/op_input_gen/matmul_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/matmul_input_generator.py @@ -43,7 +43,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for MatMul. Covers various matrix shapes and broadcasting patterns. @@ -58,7 +58,7 @@ def get_input_and_infinite_attribute_combinations( - Unidirectional broadcast B -> A (B has smaller/broadcastable dims) - Bidirectional broadcast A <-> B (both have broadcastable dims) """ - combinations = [] + combinations: list[dict[str, object]] = [] # Common matrix shape pairs for comprehensive testing shape_pairs = [ @@ -167,7 +167,7 @@ def get_infinite_property_names(self) -> list[str]: """ return ["A_shape", "B_shape"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for MatMul operator inputs.""" return { "A": QDQParameterConfig(support_activation=True, support_weight=True), @@ -211,7 +211,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for Gemm. Gemm operates on 2D matrices only. Shapes depend on transpose flags: @@ -223,7 +223,7 @@ def get_input_and_infinite_attribute_combinations( the inner dimensions match after transpose (a_shape[1] == b_shape[0]). For each valid pair we emit C options: full bias, 1D bias, scalar, or None. """ - combinations = [] + combinations: list[dict[str, object]] = [] def _after_transpose(shape: tuple[int, int], flag: int | None) -> tuple[int, int]: return (shape[1], shape[0]) if flag == 1 else shape @@ -261,7 +261,7 @@ def _after_transpose(shape: tuple[int, int], flag: int | None) -> tuple[int, int ] for c_option in c_options: - combination: dict[str, InputConstraint | int] = { + combination: dict[str, object] = { "A": InputShapeConstraint(a_shape), "B": InputShapeConstraint(b_shape), "C": c_option, @@ -302,7 +302,7 @@ def get_infinite_property_names(self) -> list[str]: """ return ["A_shape", "B_shape", "C_shape", "C_value"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Gemm operator inputs.""" # https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/operators/gemm.py return { diff --git a/src/winml/modelkit/pattern/op_input_gen/normalization_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/normalization_input_generator.py index 53c0440da..9b9330ee4 100644 --- a/src/winml/modelkit/pattern/op_input_gen/normalization_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/normalization_input_generator.py @@ -17,7 +17,6 @@ from ...onnx import SupportedONNXType from .op_input_gen import ( - InputConstraint, InputShapeConstraint, InputValueConstraint, OpInputGenerator, @@ -155,14 +154,14 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for BatchNormalization. For each data shape: - X has the full shape (N, C, ...) - scale, bias, mean, var all have shape (C,) """ - combinations = [] + combinations: list[dict[str, object]] = [] # BatchNormalization uses mean/var in older opsets and input_mean/input_var in newer opsets. x_name, scale_name, bias_name, mean_name, var_name = self.op_input_names[:5] @@ -228,7 +227,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for GroupNormalization. For each data shape, test multiple num_groups values where @@ -301,12 +300,12 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for InstanceNormalization. InstanceNorm requires at least 3D tensors (N, C, spatial_dims). """ - combinations = [] + combinations: list[dict[str, object]] = [] for shape in self.get_common_data_shapes(): # Skip 2D shapes - InstanceNorm requires spatial dimensions @@ -376,7 +375,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for LayerNormalization. Test different axis values: -1 (last dim), -2 (last 2 dims), etc. @@ -433,9 +432,9 @@ def get_finite_attribute_sets(self) -> dict[str, list]: """Return finite attribute values for LpNormalization.""" return {"p": [1, 2]} - def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, InputConstraint]]: + def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, object]]: """Return input combinations for LpNormalization.""" - combinations = [] + combinations: list[dict[str, object]] = [] for shape in self.get_common_data_shapes(): if len(shape) < 3: continue @@ -483,7 +482,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for MeanVarianceNormalization. Test different axes combinations based on tensor rank. @@ -538,7 +537,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for RMSNormalization. Test different axis values: -1 (last dim), 0 (all dims). diff --git a/src/winml/modelkit/pattern/op_input_gen/op_input_gen.py b/src/winml/modelkit/pattern/op_input_gen/op_input_gen.py index a1b95cb17..04fc7de66 100644 --- a/src/winml/modelkit/pattern/op_input_gen/op_input_gen.py +++ b/src/winml/modelkit/pattern/op_input_gen/op_input_gen.py @@ -10,7 +10,7 @@ import time import zlib from abc import ABC, abstractmethod -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterator, Sequence from pathlib import Path from threading import Lock from typing import TYPE_CHECKING, Any @@ -488,7 +488,7 @@ def _attr_has_default(attr_info: Any) -> bool: default_val = attr_info.default_value return default_val.ByteSize() > 0 if default_val else False - def _type_var_combination_iter(self) -> Any: + def _type_var_combination_iter(self) -> Iterator[dict[str, str]]: options = [ [ (name, value.annotation) for value in dtypes @@ -505,7 +505,7 @@ def _apply_type_var_combination( type_annotation = type_annotation.replace(type_var, dtype) return type_annotation - def _finite_attribute_combination_iter(self) -> Any: + def _finite_attribute_combination_iter(self) -> Iterator[dict[str, Any]]: finite_attribute_sets = self.get_finite_attribute_sets() options = [ [(name, value) for value in values] for name, values in finite_attribute_sets.items() @@ -514,7 +514,9 @@ def _finite_attribute_combination_iter(self) -> Any: # Omit attributes with None value to simulate them being not provided yield {k: v for k, v in attr_comb if v is not None} - def _optional_input_combination_iter(self, input_comb: dict[str, InputConstraint]) -> Any: + def _optional_input_combination_iter( + self, input_comb: dict[str, object] + ) -> Iterator[dict[str, object]]: """Iterate over combinations of optional inputs being provided or None. For each optional input present in input_comb, generates combinations where @@ -545,7 +547,8 @@ def _optional_input_combination_iter(self, input_comb: dict[str, InputConstraint for comb in itertools.product(*options): use_value_map = dict(comb) - modified_input_comb = {} + # None marks an optional input deliberately omitted from this combo. + modified_input_comb: dict[str, object] = {} for k, v in input_comb.items(): if k in use_value_map: if use_value_map[k]: @@ -559,7 +562,9 @@ def _optional_input_combination_iter(self, input_comb: dict[str, InputConstraint modified_input_comb[k] = v yield modified_input_comb - def _optional_attr_combination_iter(self, input_comb: dict[str, InputConstraint]) -> Any: + def _optional_attr_combination_iter( + self, input_comb: dict[str, object] + ) -> Iterator[dict[str, object]]: """Iterate over combinations of optional attributes being provided or omitted. Only covers attributes without defaults. @@ -594,7 +599,7 @@ def _optional_attr_combination_iter(self, input_comb: dict[str, InputConstraint] for comb in itertools.product(*options): use_value_map = dict(comb) - modified_input_comb = {} + modified_input_comb: dict[str, object] = {} for k, v in input_comb.items(): if k in use_value_map: if use_value_map[k]: @@ -737,13 +742,13 @@ def _iter_should_qdq_combinations( for output in self.schema.outputs: output_name = output.name output_config = qdq_config.get(output_name) - options = [] + out_options: list[bool] = [] if output_config is None or output_config.support_activation: - options.append(True) + out_options.append(True) if output_config is not None and output_config.support_non_qdq: - options.append(False) + out_options.append(False) output_names.append(output_name) - output_option_lists.append(options) + output_option_lists.append(out_options) all_names = input_names + output_names all_option_lists = input_option_lists + output_option_lists @@ -812,6 +817,10 @@ def _create_model( ) input_names_to_process = self.op_input_names + variadic_keys + # QDQ generation requires a configured generator; capture it once so each + # per-input QDQ branch can gate on (quant_type, qdq_generator) both present. + qdq_generator = self.qdq_generator + # Iterate over input names in order to maintain correct positional ordering. # For optional inputs that are None, use empty string "" as placeholder. # This is required by ONNX spec for operators with multiple optional inputs. @@ -824,13 +833,9 @@ def _create_model( input_shape = list(input_value.shape) if is_constant_map[input_name]: # Constant input -> create initializer - if ( - qdq_types is not None - and input_name in qdq_types - and qdq_types[input_name] is not None - ): + quant_type = qdq_types.get(input_name) if qdq_types is not None else None + if quant_type is not None and qdq_generator is not None: # For QDQ: create quantized initializer with DequantizeLinear - quant_type = qdq_types[input_name] dq_output_name = f"{input_name}_dq" scale_name = f"{input_name}_scale" zp_name = f"{input_name}_zero_point" @@ -863,7 +868,7 @@ def _create_model( "DequantizeLinear", inputs=[input_name, scale_name, zp_name], outputs=[dq_output_name], - domain=self.qdq_generator.domain.value, + domain=qdq_generator.domain.value, ) input_dq_nodes.append(dq_node) node_inputs.append(dq_output_name) @@ -886,13 +891,9 @@ def _create_model( if axis_idx < len(input_shape): input_shape[axis_idx] = -1 # ONNX unknown dimension - if ( - qdq_types is not None - and input_name in qdq_types - and qdq_types[input_name] is not None - ): + quant_type = qdq_types.get(input_name) if qdq_types is not None else None + if quant_type is not None and qdq_generator is not None: # For QDQ: create quantized graph input with DequantizeLinear - quant_type = qdq_types[input_name] dq_output_name = f"{input_name}_dq" scale_name = f"{input_name}_scale" zp_name = f"{input_name}_zero_point" @@ -925,7 +926,7 @@ def _create_model( "DequantizeLinear", inputs=[input_name, scale_name, zp_name], outputs=[dq_output_name], - domain=self.qdq_generator.domain.value, + domain=qdq_generator.domain.value, ) input_dq_nodes.append(dq_node) node_inputs.append(dq_output_name) @@ -953,14 +954,13 @@ def _create_model( schema_output_name = ( self.schema.outputs[idx].name if idx < len(self.schema.outputs) else None ) - if ( - qdq_types is not None - and schema_output_name is not None - and schema_output_name in qdq_types - and qdq_types[schema_output_name] is not None - ): + quant_type = ( + qdq_types.get(schema_output_name) + if qdq_types is not None and schema_output_name is not None + else None + ) + if quant_type is not None and qdq_generator is not None: # For QDQ: operator outputs to intermediate, then Q to final output - quant_type = qdq_types[schema_output_name] op_output_name = f"op_output_{idx}" final_output_name = f"output_{idx}" scale_name = f"output_{idx}_scale" @@ -988,7 +988,7 @@ def _create_model( "QuantizeLinear", inputs=[op_output_name, scale_name, zp_name], outputs=[final_output_name], - domain=self.qdq_generator.domain.value, + domain=qdq_generator.domain.value, ) output_q_nodes.append(q_node) @@ -1121,7 +1121,7 @@ def iter_const_and_dynamic_models(self, kwargs: dict[str, Any], tags: dict[str, Dynamic axes are only applied to non-constant graph inputs. """ qdq_config = self.get_qdq_config() - qdq_tested_types: set[tuple[tuple[str, str | None], ...]] = set() + qdq_tested_types: set[tuple[tuple[str, str | None] | tuple[str, bool], ...]] = set() for is_constant_map in self._iter_constant_combinations(kwargs): dynamic_axes_variants = self._build_dynamic_axes_variants(kwargs, is_constant_map) for dynamic_axes in dynamic_axes_variants: @@ -1129,11 +1129,11 @@ def iter_const_and_dynamic_models(self, kwargs: dict[str, Any], tags: dict[str, final_tags = tags.copy() final_tags["dynamic_axes"] = dynamic_axes - if self.qdq_generator is None: - # qdq_generator will set input_is_constant - # when parameters are not supported for - # quantization, so only set it here when not - # using qdq_generator + if self.qdq_generator is None or qdq_config is None: + # No generator, or this op doesn't support QDQ (get_qdq_config + # returned None) — emit a plain model. qdq_generator sets + # input_is_constant when params aren't quantizable, so only set + # it here on the non-QDQ path. final_tags["input_is_constant"] = is_constant_map output_dtypes = self.infer_output_types(kwargs, final_tags) model = self._create_model( @@ -1165,9 +1165,9 @@ def iter_qdq_combinations( kwargs: dict[str, Any], tags: dict[str, Any], is_constant_map: dict[str, bool], - should_qdq_map: dict[str, bool], - qdq_config: dict[str, QDQParameterConfig], - qdq_tested_types: set[tuple[tuple[str, str | None], ...]], + should_qdq_map: dict[str, bool | SupportedONNXType], + qdq_config: dict[str, QDQParameterConfig] | None, + qdq_tested_types: set[tuple[tuple[str, str | None] | tuple[str, bool], ...]], ) -> Any: """Iterate over different QDQ combinations. @@ -1185,6 +1185,7 @@ def iter_qdq_combinations( if self.qdq_generator is None: return + # qdq_config may be None when called directly (e.g. unit tests) — yield nothing. if qdq_config is None: return @@ -1348,6 +1349,12 @@ def iter_qdq_combinations( if isinstance(should_val, SupportedONNXType): qdq_types[input_name] = should_val elif is_constant: + # is_constant implies needs_weight_iteration, so weight_onnx_type + # is a real type here, never the `[None]` run-once placeholder. + if weight_onnx_type is None: + raise RuntimeError( + "weight type placeholder reached for a constant input" + ) qdq_types[input_name] = SupportedONNXType.from_onnx_type(weight_onnx_type) else: qdq_types[input_name] = SupportedONNXType.from_onnx_type( @@ -1371,11 +1378,11 @@ def iter_qdq_combinations( if should_qdq_map.get(output_name) is False: qdq_types[output_name] = None # No Q per should_qdq_map else: - output_config = qdq_config.get(output_name) + out_cfg = qdq_config.get(output_name) if ( - output_config is not None - and not output_config.support_weight - and not output_config.support_activation + out_cfg is not None + and not out_cfg.support_weight + and not out_cfg.support_activation ): qdq_types[output_name] = None # support_non_qdq only else: @@ -1667,7 +1674,7 @@ def _dry_run_result() -> dict[str, Any]: break # Add final_tags as ONNX metadata - def _json_default(o: Any): + def _json_default(o: Any) -> Any: if isinstance(o, np.ndarray): return o.tolist() if isinstance(o, np.generic): @@ -1725,7 +1732,7 @@ def get_finite_attribute_sets(self) -> dict[str, list[Any]]: @abstractmethod def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Returns a list of dicts: {input_or_attribute_name: constraint_or_value}. Representing a combination of inputs to test the op with. @@ -2004,7 +2011,7 @@ def get_finite_attribute_sets(self) -> dict[str, list[Any]]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for Reshape.""" return [ { diff --git a/src/winml/modelkit/pattern/op_input_gen/pad_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/pad_input_generator.py index 37bca04cd..dacf62985 100644 --- a/src/winml/modelkit/pattern/op_input_gen/pad_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/pad_input_generator.py @@ -7,7 +7,6 @@ import numpy as np from .op_input_gen import ( - InputConstraint, InputShapeConstraint, InputValueConstraint, OpInputGenerator, @@ -49,7 +48,7 @@ def get_finite_attribute_sets(self) -> dict[str, list[str]]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Returns comprehensive input combinations for Pad operator. Coverage strategy: @@ -158,7 +157,7 @@ def get_infinite_property_names(self) -> list[str]: """Returns names of infinite properties for Pad operator.""" return ["data_shape", "pads_value", "constant_value_value"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Pad operator inputs.""" return { "data": QDQParameterConfig(support_activation=True), # data diff --git a/src/winml/modelkit/pattern/op_input_gen/pooling_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/pooling_input_generator.py index 5d8958d9a..8fc733282 100644 --- a/src/winml/modelkit/pattern/op_input_gen/pooling_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/pooling_input_generator.py @@ -10,7 +10,6 @@ import numpy as np from .op_input_gen import ( - InputConstraint, InputShapeConstraint, OpInputGenerator, QDQParameterConfig, @@ -51,7 +50,7 @@ def get_infinite_property_names(self) -> list[str]: + ["attr_kernel_shape", "attr_strides", "attr_pads", "attr_dilations"] ) - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for pooling operator inputs.""" return {self.op_input_names[0]: QDQParameterConfig(support_activation=True)} @@ -70,7 +69,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: "storage_order": [0, 1], } - def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, InputConstraint]]: + def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, object]]: """Return input and infinite attribute combinations for MaxPool operator.""" combinations = [] for x_shape, kernel_shape in self.get_common_shapes_and_kernels(): @@ -191,7 +190,7 @@ def derive_properties(self, properties: dict) -> dict: return item - def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, InputConstraint]]: + def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, object]]: """Return input and infinite attribute combinations for AveragePool operator.""" combinations = [] for x_shape, kernel_shape in self.get_common_shapes_and_kernels(): @@ -243,7 +242,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: "p": [1, 2], # L1 and L2 norm } - def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, InputConstraint]]: + def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, object]]: """Return input and infinite attribute combinations for LpPool operator.""" combinations = [] for x_shape, kernel_shape in self.get_common_shapes_and_kernels(): diff --git a/src/winml/modelkit/pattern/op_input_gen/reduction_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/reduction_input_generator.py index a930dc693..b314b7714 100644 --- a/src/winml/modelkit/pattern/op_input_gen/reduction_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/reduction_input_generator.py @@ -18,7 +18,6 @@ import numpy as np from .op_input_gen import ( - InputConstraint, InputShapeConstraint, InputValueConstraint, OpInputGenerator, @@ -52,7 +51,9 @@ def get_common_data_shapes(self) -> list[tuple[int, ...]]: (2, 2, 2, 2, 2, 3), # 6D ] - def get_common_axes_combinations(self, shape: tuple[int, ...]) -> list[np.ndarray]: + def get_common_axes_combinations( + self, shape: tuple[int, ...] + ) -> list[list[int]] | list[np.ndarray]: """Return common axes patterns to test for given shape. For each shape, test: @@ -68,7 +69,7 @@ def get_common_axes_combinations(self, shape: tuple[int, ...]) -> list[np.ndarra List of axes arrays (always explicit, never None) """ rank = len(shape) - axes_combinations = [ + axes_combinations: list[list[int]] = [ list(range(rank)), # All axes (explicit) [-1], # Reduce last axis [0], # Reduce first axis @@ -82,7 +83,7 @@ def get_common_axes_combinations(self, shape: tuple[int, ...]) -> list[np.ndarra # basically all Reduce* ops have "axes" as attrribute for opset <=17 # and as input for opset >=18, EXCEPT ReduceSum, which has "axes" as input since opset 13 if "axes" in self.op_input_names: - axes_combinations = [np.array(axes, dtype=np.int64) for axes in axes_combinations] + return [np.array(axes, dtype=np.int64) for axes in axes_combinations] return axes_combinations @@ -100,7 +101,7 @@ def get_finite_attribute_sets(self) -> dict[str, list[Any]]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for reduction operators. Strategy: @@ -156,7 +157,7 @@ def get_infinite_property_names(self) -> list[str]: + ["attr_axes"] ) - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for reduction operator inputs.""" return { "data": QDQParameterConfig(support_activation=True), @@ -310,7 +311,7 @@ def get_finite_attribute_sets(self) -> dict[str, list[Any]]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for TopK operator. Strategy: @@ -318,7 +319,7 @@ def get_input_and_infinite_attribute_combinations( - For each shape, K should be smaller than the dimension being reduced - K is an input (not an attribute) """ - combinations = [] + combinations: list[dict[str, object]] = [] # Test shapes from 1D through 6D test_shapes = [ @@ -371,7 +372,7 @@ def get_infinite_property_names(self) -> list[str]: f"{input_name}_shape" for input_name in self.op_input_names ] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for TopK operator inputs.""" return { self.op_input_names[0]: QDQParameterConfig(support_activation=True), # "X" diff --git a/src/winml/modelkit/pattern/op_input_gen/reshape_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/reshape_input_generator.py index 5c2c582c5..3c2c03419 100644 --- a/src/winml/modelkit/pattern/op_input_gen/reshape_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/reshape_input_generator.py @@ -2,10 +2,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +from typing import Any + import numpy as np from .op_input_gen import ( - InputConstraint, InputShapeConstraint, InputValueConstraint, OpInputGenerator, @@ -43,7 +44,7 @@ def get_finite_attribute_sets(self) -> dict[str, list[int]]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Returns comprehensive input combinations for Reshape operator. Coverage strategy: @@ -243,7 +244,7 @@ def get_input_and_infinite_attribute_combinations( }, ] - def derive_properties(self, properties: dict[str, any]) -> dict[str, any]: + def derive_properties(self, properties: dict[str, Any]) -> dict[str, Any]: """Derive additional properties for Reshape operator testing. Args: diff --git a/src/winml/modelkit/pattern/op_input_gen/resize_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/resize_input_generator.py index b38b80bc8..b50c2d88a 100644 --- a/src/winml/modelkit/pattern/op_input_gen/resize_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/resize_input_generator.py @@ -11,7 +11,6 @@ import numpy as np from .op_input_gen import ( - InputConstraint, InputShapeConstraint, InputValueConstraint, OpInputGenerator, @@ -55,7 +54,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: """Return finite attribute combinations for Resize.""" return {"antialias": [0, 1]} - def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, InputConstraint]]: + def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, object]]: """Return input combinations for Resize operator. CRITICAL: Always provide explicit values for all inputs. @@ -267,7 +266,7 @@ def get_infinite_property_names(self) -> list[str]: + ["attr_cubic_coeff_a", "attr_extrapolation_value", "attr_axes"] ) - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Resize operator inputs.""" return { self.op_input_names[0]: QDQParameterConfig(support_activation=True), diff --git a/src/winml/modelkit/pattern/op_input_gen/rotary_embedding_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/rotary_embedding_input_generator.py index 3275f6858..352b15075 100644 --- a/src/winml/modelkit/pattern/op_input_gen/rotary_embedding_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/rotary_embedding_input_generator.py @@ -10,7 +10,6 @@ """ from .op_input_gen import ( - InputConstraint, InputShapeConstraint, OpInputGenerator, register_runtime_checker_op, @@ -58,7 +57,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for RotaryEmbedding. Tests both 4D and 3D input formats with varying rotary_embedding_dim. diff --git a/src/winml/modelkit/pattern/op_input_gen/shape_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/shape_input_generator.py index 3d1531a6d..50b9f0f4b 100644 --- a/src/winml/modelkit/pattern/op_input_gen/shape_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/shape_input_generator.py @@ -5,7 +5,6 @@ """Input generator for Shape operator.""" from .op_input_gen import ( - InputConstraint, InputShapeConstraint, OpInputGenerator, QDQParameterConfig, @@ -53,7 +52,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint | int | None]]: + ) -> list[dict[str, object]]: """Returns comprehensive input combinations for Shape operator. Coverage strategy: @@ -412,7 +411,7 @@ def get_infinite_property_names(self) -> list[str]: input_name = self.op_input_names[0] return [f"{input_name}_shape", f"{input_name}_value", "attr_start", "attr_end"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Reshape operator inputs.""" return { "data": QDQParameterConfig(support_activation=True), diff --git a/src/winml/modelkit/pattern/op_input_gen/slice_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/slice_input_generator.py index 1910ef982..7e8cb5b98 100644 --- a/src/winml/modelkit/pattern/op_input_gen/slice_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/slice_input_generator.py @@ -13,7 +13,6 @@ import numpy as np from .op_input_gen import ( - InputConstraint, InputShapeConstraint, InputValueConstraint, OpInputGenerator, @@ -69,7 +68,7 @@ def get_finite_attribute_sets(self) -> dict[str, list[Any]]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for Slice operator. Focused coverage strategy: @@ -78,7 +77,7 @@ def get_input_and_infinite_attribute_combinations( - steps: [1,...,1], [-1,...,-1], [1,2,1,1,...,1] - starts/ends: positive values covering [0, dim-1] and [1, dim-2] """ - combinations = [] + combinations: list[dict[str, object]] = [] for data_shape in _SLICE_DATA_SHAPES: rank = len(data_shape) @@ -95,7 +94,7 @@ def get_input_and_infinite_attribute_combinations( axis_dims = [data_shape[int(ax)] for ax in axes] # Step patterns: all 1s, all -1s, and mixed [1,2,1,...] - steps_patterns = [ + steps_patterns: list[np.ndarray] = [ np.ones(num_axes, dtype=np.int64), -np.ones(num_axes, dtype=np.int64), ] @@ -113,7 +112,7 @@ def get_input_and_infinite_attribute_combinations( is_all_backward = np.all(steps <= -1) # Determine start/end patterns based on step direction - starts_ends_patterns = [] + starts_ends_patterns: list[tuple[np.ndarray, np.ndarray]] = [] if is_all_forward: # Forward slicing patterns diff --git a/src/winml/modelkit/pattern/op_input_gen/squeeze_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/squeeze_input_generator.py index 5bc1cd168..433217dd7 100644 --- a/src/winml/modelkit/pattern/op_input_gen/squeeze_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/squeeze_input_generator.py @@ -9,7 +9,6 @@ import numpy as np from .op_input_gen import ( - InputConstraint, InputShapeConstraint, InputValueConstraint, OpInputGenerator, @@ -74,9 +73,9 @@ def get_finite_attribute_sets(self) -> dict[str, list[Any]]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Returns comprehensive input combinations for Squeeze operator.""" - combinations = [] + combinations: list[dict[str, object]] = [] # ===== Systematic generation for 0D through 6D ===== # Test 1: 0D tensor @@ -166,7 +165,7 @@ def get_infinite_property_names(self) -> list[str]: """Returns names of infinite properties for Squeeze operator.""" return ["data_shape", "data_value", "axes_value"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Squeeze operator inputs.""" return { "data": QDQParameterConfig(support_activation=True), diff --git a/src/winml/modelkit/pattern/op_input_gen/transpose_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/transpose_input_generator.py index fce91f796..24592fae6 100644 --- a/src/winml/modelkit/pattern/op_input_gen/transpose_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/transpose_input_generator.py @@ -5,7 +5,6 @@ """Input generator for Transpose ONNX operator.""" from .op_input_gen import ( - InputConstraint, InputShapeConstraint, OpInputGenerator, QDQParameterConfig, @@ -46,7 +45,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for Transpose operator. Strategy: @@ -123,7 +122,7 @@ def get_infinite_property_names(self) -> list[str]: input_param_name = self.op_input_names[0] return [f"{input_param_name}_shape", "attr_perm"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Transpose operator inputs.""" return { "data": QDQParameterConfig(support_non_qdq=True, support_activation=True), diff --git a/src/winml/modelkit/pattern/op_input_gen/unary_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/unary_input_generator.py index 66c217c41..3243f4cf9 100644 --- a/src/winml/modelkit/pattern/op_input_gen/unary_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/unary_input_generator.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from .op_input_gen import ( - InputConstraint, InputShapeConstraint, OpInputGenerator, QDQParameterConfig, @@ -54,7 +53,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Returns comprehensive input combinations for unary operators. Coverage strategy: @@ -72,7 +71,7 @@ def get_input_and_infinite_attribute_combinations( # Get the input parameter name from the operator schema input_param_name = self.op_input_names[0] - result = [ + result: list[dict[str, object]] = [ # ===== 1D Input (dimension 1) ===== # Scalar (single element) {input_param_name: InputShapeConstraint((1,))}, @@ -114,7 +113,7 @@ def get_infinite_property_names(self) -> list[str]: input_param_name = self.op_input_names[0] return [f"{input_param_name}_shape"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for unary operator inputs.""" return { self.op_input_names[0]: QDQParameterConfig(support_activation=True), @@ -244,7 +243,7 @@ class IsNaNInputGenerator(UnaryInputGenerator): op_name = "IsNaN" - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for IsNaN operator inputs.""" return { self.op_input_names[0]: QDQParameterConfig(support_activation=True), @@ -302,7 +301,7 @@ class ReluInputGenerator(UnaryInputGenerator): op_name = "Relu" - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Relu operator inputs.""" # From p1 model MobileNet return { diff --git a/src/winml/modelkit/pattern/op_input_gen/unary_like_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/unary_like_input_generator.py index eb5e5b2fd..79cc0e1d1 100644 --- a/src/winml/modelkit/pattern/op_input_gen/unary_like_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/unary_like_input_generator.py @@ -10,13 +10,12 @@ attributes or input handling. """ -from typing import Any +from typing import Any, cast import numpy as np from ...onnx import SupportedONNXType from .op_input_gen import ( - InputConstraint, InputShapeConstraint, InputValueConstraint, QDQParameterConfig, @@ -240,7 +239,7 @@ def derive_properties(self, properties: dict) -> dict: item["axis_size_is_one"] = False return item - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Softmax operator inputs.""" return { "input": QDQParameterConfig(support_activation=True), @@ -299,7 +298,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for ArgMax/ArgMin. Iterate over all possible axis values for each input shape. @@ -311,10 +310,11 @@ def get_input_and_infinite_attribute_combinations( # Get shapes from parent class parent_combinations = super().get_input_and_infinite_attribute_combinations() - combinations = [] + combinations: list[dict[str, object]] = [] for combo in parent_combinations: - if input_name in combo and isinstance(combo[input_name], InputShapeConstraint): - shape = combo[input_name].shape + constraint = combo.get(input_name) + if isinstance(constraint, InputShapeConstraint): + shape = constraint.shape data_dim = len(shape) # Generate axis values: negative (-1) first, then positive (0, 1, ..., data_dim-1) # Negative axis is more common in real models (e.g., axis=-1 for last dim) @@ -405,7 +405,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for LRN. LRN requires at least 4D input with format (N x C x H x W) or higher dimensions. @@ -450,7 +450,7 @@ def _get_clip_param_names(self) -> tuple[str, str]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for Clip. Test with various tensor shapes and optional min/max scalars. @@ -497,7 +497,7 @@ def get_infinite_property_names(self) -> list[str]: parent_infinite_props = super().get_infinite_property_names() return [*parent_infinite_props, min_name + "_value", max_name + "_value"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Clip operator inputs.""" return { "input": QDQParameterConfig(support_activation=True), @@ -525,7 +525,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for CumSum. Test with various tensor shapes and axis values. @@ -536,9 +536,9 @@ def get_input_and_infinite_attribute_combinations( shape_constraints = super().get_input_and_infinite_attribute_combinations() - combinations = [] + combinations: list[dict[str, object]] = [] for constraint in shape_constraints: - shape = constraint[x_name].shape + shape = cast("InputShapeConstraint", constraint[x_name]).shape axis_values = range(-len(shape), len(shape)) # Only add valid axis values for the shape combinations.extend( @@ -550,7 +550,7 @@ def get_input_and_infinite_attribute_combinations( ) return combinations - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for CumSum operator inputs.""" return { "x": QDQParameterConfig(support_activation=True), @@ -575,7 +575,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for Dropout. Test with various tensor shapes and ratio/training_mode. @@ -629,7 +629,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: """ from onnx import TensorProto - result = { + result: dict[str, list[Any]] = { "to": [ int(TensorProto.BOOL), int(TensorProto.DOUBLE), @@ -672,7 +672,7 @@ def infer_output_types( onnx_type = SupportedONNXType.from_tensor_proto_type(output_type_enum) return [onnx_type.annotation] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Cast operator inputs.""" return { "input": QDQParameterConfig(support_activation=True, support_non_qdq=True), @@ -696,7 +696,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for CastLike. Test with various source and target types. diff --git a/src/winml/modelkit/pattern/op_input_gen/variadic_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/variadic_input_generator.py index aa4a689c7..fdf196c12 100644 --- a/src/winml/modelkit/pattern/op_input_gen/variadic_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/variadic_input_generator.py @@ -9,7 +9,6 @@ import numpy as np from .op_input_gen import ( - InputConstraint, InputShapeConstraint, OpInputGenerator, QDQParameterConfig, @@ -42,7 +41,7 @@ def get_finite_attribute_sets(self) -> dict[str, list[Any]]: def get_input_and_infinite_attribute_combinations( self, - ) -> list[dict[str, InputConstraint]]: + ) -> list[dict[str, object]]: """Return input combinations for Concat. Test cases systematically cover: @@ -144,7 +143,7 @@ def get_infinite_property_names(self) -> list[str]: """ return ["inputs_shape", "inputs_value", "attr_axis", "inputs_is_constant"] - def get_qdq_config(self): + def get_qdq_config(self) -> dict[str, QDQParameterConfig]: """Return QDQ configuration for Concat operator inputs.""" return { "inputs": QDQParameterConfig(support_activation=True), diff --git a/src/winml/modelkit/pattern/rmsnorm_patterns.py b/src/winml/modelkit/pattern/rmsnorm_patterns.py index 4bf4cc311..2d40f8ac4 100644 --- a/src/winml/modelkit/pattern/rmsnorm_patterns.py +++ b/src/winml/modelkit/pattern/rmsnorm_patterns.py @@ -27,13 +27,12 @@ from .base import ( Pattern, PatternInputGenerator, - PatternMatchResult, PatternMismatchedError, PatternSchema, Skeleton, - SkeletonMatchResult, register_pattern_input_generator, ) +from .match import PatternMatchResult, SkeletonMatchResult from .op_input_gen import InputShapeConstraint, InputValueConstraint from .utils import ( get_attribute_proto_value, @@ -251,6 +250,8 @@ def _infer_schema_attributes( if axes_value is None: raise PatternMismatchedError("ReduceMean missing axes attribute") + if axes_value is None: + raise PatternMismatchedError("ReduceMean axes tensor value is None") if len(axes_value) != 1: raise PatternMismatchedError( f"Only single-axis normalization supported, got axes={axes_value}" @@ -472,7 +473,7 @@ def _get_normalized_dim(self, inputs: dict[str, np.ndarray], attributes: dict[st axis = attributes["axis"] rank = len(x_shape) normalized_axis = axis if axis >= 0 else rank + axis - return x_shape[normalized_axis] + return int(x_shape[normalized_axis]) def get_internal_constants_and_attributes( self, diff --git a/src/winml/modelkit/pattern/transpose_patterns.py b/src/winml/modelkit/pattern/transpose_patterns.py index 237dd24b0..b8ccd9cf5 100644 --- a/src/winml/modelkit/pattern/transpose_patterns.py +++ b/src/winml/modelkit/pattern/transpose_patterns.py @@ -18,13 +18,12 @@ from .base import ( Pattern, PatternInputGenerator, - PatternMatchResult, PatternMismatchedError, PatternSchema, Skeleton, - SkeletonMatchResult, register_pattern_input_generator, ) +from .match import PatternMatchResult, SkeletonMatchResult from .op_input_gen import InputShapeConstraint diff --git a/src/winml/modelkit/pattern/utils.py b/src/winml/modelkit/pattern/utils.py index 6fe90c418..78766e1c8 100644 --- a/src/winml/modelkit/pattern/utils.py +++ b/src/winml/modelkit/pattern/utils.py @@ -12,7 +12,7 @@ from __future__ import annotations import json -from typing import Any +from typing import Any, cast import numpy as np from google.protobuf import json_format @@ -82,7 +82,7 @@ def validate_scale_bias_shape_for_axis( if non_one_pos_in_input != normalized_axis: return False - return sb_shape[non_one_pos_in_sb] == normalized_dim + return bool(sb_shape[non_one_pos_in_sb] == normalized_dim) def get_tensor_shape(tensor_name: str, matcher: Any) -> tuple | None: @@ -96,8 +96,8 @@ def get_tensor_shape(tensor_name: str, matcher: Any) -> tuple | None: Shape tuple if available, None otherwise. """ if tensor_name in matcher.tensor_values: - return matcher.tensor_values[tensor_name].shape - return matcher.tensor_shapes.get(tensor_name) + return cast("tuple[Any, ...] | None", matcher.tensor_values[tensor_name].shape) + return cast("tuple[Any, ...] | None", matcher.tensor_shapes.get(tensor_name)) def make_stable_node_key(node: Any, index: int) -> str: @@ -109,7 +109,7 @@ def make_stable_node_key(node: Any, index: int) -> str: # From model_utils.py (pattern-relevant functions) # --------------------------------------------------------------------------- -DTYPE_MAP = { +DTYPE_MAP: dict[int, str] = { TensorProto.FLOAT: "FLOAT", TensorProto.UINT4: "UINT4", TensorProto.UINT8: "UINT8", @@ -134,25 +134,26 @@ def dtype_from_tensorproto_enum(tp: int) -> str: return DTYPE_MAP.get(tp, f"unknown({tp})") -def shape_and_dtype_from_valueinfo(vi: ValueInfoProto) -> tuple[list | None, str | None]: +def shape_and_dtype_from_valueinfo( + vi: ValueInfoProto, +) -> tuple[tuple[int | str | None, ...] | None, str | None]: """Extract shape and dtype from a ValueInfoProto.""" if not vi.type.HasField("tensor_type"): return (None, None) tt = vi.type.tensor_type dtype = dtype_from_tensorproto_enum(tt.elem_type) - shape = [] - if tt.HasField("shape"): - for d in tt.shape.dim: - if d.HasField("dim_value"): - shape.append(d.dim_value) - elif d.HasField("dim_param"): - shape.append(d.dim_param) - else: - shape.append(None) - else: - shape = None - return (tuple(shape) if shape is not None else None, dtype) + if not tt.HasField("shape"): + return (None, dtype) + shape: list[int | str | None] = [] + for d in tt.shape.dim: + if d.HasField("dim_value"): + shape.append(d.dim_value) + elif d.HasField("dim_param"): + shape.append(d.dim_param) + else: + shape.append(None) + return (tuple(shape), dtype) def collect_valueinfo_dict(model: ModelProto) -> dict[str, ValueInfoProto]: @@ -168,7 +169,9 @@ def collect_initializers(model: ModelProto) -> dict[str, TensorProto]: return {init.name: init for init in model.graph.initializer} -def get_op_input_properties(schema: OpSchema): +def get_op_input_properties( + schema: OpSchema, +) -> tuple[list[str], str | None, list[str], dict[str, str]]: """Get operator input properties from OpSchema. Args: @@ -205,7 +208,8 @@ def get_op_input_properties(schema: OpSchema): type_annotations[input_param.name] = type_str for name, attribute in schema.attributes.items(): - type_annotations[name] = attribute.type.name + # onnx's pybind11 AttrType enum exposes .name at runtime; its stub doesn't. + type_annotations[name] = attribute.type.name # type: ignore[attr-defined] op_attribute_names = list(schema.attributes.keys()) @@ -273,7 +277,7 @@ def _make_hashable_sequence(value: list | tuple, replace_float_with_dummy: bool) needs_conversion = True break if not needs_conversion: - return tuple(value) if type(value) is list else value + return tuple(value) if isinstance(value, list) else value return tuple([make_hashable(x, replace_float_with_dummy) for x in value])