Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/winml/modelkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import logging
import sys
from importlib.metadata import PackageNotFoundError, version
from typing import Any


# Force utf-8 stdout/stderr so emoji and Unicode output (rich console, logs,
Expand Down Expand Up @@ -98,7 +99,7 @@ def _preload_bundled_onnxruntime_dll() -> None:
}


def __getattr__(name: str):
def __getattr__(name: str) -> Any:
"""Lazy-load heavy exports on first access (PEP 562).

This avoids importing torch/transformers/optimum (~30s) when only
Expand Down
11 changes: 9 additions & 2 deletions src/winml/modelkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@
import logging
from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING

import click


if TYPE_CHECKING:
from rich.console import Console

from . import __version__
from .telemetry import ActionGroup
from .telemetry import telemetry as _telemetry_mod
Expand Down Expand Up @@ -78,7 +83,7 @@ def _gradient_color(t: float) -> tuple[int, int, int]:
return _GRADIENT[-1][1]


def _print_banner(version: str, *, _console: object | None = None) -> None:
def _print_banner(version: str, *, _console: Console | None = None) -> None:
"""Print the WinML CLI gradient banner to stderr using Rich."""
from rich.console import Console # lazy import - keeps startup fast
from rich.text import Text
Expand Down Expand Up @@ -205,7 +210,9 @@ def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None
discovered = attr
return discovered

def resolve_command(self, ctx: click.Context, args: list[str]):
def resolve_command(
self, ctx: click.Context, args: list[str]
) -> tuple[str | None, click.Command | None, list[str]]:
"""Seed ``self.commands`` so Click can emit a did-you-mean hint on typos."""
# Click's NoSuchCommand exception uses self.commands to find suggestions.
for name in self.list_commands(ctx):
Expand Down
4 changes: 2 additions & 2 deletions src/winml/modelkit/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from .hf import MODEL_BUILD_CONFIGS

Expand Down Expand Up @@ -60,7 +60,7 @@
}


def __getattr__(name: str):
def __getattr__(name: str) -> Any:
"""Lazy load modules that would cause circular imports."""
if name in _LAZY_IMPORTS:
module_path, attr_name = _LAZY_IMPORTS[name]
Expand Down
36 changes: 23 additions & 13 deletions src/winml/modelkit/models/hf/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
from __future__ import annotations

import logging
from typing import Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, cast

import torch
import torch.nn as nn
Expand All @@ -93,6 +93,9 @@
from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLStaticCache


if TYPE_CHECKING:
from transformers import GenerationConfig, PretrainedConfig

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'PretrainedConfig' is not used.

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -140,7 +143,7 @@


def _patched_bart_learned_forward(
self,
self: Any, # monkey-patched onto BartLearnedPositionalEmbedding (HF internal)
input_ids: torch.Tensor,
past_key_values_length: int = 0,
position_ids: torch.Tensor | None = None,
Expand Down Expand Up @@ -229,10 +232,14 @@
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Return encoder last hidden state."""
return self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
# self.encoder is a torch submodule (untyped __call__ -> Any).
return cast(
"torch.Tensor",
self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state,
)


class BartDecoderWrapper(nn.Module):
Expand Down Expand Up @@ -262,8 +269,10 @@
super().__init__()
self.model = model
self.num_layers = num_layers
# Expose config for OnnxConfig / NormalizedConfig access
self.config = model.config
# Expose config for OnnxConfig / NormalizedConfig access.
# model is typed nn.Module, so torch's __getattr__ types .config as
# Tensor | Module; it is really the model's PretrainedConfig.
self.config = cast("PretrainedConfig", model.config)

@classmethod
def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> BartDecoderWrapper:
Expand Down Expand Up @@ -358,7 +367,7 @@


@register_onnx_overwrite("bart", "feature-extraction", library_name="transformers")
class BartEncoderIOConfig(OnnxConfig):
class BartEncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped
"""ONNX config for BART encoder (feature-extraction task).

Inputs: input_ids, attention_mask
Expand All @@ -385,7 +394,7 @@
}


class _BartDecoderNormalizedConfig(NormalizedConfig):
class _BartDecoderNormalizedConfig(NormalizedConfig): # type: ignore[misc] # optimum base is untyped
"""NormalizedConfig for BART decoder-side export.

Maps NormalizedConfig attributes to BartConfig's decoder-side attrs.
Expand All @@ -400,11 +409,12 @@

@property
def head_dim(self) -> int:
return self.hidden_size // self.num_attention_heads
# hidden_size / num_attention_heads come from the untyped NormalizedConfig base.
return cast("int", self.hidden_size // self.num_attention_heads)


@register_onnx_overwrite("bart", "text2text-generation", library_name="transformers")
class BartDecoderIOConfig(OnnxConfig):
class BartDecoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped
"""ONNX config for BART decoder with sliding-window KV cache.

Inputs: decoder_input_ids, encoder_hidden_states, attention_mask,
Expand Down Expand Up @@ -517,7 +527,7 @@
return WinMLStaticCache # static cache (index_put_ → ScatterND)

@property
def generation_config(self): # noqa: D102
def generation_config(self) -> GenerationConfig: # noqa: D102
if not hasattr(self, "_generation_config"):
from transformers import GenerationConfig

Expand Down
2 changes: 1 addition & 1 deletion src/winml/modelkit/models/hf/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@


@register_onnx_overwrite("bert", *COMMON_TEXT_TASKS, library_name="transformers")
class BertIOConfig(BertOnnxConfig):
class BertIOConfig(BertOnnxConfig): # type: ignore[misc] # optimum base is untyped
"""BERT OnnxConfig using max_position_embeddings as sequence_length.

Inputs:
Expand Down
4 changes: 2 additions & 2 deletions src/winml/modelkit/models/hf/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@

@register_onnx_overwrite("blip", "image-to-text", library_name="transformers")
@register_onnx_overwrite("blip", "image-text-to-text", library_name="transformers")
class BlipCaptioningIOConfig(OnnxConfig):
class BlipCaptioningIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped
"""Monolithic ONNX config for BLIP captioning — single-graph fallback.

Traces ``BlipForConditionalGeneration.forward`` with pixel_values +
Expand Down Expand Up @@ -148,7 +148,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:


@register_onnx_overwrite("blip", "feature-extraction", library_name="transformers")
class BlipVisionEncoderIOConfig(OnnxConfig):
class BlipVisionEncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped
"""ONNX config for the BLIP vision encoder.

``image-feature-extraction`` is a synonym that Optimum's TasksManager
Expand Down
4 changes: 2 additions & 2 deletions src/winml/modelkit/models/hf/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
# Optimum ONNX Export Config Registrations
# =============================================================================
@register_onnx_overwrite("clip_text_model", "feature-extraction", library_name="transformers")
class CLIPTextModelIOConfig(CLIPTextWithProjectionOnnxConfig):
class CLIPTextModelIOConfig(CLIPTextWithProjectionOnnxConfig): # type: ignore[misc] # optimum base is untyped
"""ONNX config for CLIPTextModelWithProjection from transformers.

Model: openai/clip-vit-base-patch32 (text encoder only)
Expand Down Expand Up @@ -108,7 +108,7 @@ def inputs(self) -> dict[str, dict[int, str]]:


@register_onnx_overwrite("clip_vision_model", "feature-extraction", library_name="transformers")
class CLIPVisionModelIOConfig(CLIPVisionModelOnnxConfig):
class CLIPVisionModelIOConfig(CLIPVisionModelOnnxConfig): # type: ignore[misc] # optimum base is untyped
"""ONNX config for CLIPVisionModelWithProjection from transformers.

Model: openai/clip-vit-base-patch32 (vision encoder only)
Expand Down
2 changes: 1 addition & 1 deletion src/winml/modelkit/models/hf/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _build_patching_specs() -> list[PatchingSpec]:
"image-classification",
library_name="transformers",
)
class ConvNextIOConfig(ConvNextOnnxConfig):
class ConvNextIOConfig(ConvNextOnnxConfig): # type: ignore[misc] # optimum base is untyped
"""ConvNextOnnxConfig override that adds a LayerNorm fusion patch.

Inherits all I/O specs from Optimum's ``ConvNextOnnxConfig``. The only
Expand Down
2 changes: 1 addition & 1 deletion src/winml/modelkit/models/hf/decoder_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _invoke_hf(
"""Call the HF decoder with ``past_key_values=<cache>``. Returns logits."""


class WinMLStaticCacheDecoderIOConfig(OnnxConfig):
class WinMLStaticCacheDecoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum/transformers base is untyped
"""Semantic-name contract used by ``WinMLDecoderWrapper._make_cache``.

Subclasses declare their own ``inputs`` / ``outputs`` bodies (each
Expand Down
4 changes: 2 additions & 2 deletions src/winml/modelkit/models/hf/depth_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ...export import register_onnx_overwrite


class _DepthAnythingVisionInputGenerator(DummyVisionInputGenerator):
class _DepthAnythingVisionInputGenerator(DummyVisionInputGenerator): # type: ignore[misc] # optimum/transformers base is untyped
"""Vision input generator that lets explicit height/width override config.image_size.

Optimum's DummyVisionInputGenerator prioritizes normalized_config.image_size
Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(


@register_onnx_overwrite("depth_anything", "depth-estimation", library_name="transformers")
class DepthAnythingIOConfig(OnnxConfig):
class DepthAnythingIOConfig(OnnxConfig): # type: ignore[misc] # optimum/transformers base is untyped
"""ONNX config for Depth Anything depth estimation.

Model: depth-anything/Depth-Anything-V2-Small-hf
Expand Down
4 changes: 2 additions & 2 deletions src/winml/modelkit/models/hf/depth_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ...export import register_onnx_overwrite


class _DepthProNormalizedConfig(NormalizedConfig):
class _DepthProNormalizedConfig(NormalizedConfig): # type: ignore[misc] # optimum base is untyped
"""Normalized config for DepthPro with computed image_size.

image_size is derived from patch_size / min(scaled_images_ratios),
Expand All @@ -47,7 +47,7 @@ def image_size(self) -> int:


@register_onnx_overwrite("depth_pro", "depth-estimation", library_name="transformers")
class DepthProIOConfig(OnnxConfig):
class DepthProIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped
"""ONNX config for DepthPro depth estimation.

Model: apple/DepthPro-hf
Expand Down
34 changes: 22 additions & 12 deletions src/winml/modelkit/models/hf/marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
from __future__ import annotations

import logging
from typing import Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, cast

import torch
import torch.nn as nn
Expand All @@ -107,6 +107,9 @@
from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLStaticCache


if TYPE_CHECKING:
from transformers import GenerationConfig, PretrainedConfig

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'PretrainedConfig' is not used.

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -177,7 +180,7 @@


def _patched_marian_sinusoidal_forward(
self,
self: Any, # monkey-patched onto MarianSinusoidalPositionalEmbedding (HF internal)
input_ids_shape: torch.Size,
past_key_values_length: int = 0,
position_ids: torch.Tensor | None = None,
Expand Down Expand Up @@ -262,10 +265,14 @@
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Return encoder last hidden state."""
return self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
# self.encoder is a torch submodule (untyped __call__ -> Any).
return cast(
"torch.Tensor",
self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state,
)


class MarianDecoderWrapper(nn.Module):
Expand Down Expand Up @@ -301,7 +308,9 @@
self.model = model
self.num_layers = num_layers
# Expose config for OnnxConfig / NormalizedConfig access
self.config = model.config
# model is typed nn.Module, so torch's __getattr__ types .config as
# Tensor | Module; it is really the model's PretrainedConfig.
self.config = cast("PretrainedConfig", model.config)

@classmethod
def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> MarianDecoderWrapper:
Expand Down Expand Up @@ -398,7 +407,7 @@


@register_onnx_overwrite("marian", "feature-extraction", library_name="transformers")
class MarianEncoderIOConfig(OnnxConfig):
class MarianEncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped
"""ONNX config for Marian encoder (feature-extraction task).

Inputs: input_ids, attention_mask
Expand All @@ -425,7 +434,7 @@
}


class _MarianDecoderNormalizedConfig(NormalizedConfig):
class _MarianDecoderNormalizedConfig(NormalizedConfig): # type: ignore[misc] # optimum base is untyped
"""NormalizedConfig for Marian decoder-side export.

Maps NormalizedConfig attributes to MarianConfig's decoder-side attrs.
Expand All @@ -440,11 +449,12 @@

@property
def head_dim(self) -> int:
return self.hidden_size // self.num_attention_heads
# hidden_size / num_attention_heads come from the untyped NormalizedConfig base.
return cast("int", self.hidden_size // self.num_attention_heads)


@register_onnx_overwrite("marian", "text2text-generation", library_name="transformers")
class MarianDecoderIOConfig(OnnxConfig):
class MarianDecoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped
"""ONNX config for Marian decoder with sliding-window KV cache.

Inputs: decoder_input_ids, encoder_hidden_states, attention_mask,
Expand Down Expand Up @@ -554,7 +564,7 @@
return WinMLStaticCache # static cache (index_put_ → ScatterND)

@property
def generation_config(self): # noqa: D102
def generation_config(self) -> GenerationConfig: # noqa: D102
if not hasattr(self, "_generation_config"):
from transformers import GenerationConfig

Expand Down
4 changes: 2 additions & 2 deletions src/winml/modelkit/models/hf/mu2.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]:


@register_onnx_overwrite("mu2", "feature-extraction", library_name="transformers")
class Mu2EncoderIOConfig(OnnxConfig):
class Mu2EncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped
"""ONNX config for Mu2 encoder (feature-extraction task)."""

NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
Expand All @@ -218,7 +218,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102


@register_onnx_overwrite("mu2", "text2text-generation", library_name="transformers")
class Mu2DecoderIOConfig(OnnxConfig):
class Mu2DecoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped
"""ONNX config for Mu2 decoder with static KV cache."""

NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
Expand Down
Loading
Loading