diff --git a/src/winml/modelkit/analyze/analyzer.py b/src/winml/modelkit/analyze/analyzer.py index 60d8dac01..890f98540 100644 --- a/src/winml/modelkit/analyze/analyzer.py +++ b/src/winml/modelkit/analyze/analyzer.py @@ -140,10 +140,7 @@ def _build_runtime_debug_details_summary( level_bucket[node_stable_key] = candidate_entry continue - if ( - existing_entry.case_indices is None - and candidate_entry.case_indices is not None - ): + if existing_entry.case_indices is None and candidate_entry.case_indices is not None: existing_entry.case_indices = candidate_entry.case_indices if existing_entry.table_path is None and candidate_entry.table_path is not None: @@ -798,7 +795,7 @@ def analyze_from_proto( if device is not None and device.lower() == "auto": from ..sysinfo import resolve_device - resolved, _ = resolve_device("auto") + resolved, _ = resolve_device("auto", ep=ep_normalized) device_to_use = resolved.upper() logger.info("Device 'auto' resolved to: %s", device_to_use) else: diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index c3ffc660d..15ae654b5 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -545,20 +545,19 @@ def build( # hardware -- analyzing for the wrong EP leaves black nodes that block a # later build targeting the actual device (#663). # - # resolve_device() either returns a device with >=1 available EP (auto-mode - # walks the priority list, falls back to cpu which is always valid), or - # raises ValueError for an explicit device with no compatible EP. So the - # following resolve_eps()[0] is safe whenever resolve_device returns. + # resolve_check_device_ep() either returns a device with >=1 available EP + # (auto-mode walks the priority list, falls back to cpu which is always + # valid), or raises ValueError for an explicit device with no compatible EP. + # So the following available_eps[0] is safe whenever it returns. if ep is None: - from ..sysinfo import resolve_device as _resolve_device - from ..sysinfo import resolve_eps as _resolve_eps + from ..sysinfo import resolve_check_device_ep try: - resolved_device, _ = _resolve_device(device=device) + resolved_device, _, available_eps = resolve_check_device_ep(device=device, ep=ep) except ValueError as e: raise click.UsageError(str(e)) from e device = resolved_device - ep = _resolve_eps(resolved_device)[0] + ep = available_eps[0] logger.info("Auto-resolved device=%s, EP=%s", resolved_device, ep) try: @@ -579,6 +578,7 @@ def build( trust_remote_code=trust_remote_code, device=device, precision=precision, + ep=ep, ) if not quant: config_or_configs.quant = None @@ -1012,11 +1012,15 @@ def _on_iteration_start(iteration: int, max_iter: int) -> None: _header_shown[0] = False # Resolve "auto" to a concrete device once so that has_rule_data_for_ep - # doesn't search for non-existent "*_AUTO_*.parquet" files. + # doesn't search for non-existent "*_AUTO_*.parquet" files. Use + # resolve_check_device_ep so an explicit device+ep is validated + # statically (no availability cross-check): a --no-compile build may + # target a device absent on this machine (cross-compile), and this call + # only needs a concrete device name for the rule-data lookup. from ..analyze.utils.ep_utils import has_rule_data_for_ep - from ..sysinfo import resolve_device as _resolve_device + from ..sysinfo import resolve_check_device_ep - _resolved_device, _ = _resolve_device(device=device or "auto", ep=ep) + _resolved_device, _, _ = resolve_check_device_ep(device=device or "auto", ep=ep) def _on_ep_start(ep_name: EPName, operator_counts: dict) -> None: nonlocal _current_ep diff --git a/src/winml/modelkit/commands/eval.py b/src/winml/modelkit/commands/eval.py index 414153b09..02779f270 100644 --- a/src/winml/modelkit/commands/eval.py +++ b/src/winml/modelkit/commands/eval.py @@ -364,7 +364,7 @@ def _resolve_device(cfg: WinMLEvaluationConfig) -> None: console = Console(stderr=True) console.print("[bold]Detecting available devices...[/bold]") - resolved, _ = resolve_device(cfg.device) + resolved, _ = resolve_device(cfg.device, ep=cfg.ep) cfg.device = resolved console.print(f"[dim]Using device:[/dim] {resolved}") diff --git a/src/winml/modelkit/compiler/context.py b/src/winml/modelkit/compiler/context.py index fda351c50..b4e484c40 100644 --- a/src/winml/modelkit/compiler/context.py +++ b/src/winml/modelkit/compiler/context.py @@ -37,6 +37,7 @@ class CompileContext: # Input model_path: Path + # From WinMLCompileConfig.to_dict() config: dict[str, Any] model: onnx.ModelProto | None = None diff --git a/src/winml/modelkit/compiler/stages/compile.py b/src/winml/modelkit/compiler/stages/compile.py index 4bc1c28c4..9274e1c7e 100644 --- a/src/winml/modelkit/compiler/stages/compile.py +++ b/src/winml/modelkit/compiler/stages/compile.py @@ -176,7 +176,9 @@ def _compile_multiple(self, context: CompileContext) -> None: sess_options = context.shared_session_options if sess_options is None: register_execution_providers(ort=True) - resolved_device, _ = resolve_device(context.config.get("device", "auto")) + resolved_device, _ = resolve_device( + context.get_config("device", "auto"), ep=context.execution_provider + ) ep = normalize_ep_name(ep_config.provider) or resolve_eps(resolved_device)[0] device_type = DEVICE_TO_DEVICE_TYPE.get(resolved_device.upper()) diff --git a/src/winml/modelkit/sysinfo/device.py b/src/winml/modelkit/sysinfo/device.py index 6f6fdb5b9..05e714102 100644 --- a/src/winml/modelkit/sysinfo/device.py +++ b/src/winml/modelkit/sysinfo/device.py @@ -145,9 +145,9 @@ def _get_available_eps() -> frozenset[EPName]: def resolve_device( - device: str = "auto", + device: str, *, - ep: EPNameOrAlias | None = None, + ep: EPNameOrAlias | None, ) -> tuple[str, list[str]]: """Resolve target device with EP availability cross-check. @@ -233,7 +233,7 @@ def resolve_eps(resolved_device: str) -> list[EPName]: def resolve_check_device_ep( - *, device: str = "auto", ep: EPNameOrAlias | None = None + *, device: str, ep: EPNameOrAlias | None ) -> tuple[str, list[str], list[EPName]]: """Resolve or check that the requested device and/or EP combination is valid, raising if not. diff --git a/tests/unit/commands/test_build.py b/tests/unit/commands/test_build.py index 00f54fc23..1f0c40654 100644 --- a/tests/unit/commands/test_build.py +++ b/tests/unit/commands/test_build.py @@ -1212,17 +1212,17 @@ def test_resolve_device_value_error_surfaces_as_usage_error( mock_build_api: MagicMock, tmp_path: Path, ) -> None: - """resolve_device raising (explicit device w/ no compatible EP) -> UsageError. + """resolve_check_device_ep raising (explicit device w/ no compatible EP) -> UsageError. Uses default ``--device auto`` (no CLI flag) so the downstream - device-patch path isn't triggered; the only resolve_device call is - the one inside the auto-select block. + device-patch path isn't triggered; the only resolution call is the + ``resolve_check_device_ep`` inside the auto-select block. """ from winml.modelkit.commands.build import build with patch( - "winml.modelkit.sysinfo.resolve_device", - side_effect=ValueError("simulated resolve_device failure"), + "winml.modelkit.sysinfo.resolve_check_device_ep", + side_effect=ValueError("simulated resolve failure"), ): result = runner.invoke( build, @@ -1230,7 +1230,7 @@ def test_resolve_device_value_error_surfaces_as_usage_error( obj={"debug": False}, ) assert result.exit_code != 0 - assert "simulated resolve_device failure" in result.output + assert "simulated resolve failure" in result.output mock_build_api.assert_not_called() def test_auto_selection_respects_resolve_eps_priority( @@ -1240,17 +1240,15 @@ def test_auto_selection_respects_resolve_eps_priority( mock_build_api: MagicMock, tmp_path: Path, ) -> None: - """First element of resolve_eps(resolved_device) is selected, not later ones.""" + """First element of resolve_check_device_ep's available_eps is selected.""" from winml.modelkit.commands.build import build - with ( - patch( - "winml.modelkit.sysinfo.resolve_device", - return_value=("gpu", ["gpu", "cpu"]), - ), - patch( - "winml.modelkit.sysinfo.resolve_eps", - return_value=["DmlExecutionProvider", "OpenVINOExecutionProvider"], + with patch( + "winml.modelkit.sysinfo.resolve_check_device_ep", + return_value=( + "gpu", + ["gpu", "cpu"], + ["DmlExecutionProvider", "OpenVINOExecutionProvider"], ), ): result = runner.invoke( @@ -1711,3 +1709,34 @@ def test_returns_compiled_path_when_file_exists( # current_path should be updated to compiled_path assert result == compiled_path + + +class TestBuildEpResolution: + """--ep forwarding into config generation + the compile EP-availability gate.""" + + def _base_args(self, cfg: str, tmp_path: Path) -> list[str]: + return ["-c", cfg, "-m", "microsoft/resnet-50", "-o", str(tmp_path / "out")] + + def test_ep_forwarded_to_generate_build_config( + self, tmp_path: Path, mock_run_single_build: MagicMock + ): + """On the auto-config path (-m, no -c), --ep reaches generate_build_config. + + Regression: the build command dropped --ep when auto-generating a config, + so the requested EP never influenced the generated config (it failed or + analyzed/compiled for the wrong EP). + """ + fake_cfg = MagicMock() + fake_cfg.compile = None # no compile -> EP-availability gate is skipped + with ( + patch("winml.modelkit.config.generate_build_config", return_value=fake_cfg) as mock_gen, + patch( + "winml.modelkit.commands.build._validate_loader_tasks_for_model", + return_value=None, + ), + ): + result = _invoke( + ["-m", "microsoft/resnet-50", "--ep", "openvino", "-o", str(tmp_path / "out")] + ) + assert result.exit_code == 0, result.output + assert mock_gen.call_args.kwargs["ep"] == "openvino" diff --git a/tests/unit/sysinfo/test_device.py b/tests/unit/sysinfo/test_device.py index aa730ea75..5062abd63 100644 --- a/tests/unit/sysinfo/test_device.py +++ b/tests/unit/sysinfo/test_device.py @@ -83,7 +83,7 @@ def test_no_npu_no_gpu(self) -> None: def test_returns_empty_when_enumeration_fails(self) -> None: """If EP enumeration raises, return empty tuple (no devices visible). - ``resolve_device("auto")`` is responsible for the CPU fallback when no + ``resolve_device("auto", ep=None)`` is responsible for the CPU fallback when no devices are reachable; ``_get_available_devices`` only reports what is actually registered. """ @@ -214,7 +214,7 @@ def test_resolve_device_auto_npu_with_ep(self) -> None: "cpu": ("CPUExecutionProvider",), } ): - device, available = resolve_device("auto") + device, available = resolve_device("auto", ep=None) assert device == "npu" assert available == ["npu", "gpu", "cpu"] @@ -227,7 +227,7 @@ def test_resolve_device_auto_npu_without_ep(self) -> None: "cpu": ("CPUExecutionProvider",), } ): - device, available = resolve_device("auto") + device, available = resolve_device("auto", ep=None) assert device == "gpu" assert available == ["gpu", "cpu"] @@ -235,7 +235,7 @@ def test_resolve_device_auto_npu_without_ep(self) -> None: def test_resolve_device_auto_cpu_fallback(self) -> None: """Auto mode: only CPU EP registered -> returns "cpu".""" with _patch_device_ep_map({"cpu": ("CPUExecutionProvider",)}): - device, available = resolve_device("auto") + device, available = resolve_device("auto", ep=None) assert device == "cpu" assert available == ["cpu"] @@ -248,7 +248,7 @@ def test_resolve_device_explicit_valid(self) -> None: "cpu": ("CPUExecutionProvider",), } ): - device, available = resolve_device("gpu") + device, available = resolve_device("gpu", ep=None) assert device == "gpu" assert available == ["gpu", "cpu"] @@ -256,7 +256,7 @@ def test_resolve_device_explicit_valid(self) -> None: def test_resolve_device_explicit_invalid(self) -> None: """Unrecognized device "tpu" -> raises ValueError.""" with pytest.raises(ValueError, match="Unknown device 'tpu'"): - resolve_device("tpu") + resolve_device("tpu", ep=None) def test_resolve_device_explicit_no_ep_error_names_missing_eps(self) -> None: """Error message must name the compatible EPs so users know what to install.""" @@ -268,7 +268,7 @@ def test_resolve_device_explicit_no_ep_error_names_missing_eps(self) -> None: ), pytest.raises(ValueError) as exc_info, ): - resolve_device("npu") + resolve_device("npu", ep=None) message = str(exc_info.value) assert "no compatible EP" in message @@ -278,7 +278,7 @@ def test_resolve_device_explicit_no_ep_error_names_missing_eps(self) -> None: def test_resolve_device_case_insensitive(self) -> None: """Device argument should be case-insensitive.""" with _patch_device_ep_map({"cpu": ("CPUExecutionProvider",)}): - device, _ = resolve_device("CPU") + device, _ = resolve_device("CPU", ep=None) assert device == "cpu" @@ -293,7 +293,7 @@ def test_resolve_device_no_eps_raises(self) -> None: _patch_device_ep_map({}), pytest.raises(RuntimeError, match="No execution providers detected"), ): - resolve_device("auto") + resolve_device("auto", ep=None) class TestResolveDeviceWithEp: