Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/winml/modelkit/analyze/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,7 @@ def _build_runtime_debug_details_summary(
level_bucket[node_stable_key] = candidate_entry
continue

if (
existing_entry.case_indices is None
and candidate_entry.case_indices is not None
):
if existing_entry.case_indices is None and candidate_entry.case_indices is not None:
existing_entry.case_indices = candidate_entry.case_indices

if existing_entry.table_path is None and candidate_entry.table_path is not None:
Expand Down Expand Up @@ -798,7 +795,7 @@ def analyze_from_proto(
if device is not None and device.lower() == "auto":
from ..sysinfo import resolve_device

resolved, _ = resolve_device("auto")
resolved, _ = resolve_device("auto", ep=ep_normalized)
device_to_use = resolved.upper()
logger.info("Device 'auto' resolved to: %s", device_to_use)
else:
Expand Down
26 changes: 15 additions & 11 deletions src/winml/modelkit/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,20 +545,19 @@ def build(
# hardware -- analyzing for the wrong EP leaves black nodes that block a
# later build targeting the actual device (#663).
#
# resolve_device() either returns a device with >=1 available EP (auto-mode
# walks the priority list, falls back to cpu which is always valid), or
# raises ValueError for an explicit device with no compatible EP. So the
# following resolve_eps()[0] is safe whenever resolve_device returns.
# resolve_check_device_ep() either returns a device with >=1 available EP
# (auto-mode walks the priority list, falls back to cpu which is always
# valid), or raises ValueError for an explicit device with no compatible EP.
# So the following available_eps[0] is safe whenever it returns.
if ep is None:
from ..sysinfo import resolve_device as _resolve_device
from ..sysinfo import resolve_eps as _resolve_eps
from ..sysinfo import resolve_check_device_ep

try:
resolved_device, _ = _resolve_device(device=device)
resolved_device, _, available_eps = resolve_check_device_ep(device=device, ep=ep)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit: inside the 'if ep is None:' block, passing ep=ep (which is provably None here) is technically correct but slightly opaque. Writing ep=None explicitly would remove any ambiguity.

except ValueError as e:
raise click.UsageError(str(e)) from e
device = resolved_device
ep = _resolve_eps(resolved_device)[0]
ep = available_eps[0]
logger.info("Auto-resolved device=%s, EP=%s", resolved_device, ep)

try:
Expand All @@ -579,6 +578,7 @@ def build(
trust_remote_code=trust_remote_code,
device=device,
precision=precision,
ep=ep,
)
if not quant:
config_or_configs.quant = None
Expand Down Expand Up @@ -1012,11 +1012,15 @@ def _on_iteration_start(iteration: int, max_iter: int) -> None:
_header_shown[0] = False

# Resolve "auto" to a concrete device once so that has_rule_data_for_ep
# doesn't search for non-existent "*_AUTO_*.parquet" files.
# doesn't search for non-existent "*_AUTO_*.parquet" files. Use
# resolve_check_device_ep so an explicit device+ep is validated
# statically (no availability cross-check): a --no-compile build may
# target a device absent on this machine (cross-compile), and this call
# only needs a concrete device name for the rule-data lookup.
from ..analyze.utils.ep_utils import has_rule_data_for_ep
from ..sysinfo import resolve_device as _resolve_device
from ..sysinfo import resolve_check_device_ep

_resolved_device, _ = _resolve_device(device=device or "auto", ep=ep)
_resolved_device, _, _ = resolve_check_device_ep(device=device or "auto", ep=ep)

def _on_ep_start(ep_name: EPName, operator_counts: dict) -> None:
nonlocal _current_ep
Expand Down
2 changes: 1 addition & 1 deletion src/winml/modelkit/commands/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def _resolve_device(cfg: WinMLEvaluationConfig) -> None:

console = Console(stderr=True)
console.print("[bold]Detecting available devices...[/bold]")
resolved, _ = resolve_device(cfg.device)
resolved, _ = resolve_device(cfg.device, ep=cfg.ep)
cfg.device = resolved
console.print(f"[dim]Using device:[/dim] {resolved}")

Expand Down
1 change: 1 addition & 0 deletions src/winml/modelkit/compiler/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class CompileContext:

# Input
model_path: Path
# From WinMLCompileConfig.to_dict()
config: dict[str, Any]
model: onnx.ModelProto | None = None

Expand Down
4 changes: 3 additions & 1 deletion src/winml/modelkit/compiler/stages/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def _compile_multiple(self, context: CompileContext) -> None:
sess_options = context.shared_session_options
if sess_options is None:
register_execution_providers(ort=True)
resolved_device, _ = resolve_device(context.config.get("device", "auto"))
resolved_device, _ = resolve_device(
context.get_config("device", "auto"), ep=context.execution_provider
)
ep = normalize_ep_name(ep_config.provider) or resolve_eps(resolved_device)[0]
device_type = DEVICE_TO_DEVICE_TYPE.get(resolved_device.upper())

Expand Down
6 changes: 3 additions & 3 deletions src/winml/modelkit/sysinfo/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def _get_available_eps() -> frozenset[EPName]:


def resolve_device(
device: str = "auto",
device: str,
*,
ep: EPNameOrAlias | None = None,
ep: EPNameOrAlias | None,
) -> tuple[str, list[str]]:
"""Resolve target device with EP availability cross-check.

Expand Down Expand Up @@ -233,7 +233,7 @@ def resolve_eps(resolved_device: str) -> list[EPName]:


def resolve_check_device_ep(
*, device: str = "auto", ep: EPNameOrAlias | None = None
*, device: str, ep: EPNameOrAlias | None

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says 'Ideal for commands that do not need the device + ep actually exists on the system.' But static validation (no availability probe) only applies when BOTH device is explicit (not auto) AND ep is non-None. When either is absent the function delegates to resolve_device which does a full availability cross-check. A caller passing --device npu without --ep on a machine without NPU will still get a ValueError. The docstring should clarify this for users relying on this for cross-compile.

) -> tuple[str, list[str], list[EPName]]:
"""Resolve or check that the requested device and/or EP combination is valid, raising if not.

Expand Down
59 changes: 44 additions & 15 deletions tests/unit/commands/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,25 +1212,25 @@ def test_resolve_device_value_error_surfaces_as_usage_error(
mock_build_api: MagicMock,
tmp_path: Path,
) -> None:
"""resolve_device raising (explicit device w/ no compatible EP) -> UsageError.
"""resolve_check_device_ep raising (explicit device w/ no compatible EP) -> UsageError.

Uses default ``--device auto`` (no CLI flag) so the downstream
device-patch path isn't triggered; the only resolve_device call is
the one inside the auto-select block.
device-patch path isn't triggered; the only resolution call is the
``resolve_check_device_ep`` inside the auto-select block.
"""
from winml.modelkit.commands.build import build

with patch(
"winml.modelkit.sysinfo.resolve_device",
side_effect=ValueError("simulated resolve_device failure"),
"winml.modelkit.sysinfo.resolve_check_device_ep",
side_effect=ValueError("simulated resolve failure"),
):
result = runner.invoke(
build,
["-c", str(sample_config_file), "-m", "test", "-o", str(tmp_path)],
obj={"debug": False},
)
assert result.exit_code != 0
assert "simulated resolve_device failure" in result.output
assert "simulated resolve failure" in result.output
mock_build_api.assert_not_called()

def test_auto_selection_respects_resolve_eps_priority(
Expand All @@ -1240,17 +1240,15 @@ def test_auto_selection_respects_resolve_eps_priority(
mock_build_api: MagicMock,
tmp_path: Path,
) -> None:
"""First element of resolve_eps(resolved_device) is selected, not later ones."""
"""First element of resolve_check_device_ep's available_eps is selected."""
from winml.modelkit.commands.build import build

with (
patch(
"winml.modelkit.sysinfo.resolve_device",
return_value=("gpu", ["gpu", "cpu"]),
),
patch(
"winml.modelkit.sysinfo.resolve_eps",
return_value=["DmlExecutionProvider", "OpenVINOExecutionProvider"],
with patch(
"winml.modelkit.sysinfo.resolve_check_device_ep",
return_value=(
"gpu",
["gpu", "cpu"],
["DmlExecutionProvider", "OpenVINOExecutionProvider"],
),
):
result = runner.invoke(
Expand Down Expand Up @@ -1711,3 +1709,34 @@ def test_returns_compiled_path_when_file_exists(

# current_path should be updated to compiled_path
assert result == compiled_path


class TestBuildEpResolution:
"""--ep forwarding into config generation + the compile EP-availability gate."""

def _base_args(self, cfg: str, tmp_path: Path) -> list[str]:
return ["-c", cfg, "-m", "microsoft/resnet-50", "-o", str(tmp_path / "out")]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_base_args is never called by any test in this class -- dead helper. If intended for future tests add a TODO comment; otherwise remove it. The args it builds also mix -c and -m which would not match the auto-config path the class is testing.


def test_ep_forwarded_to_generate_build_config(
self, tmp_path: Path, mock_run_single_build: MagicMock
):
"""On the auto-config path (-m, no -c), --ep reaches generate_build_config.

Regression: the build command dropped --ep when auto-generating a config,
so the requested EP never influenced the generated config (it failed or
analyzed/compiled for the wrong EP).
"""
fake_cfg = MagicMock()
fake_cfg.compile = None # no compile -> EP-availability gate is skipped
with (
patch("winml.modelkit.config.generate_build_config", return_value=fake_cfg) as mock_gen,
patch(
"winml.modelkit.commands.build._validate_loader_tasks_for_model",
return_value=None,
),
):
result = _invoke(
["-m", "microsoft/resnet-50", "--ep", "openvino", "-o", str(tmp_path / "out")]
)
assert result.exit_code == 0, result.output
assert mock_gen.call_args.kwargs["ep"] == "openvino"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing regression test for the cross-compile static path: the PR's headline feature (--device npu --ep openvino --no-compile completing on a machine without NPU) is verified empirically but has no unit test. A test that mocks resolve_check_device_ep to return a static result and confirms the optimize stage proceeds would guard against future regressions (e.g. accidentally restoring resolve_device in the optimize path).

18 changes: 9 additions & 9 deletions tests/unit/sysinfo/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_no_npu_no_gpu(self) -> None:
def test_returns_empty_when_enumeration_fails(self) -> None:
"""If EP enumeration raises, return empty tuple (no devices visible).

``resolve_device("auto")`` is responsible for the CPU fallback when no
``resolve_device("auto", ep=None)`` is responsible for the CPU fallback when no
devices are reachable; ``_get_available_devices`` only reports what is
actually registered.
"""
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_resolve_device_auto_npu_with_ep(self) -> None:
"cpu": ("CPUExecutionProvider",),
}
):
device, available = resolve_device("auto")
device, available = resolve_device("auto", ep=None)

assert device == "npu"
assert available == ["npu", "gpu", "cpu"]
Expand All @@ -227,15 +227,15 @@ def test_resolve_device_auto_npu_without_ep(self) -> None:
"cpu": ("CPUExecutionProvider",),
}
):
device, available = resolve_device("auto")
device, available = resolve_device("auto", ep=None)

assert device == "gpu"
assert available == ["gpu", "cpu"]

def test_resolve_device_auto_cpu_fallback(self) -> None:
"""Auto mode: only CPU EP registered -> returns "cpu"."""
with _patch_device_ep_map({"cpu": ("CPUExecutionProvider",)}):
device, available = resolve_device("auto")
device, available = resolve_device("auto", ep=None)

assert device == "cpu"
assert available == ["cpu"]
Expand All @@ -248,15 +248,15 @@ def test_resolve_device_explicit_valid(self) -> None:
"cpu": ("CPUExecutionProvider",),
}
):
device, available = resolve_device("gpu")
device, available = resolve_device("gpu", ep=None)

assert device == "gpu"
assert available == ["gpu", "cpu"]

def test_resolve_device_explicit_invalid(self) -> None:
"""Unrecognized device "tpu" -> raises ValueError."""
with pytest.raises(ValueError, match="Unknown device 'tpu'"):
resolve_device("tpu")
resolve_device("tpu", ep=None)

def test_resolve_device_explicit_no_ep_error_names_missing_eps(self) -> None:
"""Error message must name the compatible EPs so users know what to install."""
Expand All @@ -268,7 +268,7 @@ def test_resolve_device_explicit_no_ep_error_names_missing_eps(self) -> None:
),
pytest.raises(ValueError) as exc_info,
):
resolve_device("npu")
resolve_device("npu", ep=None)

message = str(exc_info.value)
assert "no compatible EP" in message
Expand All @@ -278,7 +278,7 @@ def test_resolve_device_explicit_no_ep_error_names_missing_eps(self) -> None:
def test_resolve_device_case_insensitive(self) -> None:
"""Device argument should be case-insensitive."""
with _patch_device_ep_map({"cpu": ("CPUExecutionProvider",)}):
device, _ = resolve_device("CPU")
device, _ = resolve_device("CPU", ep=None)

assert device == "cpu"

Expand All @@ -293,7 +293,7 @@ def test_resolve_device_no_eps_raises(self) -> None:
_patch_device_ep_map({}),
pytest.raises(RuntimeError, match="No execution providers detected"),
):
resolve_device("auto")
resolve_device("auto", ep=None)


class TestResolveDeviceWithEp:
Expand Down
Loading