diff --git a/src/winml/modelkit/session/session.py b/src/winml/modelkit/session/session.py index 9c12b7c96..0a5b9b991 100644 --- a/src/winml/modelkit/session/session.py +++ b/src/winml/modelkit/session/session.py @@ -423,11 +423,9 @@ def _build_session_options(self, device: str) -> ort.SessionOptions: When ``self._ep`` is set (and not ``"cpu"``), uses ``add_provider_for_devices`` to explicitly bind that EP. - ``"cpu"`` falls through to policy-based selection so ORT handles - CPU-only inference without any EP registration. - When ``self._ep`` is not set, queries ``get_ep_devices()`` to - discover an available EP for the target device type. Falls back to - policy-based selection only as a last resort. + Otherwise uses policy-based selection via DEVICE_POLICY_MAP so ORT + handles provider assignment (including CPU fallback for unsupported + ops) without triggering ``VerifyEachNodeIsAssignedToAnEp`` warnings. Note: Returns a **fresh** SessionOptions when using explicit EP to avoid "already registered" errors from repeated calls. @@ -445,6 +443,12 @@ def _build_session_options(self, device: str) -> ort.SessionOptions: opts = ort.SessionOptions() opts.add_provider_for_devices([matched], self._provider_options) + # Register CPU as explicit fallback so shape-related ops + # have a registered home and VerifyEachNodeIsAssignedToAnEp + # does not warn about unassigned nodes. + cpu_dev = self._find_ep_device(ep_name="CPUExecutionProvider", device="cpu") + if cpu_dev: + opts.add_provider_for_devices([cpu_dev], {}) resolved = DEVICE_TYPE_TO_DEVICE.get( matched.device.type, str(matched.device.type) ) @@ -463,16 +467,7 @@ def _build_session_options(self, device: str) -> ort.SessionOptions: device, ) - # No explicit EP — discover available EP for this device type - if not self._ep and device.lower() != "cpu": - matched = self._find_ep_device(device=device) - if matched: - opts = ort.SessionOptions() - opts.add_provider_for_devices([matched], self._provider_options) - logger.info("Discovered EP for %s: %s", device, matched.ep_name) - return opts - - # Policy-based selection (last resort) + # Policy-based selection (default when no explicit EP) opts = self._session_options policy = DEVICE_POLICY_MAP.get( device.lower(), ort.OrtExecutionProviderDevicePolicy.PREFER_NPU @@ -486,23 +481,13 @@ def _build_session_options(self, device: str) -> ort.SessionOptions: def _find_ep_device(device: str, ep_name: str | None = None) -> Any: """Find the first OrtEpDevice matching the given filters. - Behavior: - - ``ep_name`` set, ``device == "auto"`` → first ep_device - matching ``ep_name`` (or None). - - ``ep_name`` unset, ``device == "auto"`` → ``None`` (no - effective filter — refuse to pick an arbitrary ep_device). - - ``ep_name`` unset, ``device`` is a concrete type → first - ep_device matching that device type (or None). - - Both set → ep_device must satisfy both (or None). - - Note: Selection order is determined by the ORT EP registry, which is - not part of any documented contract. On systems where multiple EPs - match the same device type (e.g., QNN and DML both appear as GPU), - a device-only query returns the first one in registry order. Pass - ``ep_name`` to disambiguate. + Filters are AND'd: both ``ep_name`` and ``device`` (when concrete) + must match. ``device="auto"`` acts as no-op device filter. + Returns ``None`` when no effective filter is provided (i.e. both + ``ep_name`` is None and ``device`` resolves to no device type). Args: - device: Device policy ("cpu", "gpu", "npu", "auto"). ``"auto"`` + device: Device type ("cpu", "gpu", "npu", "auto"). ``"auto"`` and unknown strings act as no-op device filters. ep_name: Full EP name (e.g., "DmlExecutionProvider"), or None to skip EP-name filtering. diff --git a/tests/unit/session/test_winml_session.py b/tests/unit/session/test_winml_session.py index c82e0ddd6..28680457a 100644 --- a/tests/unit/session/test_winml_session.py +++ b/tests/unit/session/test_winml_session.py @@ -719,19 +719,6 @@ def test_ep_name_only(self) -> None: assert match is not None assert match.ep_name == "QNNExecutionProvider" - def test_device_only(self) -> None: - """device filter alone returns first matching device type.""" - import onnxruntime as ort - - devs = [ - self._ep_dev("CPUExecutionProvider", ort.OrtHardwareDeviceType.CPU), - self._ep_dev("DmlExecutionProvider", ort.OrtHardwareDeviceType.GPU), - ] - with self._patch_devices(devs): - match = WinMLSession._find_ep_device(device="gpu") - assert match is not None - assert match.ep_name == "DmlExecutionProvider" - def test_ep_name_and_device_both_required(self) -> None: """When both filters are set, both must match (AND logic).""" import onnxruntime as ort