From 991f06c8eacf39bd63de85d12ac9b5d086f55255 Mon Sep 17 00:00:00 2001 From: Hualiang Xie Date: Wed, 24 Jun 2026 15:49:45 +0800 Subject: [PATCH 1/7] # type: ignore[misc] # optimum base is untyped --- src/winml/modelkit/__init__.py | 3 ++- src/winml/modelkit/cli.py | 11 +++++++++-- src/winml/modelkit/models/__init__.py | 4 ++-- src/winml/modelkit/models/hf/bart.py | 6 +++--- src/winml/modelkit/models/hf/bert.py | 2 +- src/winml/modelkit/models/hf/blip.py | 4 ++-- src/winml/modelkit/models/hf/clip.py | 4 ++-- src/winml/modelkit/models/hf/convnext.py | 2 +- src/winml/modelkit/models/hf/depth_pro.py | 4 ++-- src/winml/modelkit/models/hf/marian.py | 6 +++--- src/winml/modelkit/models/hf/mu2.py | 4 ++-- src/winml/modelkit/models/hf/qwen.py | 4 ++-- src/winml/modelkit/models/hf/roberta.py | 8 ++++---- src/winml/modelkit/models/hf/sam.py | 20 ++++++++++---------- src/winml/modelkit/models/hf/segformer.py | 4 ++-- src/winml/modelkit/models/hf/siglip.py | 4 ++-- src/winml/modelkit/models/hf/t5.py | 4 ++-- src/winml/modelkit/models/hf/zoedepth.py | 2 +- 18 files changed, 52 insertions(+), 44 deletions(-) diff --git a/src/winml/modelkit/__init__.py b/src/winml/modelkit/__init__.py index 3e3142d71..463a60b7f 100644 --- a/src/winml/modelkit/__init__.py +++ b/src/winml/modelkit/__init__.py @@ -31,6 +31,7 @@ import logging import sys from importlib.metadata import PackageNotFoundError, version +from typing import Any # Force utf-8 stdout/stderr so emoji and Unicode output (rich console, logs, @@ -98,7 +99,7 @@ def _preload_bundled_onnxruntime_dll() -> None: } -def __getattr__(name: str): +def __getattr__(name: str) -> Any: """Lazy-load heavy exports on first access (PEP 562). This avoids importing torch/transformers/optimum (~30s) when only diff --git a/src/winml/modelkit/cli.py b/src/winml/modelkit/cli.py index 2e4745950..d910242c7 100644 --- a/src/winml/modelkit/cli.py +++ b/src/winml/modelkit/cli.py @@ -23,9 +23,14 @@ import logging from importlib import import_module from pathlib import Path +from typing import TYPE_CHECKING import click + +if TYPE_CHECKING: + from rich.console import Console + from . import __version__ from .telemetry import ActionGroup from .telemetry import telemetry as _telemetry_mod @@ -78,7 +83,7 @@ def _gradient_color(t: float) -> tuple[int, int, int]: return _GRADIENT[-1][1] -def _print_banner(version: str, *, _console: object | None = None) -> None: +def _print_banner(version: str, *, _console: Console | None = None) -> None: """Print the WinML CLI gradient banner to stderr using Rich.""" from rich.console import Console # lazy import - keeps startup fast from rich.text import Text @@ -205,7 +210,9 @@ def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None discovered = attr return discovered - def resolve_command(self, ctx: click.Context, args: list[str]): + def resolve_command( + self, ctx: click.Context, args: list[str] + ) -> tuple[str | None, click.Command | None, list[str]]: """Seed ``self.commands`` so Click can emit a did-you-mean hint on typos.""" # Click's NoSuchCommand exception uses self.commands to find suggestions. for name in self.list_commands(ctx): diff --git a/src/winml/modelkit/models/__init__.py b/src/winml/modelkit/models/__init__.py index 2fa39bda4..c5ce1464c 100644 --- a/src/winml/modelkit/models/__init__.py +++ b/src/winml/modelkit/models/__init__.py @@ -22,7 +22,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from .hf import MODEL_BUILD_CONFIGS @@ -60,7 +60,7 @@ } -def __getattr__(name: str): +def __getattr__(name: str) -> Any: """Lazy load modules that would cause circular imports.""" if name in _LAZY_IMPORTS: module_path, attr_name = _LAZY_IMPORTS[name] diff --git a/src/winml/modelkit/models/hf/bart.py b/src/winml/modelkit/models/hf/bart.py index bb204ac4f..278b1223b 100644 --- a/src/winml/modelkit/models/hf/bart.py +++ b/src/winml/modelkit/models/hf/bart.py @@ -358,7 +358,7 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: @register_onnx_overwrite("bart", "feature-extraction", library_name="transformers") -class BartEncoderIOConfig(OnnxConfig): +class BartEncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for BART encoder (feature-extraction task). Inputs: input_ids, attention_mask @@ -385,7 +385,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 } -class _BartDecoderNormalizedConfig(NormalizedConfig): +class _BartDecoderNormalizedConfig(NormalizedConfig): # type: ignore[misc] # optimum base is untyped """NormalizedConfig for BART decoder-side export. Maps NormalizedConfig attributes to BartConfig's decoder-side attrs. @@ -404,7 +404,7 @@ def head_dim(self) -> int: @register_onnx_overwrite("bart", "text2text-generation", library_name="transformers") -class BartDecoderIOConfig(OnnxConfig): +class BartDecoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for BART decoder with sliding-window KV cache. Inputs: decoder_input_ids, encoder_hidden_states, attention_mask, diff --git a/src/winml/modelkit/models/hf/bert.py b/src/winml/modelkit/models/hf/bert.py index d537c8010..c0df6ee7c 100644 --- a/src/winml/modelkit/models/hf/bert.py +++ b/src/winml/modelkit/models/hf/bert.py @@ -44,7 +44,7 @@ @register_onnx_overwrite("bert", *COMMON_TEXT_TASKS, library_name="transformers") -class BertIOConfig(BertOnnxConfig): +class BertIOConfig(BertOnnxConfig): # type: ignore[misc] # optimum base is untyped """BERT OnnxConfig using max_position_embeddings as sequence_length. Inputs: diff --git a/src/winml/modelkit/models/hf/blip.py b/src/winml/modelkit/models/hf/blip.py index 4aa2ee727..5063f9561 100644 --- a/src/winml/modelkit/models/hf/blip.py +++ b/src/winml/modelkit/models/hf/blip.py @@ -85,7 +85,7 @@ @register_onnx_overwrite("blip", "image-to-text", library_name="transformers") @register_onnx_overwrite("blip", "image-text-to-text", library_name="transformers") -class BlipCaptioningIOConfig(OnnxConfig): +class BlipCaptioningIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """Monolithic ONNX config for BLIP captioning — single-graph fallback. Traces ``BlipForConditionalGeneration.forward`` with pixel_values + @@ -148,7 +148,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: @register_onnx_overwrite("blip", "feature-extraction", library_name="transformers") -class BlipVisionEncoderIOConfig(OnnxConfig): +class BlipVisionEncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for the BLIP vision encoder. ``image-feature-extraction`` is a synonym that Optimum's TasksManager diff --git a/src/winml/modelkit/models/hf/clip.py b/src/winml/modelkit/models/hf/clip.py index 045fcb160..28ed62b41 100644 --- a/src/winml/modelkit/models/hf/clip.py +++ b/src/winml/modelkit/models/hf/clip.py @@ -73,7 +73,7 @@ # Optimum ONNX Export Config Registrations # ============================================================================= @register_onnx_overwrite("clip_text_model", "feature-extraction", library_name="transformers") -class CLIPTextModelIOConfig(CLIPTextWithProjectionOnnxConfig): +class CLIPTextModelIOConfig(CLIPTextWithProjectionOnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for CLIPTextModelWithProjection from transformers. Model: openai/clip-vit-base-patch32 (text encoder only) @@ -108,7 +108,7 @@ def inputs(self) -> dict[str, dict[int, str]]: @register_onnx_overwrite("clip_vision_model", "feature-extraction", library_name="transformers") -class CLIPVisionModelIOConfig(CLIPVisionModelOnnxConfig): +class CLIPVisionModelIOConfig(CLIPVisionModelOnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for CLIPVisionModelWithProjection from transformers. Model: openai/clip-vit-base-patch32 (vision encoder only) diff --git a/src/winml/modelkit/models/hf/convnext.py b/src/winml/modelkit/models/hf/convnext.py index f1d453c74..5a43b0a30 100644 --- a/src/winml/modelkit/models/hf/convnext.py +++ b/src/winml/modelkit/models/hf/convnext.py @@ -102,7 +102,7 @@ def _build_patching_specs() -> list[PatchingSpec]: "image-classification", library_name="transformers", ) -class ConvNextIOConfig(ConvNextOnnxConfig): +class ConvNextIOConfig(ConvNextOnnxConfig): # type: ignore[misc] # optimum base is untyped """ConvNextOnnxConfig override that adds a LayerNorm fusion patch. Inherits all I/O specs from Optimum's ``ConvNextOnnxConfig``. The only diff --git a/src/winml/modelkit/models/hf/depth_pro.py b/src/winml/modelkit/models/hf/depth_pro.py index a5d53a770..3fe865c65 100644 --- a/src/winml/modelkit/models/hf/depth_pro.py +++ b/src/winml/modelkit/models/hf/depth_pro.py @@ -30,7 +30,7 @@ from ...export import register_onnx_overwrite -class _DepthProNormalizedConfig(NormalizedConfig): +class _DepthProNormalizedConfig(NormalizedConfig): # type: ignore[misc] # optimum base is untyped """Normalized config for DepthPro with computed image_size. image_size is derived from patch_size / min(scaled_images_ratios), @@ -47,7 +47,7 @@ def image_size(self) -> int: @register_onnx_overwrite("depth_pro", "depth-estimation", library_name="transformers") -class DepthProIOConfig(OnnxConfig): +class DepthProIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for DepthPro depth estimation. Model: apple/DepthPro-hf diff --git a/src/winml/modelkit/models/hf/marian.py b/src/winml/modelkit/models/hf/marian.py index 6251ff4ce..8540972a6 100644 --- a/src/winml/modelkit/models/hf/marian.py +++ b/src/winml/modelkit/models/hf/marian.py @@ -398,7 +398,7 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: @register_onnx_overwrite("marian", "feature-extraction", library_name="transformers") -class MarianEncoderIOConfig(OnnxConfig): +class MarianEncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for Marian encoder (feature-extraction task). Inputs: input_ids, attention_mask @@ -425,7 +425,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 } -class _MarianDecoderNormalizedConfig(NormalizedConfig): +class _MarianDecoderNormalizedConfig(NormalizedConfig): # type: ignore[misc] # optimum base is untyped """NormalizedConfig for Marian decoder-side export. Maps NormalizedConfig attributes to MarianConfig's decoder-side attrs. @@ -444,7 +444,7 @@ def head_dim(self) -> int: @register_onnx_overwrite("marian", "text2text-generation", library_name="transformers") -class MarianDecoderIOConfig(OnnxConfig): +class MarianDecoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for Marian decoder with sliding-window KV cache. Inputs: decoder_input_ids, encoder_hidden_states, attention_mask, diff --git a/src/winml/modelkit/models/hf/mu2.py b/src/winml/modelkit/models/hf/mu2.py index 3efcabc5d..2a8dd11cc 100644 --- a/src/winml/modelkit/models/hf/mu2.py +++ b/src/winml/modelkit/models/hf/mu2.py @@ -194,7 +194,7 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: @register_onnx_overwrite("mu2", "feature-extraction", library_name="transformers") -class Mu2EncoderIOConfig(OnnxConfig): +class Mu2EncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for Mu2 encoder (feature-extraction task).""" NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( @@ -218,7 +218,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 @register_onnx_overwrite("mu2", "text2text-generation", library_name="transformers") -class Mu2DecoderIOConfig(OnnxConfig): +class Mu2DecoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for Mu2 decoder with static KV cache.""" NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( diff --git a/src/winml/modelkit/models/hf/qwen.py b/src/winml/modelkit/models/hf/qwen.py index 6f88a078d..0b8c9b45b 100644 --- a/src/winml/modelkit/models/hf/qwen.py +++ b/src/winml/modelkit/models/hf/qwen.py @@ -262,7 +262,7 @@ def _qwen_io_outputs(num_layers: int) -> dict[str, dict[int, str]]: @register_onnx_overwrite("qwen3", "feature-extraction", library_name="transformers") -class QwenPrefillIOConfig(OnnxConfig): +class QwenPrefillIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for Qwen3 prefill (feature-extraction task). Inputs: input_ids [1, 64], attention_mask [1, 256], position_ids [1, 64], @@ -283,7 +283,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 @register_onnx_overwrite("qwen3", "text2text-generation", library_name="transformers") -class QwenGenIOConfig(OnnxConfig): +class QwenGenIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for Qwen3 generation (text2text-generation task). Inputs: input_ids [1, 1], attention_mask [1, 256], position_ids [1, 1], diff --git a/src/winml/modelkit/models/hf/roberta.py b/src/winml/modelkit/models/hf/roberta.py index db4fb9103..94b31bf7e 100644 --- a/src/winml/modelkit/models/hf/roberta.py +++ b/src/winml/modelkit/models/hf/roberta.py @@ -122,7 +122,7 @@ def __init__(self, config, task, **kwargs): @register_onnx_overwrite("roberta", *COMMON_TEXT_TASKS, library_name="transformers") -class RobertaIOConfig(_RobertaPositionOffsetMixin, RobertaOnnxConfig): +class RobertaIOConfig(_RobertaPositionOffsetMixin, RobertaOnnxConfig): # type: ignore[misc] # optimum base is untyped """Roberta OnnxConfig with position-offset-adjusted sequence_length. Inputs (same as DistilBERT — no token_type_ids): @@ -137,17 +137,17 @@ class RobertaIOConfig(_RobertaPositionOffsetMixin, RobertaOnnxConfig): @register_onnx_overwrite("xlm-roberta", *COMMON_TEXT_TASKS, library_name="transformers") -class XLMRobertaIOConfig(_RobertaPositionOffsetMixin, XLMRobertaOnnxConfig): +class XLMRobertaIOConfig(_RobertaPositionOffsetMixin, XLMRobertaOnnxConfig): # type: ignore[misc] # optimum base is untyped """XLM-Roberta OnnxConfig with position-offset-adjusted sequence_length.""" @register_onnx_overwrite("camembert", *COMMON_TEXT_TASKS, library_name="transformers") -class CamemBERTIOConfig(_RobertaPositionOffsetMixin, CamembertOnnxConfig): +class CamemBERTIOConfig(_RobertaPositionOffsetMixin, CamembertOnnxConfig): # type: ignore[misc] # optimum base is untyped """CamemBERT OnnxConfig with position-offset-adjusted sequence_length.""" @register_onnx_overwrite("mpnet", *COMMON_TEXT_TASKS, library_name="transformers") -class MPNetIOConfig(_RobertaPositionOffsetMixin, MPNetOnnxConfig): +class MPNetIOConfig(_RobertaPositionOffsetMixin, MPNetOnnxConfig): # type: ignore[misc] # optimum base is untyped """MPNet OnnxConfig with position-offset-adjusted sequence_length. MPNet, like Roberta-family models, sets: diff --git a/src/winml/modelkit/models/hf/sam.py b/src/winml/modelkit/models/hf/sam.py index 421bd954d..0d5976118 100644 --- a/src/winml/modelkit/models/hf/sam.py +++ b/src/winml/modelkit/models/hf/sam.py @@ -593,7 +593,7 @@ def _patched_sam2_prompt_encoder_forward( } -class Sam2ModelPatcher(ModelPatcher): +class Sam2ModelPatcher(ModelPatcher): # type: ignore[misc] # optimum base is untyped """Custom ModelPatcher that applies SAM2 QNN-compatible patches during export. Patches Sam2MultiScaleBlock and Sam2PromptEncoder forward methods on all @@ -636,7 +636,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): # ============================================================================= # Custom Dummy Input Generators for SAM2 # ============================================================================= -class Sam2PointsInputGenerator(DummyInputGenerator): +class Sam2PointsInputGenerator(DummyInputGenerator): # type: ignore[misc] # optimum base is untyped """Points input generator for SAM2 decoder. Generates: @@ -684,7 +684,7 @@ def generate( ) -class Sam2EmbeddingsInputGenerator(DummyInputGenerator): +class Sam2EmbeddingsInputGenerator(DummyInputGenerator): # type: ignore[misc] # optimum base is untyped """Embeddings input generator for SAM2 mask generation decoder. Generates raw (pre-projection) encoder outputs: @@ -728,7 +728,7 @@ def generate( return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) -class Sam2MaskInputGenerator(DummyInputGenerator): +class Sam2MaskInputGenerator(DummyInputGenerator): # type: ignore[misc] # optimum base is untyped """Mask input generator for SAM2 decoder refinement. Generates: @@ -767,7 +767,7 @@ def generate( # ============================================================================= # Normalized Config with Default Image Size # ============================================================================= -class Sam2NormalizedVisionConfig(NormalizedVisionConfig): +class Sam2NormalizedVisionConfig(NormalizedVisionConfig): # type: ignore[misc] # optimum base is untyped """NormalizedVisionConfig with default image_size for SAM2. SAM2 uses 1024x1024 input images by default. @@ -798,7 +798,7 @@ def __getattr__(self, attr_name: str): @register_onnx_overwrite("sam2", "feature-extraction", library_name="transformers") @register_onnx_overwrite("sam2_video", "feature-extraction", library_name="transformers") @register_onnx_overwrite("sam2_vision_model", "feature-extraction", library_name="transformers") -class Sam2ImageEncoderIOConfig(OnnxConfig): +class Sam2ImageEncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for SAM2 image encoder (vision_encoder component). Task: image-feature-extraction (encoder-only export) @@ -839,7 +839,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # ----------------------------------------------------------------------------- @register_onnx_overwrite("sam2", "image-segmentation", library_name="transformers") @register_onnx_overwrite("sam2_video", "image-segmentation", library_name="transformers") -class Sam2IOConfig(OnnxConfig): +class Sam2IOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for SAM2 full model (encoder + decoder monolith). Task: image-segmentation (full model export) @@ -885,7 +885,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # ----------------------------------------------------------------------------- @register_onnx_overwrite("sam2", "mask-generation", library_name="transformers") @register_onnx_overwrite("sam2_video", "mask-generation", library_name="transformers") -class Sam2MaskGenerationIOConfig(OnnxConfig): +class Sam2MaskGenerationIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for SAM2MaskGeneration (decoder with raw FPN inputs). Model: facebook/sam2-hiera-small (decoder wrapper) @@ -941,7 +941,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # ============================================================================= # SAM v1 Custom Dummy Input Generators # ============================================================================= -class SamEmbeddingsInputGenerator(DummyInputGenerator): +class SamEmbeddingsInputGenerator(DummyInputGenerator): # type: ignore[misc] # optimum base is untyped """Embeddings input generator for SAM v1 mask generation decoder. Generates: @@ -982,7 +982,7 @@ def generate( # Mask generation export (SAMMaskGeneration wrapper) - SAM v1 # ----------------------------------------------------------------------------- @register_onnx_overwrite("sam", "mask-generation", library_name="transformers") -class SamMaskGenerationIOConfig(OnnxConfig): +class SamMaskGenerationIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for SAMMaskGeneration (SAM v1 decoder). Model: facebook/sam-vit-huge, facebook/sam-vit-large, facebook/sam-vit-base diff --git a/src/winml/modelkit/models/hf/segformer.py b/src/winml/modelkit/models/hf/segformer.py index 2e70b44da..0748545b0 100644 --- a/src/winml/modelkit/models/hf/segformer.py +++ b/src/winml/modelkit/models/hf/segformer.py @@ -32,7 +32,7 @@ } -class _SegformerVisionInputGenerator(DummyVisionInputGenerator): +class _SegformerVisionInputGenerator(DummyVisionInputGenerator): # type: ignore[misc] # optimum base is untyped """Vision input generator that uses preprocessor resolution over config.image_size. Optimum's DummyVisionInputGenerator prioritizes normalized_config.image_size @@ -74,7 +74,7 @@ def __init__( @register_onnx_overwrite("segformer", "image-segmentation", library_name="transformers") -class SegformerIOConfig(OnnxConfig): +class SegformerIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for Segformer semantic segmentation. Model: nvidia/segformer-b0-finetuned-ade-512-512 diff --git a/src/winml/modelkit/models/hf/siglip.py b/src/winml/modelkit/models/hf/siglip.py index c7f240d5a..6c55f9b72 100644 --- a/src/winml/modelkit/models/hf/siglip.py +++ b/src/winml/modelkit/models/hf/siglip.py @@ -64,7 +64,7 @@ # Optimum ONNX Export Config Registrations # ============================================================================= @register_onnx_overwrite("siglip_text_model", "feature-extraction", library_name="transformers") -class SiglipTextModelIOConfig(SiglipTextOnnxConfig): +class SiglipTextModelIOConfig(SiglipTextOnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for SiglipTextModel (text encoder only). Uses ``max_position_embeddings`` (64 for SigLIP) as the fixed sequence @@ -83,7 +83,7 @@ class SiglipTextModelIOConfig(SiglipTextOnnxConfig): @register_onnx_overwrite("siglip_vision_model", "feature-extraction", library_name="transformers") -class SiglipVisionModelIOConfig(SiglipVisionModelOnnxConfig): +class SiglipVisionModelIOConfig(SiglipVisionModelOnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for SiglipVisionModel (vision encoder only). Uses Optimum defaults; no overrides needed. diff --git a/src/winml/modelkit/models/hf/t5.py b/src/winml/modelkit/models/hf/t5.py index 686f43562..48a906f43 100644 --- a/src/winml/modelkit/models/hf/t5.py +++ b/src/winml/modelkit/models/hf/t5.py @@ -203,7 +203,7 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: @register_onnx_overwrite("t5", "feature-extraction", library_name="transformers") -class T5EncoderIOConfig(OnnxConfig): +class T5EncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for T5 encoder (feature-extraction task). Inputs: input_ids, attention_mask @@ -231,7 +231,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 @register_onnx_overwrite("t5", "text2text-generation", library_name="transformers") -class T5DecoderIOConfig(OnnxConfig): +class T5DecoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for T5 decoder with sliding-window KV cache. Inputs: decoder_input_ids, encoder_hidden_states, attention_mask, diff --git a/src/winml/modelkit/models/hf/zoedepth.py b/src/winml/modelkit/models/hf/zoedepth.py index e648b8548..83dda0e05 100644 --- a/src/winml/modelkit/models/hf/zoedepth.py +++ b/src/winml/modelkit/models/hf/zoedepth.py @@ -29,7 +29,7 @@ @register_onnx_overwrite("zoedepth", "depth-estimation", library_name="transformers") -class ZoeDepthIOConfig(OnnxConfig): +class ZoeDepthIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for ZoeDepth depth estimation. Model: Intel/zoedepth-nyu-kitti From 9f8f73c9cdf7f8e348b0c49601bdcca73745bf89 Mon Sep 17 00:00:00 2001 From: Hualiang Xie Date: Wed, 24 Jun 2026 15:57:26 +0800 Subject: [PATCH 2/7] forward --- .../modelkit/models/winml/depth_estimation.py | 9 +++++++-- .../models/winml/feature_extraction.py | 2 +- .../models/winml/image_classification.py | 8 +++++--- .../models/winml/image_segmentation.py | 20 ++++++++++--------- .../modelkit/models/winml/object_detection.py | 10 +++++----- .../models/winml/question_answering.py | 8 +++++--- .../models/winml/sequence_classification.py | 8 +++++--- .../winml/zero_shot_image_classification.py | 2 +- 8 files changed, 40 insertions(+), 27 deletions(-) diff --git a/src/winml/modelkit/models/winml/depth_estimation.py b/src/winml/modelkit/models/winml/depth_estimation.py index dda6b5c3e..f57844390 100644 --- a/src/winml/modelkit/models/winml/depth_estimation.py +++ b/src/winml/modelkit/models/winml/depth_estimation.py @@ -12,13 +12,16 @@ from __future__ import annotations import logging -from typing import Any +from typing import TYPE_CHECKING, Any, cast from transformers.modeling_outputs import DepthEstimatorOutput from .base import WinMLPreTrainedModel +if TYPE_CHECKING: + import torch + logger = logging.getLogger(__name__) @@ -48,4 +51,6 @@ def forward(self, **kwargs: Any) -> DepthEstimatorOutput: # Fall back to first output for non-standard output names. predicted_depth = next(iter(outputs.values())) - return DepthEstimatorOutput(predicted_depth=predicted_depth) + # transformers' Output fields are annotated FloatTensor (legacy, over-narrow); + # the ONNX session returns a real float Tensor. + return DepthEstimatorOutput(predicted_depth=cast("torch.FloatTensor", predicted_depth)) diff --git a/src/winml/modelkit/models/winml/feature_extraction.py b/src/winml/modelkit/models/winml/feature_extraction.py index 3b4ffca66..df3444159 100644 --- a/src/winml/modelkit/models/winml/feature_extraction.py +++ b/src/winml/modelkit/models/winml/feature_extraction.py @@ -15,7 +15,7 @@ from collections import OrderedDict from typing import Any -from transformers.utils import ModelOutput +from transformers.utils.generic import ModelOutput from .base import WinMLPreTrainedModel diff --git a/src/winml/modelkit/models/winml/image_classification.py b/src/winml/modelkit/models/winml/image_classification.py index 500cee1b3..f4e749dde 100644 --- a/src/winml/modelkit/models/winml/image_classification.py +++ b/src/winml/modelkit/models/winml/image_classification.py @@ -11,7 +11,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from transformers.modeling_outputs import ImageClassifierOutput @@ -32,7 +32,7 @@ class WinMLModelForImageClassification(WinMLPreTrainedModel): Pipeline execution is done by WinMLAutoModel factory. """ - def forward( + def forward( # type: ignore[override] # HF-pipeline base uses generic **kwargs; task-specific signature self, pixel_values: torch.Tensor | np.ndarray, **kwargs: Any, @@ -53,7 +53,9 @@ def forward( # Get logits (by name or first output) logits = outputs.get("logits", next(iter(outputs.values()))) - return ImageClassifierOutput(logits=logits) + # transformers' Output fields are annotated FloatTensor (legacy, over-narrow); + # the ONNX session returns a real float Tensor. + return ImageClassifierOutput(logits=cast("torch.FloatTensor", logits)) @property def num_labels(self) -> int: diff --git a/src/winml/modelkit/models/winml/image_segmentation.py b/src/winml/modelkit/models/winml/image_segmentation.py index 8fe572c63..13d9d180c 100644 --- a/src/winml/modelkit/models/winml/image_segmentation.py +++ b/src/winml/modelkit/models/winml/image_segmentation.py @@ -19,11 +19,11 @@ import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch from transformers.modeling_outputs import SemanticSegmenterOutput -from transformers.utils import ModelOutput +from transformers.utils.generic import ModelOutput from .base import WinMLPreTrainedModel @@ -48,10 +48,10 @@ class ImageSegmentationOutput(ModelOutput): outputs.pred_boxes — [B, num_queries, 4] """ - loss: torch.FloatTensor | None = None - logits: torch.FloatTensor | None = None - pred_boxes: torch.FloatTensor | None = None - pred_masks: torch.FloatTensor | None = None + loss: torch.Tensor | None = None + logits: torch.Tensor | None = None + pred_boxes: torch.Tensor | None = None + pred_masks: torch.Tensor | None = None class WinMLModelForImageSegmentation(WinMLPreTrainedModel): @@ -65,7 +65,7 @@ class WinMLModelForImageSegmentation(WinMLPreTrainedModel): Pipeline execution is done by WinMLAutoModel factory. """ - def forward( + def forward( # type: ignore[override] # HF-pipeline base uses generic **kwargs; task-specific signature self, pixel_values: torch.Tensor | np.ndarray, pixel_mask: torch.Tensor | np.ndarray | None = None, @@ -131,7 +131,7 @@ class WinMLModelForSemanticSegmentation(WinMLPreTrainedModel): Pipeline execution is done by WinMLAutoModel factory. """ - def forward( + def forward( # type: ignore[override] # HF-pipeline base uses generic **kwargs; task-specific signature self, pixel_values: torch.Tensor | np.ndarray, **kwargs: Any, @@ -152,7 +152,9 @@ def forward( # Get logits (by name or first output) logits = outputs.get("logits", next(iter(outputs.values()))) - return SemanticSegmenterOutput(logits=logits) + # transformers' Output fields are annotated FloatTensor (legacy, over-narrow); + # the ONNX session returns a real float Tensor. + return SemanticSegmenterOutput(logits=cast("torch.FloatTensor", logits)) @property def num_labels(self) -> int: diff --git a/src/winml/modelkit/models/winml/object_detection.py b/src/winml/modelkit/models/winml/object_detection.py index 58e4186b8..d008d3256 100644 --- a/src/winml/modelkit/models/winml/object_detection.py +++ b/src/winml/modelkit/models/winml/object_detection.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any import torch -from transformers.utils import ModelOutput +from transformers.utils.generic import ModelOutput from .base import WinMLPreTrainedModel @@ -38,9 +38,9 @@ class ObjectDetectionOutput(ModelOutput): outputs.pred_boxes — [B, num_queries, 4] """ - loss: torch.FloatTensor | None = None - logits: torch.FloatTensor | None = None - pred_boxes: torch.FloatTensor | None = None + loss: torch.Tensor | None = None + logits: torch.Tensor | None = None + pred_boxes: torch.Tensor | None = None class WinMLModelForObjectDetection(WinMLPreTrainedModel): @@ -51,7 +51,7 @@ class WinMLModelForObjectDetection(WinMLPreTrainedModel): so that image_processor.post_process_object_detection() works. """ - def forward( + def forward( # type: ignore[override] # HF-pipeline base uses generic **kwargs; task-specific signature self, pixel_values: torch.Tensor | np.ndarray, pixel_mask: torch.Tensor | np.ndarray | None = None, diff --git a/src/winml/modelkit/models/winml/question_answering.py b/src/winml/modelkit/models/winml/question_answering.py index 4dc30ab22..f9ded4fde 100644 --- a/src/winml/modelkit/models/winml/question_answering.py +++ b/src/winml/modelkit/models/winml/question_answering.py @@ -15,7 +15,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from transformers.modeling_outputs import QuestionAnsweringModelOutput @@ -74,7 +74,9 @@ def forward( formatted = self._format_inputs(**inputs) outputs = self._run_inference(formatted) + # transformers' Output fields are annotated FloatTensor (legacy, over-narrow); + # the ONNX session returns real float Tensors. return QuestionAnsweringModelOutput( - start_logits=outputs.get("start_logits"), - end_logits=outputs.get("end_logits"), + start_logits=cast("torch.FloatTensor | None", outputs.get("start_logits")), + end_logits=cast("torch.FloatTensor | None", outputs.get("end_logits")), ) diff --git a/src/winml/modelkit/models/winml/sequence_classification.py b/src/winml/modelkit/models/winml/sequence_classification.py index df84948ac..602590972 100644 --- a/src/winml/modelkit/models/winml/sequence_classification.py +++ b/src/winml/modelkit/models/winml/sequence_classification.py @@ -11,7 +11,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from transformers.modeling_outputs import SequenceClassifierOutput @@ -37,7 +37,7 @@ class WinMLModelForSequenceClassification(WinMLPreTrainedModel): Pipeline execution is done by WinMLAutoModel factory. """ - def forward( + def forward( # type: ignore[override] # HF-pipeline base uses generic **kwargs; task-specific signature self, input_ids: torch.Tensor | np.ndarray, attention_mask: torch.Tensor | np.ndarray | None = None, @@ -70,7 +70,9 @@ def forward( # Get logits (by name or first output) logits = outputs.get("logits", next(iter(outputs.values()))) - return SequenceClassifierOutput(logits=logits) + # transformers' Output fields are annotated FloatTensor (legacy, over-narrow); + # the ONNX session returns a real float Tensor. + return SequenceClassifierOutput(logits=cast("torch.FloatTensor", logits)) @property def num_labels(self) -> int: diff --git a/src/winml/modelkit/models/winml/zero_shot_image_classification.py b/src/winml/modelkit/models/winml/zero_shot_image_classification.py index 94132ede8..9ab089c45 100644 --- a/src/winml/modelkit/models/winml/zero_shot_image_classification.py +++ b/src/winml/modelkit/models/winml/zero_shot_image_classification.py @@ -16,7 +16,7 @@ import numpy as np import torch -from transformers.utils import ModelOutput +from transformers.utils.generic import ModelOutput from .composite_model import WinMLCompositeModel, register_composite_model From a8958f73b399780645572e00f4f8e1767f765e8e Mon Sep 17 00:00:00 2001 From: Hualiang Xie Date: Wed, 24 Jun 2026 16:26:57 +0800 Subject: [PATCH 3/7] so many.. --- src/winml/modelkit/models/hf/bart.py | 30 ++++++++---- .../modelkit/models/hf/decoder_wrapper.py | 2 +- .../modelkit/models/hf/depth_anything.py | 4 +- src/winml/modelkit/models/hf/marian.py | 28 +++++++---- .../models/hf/vision_encoder_decoder.py | 4 +- src/winml/modelkit/models/winml/base.py | 5 +- .../modelkit/models/winml/composite_model.py | 12 +++-- .../modelkit/models/winml/decoder_only.py | 40 ++++++++++------ .../modelkit/models/winml/encoder_decoder.py | 48 ++++++++++++------- .../winml/zero_shot_image_classification.py | 9 +++- 10 files changed, 119 insertions(+), 63 deletions(-) diff --git a/src/winml/modelkit/models/hf/bart.py b/src/winml/modelkit/models/hf/bart.py index 278b1223b..cb44d2d74 100644 --- a/src/winml/modelkit/models/hf/bart.py +++ b/src/winml/modelkit/models/hf/bart.py @@ -71,7 +71,7 @@ from __future__ import annotations import logging -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast import torch import torch.nn as nn @@ -93,6 +93,9 @@ from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLStaticCache +if TYPE_CHECKING: + from transformers import GenerationConfig, PretrainedConfig + logger = logging.getLogger(__name__) @@ -140,7 +143,7 @@ def _patched_bart_learned_forward( - self, + self: Any, # monkey-patched onto BartLearnedPositionalEmbedding (HF internal) input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None, @@ -229,10 +232,14 @@ def forward( attention_mask: torch.Tensor, ) -> torch.Tensor: """Return encoder last hidden state.""" - return self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - ).last_hidden_state + # self.encoder is a torch submodule (untyped __call__ -> Any). + return cast( + "torch.Tensor", + self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ).last_hidden_state, + ) class BartDecoderWrapper(nn.Module): @@ -262,8 +269,10 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: super().__init__() self.model = model self.num_layers = num_layers - # Expose config for OnnxConfig / NormalizedConfig access - self.config = model.config + # Expose config for OnnxConfig / NormalizedConfig access. + # model is typed nn.Module, so torch's __getattr__ types .config as + # Tensor | Module; it is really the model's PretrainedConfig. + self.config = cast("PretrainedConfig", model.config) @classmethod def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> BartDecoderWrapper: @@ -400,7 +409,8 @@ class _BartDecoderNormalizedConfig(NormalizedConfig): # type: ignore[misc] # o @property def head_dim(self) -> int: - return self.hidden_size // self.num_attention_heads + # hidden_size / num_attention_heads come from the untyped NormalizedConfig base. + return cast("int", self.hidden_size // self.num_attention_heads) @register_onnx_overwrite("bart", "text2text-generation", library_name="transformers") @@ -517,7 +527,7 @@ def get_cache_class(cls) -> type: return WinMLStaticCache # static cache (index_put_ → ScatterND) @property - def generation_config(self): # noqa: D102 + def generation_config(self) -> GenerationConfig: # noqa: D102 if not hasattr(self, "_generation_config"): from transformers import GenerationConfig diff --git a/src/winml/modelkit/models/hf/decoder_wrapper.py b/src/winml/modelkit/models/hf/decoder_wrapper.py index 5e7b69dfd..ab6932df6 100644 --- a/src/winml/modelkit/models/hf/decoder_wrapper.py +++ b/src/winml/modelkit/models/hf/decoder_wrapper.py @@ -156,7 +156,7 @@ def _invoke_hf( """Call the HF decoder with ``past_key_values=``. Returns logits.""" -class WinMLStaticCacheDecoderIOConfig(OnnxConfig): +class WinMLStaticCacheDecoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum/transformers base is untyped """Semantic-name contract used by ``WinMLDecoderWrapper._make_cache``. Subclasses declare their own ``inputs`` / ``outputs`` bodies (each diff --git a/src/winml/modelkit/models/hf/depth_anything.py b/src/winml/modelkit/models/hf/depth_anything.py index 3e30a34e6..456b93eee 100644 --- a/src/winml/modelkit/models/hf/depth_anything.py +++ b/src/winml/modelkit/models/hf/depth_anything.py @@ -23,7 +23,7 @@ from ...export import register_onnx_overwrite -class _DepthAnythingVisionInputGenerator(DummyVisionInputGenerator): +class _DepthAnythingVisionInputGenerator(DummyVisionInputGenerator): # type: ignore[misc] # optimum/transformers base is untyped """Vision input generator that lets explicit height/width override config.image_size. Optimum's DummyVisionInputGenerator prioritizes normalized_config.image_size @@ -62,7 +62,7 @@ def __init__( @register_onnx_overwrite("depth_anything", "depth-estimation", library_name="transformers") -class DepthAnythingIOConfig(OnnxConfig): +class DepthAnythingIOConfig(OnnxConfig): # type: ignore[misc] # optimum/transformers base is untyped """ONNX config for Depth Anything depth estimation. Model: depth-anything/Depth-Anything-V2-Small-hf diff --git a/src/winml/modelkit/models/hf/marian.py b/src/winml/modelkit/models/hf/marian.py index 8540972a6..f3b2abb13 100644 --- a/src/winml/modelkit/models/hf/marian.py +++ b/src/winml/modelkit/models/hf/marian.py @@ -85,7 +85,7 @@ from __future__ import annotations import logging -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast import torch import torch.nn as nn @@ -107,6 +107,9 @@ from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLStaticCache +if TYPE_CHECKING: + from transformers import GenerationConfig, PretrainedConfig + logger = logging.getLogger(__name__) @@ -177,7 +180,7 @@ def _patched_marian_sinusoidal_forward( - self, + self: Any, # monkey-patched onto MarianSinusoidalPositionalEmbedding (HF internal) input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None, @@ -262,10 +265,14 @@ def forward( attention_mask: torch.Tensor, ) -> torch.Tensor: """Return encoder last hidden state.""" - return self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - ).last_hidden_state + # self.encoder is a torch submodule (untyped __call__ -> Any). + return cast( + "torch.Tensor", + self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ).last_hidden_state, + ) class MarianDecoderWrapper(nn.Module): @@ -301,7 +308,9 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: self.model = model self.num_layers = num_layers # Expose config for OnnxConfig / NormalizedConfig access - self.config = model.config + # model is typed nn.Module, so torch's __getattr__ types .config as + # Tensor | Module; it is really the model's PretrainedConfig. + self.config = cast("PretrainedConfig", model.config) @classmethod def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> MarianDecoderWrapper: @@ -440,7 +449,8 @@ class _MarianDecoderNormalizedConfig(NormalizedConfig): # type: ignore[misc] # @property def head_dim(self) -> int: - return self.hidden_size // self.num_attention_heads + # hidden_size / num_attention_heads come from the untyped NormalizedConfig base. + return cast("int", self.hidden_size // self.num_attention_heads) @register_onnx_overwrite("marian", "text2text-generation", library_name="transformers") @@ -554,7 +564,7 @@ def get_cache_class(cls) -> type: return WinMLStaticCache # static cache (index_put_ → ScatterND) @property - def generation_config(self): # noqa: D102 + def generation_config(self) -> GenerationConfig: # noqa: D102 if not hasattr(self, "_generation_config"): from transformers import GenerationConfig diff --git a/src/winml/modelkit/models/hf/vision_encoder_decoder.py b/src/winml/modelkit/models/hf/vision_encoder_decoder.py index a576b0d51..14e9a7c74 100644 --- a/src/winml/modelkit/models/hf/vision_encoder_decoder.py +++ b/src/winml/modelkit/models/hf/vision_encoder_decoder.py @@ -104,7 +104,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: @register_onnx_overwrite( "vision-encoder-decoder", "feature-extraction", library_name="transformers" ) -class VisionEncoderIOConfig(OnnxConfig): +class VisionEncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum/transformers base is untyped """ONNX config for the vision encoder.""" NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( @@ -243,7 +243,7 @@ def _build_ved_patching_specs() -> list[PatchingSpec]: # ============================================================================= -class _VedDecoderNormalizedConfig(NormalizedConfig): +class _VedDecoderNormalizedConfig(NormalizedConfig): # type: ignore[misc] # optimum/transformers base is untyped """VED decoder NormalizedConfig. Per-architecture field paths (``decoder.d_model`` vs ``decoder.n_embd`` diff --git a/src/winml/modelkit/models/winml/base.py b/src/winml/modelkit/models/winml/base.py index e8656d646..ea9172032 100644 --- a/src/winml/modelkit/models/winml/base.py +++ b/src/winml/modelkit/models/winml/base.py @@ -21,7 +21,7 @@ import logging from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import numpy as np import torch @@ -246,7 +246,8 @@ def task(self) -> str | None: if build_config is not None: loader = getattr(build_config, "loader", None) if loader: - return loader.task + # loader comes from getattr (Any); task is a str | None field. + return cast("str | None", loader.task) return None @property diff --git a/src/winml/modelkit/models/winml/composite_model.py b/src/winml/modelkit/models/winml/composite_model.py index 8c839859a..1d89b6c66 100644 --- a/src/winml/modelkit/models/winml/composite_model.py +++ b/src/winml/modelkit/models/winml/composite_model.py @@ -41,7 +41,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast import torch @@ -49,7 +49,7 @@ if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Callable, Mapping from pathlib import Path from transformers import PretrainedConfig @@ -66,7 +66,7 @@ COMPOSITE_MODEL_REGISTRY: dict[tuple[str, str], type[WinMLCompositeModel]] = {} -def register_composite_model(model_type: str, task: str): +def register_composite_model(model_type: str, task: str) -> Callable[[type], type]: """Class decorator that registers a composite model for `winml config`.""" def decorator(cls: type) -> type: @@ -109,7 +109,7 @@ class WinMLCompositeModel(PreTrainedModel): def __init__( self, sub_models: dict[str, Any], - config: PretrainedConfig, + config: PretrainedConfig | None, device: str = "cpu", ) -> None: self.sub_models = sub_models @@ -243,7 +243,9 @@ def from_onnx( # Resolve concrete class from registry model_type = getattr(hf_config, "model_type", None) if hf_config else None if not cls._SUB_MODEL_CONFIG: - resolved_cls = COMPOSITE_MODEL_REGISTRY.get((model_type, task)) + # model_type/task may be None; the str-keyed registry simply misses + # (returns None, handled below). dict.get tolerates any hashable key. + resolved_cls = COMPOSITE_MODEL_REGISTRY.get(cast("tuple[str, str]", (model_type, task))) if resolved_cls is None: raise ValueError( f"No composite model for ({model_type!r}, {task!r}). " diff --git a/src/winml/modelkit/models/winml/decoder_only.py b/src/winml/modelkit/models/winml/decoder_only.py index 3bfa77700..d47e78703 100644 --- a/src/winml/modelkit/models/winml/decoder_only.py +++ b/src/winml/modelkit/models/winml/decoder_only.py @@ -58,7 +58,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch from optimum.utils.input_generators import DummyInputGenerator @@ -71,6 +71,8 @@ if TYPE_CHECKING: from transformers import Cache, PretrainedConfig + from .kv_cache import WinMLCache + logger = logging.getLogger(__name__) @@ -79,7 +81,7 @@ # ========================================================================= -class DecoderOnlyInputGenerator(DummyInputGenerator): +class DecoderOnlyInputGenerator(DummyInputGenerator): # type: ignore[misc] # optimum/transformers base is untyped """Generates base inputs for decoder-only models with static KV cache. Produces ``input_ids``, ``attention_mask``, ``position_ids``, and @@ -118,7 +120,9 @@ def __init__( self.batch_size = batch_size self.vocab_size = normalized_config.vocab_size self.max_cache_len = max_cache_len or normalized_config.max_cache_len - self.seq_len: int = seq_len or getattr(normalized_config, "seq_len", self._default_seq_len) + self.seq_len: int = seq_len or cast( + "int", getattr(normalized_config, "seq_len", self._default_seq_len) + ) def generate( self, @@ -129,11 +133,15 @@ def generate( ) -> torch.Tensor: """Generate a dummy tensor for the given input name.""" if input_name == "input_ids": - return self.random_int_tensor( - (self.batch_size, self.seq_len), - max_value=self.vocab_size, - framework=framework, - dtype=int_dtype, + # optimum's DummyInputGenerator is untyped, so random_int_tensor returns Any. + return cast( + "torch.Tensor", + self.random_int_tensor( + (self.batch_size, self.seq_len), + max_value=self.vocab_size, + framework=framework, + dtype=int_dtype, + ), ) if input_name == "attention_mask": mask = torch.zeros(self.batch_size, self.max_cache_len, dtype=torch.int64) @@ -225,7 +233,7 @@ def __init__( # ----- Cache + GenerationMixin interface ----- @classmethod - def get_cache_class(cls) -> type: + def get_cache_class(cls) -> type[WinMLCache]: """Return the WinMLCache subclass. Subclasses must override.""" raise NotImplementedError @@ -250,6 +258,8 @@ def _resolve_cache(self, past_key_values: Any) -> Any: if isinstance(past_key_values, WinMLCache): return past_key_values + if self.config is None: + raise ValueError("Decoder-only generation requires an HF config to build the KV cache.") kv_shape = [1, self._num_kv_heads, self._max_cache_len, self._head_dim] cache = self.get_cache_class().create(self.config, kv_shape, self._kv_dtype) cache.reset() @@ -258,7 +268,7 @@ def _resolve_cache(self, past_key_values: Any) -> Any: def can_generate(self) -> bool: # noqa: D102 return True - def prepare_inputs_for_generation( + def prepare_inputs_for_generation( # type: ignore[override] # GenerationMixin's base signature differs; static-cache flow self, input_ids: torch.LongTensor, past_key_values: Cache | None = None, @@ -269,7 +279,7 @@ def prepare_inputs_for_generation( from .kv_cache import WinMLCache if isinstance(past_key_values, WinMLCache) and past_key_values.get_seq_length() > 0: - input_ids = input_ids[:, -1:] + input_ids = cast("torch.LongTensor", input_ids[:, -1:]) else: past_key_values = None return { @@ -280,7 +290,7 @@ def prepare_inputs_for_generation( # ----- Forward ----- - def forward( + def forward( # type: ignore[override] # HF-pipeline base uses generic **kwargs; task-specific signature self, *, input_ids: torch.Tensor, @@ -311,8 +321,10 @@ def forward( else: logits = self._run_gen(input_ids, cache) + # transformers' Output fields are annotated FloatTensor (legacy, over-narrow); + # the ONNX session returns a real float Tensor. return CausalLMOutputWithPast( - logits=logits, + logits=cast("torch.FloatTensor", logits), past_key_values=cache, ) @@ -395,4 +407,4 @@ def _run_gen(self, input_ids: torch.Tensor, cache: Any) -> torch.Tensor: outputs = self._gen_model(**feeds) cache.update_all_layers(outputs) - return outputs["logits"] + return cast("torch.Tensor", outputs["logits"]) diff --git a/src/winml/modelkit/models/winml/encoder_decoder.py b/src/winml/modelkit/models/winml/encoder_decoder.py index ae1b669fe..8430b2477 100644 --- a/src/winml/modelkit/models/winml/encoder_decoder.py +++ b/src/winml/modelkit/models/winml/encoder_decoder.py @@ -57,7 +57,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch from optimum.utils.input_generators import DummyInputGenerator @@ -72,6 +72,8 @@ from optimum.utils import NormalizedConfig from transformers import Cache, PretrainedConfig + from .kv_cache import WinMLCache + logger = logging.getLogger(__name__) @@ -80,7 +82,7 @@ # ============================================================================= -class EncoderDecoderInputGenerator(DummyInputGenerator): +class EncoderDecoderInputGenerator(DummyInputGenerator): # type: ignore[misc] # optimum/transformers base is untyped """Generates decoder base inputs for encoder-decoder models. Produces ``decoder_input_ids``, ``encoder_hidden_states``, @@ -108,7 +110,9 @@ def __init__( ) -> None: self.batch_size = batch_size self.d_model = normalized_config.hidden_size - self.enc_seq = sequence_length or getattr(normalized_config, "sequence_length", 16) + self.enc_seq: int = sequence_length or cast( + "int", getattr(normalized_config, "sequence_length", 16) + ) self.max_cache_len = max_cache_len or normalized_config.max_cache_len self.vocab_size = normalized_config.vocab_size @@ -120,18 +124,25 @@ def generate( float_dtype: str = "fp32", ) -> torch.Tensor: """Generate a dummy tensor for the given input name.""" + # optimum's DummyInputGenerator is untyped, so random_*_tensor returns Any. if input_name == "decoder_input_ids": - return self.random_int_tensor( - (self.batch_size, 1), - max_value=self.vocab_size, - framework=framework, - dtype=int_dtype, + return cast( + "torch.Tensor", + self.random_int_tensor( + (self.batch_size, 1), + max_value=self.vocab_size, + framework=framework, + dtype=int_dtype, + ), ) if input_name == "encoder_hidden_states": - return self.random_float_tensor( - (self.batch_size, self.enc_seq, self.d_model), - framework=framework, - dtype=float_dtype, + return cast( + "torch.Tensor", + self.random_float_tensor( + (self.batch_size, self.enc_seq, self.d_model), + framework=framework, + dtype=float_dtype, + ), ) if input_name == "attention_mask": return torch.ones(self.batch_size, self.enc_seq, dtype=torch.int64) @@ -226,7 +237,8 @@ def __init__(self, encoder: Any, expected: dict[str, list[int]]) -> None: def forward(self, **kwargs: Any) -> BaseModelOutput: feeds = pad_inputs(kwargs, self._expected) - return self._encoder(**feeds) + # self._encoder is a torch Module (untyped __call__ -> Any). + return cast("BaseModelOutput", self._encoder(**feeds)) def get_encoder(self) -> torch.nn.Module: """Return encoder for GenerationMixin (already wrapped with padding).""" @@ -235,7 +247,7 @@ def get_encoder(self) -> torch.nn.Module: def can_generate(self) -> bool: # noqa: D102 return True - def prepare_inputs_for_generation( + def prepare_inputs_for_generation( # type: ignore[override] # GenerationMixin's base signature differs; static-cache flow self, input_ids: torch.LongTensor, past_key_values: Cache | None = None, @@ -260,7 +272,7 @@ def prepare_inputs_for_generation( # ----- Cache management ----- @classmethod - def get_cache_class(cls) -> type: + def get_cache_class(cls) -> type[WinMLCache]: """Return the WinMLCache subclass. Subclasses must override.""" raise NotImplementedError @@ -282,6 +294,8 @@ def _resolve_cache(self, past_key_values: Any) -> Any: return past_key_values # (3) Create fresh cache and reset + if self.config is None: + raise ValueError("Encoder-decoder generation requires an HF config to build the KV cache.") kv_shape = self._dec_expected["past_0_key"] cache = self.get_cache_class().create(self.config, kv_shape, self._kv_dtype) cache.reset() @@ -316,7 +330,9 @@ def forward( encoder_outputs = self._encoder(input_ids=input_ids, **model_kwargs) if encoder_outputs is None: raise ValueError("Either encoder_outputs or input_ids required") - enc_h = encoder_outputs["last_hidden_state"] + # The encoder wrapper always returns a dict-like BaseModelOutput; the tuple + # arm of the annotation exists only for GenerationMixin signature compat. + enc_h = cast("BaseModelOutput", encoder_outputs)["last_hidden_state"] # Resolve or create cache (subclasses override get_cache_class). cache = self._resolve_cache(past_key_values) diff --git a/src/winml/modelkit/models/winml/zero_shot_image_classification.py b/src/winml/modelkit/models/winml/zero_shot_image_classification.py index 9ab089c45..ba40a295d 100644 --- a/src/winml/modelkit/models/winml/zero_shot_image_classification.py +++ b/src/winml/modelkit/models/winml/zero_shot_image_classification.py @@ -12,7 +12,7 @@ import logging from dataclasses import dataclass -from typing import Any, ClassVar +from typing import Any, ClassVar, cast import numpy as np import torch @@ -101,7 +101,12 @@ def forward( def _preprocess_vision(self, pixel_values: torch.Tensor | None) -> dict[str, np.ndarray]: """Torch→numpy via the sub-model's formatter.""" - return self.sub_models["image-encoder"]._format_inputs(pixel_values=pixel_values) + # sub_models values are Any (heterogeneous WinML models); _format_inputs + # returns a {name: ndarray} feed dict. + return cast( + "dict[str, np.ndarray]", + self.sub_models["image-encoder"]._format_inputs(pixel_values=pixel_values), + ) def _run_vision(self, inputs: dict[str, np.ndarray]) -> torch.Tensor: """Run vision encoder over ``M`` images, batching per the ONNX's fixed batch dim.""" From 67a272c57d8fbf3a3a357a6665f1d3f80f76b6a0 Mon Sep 17 00:00:00 2001 From: Hualiang Xie Date: Thu, 25 Jun 2026 10:50:12 +0800 Subject: [PATCH 4/7] more changes --- .github/workflows/lint.yml | 40 ++++------------ src/winml/modelkit/eval/base_evaluator.py | 5 +- src/winml/modelkit/eval/evaluate.py | 3 +- src/winml/modelkit/export/io.py | 6 ++- src/winml/modelkit/inference/engine.py | 3 +- src/winml/modelkit/inference/pipeline.py | 3 +- src/winml/modelkit/models/auto.py | 12 +++-- src/winml/modelkit/models/hf/__init__.py | 31 ++++++++----- src/winml/modelkit/models/hf/blip.py | 15 +++--- src/winml/modelkit/models/hf/convnext.py | 3 +- .../modelkit/models/hf/depth_anything.py | 8 ++-- src/winml/modelkit/models/hf/mu2.py | 30 ++++++++---- src/winml/modelkit/models/hf/qwen.py | 12 +++-- src/winml/modelkit/models/hf/roberta.py | 9 ++-- src/winml/modelkit/models/hf/segformer.py | 8 ++-- src/winml/modelkit/models/hf/t5.py | 24 +++++++--- src/winml/modelkit/models/winml/__init__.py | 6 +-- .../modelkit/models/winml/composite_model.py | 4 +- .../modelkit/models/winml/decoder_only.py | 2 - .../modelkit/models/winml/encoder_decoder.py | 2 - src/winml/modelkit/models/winml/kv_cache.py | 46 ++++++++++++------- 21 files changed, 158 insertions(+), 114 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index a2e24093a..774efa391 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -36,39 +36,15 @@ jobs: - name: Lint run: uv run ruff check src/ tests/ - # Required type check: these packages are clean against the strict - # config in pyproject.toml. Any new mypy error here blocks the PR. - # Expand the package list as more folders are cleaned up. + # Required type check: the ENTIRE winml.modelkit package is clean against + # the strict config in pyproject.toml. Any new mypy error here blocks the PR. # - # Single mypy invocation across all packages — a per-package loop pays - # cold typeshed/plugin startup per package and tipped the job past the - # 5-minute timeout once the list grew to 12. The combined summary still - # reports total error/file counts; error lines include file paths so - # the failing package is identifiable without per-package groups. + # A single whole-package invocation (not a per-subpackage list) — it covers + # every subpackage plus the top-level modules (cli.py, __init__.py, etc.), + # and surfaces cross-package interactions (import cycles, shared return-type + # unions) that per-subpackage runs miss. Relaxed/ignored modules (tests, + # analyze.onnx_opset) are scoped via [[tool.mypy.overrides]] in pyproject. - name: Type check (required) run: >- uv run mypy - -p winml.modelkit.analyze - -p winml.modelkit.build - -p winml.modelkit.cache - -p winml.modelkit.commands - -p winml.modelkit.compiler - -p winml.modelkit.config - -p winml.modelkit.core - -p winml.modelkit.data - -p winml.modelkit.datasets - -p winml.modelkit.eval - -p winml.modelkit.export - -p winml.modelkit.inference - -p winml.modelkit.inspect - -p winml.modelkit.loader - -p winml.modelkit.onnx - -p winml.modelkit.optim - -p winml.modelkit.optracing - -p winml.modelkit.pattern - -p winml.modelkit.quant - -p winml.modelkit.serve - -p winml.modelkit.session - -p winml.modelkit.sysinfo - -p winml.modelkit.telemetry - -p winml.modelkit.utils + -p winml.modelkit diff --git a/src/winml/modelkit/eval/base_evaluator.py b/src/winml/modelkit/eval/base_evaluator.py index 30eba265d..2e689b13f 100644 --- a/src/winml/modelkit/eval/base_evaluator.py +++ b/src/winml/modelkit/eval/base_evaluator.py @@ -18,6 +18,7 @@ from transformers.pipelines.base import Pipeline from ..models.winml.base import WinMLPreTrainedModel + from ..models.winml.composite_model import WinMLCompositeModel from .config import DatasetConfig, WinMLEvaluationConfig logger = logging.getLogger(__name__) @@ -34,7 +35,7 @@ class WinMLEvaluator: def __init__( self, config: WinMLEvaluationConfig, - model: WinMLPreTrainedModel, + model: WinMLPreTrainedModel | WinMLCompositeModel, ) -> None: self.model = model self.config = config @@ -138,7 +139,7 @@ def prepare_pipeline(self) -> Pipeline: # can't be statically matched. The string-task fallback handles unknown tasks. return cast( "Pipeline", - pipeline( # type: ignore[call-overload] + pipeline( # type: ignore[call-overload, misc] # 60+ Literal overloads + union model arg pipeline_task, model=self.model, framework="pt", diff --git a/src/winml/modelkit/eval/evaluate.py b/src/winml/modelkit/eval/evaluate.py index c5633f1e3..e222bc519 100644 --- a/src/winml/modelkit/eval/evaluate.py +++ b/src/winml/modelkit/eval/evaluate.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from ..models.winml.base import WinMLPreTrainedModel + from ..models.winml.composite_model import WinMLCompositeModel from .base_evaluator import WinMLEvaluator logger = logging.getLogger(__name__) @@ -190,7 +191,7 @@ def to_dict(self) -> dict[str, Any]: } -def _load_model(config: WinMLEvaluationConfig) -> WinMLPreTrainedModel: +def _load_model(config: WinMLEvaluationConfig) -> WinMLPreTrainedModel | WinMLCompositeModel: """Load model from ONNX path or HF model ID.""" from ..models import WinMLAutoModel from ..utils import cli as cli_utils diff --git a/src/winml/modelkit/export/io.py b/src/winml/modelkit/export/io.py index 139200802..3bedca226 100644 --- a/src/winml/modelkit/export/io.py +++ b/src/winml/modelkit/export/io.py @@ -60,7 +60,11 @@ class ONNXConfigNotFoundError(ValueError): # Create register with overwrite_existing=True to override Optimum's defaults. # Optimum's register_tasks_manager_onnx uses overwrite_existing=False, which means # registrations are silently skipped if a config already exists for the model/task. -register_onnx_overwrite = TasksManager.create_register("onnx", overwrite_existing=True) +# Explicit annotation: the RHS is Any (optimum is untyped), and the import cycle +# export.io -> models.hf -> export makes the *inferred* type undeterminable in a +# combined mypy run ([has-type] at every @register_onnx_overwrite site). Declaring +# the type breaks that ordering dependency. +register_onnx_overwrite: Any = TasksManager.create_register("onnx", overwrite_existing=True) def ensure_hf_models_registered() -> None: diff --git a/src/winml/modelkit/inference/engine.py b/src/winml/modelkit/inference/engine.py index 7d0a86205..5744bccd1 100644 --- a/src/winml/modelkit/inference/engine.py +++ b/src/winml/modelkit/inference/engine.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: from ..models.winml.base import WinMLPreTrainedModel + from ..models.winml.composite_model import WinMLCompositeModel from ..utils.constants import EPNameOrAlias logger = logging.getLogger(__name__) @@ -269,7 +270,7 @@ class InferenceEngine: """ def __init__(self) -> None: - self._model: WinMLPreTrainedModel | None = None + self._model: WinMLPreTrainedModel | WinMLCompositeModel | None = None self._pipeline: Any | None = None # transformers.Pipeline self._model_id: str | None = None self._task: str | None = None diff --git a/src/winml/modelkit/inference/pipeline.py b/src/winml/modelkit/inference/pipeline.py index 33e243cfe..fc8bdb66c 100644 --- a/src/winml/modelkit/inference/pipeline.py +++ b/src/winml/modelkit/inference/pipeline.py @@ -27,6 +27,7 @@ from collections.abc import Mapping from ..models.winml.base import WinMLPreTrainedModel + from ..models.winml.composite_model import WinMLCompositeModel logger = logging.getLogger(__name__) @@ -39,7 +40,7 @@ def create_pipeline( task: str, - model: WinMLPreTrainedModel, + model: WinMLPreTrainedModel | WinMLCompositeModel, model_id: str | None = None, ) -> Any: """Create an HF pipeline for a WinML model. diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index d184091c5..ee35ea313 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -24,6 +24,7 @@ from __future__ import annotations import logging +from collections.abc import Mapping from pathlib import Path from typing import TYPE_CHECKING, Any @@ -35,13 +36,12 @@ if TYPE_CHECKING: - from collections.abc import Mapping - from transformers import PretrainedConfig from ..config import WinMLBuildConfig from ..utils.constants import EPNameOrAlias from .winml.base import WinMLPreTrainedModel + from .winml.composite_model import WinMLCompositeModel logger = logging.getLogger(__name__) @@ -119,7 +119,7 @@ def from_onnx( session_options: Any | None = None, hf_config: PretrainedConfig | None = None, **kwargs: Any, - ) -> WinMLPreTrainedModel | WinMLCompositeModel: # noqa: F821 + ) -> WinMLPreTrainedModel | WinMLCompositeModel: """Build from a pre-exported ONNX file. Runs optimize -> [quantize] -> [compile] via ``build_onnx_model()``. @@ -149,7 +149,7 @@ def from_onnx( Returns: WinMLPreTrainedModel inference wrapper. """ - if isinstance(onnx_path, dict): + if isinstance(onnx_path, Mapping): from .winml.composite_model import WinMLCompositeModel return WinMLCompositeModel.from_onnx( @@ -268,7 +268,7 @@ def from_pretrained( no_compile: bool = False, allow_unsupported_nodes: bool = False, **kwargs: Any, - ) -> WinMLPreTrainedModel: + ) -> WinMLPreTrainedModel | WinMLCompositeModel: """Load appropriate WinML model based on task detection. Supports two input modes: @@ -402,6 +402,8 @@ def from_pretrained( resolved_task = build_config.loader.task logger.debug("Generated config with task: %s", resolved_task) + if resolved_task is None: + raise ValueError(f"Could not resolve a task for model {model_id!r}.") config = build_config task = resolved_task diff --git a/src/winml/modelkit/models/hf/__init__.py b/src/winml/modelkit/models/hf/__init__.py index c6f4c9520..dbafd112d 100644 --- a/src/winml/modelkit/models/hf/__init__.py +++ b/src/winml/modelkit/models/hf/__init__.py @@ -85,18 +85,27 @@ # reverse-looks-up the task name from the matching (model_type, default_task) # entry. See sam.py for the canonical example (mask-generation default for # SAM/SAM2). +# Built via a comprehension (per-key write) rather than ** unpacking: the +# sub-mappings are keyed tuple[str, str] while this aggregate is keyed +# tuple[str, str | None] (for the task=None sentinel). dict key types are +# invariant, so ** unpacking is rejected even though every tuple[str, str] is a +# valid tuple[str, str | None]; a write-site assignment is covariant in the key. MODEL_CLASS_MAPPING: dict[tuple[str, str | None], type] = { - **_BART_CLASS_MAPPING, - **_BLIP_CLASS_MAPPING, - **_CLIP_CLASS_MAPPING, - **_MARIAN_CLASS_MAPPING, - **_MU2_CLASS_MAPPING, - **_QWEN_CLASS_MAPPING, - **_SAM2_CLASS_MAPPING, - **_SEGFORMER_CLASS_MAPPING, - **_SIGLIP_CLASS_MAPPING, - **_T5_CLASS_MAPPING, - **_VED_CLASS_MAPPING, + _key: _model_cls + for _sub_mapping in ( + _BART_CLASS_MAPPING, + _BLIP_CLASS_MAPPING, + _CLIP_CLASS_MAPPING, + _MARIAN_CLASS_MAPPING, + _MU2_CLASS_MAPPING, + _QWEN_CLASS_MAPPING, + _SAM2_CLASS_MAPPING, + _SEGFORMER_CLASS_MAPPING, + _SIGLIP_CLASS_MAPPING, + _T5_CLASS_MAPPING, + _VED_CLASS_MAPPING, + ) + for _key, _model_cls in _sub_mapping.items() } # Registry: model_type -> WinMLBuildConfig diff --git a/src/winml/modelkit/models/hf/blip.py b/src/winml/modelkit/models/hf/blip.py index 5063f9561..b430769c0 100644 --- a/src/winml/modelkit/models/hf/blip.py +++ b/src/winml/modelkit/models/hf/blip.py @@ -35,7 +35,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast import torch import torch.nn as nn @@ -55,7 +55,7 @@ if TYPE_CHECKING: - from transformers import PretrainedConfig + from transformers import GenerationConfig, PretrainedConfig # ============================================================================= @@ -144,7 +144,8 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> BlipVisionEn def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: """Trace ``pixel_values → encoder_hidden_states``.""" - return self.vision_model(pixel_values=pixel_values).last_hidden_state + # self.vision_model is a torch submodule (untyped __call__ -> Any). + return cast("torch.Tensor", self.vision_model(pixel_values=pixel_values).last_hidden_state) @register_onnx_overwrite("blip", "feature-extraction", library_name="transformers") @@ -264,7 +265,9 @@ def _invoke_hf(self, cache: Any, inputs: dict[str, torch.Tensor]) -> torch.Tenso dtype=torch.long, device=encoder_hidden_states.device, ) - outputs = self.model.text_decoder( + # self.model is nn.Module; torch's __getattr__ types text_decoder as + # Tensor | Module, so narrow to a callable Module. + outputs = cast("nn.Module", self.model.text_decoder)( input_ids=inputs["decoder_input_ids"], # HF's causal-mask reconstruction traces as ops the NPU analyzer # doesn't support; pass a 3-D mask to bypass that reconstruction. @@ -280,7 +283,7 @@ def _invoke_hf(self, cache: Any, inputs: dict[str, torch.Tensor]) -> torch.Tenso cache_position=inputs["cache_position"], return_dict=True, ) - return outputs.logits + return cast("torch.Tensor", outputs.logits) # ============================================================================= @@ -315,7 +318,7 @@ def get_cache_class(cls) -> type: # noqa: D102 return WinMLStaticCache @property - def generation_config(self): # noqa: D102 + def generation_config(self) -> GenerationConfig: # noqa: D102 if not hasattr(self, "_generation_config"): from transformers import GenerationConfig diff --git a/src/winml/modelkit/models/hf/convnext.py b/src/winml/modelkit/models/hf/convnext.py index 5a43b0a30..cbde93997 100644 --- a/src/winml/modelkit/models/hf/convnext.py +++ b/src/winml/modelkit/models/hf/convnext.py @@ -24,6 +24,7 @@ from __future__ import annotations import logging +from typing import Any import torch import torch.nn.functional as F @@ -41,7 +42,7 @@ # --------------------------------------------------------------------------- -def _patched_layernorm_forward(self, x: torch.Tensor) -> torch.Tensor: +def _patched_layernorm_forward(self: Any, x: torch.Tensor) -> torch.Tensor: """ConvNextLayerNorm.forward replacement that enables ONNX LayerNorm fusion. The stock implementation branches on ``data_format`` with code paths diff --git a/src/winml/modelkit/models/hf/depth_anything.py b/src/winml/modelkit/models/hf/depth_anything.py index 456b93eee..4c6d1808f 100644 --- a/src/winml/modelkit/models/hf/depth_anything.py +++ b/src/winml/modelkit/models/hf/depth_anything.py @@ -16,6 +16,8 @@ from __future__ import annotations +from typing import Any + from optimum.exporters.onnx import OnnxConfig from optimum.utils import DEFAULT_DUMMY_SHAPES, NormalizedConfig from optimum.utils.input_generators import DummyVisionInputGenerator @@ -37,13 +39,13 @@ class _DepthAnythingVisionInputGenerator(DummyVisionInputGenerator): # type: ig def __init__( self, task: str, - normalized_config, + normalized_config: NormalizedConfig, batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], width: int = DEFAULT_DUMMY_SHAPES["width"], height: int = DEFAULT_DUMMY_SHAPES["height"], - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__( task, normalized_config, diff --git a/src/winml/modelkit/models/hf/mu2.py b/src/winml/modelkit/models/hf/mu2.py index 2a8dd11cc..e5dfd1cc3 100644 --- a/src/winml/modelkit/models/hf/mu2.py +++ b/src/winml/modelkit/models/hf/mu2.py @@ -59,7 +59,7 @@ class for Mu2 (custom ``trust_remote_code`` model). from __future__ import annotations -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast import torch import torch.nn as nn @@ -75,6 +75,10 @@ class for Mu2 (custom ``trust_remote_code`` model). from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLSlidingWindowCache +if TYPE_CHECKING: + from transformers import GenerationConfig + + # ============================================================================= # Wrapper nn.Modules # ============================================================================= @@ -85,7 +89,9 @@ class Mu2EncoderWrapper(nn.Module): def __init__(self, model: nn.Module) -> None: super().__init__() - self.encoder = model.encoder + # model is typed nn.Module, so torch's __getattr__ types submodule/attr + # access as Tensor | Module; narrow to their real types. + self.encoder = cast("nn.Module", model.encoder) self.config = model.config @classmethod @@ -100,9 +106,8 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> Mu2EncoderWr def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """Return encoder last hidden state.""" - return self.encoder( - input_ids=input_ids, attention_mask=attention_mask.bool() - ).last_hidden_state + out = self.encoder(input_ids=input_ids, attention_mask=attention_mask.bool()) + return cast("torch.Tensor", out.last_hidden_state) class Mu2DecoderWrapper(nn.Module): @@ -118,8 +123,11 @@ class Mu2DecoderWrapper(nn.Module): def __init__(self, model: nn.Module) -> None: super().__init__() self.model = model - self.config = model.config - self.num_layers = model.config.n_decoder_layer + # Mu2's config is a custom architecture config with dynamic attributes + # (n_decoder_layer, n_kv_head, head_dim) absent from any stub. + config: Any = model.config + self.config = config + self.num_layers = config.n_decoder_layer @classmethod def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> Mu2DecoderWrapper: @@ -170,7 +178,9 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: # Delegate to model's decoder — position_id is passed as cache_position # for RoPE computation (WinMLSlidingWindowCache.update ignores it for indexing) - hidden_states = self.model.decoder( + # self.model is nn.Module; torch's __getattr__ types submodules as + # Tensor | Module, so narrow decoder/lm_head to callable Modules. + hidden_states = cast("nn.Module", self.model.decoder)( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_hidden_states, @@ -178,7 +188,7 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: past_key_values=cache, cache_position=position_id, ) - logits = self.model.lm_head(hidden_states) + logits = cast("nn.Module", self.model.lm_head)(hidden_states) # Output new-token KV only (same as T5 — captured during update) result: list[torch.Tensor] = [logits] @@ -298,7 +308,7 @@ def get_cache_class(cls) -> type: # noqa: D102 return WinMLSlidingWindowCache @property - def generation_config(self): # noqa: D102 + def generation_config(self) -> GenerationConfig: # noqa: D102 if not hasattr(self, "_generation_config"): from transformers import GenerationConfig diff --git a/src/winml/modelkit/models/hf/qwen.py b/src/winml/modelkit/models/hf/qwen.py index 0b8c9b45b..913054997 100644 --- a/src/winml/modelkit/models/hf/qwen.py +++ b/src/winml/modelkit/models/hf/qwen.py @@ -101,7 +101,7 @@ from __future__ import annotations -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast import torch import torch.nn as nn @@ -123,6 +123,10 @@ from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLSlidingWindowCache +if TYPE_CHECKING: + from transformers import GenerationConfig, PretrainedConfig + + # ============================================================================= # Wrapper nn.Module # ============================================================================= @@ -144,7 +148,9 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: super().__init__() self.model = model self.num_layers = num_layers - self.config = model.config + # model is typed nn.Module, so torch's __getattr__ types .config as + # Tensor | Module; it is really the model's PretrainedConfig. + self.config = cast("PretrainedConfig", model.config) @classmethod def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> QwenDecoderWrapper: @@ -351,7 +357,7 @@ def get_cache_class(cls) -> type: # noqa: D102 return WinMLSlidingWindowCache @property - def generation_config(self): # noqa: D102 + def generation_config(self) -> GenerationConfig: # noqa: D102 if not hasattr(self, "_generation_config"): from transformers import GenerationConfig diff --git a/src/winml/modelkit/models/hf/roberta.py b/src/winml/modelkit/models/hf/roberta.py index 94b31bf7e..9ad74fb8e 100644 --- a/src/winml/modelkit/models/hf/roberta.py +++ b/src/winml/modelkit/models/hf/roberta.py @@ -23,6 +23,7 @@ from __future__ import annotations import logging +from typing import Any from optimum.exporters.onnx.model_configs import ( COMMON_TEXT_TASKS, @@ -57,7 +58,7 @@ # ============================================================================= -def _adjust_position_embeddings(config) -> None: +def _adjust_position_embeddings(config: Any) -> None: """Adjust max_position_embeddings for Roberta-style position offset. Roberta-family models define: @@ -116,9 +117,11 @@ class _RobertaPositionOffsetMixin: ) DUMMY_INPUT_GENERATOR_CLASSES = (MaxLengthTextInputGenerator,) - def __init__(self, config, task, **kwargs): + def __init__(self, config: Any, task: str, **kwargs: Any) -> None: _adjust_position_embeddings(config) - super().__init__(config, task, **kwargs) + # Mixin is first in MRO; super() resolves to the (untyped) OnnxConfig + # base at runtime, which mypy can't see, so it thinks super() is object. + super().__init__(config, task, **kwargs) # type: ignore[call-arg] @register_onnx_overwrite("roberta", *COMMON_TEXT_TASKS, library_name="transformers") diff --git a/src/winml/modelkit/models/hf/segformer.py b/src/winml/modelkit/models/hf/segformer.py index 0748545b0..b5860e88f 100644 --- a/src/winml/modelkit/models/hf/segformer.py +++ b/src/winml/modelkit/models/hf/segformer.py @@ -16,6 +16,8 @@ from __future__ import annotations +from typing import Any + from optimum.exporters.onnx import OnnxConfig from optimum.utils import DEFAULT_DUMMY_SHAPES, NormalizedConfig from optimum.utils.input_generators import DummyVisionInputGenerator @@ -49,13 +51,13 @@ class _SegformerVisionInputGenerator(DummyVisionInputGenerator): # type: ignore def __init__( self, task: str, - normalized_config, + normalized_config: NormalizedConfig, batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], width: int = DEFAULT_DUMMY_SHAPES["width"], height: int = DEFAULT_DUMMY_SHAPES["height"], - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__( task, normalized_config, diff --git a/src/winml/modelkit/models/hf/t5.py b/src/winml/modelkit/models/hf/t5.py index 48a906f43..116dc4754 100644 --- a/src/winml/modelkit/models/hf/t5.py +++ b/src/winml/modelkit/models/hf/t5.py @@ -24,7 +24,7 @@ from __future__ import annotations -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast import torch import torch.nn as nn @@ -42,6 +42,10 @@ from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLSlidingWindowCache +if TYPE_CHECKING: + from transformers import GenerationConfig, PretrainedConfig + + # ============================================================================= # Wrapper nn.Modules (with from_pretrained, like SAM2 wrappers) # ============================================================================= @@ -71,10 +75,14 @@ def forward( attention_mask: torch.Tensor, ) -> torch.Tensor: """Return encoder last hidden state.""" - return self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - ).last_hidden_state + # self.encoder is a torch submodule (untyped __call__ -> Any). + return cast( + "torch.Tensor", + self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ).last_hidden_state, + ) class T5DecoderWrapper(nn.Module): @@ -110,7 +118,9 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: self.model = model self.num_layers = num_layers # Expose config for OnnxConfig / NormalizedConfig access - self.config = model.config + # model is typed nn.Module, so torch's __getattr__ types .config as + # Tensor | Module; it is really the model's PretrainedConfig. + self.config = cast("PretrainedConfig", model.config) @classmethod def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> T5DecoderWrapper: @@ -351,7 +361,7 @@ def get_cache_class(cls) -> type: return WinMLSlidingWindowCache @property - def generation_config(self): # noqa: D102 + def generation_config(self) -> GenerationConfig: # noqa: D102 if not hasattr(self, "_generation_config"): from transformers import GenerationConfig diff --git a/src/winml/modelkit/models/winml/__init__.py b/src/winml/modelkit/models/winml/__init__.py index df3c9d94a..c201275e8 100644 --- a/src/winml/modelkit/models/winml/__init__.py +++ b/src/winml/modelkit/models/winml/__init__.py @@ -108,7 +108,7 @@ def _import_winml_class(class_name: str) -> type[WinMLPreTrainedModel]: return class_map[class_name] -def get_winml_class(model_type: str | None, task: str) -> type[WinMLPreTrainedModel]: +def get_winml_class(model_type: str | None, task: str | None) -> type[WinMLPreTrainedModel]: """Get appropriate WinML class using three-level mapping. Level 1: Check class mapping (model_type, task) -> specialized class @@ -126,7 +126,7 @@ def get_winml_class(model_type: str | None, task: str) -> type[WinMLPreTrainedMo model_type_normalized = model_type.lower().replace("_", "-") if model_type else None # Level 1: Check for (model_type, task) class mapping - if model_type_normalized is not None: + if model_type_normalized is not None and task is not None: specialized_name = WINML_MODEL_CLASS_MAPPING.get((model_type_normalized, task)) else: specialized_name = None @@ -135,7 +135,7 @@ def get_winml_class(model_type: str | None, task: str) -> type[WinMLPreTrainedMo return _import_winml_class(specialized_name) # Level 2: Universal class by task - class_name = TASK_TO_WINML_CLASS.get(task) + class_name = TASK_TO_WINML_CLASS.get(task) if task is not None else None if class_name is not None: # Try to import the class - if not implemented, fall through to Level 3 try: diff --git a/src/winml/modelkit/models/winml/composite_model.py b/src/winml/modelkit/models/winml/composite_model.py index 1d89b6c66..cc5d8e517 100644 --- a/src/winml/modelkit/models/winml/composite_model.py +++ b/src/winml/modelkit/models/winml/composite_model.py @@ -109,7 +109,7 @@ class WinMLCompositeModel(PreTrainedModel): def __init__( self, sub_models: dict[str, Any], - config: PretrainedConfig | None, + config: PretrainedConfig, device: str = "cpu", ) -> None: self.sub_models = sub_models @@ -267,6 +267,8 @@ def from_onnx( merged = {**kwargs, "task": component_task, **per_component.get(name, {})} sub_models[name] = WinMLAutoModel.from_onnx(Path(path), **merged) + if hf_config is None: + raise ValueError("Composite model construction requires an HF config (hf_config).") return resolved_cls(sub_models=sub_models, config=hf_config) @property diff --git a/src/winml/modelkit/models/winml/decoder_only.py b/src/winml/modelkit/models/winml/decoder_only.py index d47e78703..1a9aca262 100644 --- a/src/winml/modelkit/models/winml/decoder_only.py +++ b/src/winml/modelkit/models/winml/decoder_only.py @@ -258,8 +258,6 @@ def _resolve_cache(self, past_key_values: Any) -> Any: if isinstance(past_key_values, WinMLCache): return past_key_values - if self.config is None: - raise ValueError("Decoder-only generation requires an HF config to build the KV cache.") kv_shape = [1, self._num_kv_heads, self._max_cache_len, self._head_dim] cache = self.get_cache_class().create(self.config, kv_shape, self._kv_dtype) cache.reset() diff --git a/src/winml/modelkit/models/winml/encoder_decoder.py b/src/winml/modelkit/models/winml/encoder_decoder.py index 8430b2477..d488ae842 100644 --- a/src/winml/modelkit/models/winml/encoder_decoder.py +++ b/src/winml/modelkit/models/winml/encoder_decoder.py @@ -294,8 +294,6 @@ def _resolve_cache(self, past_key_values: Any) -> Any: return past_key_values # (3) Create fresh cache and reset - if self.config is None: - raise ValueError("Encoder-decoder generation requires an HF config to build the KV cache.") kv_shape = self._dec_expected["past_0_key"] cache = self.get_cache_class().create(self.config, kv_shape, self._kv_dtype) cache.reset() diff --git a/src/winml/modelkit/models/winml/kv_cache.py b/src/winml/modelkit/models/winml/kv_cache.py index 18cced165..db7059832 100644 --- a/src/winml/modelkit/models/winml/kv_cache.py +++ b/src/winml/modelkit/models/winml/kv_cache.py @@ -42,7 +42,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast from optimum.utils.input_generators import DummyInputGenerator from transformers import StaticCache @@ -158,8 +158,10 @@ def reset(self) -> None: self.step = 0 self.captured.clear() for i in range(self.num_layers): - self.layers[i].keys.zero_() - self.layers[i].values.zero_() + # keys/values are typed Tensor | None (None only pre-init); here the + # cache is always initialized, so narrow to Tensor. + cast("torch.Tensor", self.layers[i].keys).zero_() + cast("torch.Tensor", self.layers[i].values).zero_() @classmethod def create( @@ -198,7 +200,7 @@ class WinMLStaticCache(WinMLCache): Mask is left-aligned: ``[1, 1, ..., 1, 0, 0, ..., 0]``. """ - position_input_name: str = "cache_position" + position_input_name: ClassVar[str] = "cache_position" def update( self, @@ -216,10 +218,14 @@ def update( import torch self.captured[layer_idx] = (key_states, value_states) + if cache_kwargs is None: + raise ValueError("update() requires cache_kwargs with 'cache_position'") cache_position = cache_kwargs["cache_position"] - k_out = self.layers[layer_idx].keys - v_out = self.layers[layer_idx].values + # keys/values are typed Tensor | None (None only pre-init); the cache is + # always initialized before update(), so narrow to Tensor. + k_out = cast("torch.Tensor", self.layers[layer_idx].keys) + v_out = cast("torch.Tensor", self.layers[layer_idx].values) bsz, n_heads, n_new, _ = key_states.shape bi = torch.arange(bsz, device=k_out.device).view(bsz, 1, 1).expand(bsz, n_heads, n_new) @@ -281,7 +287,7 @@ class WinMLSlidingWindowCache(WinMLCache): Mask is right-aligned: ``[0, 0, ..., 0, 1, 1, ..., 1]``. """ - position_input_name: str = "position_id" + position_input_name: ClassVar[str] = "position_id" def update( self, @@ -299,11 +305,15 @@ def update( self.captured[layer_idx] = (key_states, value_states) n = key_states.size(2) - old_k = self.layers[layer_idx].keys[:, :, n:, :] + # keys/values are typed Tensor | None (None only pre-init); the cache is + # always initialized before update(), so narrow to Tensor. + cur_k = cast("torch.Tensor", self.layers[layer_idx].keys) + cur_v = cast("torch.Tensor", self.layers[layer_idx].values) + old_k = cur_k[:, :, n:, :] new_k = torch.cat([old_k, key_states], dim=2) self.layers[layer_idx].keys = new_k - old_v = self.layers[layer_idx].values[:, :, n:, :] + old_v = cur_v[:, :, n:, :] new_v = torch.cat([old_v, value_states], dim=2) self.layers[layer_idx].values = new_v @@ -356,7 +366,7 @@ def prepare_prefill_chunk( def get_seq_length(self, layer_idx: int = 0) -> int: """Filled positions: ``min(step, max_cache_len)``.""" - max_len = self.layers[layer_idx].keys.shape[2] + max_len = cast("torch.Tensor", self.layers[layer_idx].keys).shape[2] return min(self.step, max_len) @@ -365,14 +375,14 @@ def get_seq_length(self, layer_idx: int = 0) -> int: # ============================================================================= -class PastKeyValueInputGenerator(DummyInputGenerator): +class PastKeyValueInputGenerator(DummyInputGenerator): # type: ignore[misc] # optimum/transformers base is untyped """Generates ``past_{i}_key`` / ``past_{i}_value`` tensors for static KV cache. Reads ``num_layers``, ``num_attention_heads``, ``head_dim``, and ``max_cache_len`` from the ``NormalizedConfig``. """ - SUPPORTED_INPUT_NAMES = () # dynamic — built in __init__ + SUPPORTED_INPUT_NAMES: tuple[str, ...] = () # dynamic — built in __init__ def __init__( self, @@ -399,8 +409,12 @@ def generate( float_dtype: str = "fp32", ) -> torch.Tensor: """Return a random float tensor of shape ``[batch, heads, max_cache_len, head_dim]``.""" - return self.random_float_tensor( - (self.batch_size, self.num_heads, self.max_cache_len, self.head_dim), - framework=framework, - dtype=float_dtype, + # optimum's DummyInputGenerator is untyped, so random_float_tensor returns Any. + return cast( + "torch.Tensor", + self.random_float_tensor( + (self.batch_size, self.num_heads, self.max_cache_len, self.head_dim), + framework=framework, + dtype=float_dtype, + ), ) From 8faee9d621c15e265e236f58128b4bf27fe605b7 Mon Sep 17 00:00:00 2001 From: Hualiang Xie Date: Thu, 25 Jun 2026 11:19:33 +0800 Subject: [PATCH 5/7] more --- .../modelkit/models/hf/decoder_wrapper.py | 13 +- src/winml/modelkit/models/hf/sam.py | 117 +++++++++++------- 2 files changed, 80 insertions(+), 50 deletions(-) diff --git a/src/winml/modelkit/models/hf/decoder_wrapper.py b/src/winml/modelkit/models/hf/decoder_wrapper.py index ab6932df6..e83148268 100644 --- a/src/winml/modelkit/models/hf/decoder_wrapper.py +++ b/src/winml/modelkit/models/hf/decoder_wrapper.py @@ -35,17 +35,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, ClassVar +from typing import Any, ClassVar import torch import torch.nn as nn from optimum.exporters.onnx import OnnxConfig +from transformers import PreTrainedModel -from ..winml.kv_cache import WinMLStaticCache - - -if TYPE_CHECKING: - from ..winml.kv_cache import WinMLCache +from ..winml.kv_cache import WinMLCache, WinMLStaticCache class WinMLDecoderWrapper(nn.Module, ABC): @@ -63,10 +60,10 @@ class WinMLDecoderWrapper(nn.Module, ABC): num_layers — derived from ``onnx_config._normalized_config.num_layers`` """ - _HF_MODEL_CLS: ClassVar[type] + _HF_MODEL_CLS: ClassVar[type[PreTrainedModel]] # set per-subclass to a concrete HF model class _IO_CONFIG_CLS: ClassVar[type] _TASK: ClassVar[str] = "text2text-generation" - _CACHE_CLS: ClassVar[type] = WinMLStaticCache + _CACHE_CLS: ClassVar[type[WinMLCache]] = WinMLStaticCache # ---- Instance attrs ---- model: nn.Module diff --git a/src/winml/modelkit/models/hf/sam.py b/src/winml/modelkit/models/hf/sam.py index 0d5976118..deff3fcc7 100644 --- a/src/winml/modelkit/models/hf/sam.py +++ b/src/winml/modelkit/models/hf/sam.py @@ -34,7 +34,7 @@ from __future__ import annotations import types -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch import torch.nn.functional as F @@ -51,6 +51,8 @@ if TYPE_CHECKING: + from collections.abc import Callable + from optimum.utils import NormalizedConfig @@ -88,7 +90,7 @@ def __init__(self, vision_encoder: torch.nn.Module) -> None: self.config = vision_encoder.config @classmethod - def from_pretrained(cls, model_name_or_path: str, **kwargs) -> Sam2VisionEncoder: + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> Sam2VisionEncoder: full_model = Sam2Model.from_pretrained(model_name_or_path, **kwargs) return cls(full_model.vision_encoder) @@ -134,12 +136,12 @@ class SAM2MaskGeneration(torch.nn.Module): """ @classmethod - def from_pretrained(cls, model_name_or_path: str, **kwargs) -> SAM2MaskGeneration: + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> SAM2MaskGeneration: """Load from a HuggingFace Sam2Model checkpoint.""" sam2_model = Sam2Model.from_pretrained(model_name_or_path, **kwargs) return cls(sam2_model) - def __init__(self, sam2_model): + def __init__(self, sam2_model: Sam2Model) -> None: super().__init__() self.prompt_encoder = sam2_model.prompt_encoder @@ -157,8 +159,11 @@ def __init__(self, sam2_model): def _get_image_positional_embeddings(self, batch_size: int = 1) -> torch.Tensor: """Replicates Sam2Model.get_image_wide_positional_embeddings().""" size = self.image_embedding_size - target_device = self.shared_image_embedding.positional_embedding.device - target_dtype = self.shared_image_embedding.positional_embedding.dtype + # positional_embedding is a registered Parameter reached via torch's + # __getattr__ (typed Tensor | Module); narrow to Tensor for device/dtype. + pos_emb = cast("torch.Tensor", self.shared_image_embedding.positional_embedding) + target_device = pos_emb.device + target_dtype = pos_emb.dtype grid = torch.ones(size, device=target_device, dtype=target_dtype) y_embed = grid.cumsum(dim=0) - 0.5 @@ -168,7 +173,8 @@ def _get_image_positional_embeddings(self, batch_size: int = 1) -> torch.Tensor: positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) positional_embedding = positional_embedding.permute(2, 0, 1).unsqueeze(0) - return positional_embedding.repeat(batch_size, 1, 1, 1) + # shared_image_embedding is a torch submodule (untyped __call__ -> Any). + return cast("torch.Tensor", positional_embedding.repeat(batch_size, 1, 1, 1)) def forward( self, @@ -265,12 +271,12 @@ class SAMMaskGeneration(torch.nn.Module): """ @classmethod - def from_pretrained(cls, model_name_or_path: str, **kwargs) -> SAMMaskGeneration: + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> SAMMaskGeneration: """Load from a HuggingFace SamModel checkpoint.""" sam_model = SamModel.from_pretrained(model_name_or_path, **kwargs) return cls(sam_model) - def __init__(self, sam_model): + def __init__(self, sam_model: SamModel) -> None: super().__init__() self.prompt_encoder = sam_model.prompt_encoder @@ -282,8 +288,11 @@ def __init__(self, sam_model): def _get_image_positional_embeddings(self, batch_size: int = 1) -> torch.Tensor: """Replicates SamModel.get_image_wide_positional_embeddings().""" size = self.config.prompt_encoder_config.image_embedding_size - target_device = self.shared_image_embedding.positional_embedding.device - target_dtype = self.shared_image_embedding.positional_embedding.dtype + # positional_embedding is a registered Parameter reached via torch's + # __getattr__ (typed Tensor | Module); narrow to Tensor for device/dtype. + pos_emb = cast("torch.Tensor", self.shared_image_embedding.positional_embedding) + target_device = pos_emb.device + target_dtype = pos_emb.dtype grid = torch.ones((size, size), device=target_device, dtype=target_dtype) y_embed = grid.cumsum(dim=0) - 0.5 @@ -293,7 +302,8 @@ def _get_image_positional_embeddings(self, batch_size: int = 1) -> torch.Tensor: positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) positional_embedding = positional_embedding.permute(2, 0, 1).unsqueeze(0) - return positional_embedding.repeat(batch_size, 1, 1, 1) + # shared_image_embedding is a torch submodule (untyped __call__ -> Any). + return cast("torch.Tensor", positional_embedding.repeat(batch_size, 1, 1, 1)) def forward( self, @@ -395,7 +405,9 @@ def forward( # discovers optimization flags automatically. See issue #232. -def _window_partition(hidden_state: torch.Tensor, window_size: int): +def _window_partition( + hidden_state: torch.Tensor, window_size: int +) -> tuple[torch.Tensor, tuple[int, int]]: """QNN-compatible window partition (5D instead of 6D). Original HF creates 6D view: [B, H//ws, ws, W//ws, ws, C] @@ -451,7 +463,7 @@ def _window_unpartition( def _patched_sam2_multiscale_block_forward( - self, hidden_states: torch.Tensor, **kwargs + self: Any, hidden_states: torch.Tensor, **kwargs: Any ) -> torch.Tensor: """Patched Sam2MultiScaleBlock.forward with 5D window functions. @@ -463,7 +475,7 @@ def _patched_sam2_multiscale_block_forward( """ # No windowing needed, use original if self.window_size <= 0: - return self._original_forward(hidden_states, **kwargs) + return cast("torch.Tensor", self._original_forward(hidden_states, **kwargs)) residual = hidden_states hidden_states = self.layer_norm1(hidden_states) @@ -497,15 +509,20 @@ def _patched_sam2_multiscale_block_forward( pad_hw = (H + pad_h, W + pad_w) if self.window_size > 0: + # Set together with pad_hw under the same window_size > 0 guard above; + # mypy can't correlate the two blocks, so assert the shared invariant. + assert H is not None + assert W is not None + assert pad_hw is not None hidden_states = _window_unpartition(hidden_states, window_size, pad_hw, (H, W)) hidden_states = residual + hidden_states layernorm_output = self.layer_norm2(hidden_states) - return hidden_states + self.mlp(layernorm_output) + return cast("torch.Tensor", hidden_states + self.mlp(layernorm_output)) def _patched_sam2_embed_points( - self, points: torch.Tensor, labels: torch.Tensor, pad: bool + self: Any, points: torch.Tensor, labels: torch.Tensor, pad: bool ) -> torch.Tensor: """Patched _embed_points with arithmetic masking instead of torch.where. @@ -531,11 +548,11 @@ def _patched_sam2_embed_points( # Add point type embedding mask_ge0 = (labels >= 0).unsqueeze(-1).to(point_embedding.dtype) point_embed_lookup = self.point_embed(labels.clamp(min=0)) - return point_embedding + point_embed_lookup * mask_ge0 + return cast("torch.Tensor", point_embedding + point_embed_lookup * mask_ge0) def _patched_sam2_prompt_encoder_forward( - self, + self: Any, input_points: torch.Tensor | None, input_labels: torch.Tensor | None, input_boxes: torch.Tensor | None, @@ -556,7 +573,10 @@ def _patched_sam2_prompt_encoder_forward( # If use_mask_input not provided, use original behavior if use_mask_input is None: - return self._original_forward(input_points, input_labels, input_boxes, input_masks) + return cast( + "tuple[torch.Tensor, torch.Tensor]", + self._original_forward(input_points, input_labels, input_boxes, input_masks), + ) # Get batch size batch_size = 1 @@ -587,7 +607,7 @@ def _patched_sam2_prompt_encoder_forward( # Target class names for instance-level patching. # These are matched by type(module).__name__ to stay architecture-agnostic # (no class import needed; the classes come from transformers internals). -_SAM2_PATCH_TARGETS = { +_SAM2_PATCH_TARGETS: dict[str, Callable[..., Any]] = { "Sam2MultiScaleBlock": _patched_sam2_multiscale_block_forward, "Sam2PromptEncoder": _patched_sam2_prompt_encoder_forward, } @@ -613,7 +633,7 @@ def __init__( super().__init__(config, model, model_kwargs=model_kwargs) self._sam2_originals: list[tuple[torch.nn.Module, Any]] = [] - def __enter__(self): + def __enter__(self) -> Sam2ModelPatcher: super().__enter__() for _name, module in self._model.named_modules(): class_name = type(module).__name__ @@ -624,7 +644,7 @@ def __enter__(self): module.forward = types.MethodType(patch_fn, module) return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: for module, original_forward in self._sam2_originals: module.forward = original_forward if hasattr(module, "_original_forward"): @@ -653,8 +673,8 @@ def __init__( batch_size: int = 1, point_batch_size: int = 1, nb_points_per_image: int = 5, - **kwargs, - ): + **kwargs: Any, + ) -> None: self.task = task self.batch_size = batch_size self.point_batch_size = point_batch_size @@ -666,7 +686,7 @@ def generate( framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32", - ): + ) -> torch.Tensor: if input_name == "input_points": shape = [ self.batch_size, @@ -674,13 +694,20 @@ def generate( self.nb_points_per_image, 2, ] - # Scale to 0-1024 pixel coordinates - return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) * 1024 + # Scale to 0-1024 pixel coordinates. optimum's DummyInputGenerator is + # untyped, so random_*_tensor returns Any. + return cast( + "torch.Tensor", + self.random_float_tensor(shape, framework=framework, dtype=float_dtype) * 1024, + ) # input_labels shape = [self.batch_size, self.point_batch_size, self.nb_points_per_image] # Labels: 1=positive for all test points - return self.random_int_tensor( - shape, max_value=2, min_value=1, framework=framework, dtype=int_dtype + return cast( + "torch.Tensor", + self.random_int_tensor( + shape, max_value=2, min_value=1, framework=framework, dtype=int_dtype + ), ) @@ -704,8 +731,8 @@ def __init__( task: str, normalized_config: NormalizedConfig, batch_size: int = 1, - **kwargs, - ): + **kwargs: Any, + ) -> None: self.task = task self.batch_size = batch_size @@ -715,7 +742,7 @@ def generate( framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32", - ): + ) -> torch.Tensor: if input_name == "image_embeddings": shape = [self.batch_size, 256, 64, 64] elif input_name == "high_res_features0": @@ -725,7 +752,10 @@ def generate( else: raise ValueError(f"Unknown input: {input_name}") - return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + # optimum's DummyInputGenerator is untyped, so random_float_tensor returns Any. + return cast( + "torch.Tensor", self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + ) class Sam2MaskInputGenerator(DummyInputGenerator): # type: ignore[misc] # optimum base is untyped @@ -743,8 +773,8 @@ def __init__( task: str, normalized_config: NormalizedConfig, batch_size: int = 1, - **kwargs, - ): + **kwargs: Any, + ) -> None: self.task = task self.batch_size = batch_size @@ -754,7 +784,7 @@ def generate( framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32", - ): + ) -> torch.Tensor: if input_name == "mask_input": shape = [self.batch_size, 1, 256, 256] return torch.zeros(shape, dtype=torch.float32) @@ -775,7 +805,7 @@ class Sam2NormalizedVisionConfig(NormalizedVisionConfig): # type: ignore[misc] DEFAULT_IMAGE_SIZE = 1024 - def __getattr__(self, attr_name: str): + def __getattr__(self, attr_name: str) -> Any: """Return default image_size when not found in model config.""" try: return super().__getattr__(attr_name) @@ -955,8 +985,8 @@ def __init__( task: str, normalized_config: NormalizedConfig, batch_size: int = 1, - **kwargs, - ): + **kwargs: Any, + ) -> None: self.task = task self.batch_size = batch_size @@ -966,11 +996,14 @@ def generate( framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32", - ): + ) -> torch.Tensor: # SAM v1 decoder export expects the canonical embedding shape from the # vision encoder output; this mirrors the existing SAM2 generator path. shape = [self.batch_size, 256, 64, 64] - return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + # optimum's DummyInputGenerator is untyped, so random_float_tensor returns Any. + return cast( + "torch.Tensor", self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + ) # ============================================================================= From 0a7b0b73f8a9f37d2aad84063ca0852a605ef5fa Mon Sep 17 00:00:00 2001 From: Hualiang Xie Date: Thu, 25 Jun 2026 13:56:39 +0800 Subject: [PATCH 6/7] use more types --- src/winml/modelkit/models/hf/bart.py | 3 +- src/winml/modelkit/models/hf/convnext.py | 8 +- src/winml/modelkit/models/hf/marian.py | 3 +- src/winml/modelkit/models/hf/sam.py | 29 ++++--- .../models/hf/vision_encoder_decoder.py | 76 +++++++++++-------- 5 files changed, 74 insertions(+), 45 deletions(-) diff --git a/src/winml/modelkit/models/hf/bart.py b/src/winml/modelkit/models/hf/bart.py index cb44d2d74..5e529fcc5 100644 --- a/src/winml/modelkit/models/hf/bart.py +++ b/src/winml/modelkit/models/hf/bart.py @@ -95,6 +95,7 @@ if TYPE_CHECKING: from transformers import GenerationConfig, PretrainedConfig + from transformers.models.bart.modeling_bart import BartLearnedPositionalEmbedding logger = logging.getLogger(__name__) @@ -143,7 +144,7 @@ def _patched_bart_learned_forward( - self: Any, # monkey-patched onto BartLearnedPositionalEmbedding (HF internal) + self: BartLearnedPositionalEmbedding, # monkey-patched onto this HF module input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None, diff --git a/src/winml/modelkit/models/hf/convnext.py b/src/winml/modelkit/models/hf/convnext.py index cbde93997..42dc9d48c 100644 --- a/src/winml/modelkit/models/hf/convnext.py +++ b/src/winml/modelkit/models/hf/convnext.py @@ -24,7 +24,7 @@ from __future__ import annotations import logging -from typing import Any +from typing import TYPE_CHECKING import torch import torch.nn.functional as F @@ -34,6 +34,10 @@ from ...export import register_onnx_overwrite +if TYPE_CHECKING: + from transformers.models.convnext.modeling_convnext import ConvNextLayerNorm + + logger = logging.getLogger(__name__) @@ -42,7 +46,7 @@ # --------------------------------------------------------------------------- -def _patched_layernorm_forward(self: Any, x: torch.Tensor) -> torch.Tensor: +def _patched_layernorm_forward(self: ConvNextLayerNorm, x: torch.Tensor) -> torch.Tensor: """ConvNextLayerNorm.forward replacement that enables ONNX LayerNorm fusion. The stock implementation branches on ``data_format`` with code paths diff --git a/src/winml/modelkit/models/hf/marian.py b/src/winml/modelkit/models/hf/marian.py index f3b2abb13..5af97fb73 100644 --- a/src/winml/modelkit/models/hf/marian.py +++ b/src/winml/modelkit/models/hf/marian.py @@ -109,6 +109,7 @@ if TYPE_CHECKING: from transformers import GenerationConfig, PretrainedConfig + from transformers.models.marian.modeling_marian import MarianSinusoidalPositionalEmbedding logger = logging.getLogger(__name__) @@ -180,7 +181,7 @@ def _patched_marian_sinusoidal_forward( - self: Any, # monkey-patched onto MarianSinusoidalPositionalEmbedding (HF internal) + self: MarianSinusoidalPositionalEmbedding, # monkey-patched onto this HF module input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None, diff --git a/src/winml/modelkit/models/hf/sam.py b/src/winml/modelkit/models/hf/sam.py index deff3fcc7..3e7cb2a6f 100644 --- a/src/winml/modelkit/models/hf/sam.py +++ b/src/winml/modelkit/models/hf/sam.py @@ -54,6 +54,7 @@ from collections.abc import Callable from optimum.utils import NormalizedConfig + from transformers.models.sam2.modeling_sam2 import Sam2MultiScaleBlock, Sam2PromptEncoder # ============================================================================= @@ -463,7 +464,7 @@ def _window_unpartition( def _patched_sam2_multiscale_block_forward( - self: Any, hidden_states: torch.Tensor, **kwargs: Any + self: Sam2MultiScaleBlock, hidden_states: torch.Tensor, **kwargs: Any ) -> torch.Tensor: """Patched Sam2MultiScaleBlock.forward with 5D window functions. @@ -473,9 +474,12 @@ def _patched_sam2_multiscale_block_forward( Applied via Sam2ModelPatcher during export. """ + # _original_forward is added dynamically by Sam2ModelPatcher (not on the class), + # so it's reached via nn.Module.__getattr__ (typed Tensor | Module); type it callable. + orig_forward = cast("Callable[..., torch.Tensor]", self._original_forward) # No windowing needed, use original if self.window_size <= 0: - return cast("torch.Tensor", self._original_forward(hidden_states, **kwargs)) + return orig_forward(hidden_states, **kwargs) residual = hidden_states hidden_states = self.layer_norm1(hidden_states) @@ -522,7 +526,7 @@ def _patched_sam2_multiscale_block_forward( def _patched_sam2_embed_points( - self: Any, points: torch.Tensor, labels: torch.Tensor, pad: bool + self: Sam2PromptEncoder, points: torch.Tensor, labels: torch.Tensor, pad: bool ) -> torch.Tensor: """Patched _embed_points with arithmetic masking instead of torch.where. @@ -552,7 +556,7 @@ def _patched_sam2_embed_points( def _patched_sam2_prompt_encoder_forward( - self: Any, + self: Sam2PromptEncoder, input_points: torch.Tensor | None, input_labels: torch.Tensor | None, input_boxes: torch.Tensor | None, @@ -568,15 +572,18 @@ def _patched_sam2_prompt_encoder_forward( Applied via Sam2ModelPatcher during export. """ - # Patch _embed_points to use arithmetic masking - self._embed_points = types.MethodType(_patched_sam2_embed_points, self) + # Patch _embed_points to use arithmetic masking (intentional instance-method patch). + self._embed_points = types.MethodType(_patched_sam2_embed_points, self) # type: ignore[method-assign] + + # _original_forward is added dynamically by Sam2ModelPatcher (not on the class), + # so it's reached via nn.Module.__getattr__ (typed Tensor | Module); type it callable. + orig_forward = cast( + "Callable[..., tuple[torch.Tensor, torch.Tensor]]", self._original_forward + ) # If use_mask_input not provided, use original behavior if use_mask_input is None: - return cast( - "tuple[torch.Tensor, torch.Tensor]", - self._original_forward(input_points, input_labels, input_boxes, input_masks), - ) + return orig_forward(input_points, input_labels, input_boxes, input_masks) # Get batch size batch_size = 1 @@ -586,7 +593,7 @@ def _patched_sam2_prompt_encoder_forward( batch_size = input_boxes.shape[0] # Get sparse embeddings (with patched _embed_points) - sparse_embeddings, _ = self._original_forward(input_points, input_labels, input_boxes, None) + sparse_embeddings, _ = orig_forward(input_points, input_labels, input_boxes, None) # Arithmetic mask blending mask_dense = self.mask_embed(input_masks) diff --git a/src/winml/modelkit/models/hf/vision_encoder_decoder.py b/src/winml/modelkit/models/hf/vision_encoder_decoder.py index 14e9a7c74..35e22e360 100644 --- a/src/winml/modelkit/models/hf/vision_encoder_decoder.py +++ b/src/winml/modelkit/models/hf/vision_encoder_decoder.py @@ -32,7 +32,7 @@ import inspect import logging -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast import torch import torch.nn as nn @@ -49,12 +49,16 @@ from ...optim import WinMLOptimizationConfig from ..winml.composite_model import register_composite_model from ..winml.encoder_decoder import EncoderDecoderInputGenerator, WinMLEncoderDecoderModel -from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLStaticCache +from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLCache, WinMLStaticCache from .decoder_wrapper import WinMLDecoderWrapper, WinMLStaticCacheDecoderIOConfig if TYPE_CHECKING: - from transformers import PretrainedConfig + from transformers import GenerationConfig, PretrainedConfig + from transformers.models.trocr.modeling_trocr import ( + TrOCRLearnedPositionalEmbedding, + TrOCRSinusoidalPositionalEmbedding, + ) logger = logging.getLogger(__name__) @@ -98,7 +102,8 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> VisionEncode def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: """Trace ``pixel_values → encoder_hidden_states``.""" - return self.encoder(pixel_values=pixel_values).last_hidden_state + # self.encoder is a torch submodule (untyped __call__ -> Any). + return cast("torch.Tensor", self.encoder(pixel_values=pixel_values).last_hidden_state) @register_onnx_overwrite( @@ -133,9 +138,9 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 def _patched_trocr_learned_positional_embedding_forward( - self, + self: TrOCRLearnedPositionalEmbedding, # monkey-patched onto this HF module input_ids: torch.Tensor, - past_key_values_length: Any = 0, + past_key_values_length: int = 0, position_ids: torch.Tensor | None = None, ) -> torch.Tensor: """Patched ``TrOCRLearnedPositionalEmbedding.forward``. @@ -169,7 +174,7 @@ def _patched_trocr_learned_positional_embedding_forward( def _patched_trocr_sinusoidal_positional_embedding_forward( - self, + self: TrOCRSinusoidalPositionalEmbedding, # monkey-patched onto this HF module input_ids: torch.Tensor, past_key_values_length: int = 0, ) -> torch.Tensor: @@ -183,27 +188,35 @@ def _patched_trocr_sinusoidal_positional_embedding_forward( use it as the lookup index (shifted by ``padding_idx + 1`` to match HF's offset for non-padded sequences). Otherwise behavior is unchanged. """ + # TrOCR sinusoidal embeddings always set a padding_idx; nn.Embedding types it int | None. + padding_idx = cast("int", self.padding_idx) abs_pos = getattr(self, "position_id", None) if abs_pos is not None: # Patched path: lookup with the side-channel index, shifted by # ``padding_idx + 1`` to match HF's offset for non-padded sequences. if abs_pos.dim() == 1: abs_pos = abs_pos.unsqueeze(0) - position_ids = (abs_pos + self.padding_idx + 1).long() + position_ids = (abs_pos + padding_idx + 1).long() weights = self.weights.to(self._float_tensor) - return weights.index_select(0, position_ids.view(-1)).view( - position_ids.size(0), position_ids.size(1), -1 - ).detach() + return cast( + "torch.Tensor", + weights.index_select(0, position_ids.view(-1)) + .view(position_ids.size(0), position_ids.size(1), -1) + .detach(), + ) # Below: identical to HF's original ``TrOCRSinusoidalPositionalEmbedding.forward``. bsz, seq_len = input_ids.size() position_ids = self.create_position_ids_from_input_ids( - input_ids, self.padding_idx, past_key_values_length + input_ids, padding_idx, past_key_values_length ).to(input_ids.device) - max_pos = self.padding_idx + 1 + seq_len + max_pos = padding_idx + 1 + seq_len if self.weights is None or max_pos > self.weights.size(0): - self.weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx) + self.weights = self.get_embedding(max_pos, self.embedding_dim, padding_idx) self.weights = self.weights.to(self._float_tensor) - return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() + return cast( + "torch.Tensor", + self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach(), + ) def _build_ved_patching_specs() -> list[PatchingSpec]: @@ -265,22 +278,23 @@ def __init__(self, config: Any, **kwargs: Any) -> None: @property def vocab_size(self) -> int: - return self._dec.vocab_size + # _dec/_enc/config are untyped optimum/HF config objects (Any). + return cast("int", self._dec.vocab_size) @property def hidden_size(self) -> int: - return self._dec.hidden_size + return cast("int", self._dec.hidden_size) @property def num_layers(self) -> int: # Not every model family defines ``decoder_layers`` on # ``config.decoder``; fall back to Optimum's NormalizedConfig. decoder_layers = getattr(self.config.decoder, "decoder_layers", None) - return decoder_layers if decoder_layers is not None else self._dec.num_layers + return cast("int", decoder_layers if decoder_layers is not None else self._dec.num_layers) @property def num_attention_heads(self) -> int: - return self._dec.num_attention_heads + return cast("int", self._dec.num_attention_heads) @property def head_dim(self) -> int: @@ -290,7 +304,7 @@ def head_dim(self) -> int: def max_cache_len(self) -> int: # Optimum's normalized configs don't uniformly expose this; read # the raw decoder config field that BART/TrOCR-family use. - return self.config.decoder.max_position_embeddings + return cast("int", self.config.decoder.max_position_embeddings) @property def encoder_hidden_size(self) -> int: @@ -298,17 +312,17 @@ def encoder_hidden_size(self) -> int: # Falls back to encoder.hidden_size when no explicit projection is # configured (HF convention when enc.hidden_size matches). cah = getattr(self.config.decoder, "cross_attention_hidden_size", None) - return cah if cah is not None else self._enc.hidden_size + return cast("int", cah if cah is not None else self._enc.hidden_size) @property def image_size(self) -> int | list[int]: # Some model types ship a scalar (square input); # others ship a ``[H, W]`` list. - return self.config.encoder.image_size + return cast("int | list[int]", self.config.encoder.image_size) @property def patch_size(self) -> int: - return self.config.encoder.patch_size + return cast("int", self.config.encoder.patch_size) @property def encoder_seq_length(self) -> int: @@ -330,12 +344,12 @@ def encoder_seq_length(self) -> int: # Scalar image_size: square ViT-style encoder. if not isinstance(image_size, (list, tuple)): - return (image_size // patch_size) ** 2 + 1 + return cast("int", (image_size // patch_size) ** 2 + 1) # [H, W] image_size: hierarchical Swin-style encoder. h, w = image_size[0], image_size[1] shrink = 2 ** (len(enc.depths) - 1) - return (h // patch_size // shrink) * (w // patch_size // shrink) + return cast("int", (h // patch_size // shrink) * (w // patch_size // shrink)) class VedDecoderInputGenerator(EncoderDecoderInputGenerator): @@ -423,13 +437,15 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> VisionDecode self.eval() return self - def _make_cache(self, inputs: dict[str, torch.Tensor]) -> Any: + def _make_cache(self, inputs: dict[str, torch.Tensor]) -> WinMLCache: cache = super()._make_cache(inputs) # HF decoders that use BART-style learned positional embeddings read # ``past_key_values.get_seq_length()`` to drive the position offset. - # Override so the mask shape reflects the actual generation step. + # Override the instance method with a trace-time constant so the mask + # shape reflects the actual generation step (intentional method patch: + # the lambda returns the cache_position Tensor, not get_seq_length's int). position = inputs["cache_position"].squeeze() - cache.get_seq_length = lambda layer_idx=0: position + cache.get_seq_length = lambda layer_idx=0: position # type: ignore[method-assign, assignment, return-value] return cache def _invoke_hf(self, cache: Any, inputs: dict[str, torch.Tensor]) -> torch.Tensor: @@ -471,7 +487,7 @@ def _invoke_hf(self, cache: Any, inputs: dict[str, torch.Tensor]) -> torch.Tenso return_dict=True, **extra, ) - return outputs.logits + return cast("torch.Tensor", outputs.logits) # ============================================================================= @@ -509,7 +525,7 @@ def get_cache_class(cls) -> type: # noqa: D102 return WinMLStaticCache @property - def generation_config(self): # noqa: D102 + def generation_config(self) -> GenerationConfig: # noqa: D102 if not hasattr(self, "_generation_config"): from transformers import GenerationConfig From 0aff603b41e0f1afa2a136b05a4c56146c55ad39 Mon Sep 17 00:00:00 2001 From: Hualiang Xie Date: Thu, 25 Jun 2026 14:52:24 +0800 Subject: [PATCH 7/7] .. --- src/winml/modelkit/models/hf/bart.py | 2 +- src/winml/modelkit/models/hf/marian.py | 2 +- src/winml/modelkit/models/hf/qwen.py | 2 +- src/winml/modelkit/models/hf/t5.py | 2 +- src/winml/modelkit/models/winml/depth_estimation.py | 3 ++- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/winml/modelkit/models/hf/bart.py b/src/winml/modelkit/models/hf/bart.py index 5e529fcc5..29667fd33 100644 --- a/src/winml/modelkit/models/hf/bart.py +++ b/src/winml/modelkit/models/hf/bart.py @@ -273,7 +273,7 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: # Expose config for OnnxConfig / NormalizedConfig access. # model is typed nn.Module, so torch's __getattr__ types .config as # Tensor | Module; it is really the model's PretrainedConfig. - self.config = cast("PretrainedConfig", model.config) + self.config: PretrainedConfig = cast("PretrainedConfig", model.config) @classmethod def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> BartDecoderWrapper: diff --git a/src/winml/modelkit/models/hf/marian.py b/src/winml/modelkit/models/hf/marian.py index 5af97fb73..016dfc7ce 100644 --- a/src/winml/modelkit/models/hf/marian.py +++ b/src/winml/modelkit/models/hf/marian.py @@ -311,7 +311,7 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: # Expose config for OnnxConfig / NormalizedConfig access # model is typed nn.Module, so torch's __getattr__ types .config as # Tensor | Module; it is really the model's PretrainedConfig. - self.config = cast("PretrainedConfig", model.config) + self.config: PretrainedConfig = cast("PretrainedConfig", model.config) @classmethod def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> MarianDecoderWrapper: diff --git a/src/winml/modelkit/models/hf/qwen.py b/src/winml/modelkit/models/hf/qwen.py index 913054997..e20eef452 100644 --- a/src/winml/modelkit/models/hf/qwen.py +++ b/src/winml/modelkit/models/hf/qwen.py @@ -150,7 +150,7 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: self.num_layers = num_layers # model is typed nn.Module, so torch's __getattr__ types .config as # Tensor | Module; it is really the model's PretrainedConfig. - self.config = cast("PretrainedConfig", model.config) + self.config: PretrainedConfig = cast("PretrainedConfig", model.config) @classmethod def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> QwenDecoderWrapper: diff --git a/src/winml/modelkit/models/hf/t5.py b/src/winml/modelkit/models/hf/t5.py index 116dc4754..cffd5d4dc 100644 --- a/src/winml/modelkit/models/hf/t5.py +++ b/src/winml/modelkit/models/hf/t5.py @@ -120,7 +120,7 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: # Expose config for OnnxConfig / NormalizedConfig access # model is typed nn.Module, so torch's __getattr__ types .config as # Tensor | Module; it is really the model's PretrainedConfig. - self.config = cast("PretrainedConfig", model.config) + self.config: PretrainedConfig = cast("PretrainedConfig", model.config) @classmethod def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> T5DecoderWrapper: diff --git a/src/winml/modelkit/models/winml/depth_estimation.py b/src/winml/modelkit/models/winml/depth_estimation.py index f57844390..4fd3d17ab 100644 --- a/src/winml/modelkit/models/winml/depth_estimation.py +++ b/src/winml/modelkit/models/winml/depth_estimation.py @@ -53,4 +53,5 @@ def forward(self, **kwargs: Any) -> DepthEstimatorOutput: # transformers' Output fields are annotated FloatTensor (legacy, over-narrow); # the ONNX session returns a real float Tensor. - return DepthEstimatorOutput(predicted_depth=cast("torch.FloatTensor", predicted_depth)) + depth: torch.FloatTensor = cast("torch.FloatTensor", predicted_depth) + return DepthEstimatorOutput(predicted_depth=depth)