Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 15 additions & 30 deletions src/winml/modelkit/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,9 @@ def _build_session_options(self, device: str) -> ort.SessionOptions:

When ``self._ep`` is set (and not ``"cpu"``), uses
``add_provider_for_devices`` to explicitly bind that EP.
``"cpu"`` falls through to policy-based selection so ORT handles
CPU-only inference without any EP registration.
When ``self._ep`` is not set, queries ``get_ep_devices()`` to
discover an available EP for the target device type. Falls back to
policy-based selection only as a last resort.
Otherwise uses policy-based selection via DEVICE_POLICY_MAP so ORT
handles provider assignment (including CPU fallback for unsupported
ops) without triggering ``VerifyEachNodeIsAssignedToAnEp`` warnings.

Note: Returns a **fresh** SessionOptions when using explicit EP to
avoid "already registered" errors from repeated calls.
Expand All @@ -445,6 +443,12 @@ def _build_session_options(self, device: str) -> ort.SessionOptions:

opts = ort.SessionOptions()
opts.add_provider_for_devices([matched], self._provider_options)
# Register CPU as explicit fallback so shape-related ops
# have a registered home and VerifyEachNodeIsAssignedToAnEp
# does not warn about unassigned nodes.
cpu_dev = self._find_ep_device(ep_name="CPUExecutionProvider", device="cpu")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe only change this line only

if cpu_dev:
opts.add_provider_for_devices([cpu_dev], {})
resolved = DEVICE_TYPE_TO_DEVICE.get(
matched.device.type, str(matched.device.type)
)
Expand All @@ -463,16 +467,7 @@ def _build_session_options(self, device: str) -> ort.SessionOptions:
device,
)

# No explicit EP — discover available EP for this device type
if not self._ep and device.lower() != "cpu":
matched = self._find_ep_device(device=device)
if matched:
opts = ort.SessionOptions()
opts.add_provider_for_devices([matched], self._provider_options)
logger.info("Discovered EP for %s: %s", device, matched.ep_name)
return opts

# Policy-based selection (last resort)
# Policy-based selection (default when no explicit EP)
opts = self._session_options
policy = DEVICE_POLICY_MAP.get(
device.lower(), ort.OrtExecutionProviderDevicePolicy.PREFER_NPU
Expand All @@ -486,23 +481,13 @@ def _build_session_options(self, device: str) -> ort.SessionOptions:
def _find_ep_device(device: str, ep_name: str | None = None) -> Any:
"""Find the first OrtEpDevice matching the given filters.

Behavior:
- ``ep_name`` set, ``device == "auto"`` → first ep_device
matching ``ep_name`` (or None).
- ``ep_name`` unset, ``device == "auto"`` → ``None`` (no
effective filter — refuse to pick an arbitrary ep_device).
- ``ep_name`` unset, ``device`` is a concrete type → first
ep_device matching that device type (or None).
- Both set → ep_device must satisfy both (or None).

Note: Selection order is determined by the ORT EP registry, which is
not part of any documented contract. On systems where multiple EPs
match the same device type (e.g., QNN and DML both appear as GPU),
a device-only query returns the first one in registry order. Pass
``ep_name`` to disambiguate.
Filters are AND'd: both ``ep_name`` and ``device`` (when concrete)
must match. ``device="auto"`` acts as no-op device filter.
Returns ``None`` when no effective filter is provided (i.e. both
``ep_name`` is None and ``device`` resolves to no device type).

Args:
device: Device policy ("cpu", "gpu", "npu", "auto"). ``"auto"``
device: Device type ("cpu", "gpu", "npu", "auto"). ``"auto"``
and unknown strings act as no-op device filters.
ep_name: Full EP name (e.g., "DmlExecutionProvider"), or None
to skip EP-name filtering.
Expand Down
13 changes: 0 additions & 13 deletions tests/unit/session/test_winml_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,19 +719,6 @@ def test_ep_name_only(self) -> None:
assert match is not None
assert match.ep_name == "QNNExecutionProvider"

def test_device_only(self) -> None:
"""device filter alone returns first matching device type."""
import onnxruntime as ort

devs = [
self._ep_dev("CPUExecutionProvider", ort.OrtHardwareDeviceType.CPU),
self._ep_dev("DmlExecutionProvider", ort.OrtHardwareDeviceType.GPU),
]
with self._patch_devices(devs):
match = WinMLSession._find_ep_device(device="gpu")
assert match is not None
assert match.ep_name == "DmlExecutionProvider"

def test_ep_name_and_device_both_required(self) -> None:
"""When both filters are set, both must match (AND logic)."""
import onnxruntime as ort
Expand Down
Loading