diff --git a/src/winml/modelkit/commands/export.py b/src/winml/modelkit/commands/export.py index 9335235c3..8b7d151fa 100644 --- a/src/winml/modelkit/commands/export.py +++ b/src/winml/modelkit/commands/export.py @@ -134,6 +134,15 @@ def _delete_onnx_with_external_data(onnx_path: Path) -> None: default=None, help='JSON with shape overrides (e.g., {"sequence_length": 2048, "height": 640}).', ) +@click.option( + "--loader-config-overrides", + "loader_config_overrides_path", + type=click.Path(exists=True, path_type=Path), + default=None, + help="JSON file with overrides patched recursively onto the loaded HF " + 'config (e.g., {"scale": 2} for Real-ESRGAN). Deep-merges with any ' + "loader.loader_config_overrides set in --config; CLI keys win.", +) @cli_utils.build_config_option @click.pass_context def export( @@ -149,6 +158,7 @@ def export( input_specs: Path | None, export_config: Path | None, shape_config: Path | None, + loader_config_overrides_path: Path | None, config_file: Path | None, ) -> None: r"""Export HuggingFace model to ONNX format with HTP. @@ -198,17 +208,50 @@ def export( # Apply build config defaults (CLI explicit options take precedence) _build_export_cfg = None + loader_config_overrides: dict | None = None if config_file is not None: build_cfg = cli_utils.load_build_config(config_file) if build_cfg.export: _build_export_cfg = build_cfg.export if build_cfg.loader and not cli_utils.is_cli_provided(ctx, "task"): task = build_cfg.loader.task + if build_cfg.loader and build_cfg.loader.loader_config_overrides: + loader_config_overrides = dict(build_cfg.loader.loader_config_overrides) if _build_export_cfg and not cli_utils.is_cli_provided(ctx, "no_hierarchy"): no_hierarchy = not _build_export_cfg.enable_hierarchy_tags if _build_export_cfg and not cli_utils.is_cli_provided(ctx, "dynamo"): dynamo = _build_export_cfg.dynamo + # Sidecar --loader-config-overrides JSON deep-merges on top of any value + # carried by --config. CLI keys win on conflicts. + if loader_config_overrides_path is not None: + try: + cli_overrides = json.loads(loader_config_overrides_path.read_text()) + except json.JSONDecodeError as e: + raise click.ClickException( + f"Invalid JSON in --loader-config-overrides " + f"{loader_config_overrides_path}: {e}" + ) from e + if not isinstance(cli_overrides, dict): + raise click.ClickException( + f"--loader-config-overrides must contain a JSON object, " + f"got {type(cli_overrides).__name__}" + ) + + def _deep_merge(base: dict, top: dict) -> dict: + out = dict(base) + for k, v in top.items(): + if isinstance(v, dict) and isinstance(out.get(k), dict): + out[k] = _deep_merge(out[k], v) + else: + out[k] = v + return out + + loader_config_overrides = _deep_merge(loader_config_overrides or {}, cli_overrides) + console.print( + f"[dim]Loader config overrides: {loader_config_overrides}[/dim]" + ) + from ..export import InputTensorSpec, OutputTensorSpec, WinMLExportConfig from ..export import export_pytorch as export_onnx from ..loader import load_hf_model @@ -362,7 +405,11 @@ def export( console.print("\n[bold]Starting HTP export...[/bold]") # Load model with task detection (CLI is the orchestration layer) - pytorch_model, _, detected_task = load_hf_model(model, task=task) + pytorch_model, _, detected_task = load_hf_model( + model, + task=task, + loader_config_overrides=loader_config_overrides, + ) if task: console.print(f"[dim]Task (override): {detected_task}[/dim]") else: diff --git a/src/winml/modelkit/eval/evaluate.py b/src/winml/modelkit/eval/evaluate.py index f568904dc..a369fa835 100644 --- a/src/winml/modelkit/eval/evaluate.py +++ b/src/winml/modelkit/eval/evaluate.py @@ -183,6 +183,9 @@ def _load_model(config: WinMLEvaluationConfig) -> WinMLPreTrainedModel: # ignored here (mirrors winml perf's ONNX path). from transformers import AutoConfig + from ..export.io import ensure_hf_models_registered + + ensure_hf_models_registered() hf_config = AutoConfig.from_pretrained(config.model_id) model = WinMLAutoModel.from_onnx( onnx_path=config.model_path, @@ -214,10 +217,12 @@ def _resolve_task(config: WinMLEvaluationConfig) -> str: from transformers import AutoConfig - from ..loader.task import _detect_task_from_config + from ..export.io import ensure_hf_models_registered + from ..loader.task import _detect_task_and_class_from_config + ensure_hf_models_registered() hf_config = AutoConfig.from_pretrained(config.model_id) - return _detect_task_from_config(hf_config) + return _detect_task_and_class_from_config(hf_config)[0] def evaluate(config: WinMLEvaluationConfig) -> EvalResult: diff --git a/src/winml/modelkit/loader/config.py b/src/winml/modelkit/loader/config.py index 46e06ca41..0341f144c 100644 --- a/src/winml/modelkit/loader/config.py +++ b/src/winml/modelkit/loader/config.py @@ -45,6 +45,12 @@ class WinMLLoaderConfig: Requires trust_remote_code=True for security. trust_remote_code: Whether to trust remote/custom code. Required when using user_script. + loader_config_overrides: Optional patch applied recursively to the HF + config object after ``AutoConfig.from_pretrained``. Keys are + attribute names; nested dicts are merged into sub-configs (e.g. + ``{"vision_config": {"image_size": 320}}``). Use cases include + selecting a non-default hyperparameter (``{"scale": 2}`` for + Real-ESRGAN) without committing a separate ``config.json``. Example: # Standard usage with auto-detection @@ -71,6 +77,7 @@ class WinMLLoaderConfig: module_path: str | None = None user_script: str | None = None trust_remote_code: bool = False + loader_config_overrides: dict[str, Any] | None = None def to_dict(self) -> dict[str, Any]: """Serialize to dictionary. @@ -91,6 +98,8 @@ def to_dict(self) -> dict[str, Any]: result["user_script"] = self.user_script if self.trust_remote_code: result["trust_remote_code"] = self.trust_remote_code + if self.loader_config_overrides: + result["loader_config_overrides"] = self.loader_config_overrides return result @classmethod @@ -110,9 +119,69 @@ def from_dict(cls, data: dict[str, Any]) -> WinMLLoaderConfig: module_path=data.get("module_path"), user_script=data.get("user_script"), trust_remote_code=data.get("trust_remote_code", False), + loader_config_overrides=data.get("loader_config_overrides"), ) +def _deep_merge_dicts(base: dict, top: dict) -> dict: + """Return a new dict deep-merging ``top`` on top of ``base``. + + Nested dicts on both sides are merged recursively; otherwise ``top``'s + value wins. ``base`` is not mutated. + """ + out = dict(base) + for key, value in top.items(): + if isinstance(value, dict) and isinstance(out.get(key), dict): + out[key] = _deep_merge_dicts(out[key], value) + else: + out[key] = value + return out + + +def apply_loader_config_overrides( + hf_config: PretrainedConfig, + overrides: dict[str, Any] | None, +) -> PretrainedConfig: + """Return a new HF config with ``overrides`` deep-merged onto ``hf_config``. + + Serializes the original config via :meth:`PretrainedConfig.to_dict`, + recursively deep-merges ``overrides`` into the resulting plain dict + (``overrides`` keys win on conflict, nested dicts merge into nested + dicts), then reconstructs the config via + ``type(hf_config).from_dict(merged)``. + + Going through ``to_dict`` / ``from_dict`` lets the config class's own + constructor handle validation, defaulting, and sub-config reconstruction + — including nested :class:`PretrainedConfig` fields like + ``CLIPConfig.vision_config`` (HF's ``from_dict`` rebuilds them from + nested dicts automatically). A raw ``setattr`` loop, by contrast, can + silently create attributes the model class never reads when the + override key is missing on the original config. + + Empty / ``None`` overrides return the original config unchanged. + + Args: + hf_config: The HF :class:`PretrainedConfig` to patch. + overrides: Nested dict of overrides, or ``None``. + + Returns: + A :class:`PretrainedConfig` of the same concrete type as + ``hf_config`` with the overrides applied. May be the original + instance (when overrides are empty) or a freshly constructed one. + """ + if not overrides: + return hf_config + + merged = _deep_merge_dicts(hf_config.to_dict(), overrides) + new_config = type(hf_config).from_dict(merged) + logger.debug( + "Applied loader_config_overrides to %s: %s", + type(hf_config).__name__, + overrides, + ) + return new_config + + def resolve_loader_config( model_id: str | None = None, *, @@ -121,6 +190,7 @@ def resolve_loader_config( model_type: str | None = None, trust_remote_code: bool = False, library_name: str = "transformers", + loader_config_overrides: dict[str, Any] | None = None, ) -> tuple[WinMLLoaderConfig, PretrainedConfig, type]: """Resolve all loader concerns from raw user inputs. @@ -154,6 +224,11 @@ def resolve_loader_config( When provided without task, the first supported task is used. trust_remote_code: Whether to trust remote/custom code. library_name: Source library for TasksManager lookup. + loader_config_overrides: Optional nested dict patched recursively onto + the HF config after it is loaded — see + :func:`apply_loader_config_overrides`. Stored on the returned + :class:`WinMLLoaderConfig` so downstream + ``resolved_class.from_pretrained`` consumers can re-apply it. Returns: Tuple of: @@ -169,8 +244,13 @@ def resolve_loader_config( """ from transformers import AutoConfig + from ..export.io import ensure_hf_models_registered from .task import get_supported_tasks, resolve_task_and_model_class + # Ensure HF model registrations (AutoConfig.register, OnnxConfig overwrites, + # task-mapping fallbacks) have run before any AutoConfig / TasksManager calls. + ensure_hf_models_registered() + # 1. Load hf_config (depends on: model_id, model_type, or model_class) if model_id is not None: hf_config = AutoConfig.from_pretrained( @@ -206,6 +286,10 @@ def resolve_loader_config( f"attribute. Cannot proceed with config generation." ) + # 1a. Apply caller-supplied overrides — returns a new config when overrides + # are non-empty so the config class's own __init__ / from_dict validates. + hf_config = apply_loader_config_overrides(hf_config, loader_config_overrides) + # 2. Infer task (depends on: model_type param or hf_config.architectures) if task is None and model_type is not None: supported = get_supported_tasks(model_type, library_name=library_name) @@ -241,6 +325,7 @@ def resolve_loader_config( model_class=resolved_class.__name__, model_type=resolved_model_type, trust_remote_code=trust_remote_code, + loader_config_overrides=loader_config_overrides or None, ) return loader_config, resolved_hf_config, resolved_class @@ -313,4 +398,8 @@ def _resolve_hf_config_for_class( return hf_config, hf_config.model_type -__all__ = ["WinMLLoaderConfig", "resolve_loader_config"] +__all__ = [ + "WinMLLoaderConfig", + "apply_loader_config_overrides", + "resolve_loader_config", +] diff --git a/src/winml/modelkit/loader/hf.py b/src/winml/modelkit/loader/hf.py index 07084499d..83a16cafb 100644 --- a/src/winml/modelkit/loader/hf.py +++ b/src/winml/modelkit/loader/hf.py @@ -145,6 +145,7 @@ def load_hf_model( model_class: str | None = None, user_script: str | None = None, trust_remote_code: bool = False, + loader_config_overrides: dict | None = None, ) -> tuple[nn.Module, PretrainedConfig, str]: """Load, detect task, and prepare HuggingFace model. @@ -194,6 +195,12 @@ def load_hf_model( """ logger.info("Loading HF model: %s", model_name_or_path) + # Ensure HF model registrations (AutoConfig.register, OnnxConfig overwrites, + # task-mapping fallbacks) have run before any AutoConfig / TasksManager calls. + from ..export.io import ensure_hf_models_registered + + ensure_hf_models_registered() + # Validate user_script requirements before any network calls if user_script is not None: if not trust_remote_code: @@ -210,6 +217,15 @@ def load_hf_model( trust_remote_code=trust_remote_code, ) + # [1a] Apply optional ``loader_config_overrides`` (e.g. ESRGAN ``scale``) + # before task resolution so downstream lookups see the patched values. + # ``apply_loader_config_overrides`` returns a new config instance built via + # ``from_dict`` so the class constructor validates / reconstructs. + if loader_config_overrides: + from .config import apply_loader_config_overrides + + hf_config = apply_loader_config_overrides(hf_config, loader_config_overrides) + # [2] Task & Model Class Resolution if user_script is not None: resolved_class = _load_class_from_script(user_script, model_class) @@ -230,11 +246,18 @@ def load_hf_model( f"Cannot resolve task/model for {model_name_or_path}. Original error: {e}" ) from e - # [4] Model Instantiation + # [4] Model Instantiation. When overrides were applied, pass the patched + # ``hf_config`` as ``config=`` so the model class uses it instead of + # re-loading from disk/Hub and losing the patches. logger.debug("Loading model with class: %s", resolved_class.__name__) + from_pretrained_kwargs: dict = { + "trust_remote_code": trust_remote_code, + } + if loader_config_overrides: + from_pretrained_kwargs["config"] = hf_config model = resolved_class.from_pretrained( model_name_or_path, - trust_remote_code=trust_remote_code, + **from_pretrained_kwargs, ) # [5] Export Preparation diff --git a/src/winml/modelkit/loader/task.py b/src/winml/modelkit/loader/task.py index c1604b425..fab4ace05 100644 --- a/src/winml/modelkit/loader/task.py +++ b/src/winml/modelkit/loader/task.py @@ -223,33 +223,42 @@ def _detect_task_and_class_from_config(config: PretrainedConfig) -> tuple[str, t Called by ``resolve_task_and_model_class`` Case 1. Resolution flow: - 1. ``_resolve_model_class_from_config(config)`` -> arch_model_class - 2. ``_detect_task_from_model_class(arch_model_class)`` -> task - 3. ``_get_custom_model_class(model_type, task)`` -> specialization check - 4. If specialization found -> return (task, specialized_class) - 5. Else ``TasksManager.get_model_class_for_task(task)`` -> tm_class - 6. If TasksManager fails -> fallback to arch_model_class + + 0. Read ``config.model_type`` (required); raise if absent. + 0a. Sentinel short-circuit: if ``MODEL_CLASS_MAPPING`` contains + ``(model_type, None) -> default_class``, reverse-lookup the matching + ``(model_type, task) -> default_class`` entry and return + ``(task, default_class)`` immediately. This bypasses steps 1-6 and + covers two cases: + + - Model families like SAM/SAM2 whose architecture class's natural + TasksManager mapping (``"feature-extraction"``) differs from the + canonical export target (``"mask-generation"``). + - Custom architecture classes that live outside ``transformers`` + (e.g. ESRGAN) and would otherwise fail step 1's + ``getattr(transformers, arch_name)``. + 1. ``_resolve_model_class_from_config(config)`` -> arch_model_class. + 2. ``_detect_task_from_model_class(arch_model_class)`` -> task. + 3. (Reserved.) + 4. ``_get_custom_model_class(model_type, task)`` -> specialization check; + return early when a specialized class is registered. + 5. Else ``TasksManager.get_model_class_for_task(task)`` -> tm_class. + 6. If TasksManager fails, fall back to arch_model_class. Args: - config: HuggingFace PretrainedConfig + config: HuggingFace PretrainedConfig. Returns: - Tuple of (task, model_class) + Tuple of (task, model_class). Raises: - ValueError: If task cannot be detected or model_type is missing + ValueError: If model_type is missing, task cannot be detected, or + a (model_type, None) sentinel exists with no matching + (model_type, task) entry. """ from optimum.exporters.tasks import TasksManager - # [1] Resolve architecture class from config - arch_model_class = _resolve_model_class_from_config(config) - arch_name = arch_model_class.__name__ - - # [2] Infer task from model class - task = _detect_task_from_model_class(arch_model_class) - logger.info("Detected task: %s (from %s)", task, arch_name) - - # [3] Get model_type - REQUIRED for specialization lookup + # [0] Get model_type - REQUIRED for specialization / sentinel lookup. model_type = getattr(config, "model_type", None) if model_type is None: raise ValueError( @@ -257,14 +266,16 @@ def _detect_task_and_class_from_config(config: PretrainedConfig) -> tuple[str, t "Please specify model_class explicitly." ) - # [3a] Per-model-type default task override. - # Some model families (e.g., SAM/SAM2) have an architecture class whose - # default TasksManager mapping ("feature-extraction") differs from the - # canonical export target ("mask-generation"). The default is encoded as - # a sentinel entry MODEL_CLASS_MAPPING[(model_type, None)] = ; - # we reverse-lookup the task name from the matching - # (model_type, default_task) -> same_class entry. This keeps the data in - # one table and structurally enforces that the matching class entry exists. + # [0a] Per-model-type default task override (consulted before + # architecture-based detection). Some model families either have an + # architecture class whose default TasksManager mapping ("feature- + # extraction") differs from the canonical export target ("mask-generation" + # for SAM/SAM2), or live outside ``transformers`` entirely (e.g. ESRGAN). + # The default is encoded as a sentinel entry + # ``MODEL_CLASS_MAPPING[(model_type, None)] = ``; we reverse-lookup + # the task name from the matching ``(model_type, default_task) -> same_class`` + # entry. Checking this BEFORE arch detection lets us short-circuit for + # custom classes that are not importable from ``transformers``. from ..models.hf import MODEL_CLASS_MAPPING model_type_normalized = model_type.lower().replace("_", "-") @@ -285,15 +296,16 @@ def _detect_task_and_class_from_config(config: PretrainedConfig) -> tuple[str, t f"({model_type_normalized!r}, ) entry maps to that class. " f"Add the corresponding (model_type, task) entry." ) - if default_task != task: - logger.info( - "Overriding auto-detected task %r with model-type default %r for %s", - task, - default_task, - model_type_normalized, - ) return default_task, default_class + # [1] Resolve architecture class from config + arch_model_class = _resolve_model_class_from_config(config) + arch_name = arch_model_class.__name__ + + # [2] Infer task from model class + task = _detect_task_from_model_class(arch_model_class) + logger.info("Detected task: %s (from %s)", task, arch_name) + # [4] Check specializations first (CLIP, SAM2, etc.) - highest priority model_class = _get_custom_model_class(model_type, task) if model_class: diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index 2864ddab8..46aadf5dc 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -186,10 +186,15 @@ def from_onnx( if skip_build or is_compiled_onnx(onnx_path): logger.info("Skipping build (compiled model or explicit skip). Using original ONNX.") # TODO: run analyze_onnx for validation/lint - winml_class = get_winml_class(None, resolved_task) + # Use hf_config.model_type so WINML_MODEL_CLASS_MAPPING's + # specialised entries (e.g. ESRGAN) are picked when the caller + # knows the model_type; falls back to the generic task class + # otherwise. + model_type = hf_config.model_type if hf_config is not None else None + winml_class = get_winml_class(model_type, resolved_task) return winml_class( onnx_path=onnx_path, - config=None, + config=hf_config, device=device, session_options=session_options, ep=ep, @@ -220,13 +225,16 @@ def from_onnx( **kwargs, ) - # Wrap in inference model (task-specific or generic fallback) - winml_class = get_winml_class(None, resolved_task) + # Wrap in inference model (task-specific or generic fallback). + # When the caller supplies hf_config, pick the specialised class + # registered for its ``model_type`` and propagate the config. + model_type = hf_config.model_type if hf_config is not None else None + winml_class = get_winml_class(model_type, resolved_task) logger.info("Creating inference wrapper: %s", winml_class.__name__) return winml_class( onnx_path=result.final_onnx_path, - config=None, # No HF PretrainedConfig for bare ONNX builds + config=hf_config, device=device, session_options=session_options, ep=ep, diff --git a/src/winml/modelkit/models/hf/__init__.py b/src/winml/modelkit/models/hf/__init__.py index 98bd30e66..efbceee60 100644 --- a/src/winml/modelkit/models/hf/__init__.py +++ b/src/winml/modelkit/models/hf/__init__.py @@ -25,6 +25,8 @@ from __future__ import annotations +from transformers import AutoConfig + # Import configs - importing triggers ONNX config registration with TasksManager # ConvNeXT and SAM2 modules also register PATCHING_SPECS / _MODEL_PATCHER # on their OnnxConfig classes at import time. @@ -43,6 +45,9 @@ from .depth_anything import DepthAnythingIOConfig as _DepthAnythingIOConfig # triggers registration from .depth_pro import DepthProIOConfig as _DepthProIOConfig # triggers registration from .detr import DETR_CONFIG +from .esrgan import MODEL_CLASS_MAPPING as _ESRGAN_CLASS_MAPPING +from .esrgan import ESRGANConfig, ESRGANForImageSuperResolution +from .esrgan import ESRGANIOConfig as _ESRGANIOConfig # triggers registration from .marian import MARIAN_CONFIG from .marian import MODEL_CLASS_MAPPING as _MARIAN_CLASS_MAPPING from .marian import MarianDecoderIOConfig as _MarianDecoderIOConfig # triggers registration @@ -77,6 +82,13 @@ from .zoedepth import ZoeDepthIOConfig as _ZoeDepthIOConfig # triggers registration +# Register ESRGAN with HF's AutoConfig. The model_type string is uppercase +# ``ESRGAN`` so that HF's name-based fallback in AutoConfig.from_pretrained +# matches repo names like ``ai-forever/Real-ESRGAN`` (case-sensitive +# substring) and produces a default ESRGANConfig. +AutoConfig.register("ESRGAN", ESRGANConfig, exist_ok=True) + + # Aggregated model class mappings: (model_type, task) -> HF model class. # # A sentinel entry with task=None encodes the per-model-type default task @@ -88,6 +100,7 @@ **_BART_CLASS_MAPPING, **_BLIP_CLASS_MAPPING, **_CLIP_CLASS_MAPPING, + **_ESRGAN_CLASS_MAPPING, **_MARIAN_CLASS_MAPPING, **_MU2_CLASS_MAPPING, **_QWEN_CLASS_MAPPING, @@ -98,6 +111,7 @@ **_VED_CLASS_MAPPING, } + # Registry: model_type -> WinMLBuildConfig # Only models that need non-autoconf-discoverable settings retain configs. # Models with only optim flags rely on the analyzer autoconf loop. diff --git a/src/winml/modelkit/models/hf/esrgan.py b/src/winml/modelkit/models/hf/esrgan.py new file mode 100644 index 000000000..af8ae2b79 --- /dev/null +++ b/src/winml/modelkit/models/hf/esrgan.py @@ -0,0 +1,412 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Real-ESRGAN (RRDBNet) as a HuggingFace PreTrainedModel. + +Implements the RRDBNet architecture from Real-ESRGAN as a PreTrainedModel +so it can be saved/loaded via save_pretrained/from_pretrained and exported +to ONNX via the standard HuggingFace/Optimum pipeline. + +Architecture reference: sberbank-ai/Real-ESRGAN (BSD-3-Clause license). + +Classes: + ESRGANConfig: PretrainedConfig with RRDBNet hyperparameters. + ESRGANPreTrainedModel: Base PreTrainedModel (config_class, init_weights). + ResidualDenseBlock: 5-conv dense block with residual scaling. + RRDB: Residual-in-Residual Dense Block (3x ResidualDenseBlock). + ESRGANForImageSuperResolution: Full RRDBNet for image super-resolution. + +Note: + ``ESRGANForImageSuperResolution.from_pretrained`` is overridden so that + a ``.pth``-only HF repo (e.g. ``ai-forever/Real-ESRGAN``) is supported + end-to-end without an offline ``save_pretrained`` conversion step. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from optimum.exporters.onnx import OnnxConfig +from optimum.utils import NormalizedVisionConfig +from optimum.utils.input_generators import DummyVisionInputGenerator +from transformers import PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import ImageSuperResolutionOutput + +from ...export import register_onnx_overwrite + + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Config +# ============================================================================= + + +class ESRGANConfig(PretrainedConfig): + """Configuration for Real-ESRGAN RRDBNet architecture. + + The ``model_type`` is intentionally uppercase ``"ESRGAN"`` so that + HuggingFace's name-based fallback in + :meth:`AutoConfig.from_pretrained` (which does a case-sensitive + substring match against the repo name) succeeds for community repos + like ``ai-forever/Real-ESRGAN`` that ship raw ``.pth`` weights with + no ``config.json``. + + Attributes: + num_in_ch: Number of input channels. + num_out_ch: Number of output channels. + num_feat: Number of intermediate feature channels. + num_block: Number of RRDB blocks in the body. + num_grow_ch: Growth channel count inside ResidualDenseBlock. + scale: Upscaling factor (1, 2, 4, or 8). + weight_file_format: ``str.format``-style template for the ``.pth`` + filename to download from the Hub repo. Receives the current + :attr:`scale` as the ``scale`` keyword. Defaults to the + sberbank-ai / ai-forever naming convention + (``"RealESRGAN_x{scale}.pth"``); override e.g. via + ``loader_config_overrides`` to point at a fork that uses a + different filename. + """ + + model_type = "ESRGAN" + + def __init__( + self, + num_in_ch: int = 3, + num_out_ch: int = 3, + num_feat: int = 64, + num_block: int = 23, + num_grow_ch: int = 32, + scale: int = 4, + weight_file_format: str = "RealESRGAN_x{scale}.pth", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.num_in_ch = num_in_ch + self.num_out_ch = num_out_ch + self.num_feat = num_feat + self.num_block = num_block + self.num_grow_ch = num_grow_ch + self.scale = scale + self.weight_file_format = weight_file_format + + +# ============================================================================= +# Weight initialisation helper +# ============================================================================= + + +def default_init_weights(module_list: list[nn.Module] | nn.Module, scale: float = 1.0) -> None: + """Kaiming normal init for Conv2d layers with optional scale multiplier. + + Mirrors the sberbank-ai/Real-ESRGAN initialisation used in ResidualDenseBlock. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in", nonlinearity="leaky_relu") + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + + +# ============================================================================= +# Building blocks +# ============================================================================= + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block with 5 convolutions. + + Each conv receives the concatenation of all preceding feature maps. + A 0.2 residual scaling is applied before adding back to the input. + """ + + def __init__(self, num_feat: int = 64, num_grow_ch: int = 32) -> None: + super().__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # Initialise weights (conv5 uses scale=0.1 for stability) + default_init_weights( + [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], + scale=0.1, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Dense forward: each conv sees all prior feature maps.""" + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Residual scaling + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual-in-Residual Dense Block (3x ResidualDenseBlock).""" + + def __init__(self, num_feat: int, num_grow_ch: int = 32) -> None: + super().__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply 3 RDB blocks with 0.2 residual scaling.""" + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + return out * 0.2 + x + + +# ============================================================================= +# PreTrainedModel base +# ============================================================================= + + +class ESRGANPreTrainedModel(PreTrainedModel): + """Base PreTrainedModel for Real-ESRGAN variants.""" + + config_class = ESRGANConfig + base_model_prefix = "esrgan" + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + + def _init_weights(self, module: nn.Module) -> None: + """No-op: RRDBNet blocks self-initialise via default_init_weights.""" + + +# ============================================================================= +# Full model +# ============================================================================= + + +class ESRGANForImageSuperResolution(ESRGANPreTrainedModel): + """RRDBNet for image super-resolution. + + Architecture: + - Optional pixel_unshuffle for scale 1 or 2 + - conv_first -> body (N x RRDB) -> conv_body (skip connection) + - Upsampling via nearest-neighbour interpolation + conv + - conv_hr -> conv_last for final output + + Attribute names match sberbank-ai/Real-ESRGAN for weight compatibility. + """ + + def __init__(self, config: ESRGANConfig) -> None: + super().__init__(config) + + scale = config.scale + num_feat = config.num_feat + num_grow_ch = config.num_grow_ch + + # For scale <= 2, pixel_unshuffle compresses spatial dims + # and increases channel count before the network body + if scale == 2: + num_in_ch = config.num_in_ch * 4 + elif scale == 1: + num_in_ch = config.num_in_ch * 16 + else: + num_in_ch = config.num_in_ch + + self.scale = scale + + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = nn.Sequential( + *[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(config.num_block)] + ) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + + # Upsampling convolutions + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if scale == 8: + self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, config.num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # Initialize weights and apply final processing (PreTrainedModel) + self.post_init() + + def forward( + self, + pixel_values: torch.Tensor, + return_dict: bool | None = None, + ) -> ImageSuperResolutionOutput | tuple[torch.Tensor]: + """Run super-resolution on input images. + + Args: + pixel_values: Input tensor of shape (B, C, H, W). + return_dict: Whether to return ImageSuperResolutionOutput or tuple. + + Returns: + ImageSuperResolutionOutput with reconstruction, or tuple if return_dict=False. + """ + if return_dict is None: + return_dict = ( + self.config.use_return_dict if hasattr(self.config, "use_return_dict") else True + ) + + feat = pixel_values + + # Pixel unshuffle for scale <= 2 + if self.scale == 2: + feat = F.pixel_unshuffle(feat, downscale_factor=2) + elif self.scale == 1: + feat = F.pixel_unshuffle(feat, downscale_factor=4) + + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + + # Upsample + feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest"))) + feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest"))) + if self.scale == 8: + feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode="nearest"))) + + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + + if not return_dict: + return (out,) + + return ImageSuperResolutionOutput(reconstruction=out) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str | Path, + *model_args: Any, + **kwargs: Any, + ) -> ESRGANForImageSuperResolution: + """Load Real-ESRGAN from a ``.pth``-only HF repo. + + Real-ESRGAN distributions on the Hub (e.g. ``ai-forever/Real-ESRGAN``) + ship raw ``.pth`` checkpoints with no ``config.json`` / + ``pytorch_model.bin``, so the standard + :meth:`PreTrainedModel.from_pretrained` flow can't load them. Because + :class:`ESRGANConfig` is registered with :class:`AutoConfig`, + :meth:`ESRGANConfig.from_pretrained` already returns a default-valued + config when ``config.json`` is missing on the Hub — so we just: + + 1. Build the config via that path (or accept one passed through ``config=``). + 2. Override ``scale`` from a ``scale=`` kwarg if the caller specifies one. + 3. Download ``RealESRGAN_x{scale}.pth`` to a temp dir, ``load_state_dict`` + it, then discard the temp dir. + + Local directories that already contain ``config.json``/weights are + delegated to :meth:`PreTrainedModel.from_pretrained` unchanged. + """ + import tempfile + + from huggingface_hub import hf_hub_download + + scale_hint = kwargs.pop("scale", None) + + # Local directory with config.json -> use base implementation as-is. + local_path = Path(pretrained_model_name_or_path) + if local_path.exists() and (local_path / "config.json").exists(): + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + # Build an ESRGANConfig either from a caller-provided one or by + # invoking ESRGANConfig.from_pretrained (which returns defaults when + # the Hub repo has no config.json). + config = kwargs.pop("config", None) + if not isinstance(config, ESRGANConfig): + config = ESRGANConfig.from_pretrained(str(pretrained_model_name_or_path)) + + if scale_hint is not None: + config.scale = int(scale_hint) + + weight_file = config.weight_file_format.format(scale=config.scale) + with tempfile.TemporaryDirectory(prefix="esrgan-weights-") as tmpdir: + logger.info( + "Downloading %s from %s (scale=%d) to %s", + weight_file, + pretrained_model_name_or_path, + config.scale, + tmpdir, + ) + pth_path = hf_hub_download( + repo_id=str(pretrained_model_name_or_path), + filename=weight_file, + local_dir=tmpdir, + ) + state = torch.load(pth_path, map_location="cpu", weights_only=True) + + if isinstance(state, dict): + if "params_ema" in state: + state = state["params_ema"] + elif "params" in state: + state = state["params"] + + model = cls(config) + model.load_state_dict(state, strict=True) + return model + + +# ============================================================================= +# Class mapping for the loader +# ============================================================================= +# ``model_type`` keys are lowercased before MODEL_CLASS_MAPPING lookup +# (see ``loader.task._get_custom_model_class``), so we use ``"esrgan"`` here +# even though the canonical model_type on the config is uppercase ``ESRGAN``. +# +# The ``("esrgan", None)`` sentinel encodes the per-model-type default task +# for auto-detection — see the ``[3a]`` block in ``loader/task.py``. Because +# the sentinel is consulted ahead of architecture-based detection, the +# resolver does not need ``ESRGANForImageSuperResolution`` to be importable +# from the ``transformers`` package. +MODEL_CLASS_MAPPING: dict[tuple[str, str | None], type] = { + ("esrgan", None): ESRGANForImageSuperResolution, + ("esrgan", "image-to-image"): ESRGANForImageSuperResolution, +} + + +# ============================================================================= +# ONNX export config +# ============================================================================= + + +@register_onnx_overwrite("ESRGAN", "image-to-image", library_name="transformers") +class ESRGANIOConfig(OnnxConfig): + """ONNX export config for Real-ESRGAN. + + Inputs: + - pixel_values: {0: "batch_size", 2: "height", 3: "width"} + + Outputs: + - reconstruction: {0: "batch_size", 2: "height", 3: "width"} + """ + + NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,) + + @property + def inputs(self) -> dict[str, dict[int, str]]: + """Return input tensor names and their dynamic axes.""" + return {"pixel_values": {0: "batch_size", 2: "height", 3: "width"}} + + @property + def outputs(self) -> dict[str, dict[int, str]]: + """Return output tensor names and their dynamic axes.""" + return {"reconstruction": {0: "batch_size", 2: "height", 3: "width"}} diff --git a/src/winml/modelkit/models/winml/__init__.py b/src/winml/modelkit/models/winml/__init__.py index b0c3a941c..4009803ae 100644 --- a/src/winml/modelkit/models/winml/__init__.py +++ b/src/winml/modelkit/models/winml/__init__.py @@ -40,6 +40,7 @@ "image-segmentation": "WinMLModelForImageSegmentation", "semantic-segmentation": "WinMLModelForSemanticSegmentation", "object-detection": "WinMLModelForObjectDetection", + "image-to-image": "WinMLModelForImageToImage", # Not yet implemented — falls back to WinMLModelForGenericTask at runtime "token-classification": "WinMLModelForTokenClassification", "question-answering": "WinMLModelForQuestionAnswering", @@ -58,6 +59,10 @@ # - Specific OPSET requirements # - Input remapping (non-standard tensor names) # - Custom pre/post-processing + # + # ESRGAN adds a patch-based predict() that runs the (shape-specialised) + # ONNX session over overlapping patches and stitches the results back. + ("esrgan", "image-to-image"): "WinMLESRGANForImageToImage", } @@ -74,21 +79,25 @@ def _import_winml_class(class_name: str) -> type[WinMLPreTrainedModel]: ImportError: If class is not implemented yet """ from .base import WinMLModelForGenericTask + from .esrgan import WinMLESRGANForImageToImage from .feature_extraction import WinMLModelForFeatureExtraction from .image_classification import WinMLModelForImageClassification from .image_segmentation import ( WinMLModelForImageSegmentation, WinMLModelForSemanticSegmentation, ) + from .image_to_image import WinMLModelForImageToImage from .object_detection import WinMLModelForObjectDetection from .question_answering import WinMLModelForQuestionAnswering from .sequence_classification import WinMLModelForSequenceClassification # Map class names to modules class_map: dict[str, type] = { + "WinMLESRGANForImageToImage": WinMLESRGANForImageToImage, "WinMLModelForFeatureExtraction": WinMLModelForFeatureExtraction, "WinMLModelForImageClassification": WinMLModelForImageClassification, "WinMLModelForImageSegmentation": WinMLModelForImageSegmentation, + "WinMLModelForImageToImage": WinMLModelForImageToImage, "WinMLModelForObjectDetection": WinMLModelForObjectDetection, "WinMLModelForQuestionAnswering": WinMLModelForQuestionAnswering, "WinMLModelForSemanticSegmentation": WinMLModelForSemanticSegmentation, @@ -183,6 +192,7 @@ def register_specialization(model_type: str, task: str, class_name: str) -> None ) from .decoder_only import WinMLDecoderOnlyModel from .encoder_decoder import WinMLEncoderDecoderModel +from .esrgan import WinMLESRGANForImageToImage from .feature_extraction import WinMLModelForFeatureExtraction from .image_classification import WinMLModelForImageClassification from .image_segmentation import ( @@ -190,6 +200,7 @@ def register_specialization(model_type: str, task: str, class_name: str) -> None WinMLModelForImageSegmentation, WinMLModelForSemanticSegmentation, ) +from .image_to_image import ImageReconstructionOutput, WinMLModelForImageToImage from .kv_cache import ( WinMLCache, WinMLSlidingWindowCache, @@ -204,15 +215,18 @@ def register_specialization(model_type: str, task: str, class_name: str) -> None "COMPOSITE_MODEL_REGISTRY", "TASK_TO_WINML_CLASS", "WINML_MODEL_CLASS_MAPPING", + "ImageReconstructionOutput", "ImageSegmentationOutput", "WinMLCache", "WinMLCompositeModel", "WinMLDecoderOnlyModel", + "WinMLESRGANForImageToImage", "WinMLEncoderDecoderModel", "WinMLModelForFeatureExtraction", "WinMLModelForGenericTask", "WinMLModelForImageClassification", "WinMLModelForImageSegmentation", + "WinMLModelForImageToImage", "WinMLModelForObjectDetection", "WinMLModelForSemanticSegmentation", "WinMLModelForSequenceClassification", diff --git a/src/winml/modelkit/models/winml/esrgan.py b/src/winml/modelkit/models/winml/esrgan.py new file mode 100644 index 000000000..f9a445c85 --- /dev/null +++ b/src/winml/modelkit/models/winml/esrgan.py @@ -0,0 +1,230 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Real-ESRGAN-specialised WinML inference wrapper. + +Adds a torch-based ``predict(lr_image) -> PIL.Image`` method on top of the +generic :class:`WinMLModelForImageToImage`. ``predict`` is the official +``RealESRGAN.model.predict`` flow from sberbank-ai/Real-ESRGAN ported as +directly as possible — same kwargs, same patch geometry, same torch ops — +with the underlying ``self.model(...)`` call replaced by ``self.forward()`` +on the ONNX-backed runtime model. The helpers (``pad_reflect``, +``split_image_into_overlapping_patches``, ``stich_together``, +``unpad_image``, plus the private ``pad_patch`` / ``unpad_patches``) are +copied verbatim from ``RealESRGAN/utils.py``. + +Wired into the task-class resolver via +``WINML_MODEL_CLASS_MAPPING[("esrgan", "image-to-image")]``. + +Architecture / algorithm reference: sberbank-ai/Real-ESRGAN +(BSD-3-Clause license). +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import numpy as np +import torch +from PIL import Image + +from .image_to_image import WinMLModelForImageToImage + + +if TYPE_CHECKING: + from PIL.Image import Image as PILImage + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Helpers — copied verbatim from sberbank-ai/Real-ESRGAN/RealESRGAN/utils.py +# ============================================================================= + + +def pad_reflect(image, pad_size): + """Reflect-pad ``image`` (H, W, C) by ``pad_size`` on each side (verbatim upstream).""" + imsize = image.shape + height, width = imsize[:2] + new_img = np.zeros([height + pad_size * 2, width + pad_size * 2, imsize[2]]).astype(np.uint8) + new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image + + new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) # top + new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) # bottom + new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size : pad_size * 2, :], axis=1) # left + new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size * 2 : -pad_size, :], axis=1) # right + + return new_img + + +def unpad_image(image, pad_size): + """Inverse of :func:`pad_reflect` (verbatim upstream).""" + return image[pad_size:-pad_size, pad_size:-pad_size, :] + + +def pad_patch(image_patch, padding_size, channel_last=True): + """Pads image_patch with with padding_size edge values.""" + if channel_last: + return np.pad( + image_patch, + ((padding_size, padding_size), (padding_size, padding_size), (0, 0)), + "edge", + ) + return np.pad( + image_patch, + ((0, 0), (padding_size, padding_size), (padding_size, padding_size)), + "edge", + ) + + +def unpad_patches(image_patches, padding_size): + """Strip the spatial border added by :func:`pad_patch` (verbatim upstream).""" + return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :] + + +def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2): + """Splits the image into partially overlapping patches. + + The patches overlap by padding_size pixels. + Pads the image twice: + - first to have a size multiple of the patch size, + - then to have equal padding at the borders. + + Args: + image_array: numpy array of the input image. + patch_size: size of the patches from the original image (without padding). + padding_size: size of the overlapping area. + """ + xmax, ymax, _ = image_array.shape + x_remainder = xmax % patch_size + y_remainder = ymax % patch_size + + # modulo here is to avoid extending of patch_size instead of 0 + x_extend = (patch_size - x_remainder) % patch_size + y_extend = (patch_size - y_remainder) % patch_size + + # make sure the image is divisible into regular patches + extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), "edge") + + # add padding around the image to simplify computations + padded_image = pad_patch(extended_image, padding_size, channel_last=True) + + xmax, ymax, _ = padded_image.shape + patches = [] + + x_lefts = range(padding_size, xmax - padding_size, patch_size) + y_tops = range(padding_size, ymax - padding_size, patch_size) + + for x in x_lefts: + for y in y_tops: + x_left = x - padding_size + y_top = y - padding_size + x_right = x + patch_size + padding_size + y_bottom = y + patch_size + padding_size + patch = padded_image[x_left:x_right, y_top:y_bottom, :] + patches.append(patch) + + return np.array(patches), padded_image.shape + + +def stich_together(patches, padded_image_shape, target_shape, padding_size=4): + """Reconstruct the image from overlapping patches. + + After scaling, shapes and padding should be scaled too. + + Args: + patches: patches obtained with split_image_into_overlapping_patches + padded_image_shape: shape of the padded image contructed in + split_image_into_overlapping_patches + target_shape: shape of the final image + padding_size: size of the overlapping area. + """ + xmax, ymax, _ = padded_image_shape + patches = unpad_patches(patches, padding_size) + patch_size = patches.shape[1] + n_patches_per_row = ymax // patch_size + + complete_image = np.zeros((xmax, ymax, 3)) + + row = -1 + col = 0 + for i in range(len(patches)): + if i % n_patches_per_row == 0: + row += 1 + col = 0 + complete_image[ + row * patch_size : (row + 1) * patch_size, + col * patch_size : (col + 1) * patch_size, + :, + ] = patches[i] + col += 1 + return complete_image[0 : target_shape[0], 0 : target_shape[1], :] + + +# ============================================================================= +# Specialised WinML class +# ============================================================================= + + +class WinMLESRGANForImageToImage(WinMLModelForImageToImage): + """ESRGAN-specialised ``WinMLModelForImageToImage`` with patch-based SR. + + The exported ONNX session is shape-specialised to a fixed patch tensor + (``patches_size + 2 * padding`` on each spatial side) and an upscale + factor encoded on :attr:`config.scale`. :meth:`predict` accepts any-size + PIL image, runs the official Real-ESRGAN patch flow, and returns the + upscaled PIL image. + """ + + def predict( + self, + lr_image: PILImage, + patches_size: int = 192, + padding: int = 24, + pad_size: int = 15, + ) -> PILImage: + """Port of ``RealESRGAN.model.predict``. + + Equivalent to the original line-for-line, with ``self.model(...)`` + replaced by the WinML ``self.forward(pixel_values=...).reconstruction`` + call. ``self.scale`` is read from ``self.config.scale``; ``batch_size`` + is no longer a kwarg — it is inferred from the ONNX input's batch + dim (``1`` when the export left the batch dim dynamic). No torch + device placement: the ONNX session handles its own device via the + configured EP, so tensors stay on CPU. + """ + scale = int(self.config.scale) + # Inferred from the session's input shape; dynamic batch -> 1. + batch_size = self.io_config["input_shapes"][0][0] or 1 + lr_image = np.array(lr_image) + lr_image = pad_reflect(lr_image, pad_size) + + patches, p_shape = split_image_into_overlapping_patches( + lr_image, patch_size=patches_size, padding_size=padding + ) + img = torch.FloatTensor(patches / 255).permute((0, 3, 1, 2)).detach() + + with torch.no_grad(): + res = self.forward(pixel_values=img[0:batch_size]).reconstruction + for i in range(batch_size, img.shape[0], batch_size): + res = torch.cat( + (res, self.forward(pixel_values=img[i : i + batch_size]).reconstruction), + 0, + ) + + sr_image = res.permute((0, 2, 3, 1)).clamp_(0, 1).cpu() + np_sr_image = sr_image.numpy() + + padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,) # noqa: RUF005 — verbatim upstream + scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,) # noqa: RUF005 — verbatim upstream + np_sr_image = stich_together( + np_sr_image, + padded_image_shape=padded_size_scaled, + target_shape=scaled_image_shape, + padding_size=padding * scale, + ) + sr_img = (np_sr_image * 255).astype(np.uint8) + sr_img = unpad_image(sr_img, pad_size * scale) + return Image.fromarray(sr_img) diff --git a/src/winml/modelkit/models/winml/image_to_image.py b/src/winml/modelkit/models/winml/image_to_image.py new file mode 100644 index 000000000..05d5f3a9e --- /dev/null +++ b/src/winml/modelkit/models/winml/image_to_image.py @@ -0,0 +1,68 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""WinMLModelForImageToImage. + +Thin wrapper for image-to-image inference (super-resolution, denoising, +JPEG-artifact removal, etc.). Pipeline execution (export/optimize/compile) +is done by WinMLAutoModel factory. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import torch +from transformers.utils import ModelOutput + +from .base import WinMLPreTrainedModel + + +if TYPE_CHECKING: + import numpy as np + +logger = logging.getLogger(__name__) + + +@dataclass +class ImageReconstructionOutput(ModelOutput): + """Output for image-to-image models (super-resolution, denoising, etc.). + + Compatible with HF ImageToImagePipeline which reads outputs.reconstruction. + """ + + loss: torch.FloatTensor | None = None + reconstruction: torch.FloatTensor | None = None + + +class WinMLModelForImageToImage(WinMLPreTrainedModel): + """WinML model for image-to-image tasks. + + Covers: super-resolution, denoising, JPEG artifact removal, etc. + Thin wrapper - only handles inference I/O. + Pipeline execution is done by WinMLAutoModel factory. + """ + + def forward( + self, + pixel_values: torch.Tensor | np.ndarray, + **kwargs: Any, + ) -> ImageReconstructionOutput: + """Run image-to-image inference. + + Args: + pixel_values: Image tensor (B, C, H, W) + **kwargs: Additional arguments (ignored, for HF pipeline compatibility) + + Returns: + ImageReconstructionOutput with reconstruction tensor + """ + inputs = self._format_inputs(pixel_values=pixel_values) + outputs = self._run_inference(inputs) + + reconstruction = outputs.get("reconstruction", next(iter(outputs.values()))) + + return ImageReconstructionOutput(reconstruction=reconstruction) diff --git a/tests/unit/commands/test_export.py b/tests/unit/commands/test_export.py index 1b94d19ad..30af66426 100644 --- a/tests/unit/commands/test_export.py +++ b/tests/unit/commands/test_export.py @@ -69,6 +69,7 @@ def test_export_help_shows_all_options(self, runner: CliRunner) -> None: assert "--torch-module" in result.output assert "--input-specs" in result.output assert "--export-config" in result.output + assert "--loader-config-overrides" in result.output def test_export_requires_model(self, runner: CliRunner) -> None: """Test export fails without --model argument.""" @@ -360,6 +361,214 @@ def test_export_invalid_export_config_raises( assert "Failed to load export config" in result.output +class TestExportLoaderConfigOverrides: + """Tests for the ``--loader-config-overrides`` CLI flag.""" + + def test_no_flag_passes_none_to_loader( + self, + runner: CliRunner, + mock_export_onnx: MagicMock, + mock_load_hf_model: MagicMock, + tmp_path: Path, + ) -> None: + """Without --loader-config-overrides, load_hf_model receives ``None``.""" + from winml.modelkit.commands.export import export + + runner.invoke( + export, + ["--model", "m", "--output", str(tmp_path / "m.onnx")], + obj={"debug": False}, + ) + + assert mock_load_hf_model.called + assert mock_load_hf_model.call_args.kwargs.get("loader_config_overrides") is None + + def test_flag_loads_json_and_forwards( + self, + runner: CliRunner, + mock_export_onnx: MagicMock, + mock_load_hf_model: MagicMock, + tmp_path: Path, + ) -> None: + """JSON file content is parsed and passed as ``loader_config_overrides``.""" + from winml.modelkit.commands.export import export + + overrides_file = tmp_path / "overrides.json" + overrides_file.write_text(json.dumps({"scale": 2})) + + runner.invoke( + export, + [ + "--model", + "m", + "--output", + str(tmp_path / "m.onnx"), + "--loader-config-overrides", + str(overrides_file), + ], + obj={"debug": False}, + ) + + kwargs = mock_load_hf_model.call_args.kwargs + assert kwargs["loader_config_overrides"] == {"scale": 2} + + def test_flag_invalid_json_raises( + self, + runner: CliRunner, + mock_load_hf_model: MagicMock, + tmp_path: Path, + ) -> None: + """Malformed JSON exits non-zero with a clear ``Invalid JSON`` message.""" + from winml.modelkit.commands.export import export + + overrides_file = tmp_path / "bad.json" + overrides_file.write_text("{ not json }") + + result = runner.invoke( + export, + [ + "--model", + "m", + "--output", + str(tmp_path / "m.onnx"), + "--loader-config-overrides", + str(overrides_file), + ], + obj={"debug": False}, + ) + + assert result.exit_code != 0 + assert "Invalid JSON" in result.output + + def test_flag_non_object_raises( + self, + runner: CliRunner, + mock_load_hf_model: MagicMock, + tmp_path: Path, + ) -> None: + """A top-level JSON array (not an object) is rejected.""" + from winml.modelkit.commands.export import export + + overrides_file = tmp_path / "arr.json" + overrides_file.write_text("[1, 2, 3]") + + result = runner.invoke( + export, + [ + "--model", + "m", + "--output", + str(tmp_path / "m.onnx"), + "--loader-config-overrides", + str(overrides_file), + ], + obj={"debug": False}, + ) + + assert result.exit_code != 0 + assert "must contain a JSON object" in result.output + + def test_cli_deep_merges_with_build_config( + self, + runner: CliRunner, + mock_export_onnx: MagicMock, + mock_load_hf_model: MagicMock, + tmp_path: Path, + ) -> None: + """``-c`` build config supplies a base; CLI flag deep-merges on top + (CLI wins on conflicts, sibling keys preserved).""" + from winml.modelkit.commands.export import export + + build_cfg_file = tmp_path / "build.json" + build_cfg_file.write_text( + json.dumps( + { + "loader": { + "task": "image-to-image", + "loader_config_overrides": { + "scale": 4, + "from_build": "stays", + "vision_config": {"image_size": 320}, + }, + } + } + ) + ) + + cli_overrides_file = tmp_path / "cli.json" + cli_overrides_file.write_text( + json.dumps( + { + "scale": 2, # overrides build_cfg's 4 + "vision_config": {"hidden_size": 128}, # merges with build's image_size + } + ) + ) + + runner.invoke( + export, + [ + "--model", + "m", + "--output", + str(tmp_path / "m.onnx"), + "--config", + str(build_cfg_file), + "--loader-config-overrides", + str(cli_overrides_file), + ], + obj={"debug": False}, + ) + + merged = mock_load_hf_model.call_args.kwargs["loader_config_overrides"] + assert merged["scale"] == 2 # CLI wins on conflict + assert merged["from_build"] == "stays" # build-config-only key preserved + assert merged["vision_config"] == { + "image_size": 320, + "hidden_size": 128, + } # deep-merged + + def test_build_config_overrides_used_without_cli_flag( + self, + runner: CliRunner, + mock_export_onnx: MagicMock, + mock_load_hf_model: MagicMock, + tmp_path: Path, + ) -> None: + """``loader.loader_config_overrides`` in --config is honored even with + no --loader-config-overrides flag.""" + from winml.modelkit.commands.export import export + + build_cfg_file = tmp_path / "build.json" + build_cfg_file.write_text( + json.dumps( + { + "loader": { + "task": "image-to-image", + "loader_config_overrides": {"scale": 8}, + } + } + ) + ) + + runner.invoke( + export, + [ + "--model", + "m", + "--output", + str(tmp_path / "m.onnx"), + "--config", + str(build_cfg_file), + ], + obj={"debug": False}, + ) + + assert mock_load_hf_model.call_args.kwargs["loader_config_overrides"] == { + "scale": 8 + } + + class TestExportWarnings: """Test export warning messages for unsupported options.""" diff --git a/tests/unit/eval/test_eval.py b/tests/unit/eval/test_eval.py index e05aa1abc..daeead553 100644 --- a/tests/unit/eval/test_eval.py +++ b/tests/unit/eval/test_eval.py @@ -75,8 +75,8 @@ def test_infer_from_model_id(self): return_value=fake_hf_config, ), patch( - "winml.modelkit.loader.task._detect_task_from_config", - return_value="image-classification", + "winml.modelkit.loader.task._detect_task_and_class_from_config", + return_value=("image-classification", MagicMock()), ), ): assert _resolve_task(config) == "image-classification" diff --git a/tests/unit/loader/test_apply_loader_config_overrides.py b/tests/unit/loader/test_apply_loader_config_overrides.py new file mode 100644 index 000000000..0dbd0558a --- /dev/null +++ b/tests/unit/loader/test_apply_loader_config_overrides.py @@ -0,0 +1,127 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for ``apply_loader_config_overrides``. + +The helper deep-merges an overrides dict onto a HF :class:`PretrainedConfig` +by going through ``to_dict`` / ``from_dict``. Tests exercise: + +* No-op behaviour for empty / ``None`` overrides. +* Simple scalar overrides. +* Recursive merge into nested sub-configs (CLIP-style). +* Non-mutation of the original config. +* Concrete return type matches the input config type. +""" + +from __future__ import annotations + +import pytest + +from winml.modelkit.loader.config import apply_loader_config_overrides + + +@pytest.fixture +def bert_config(): + from transformers import BertConfig + + return BertConfig(hidden_size=768, num_attention_heads=12) + + +@pytest.fixture +def clip_config(): + from transformers import CLIPConfig + + return CLIPConfig() + + +class TestNoOp: + """``None`` / empty overrides return the same instance unchanged.""" + + def test_none_overrides_returns_same_instance(self, bert_config): + result = apply_loader_config_overrides(bert_config, None) + assert result is bert_config + + def test_empty_dict_returns_same_instance(self, bert_config): + result = apply_loader_config_overrides(bert_config, {}) + assert result is bert_config + + +class TestScalarOverride: + """Single-level scalar patches.""" + + def test_scalar_value_applied(self, bert_config): + result = apply_loader_config_overrides(bert_config, {"hidden_size": 1024}) + assert result.hidden_size == 1024 + + def test_does_not_mutate_original(self, bert_config): + _ = apply_loader_config_overrides(bert_config, {"hidden_size": 1024}) + assert bert_config.hidden_size == 768 + + def test_returned_config_is_same_concrete_type(self, bert_config): + result = apply_loader_config_overrides(bert_config, {"hidden_size": 512}) + assert type(result) is type(bert_config) + + def test_returned_config_is_new_instance(self, bert_config): + result = apply_loader_config_overrides(bert_config, {"hidden_size": 512}) + assert result is not bert_config + + def test_untouched_fields_preserved(self, bert_config): + result = apply_loader_config_overrides(bert_config, {"hidden_size": 1024}) + # num_attention_heads was set on the original; should round-trip + assert result.num_attention_heads == bert_config.num_attention_heads + + +class TestNestedOverride: + """Recursive merge into nested :class:`PretrainedConfig` attributes.""" + + def test_nested_dict_recurses_into_subconfig(self, clip_config): + original_size = clip_config.vision_config.image_size + result = apply_loader_config_overrides( + clip_config, {"vision_config": {"image_size": original_size + 16}} + ) + assert result.vision_config.image_size == original_size + 16 + + def test_nested_override_preserves_sibling_fields(self, clip_config): + original_hidden = clip_config.vision_config.hidden_size + result = apply_loader_config_overrides( + clip_config, {"vision_config": {"image_size": 320}} + ) + # ``image_size`` patched; ``hidden_size`` of the same sub-config preserved + assert result.vision_config.hidden_size == original_hidden + + def test_nested_override_does_not_mutate_original(self, clip_config): + original_size = clip_config.vision_config.image_size + _ = apply_loader_config_overrides( + clip_config, {"vision_config": {"image_size": original_size + 16}} + ) + assert clip_config.vision_config.image_size == original_size + + def test_top_level_and_nested_in_one_call(self, clip_config): + result = apply_loader_config_overrides( + clip_config, + { + "logit_scale_init_value": 3.0, + "vision_config": {"image_size": 320}, + }, + ) + assert result.logit_scale_init_value == 3.0 + assert result.vision_config.image_size == 320 + + +class TestESRGANUseCase: + """End-to-end sanity check on the actual ESRGAN config (drives this feature).""" + + def test_scale_override_on_esrgan(self): + # Triggers HF registrations that ESRGANConfig depends on. + import winml.modelkit.models.hf # noqa: F401 + from winml.modelkit.models.hf.esrgan import ESRGANConfig + + cfg = ESRGANConfig() + assert cfg.scale == 4 # default + + result = apply_loader_config_overrides(cfg, {"scale": 2}) + assert result.scale == 2 + assert isinstance(result, ESRGANConfig) + # Original untouched + assert cfg.scale == 4 diff --git a/tests/unit/loader/test_loader_config.py b/tests/unit/loader/test_loader_config.py index 3e46293f9..f4c0e481e 100644 --- a/tests/unit/loader/test_loader_config.py +++ b/tests/unit/loader/test_loader_config.py @@ -115,6 +115,45 @@ def test_roundtrip(self): assert restored.trust_remote_code == original.trust_remote_code +class TestWinMLLoaderConfigOverrides: + """Test the optional ``loader_config_overrides`` field.""" + + def test_default_is_none(self): + config = WinMLLoaderConfig() + assert config.loader_config_overrides is None + + def test_to_dict_omits_when_none(self): + config = WinMLLoaderConfig(task="image-classification") + assert "loader_config_overrides" not in config.to_dict() + + def test_to_dict_omits_when_empty_dict(self): + """Empty dict is falsy — same on-disk shape as None.""" + config = WinMLLoaderConfig(loader_config_overrides={}) + assert "loader_config_overrides" not in config.to_dict() + + def test_to_dict_includes_when_present(self): + config = WinMLLoaderConfig(loader_config_overrides={"scale": 2}) + assert config.to_dict()["loader_config_overrides"] == {"scale": 2} + + def test_from_dict_reads_overrides(self): + config = WinMLLoaderConfig.from_dict( + {"loader_config_overrides": {"hidden_act": "gelu"}} + ) + assert config.loader_config_overrides == {"hidden_act": "gelu"} + + def test_roundtrip_with_overrides(self): + original = WinMLLoaderConfig( + task="image-to-image", + model_type="ESRGAN", + loader_config_overrides={ + "scale": 2, + "vision_config": {"image_size": 320}, + }, + ) + restored = WinMLLoaderConfig.from_dict(original.to_dict()) + assert restored.loader_config_overrides == original.loader_config_overrides + + class TestWinMLBuildConfigIncludesLoader: """Test WinMLBuildConfig integration with loader config.""" diff --git a/tests/unit/models/esrgan/test_esrgan_config.py b/tests/unit/models/esrgan/test_esrgan_config.py new file mode 100644 index 000000000..f508e43b3 --- /dev/null +++ b/tests/unit/models/esrgan/test_esrgan_config.py @@ -0,0 +1,188 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for Real-ESRGAN PreTrainedModel config and construction.""" + +from __future__ import annotations + +import pytest +import torch +from transformers.modeling_outputs import ImageSuperResolutionOutput + +from winml.modelkit.models.hf.esrgan import ( + ESRGANConfig, + ESRGANForImageSuperResolution, +) + + +# ============================================================================= +# Minimal model params for fast tests +# ============================================================================= +FAST_NUM_BLOCK = 2 + + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture(scope="module") +def default_config() -> ESRGANConfig: + """Config with all default values.""" + return ESRGANConfig() + + +@pytest.fixture(scope="module") +def fast_config() -> ESRGANConfig: + """Config with minimal blocks for fast instantiation.""" + return ESRGANConfig(num_block=FAST_NUM_BLOCK) + + +# ============================================================================= +# TestESRGANConfig +# ============================================================================= + + +class TestESRGANConfig: + """Tests for ESRGANConfig (PretrainedConfig subclass).""" + + def test_model_type(self, default_config: ESRGANConfig) -> None: + """model_type is uppercase 'ESRGAN' to enable HF's case-sensitive + substring fallback against repo names like 'ai-forever/Real-ESRGAN'.""" + assert default_config.model_type == "ESRGAN" + + def test_default_params(self, default_config: ESRGANConfig) -> None: + """Default config has expected RRDBNet hyperparameters.""" + assert default_config.num_in_ch == 3 + assert default_config.num_out_ch == 3 + assert default_config.num_feat == 64 + assert default_config.num_block == 23 + assert default_config.num_grow_ch == 32 + assert default_config.scale == 4 + assert default_config.weight_file_format == "RealESRGAN_x{scale}.pth" + + def test_custom_params(self) -> None: + """Custom overrides propagate correctly.""" + cfg = ESRGANConfig(scale=2, num_block=6) + assert cfg.scale == 2 + assert cfg.num_block == 6 + # Other defaults unchanged + assert cfg.num_in_ch == 3 + assert cfg.num_feat == 64 + + def test_weight_file_format_renders_default_filename(self) -> None: + """Default template + scale produces the upstream ``RealESRGAN_xN.pth`` filename.""" + for scale in (2, 4, 8): + cfg = ESRGANConfig(scale=scale) + assert cfg.weight_file_format.format(scale=cfg.scale) == (f"RealESRGAN_x{scale}.pth") + + def test_weight_file_format_override(self) -> None: + """A caller-supplied template is honored and produces the rendered filename.""" + cfg = ESRGANConfig(scale=4, weight_file_format="fork_x{scale}_v2.bin") + assert cfg.weight_file_format == "fork_x{scale}_v2.bin" + assert cfg.weight_file_format.format(scale=cfg.scale) == "fork_x4_v2.bin" + + def test_config_serialization_roundtrip(self, tmp_path) -> None: + """save_pretrained -> from_pretrained preserves all values.""" + original = ESRGANConfig( + num_in_ch=3, + num_out_ch=3, + num_feat=32, + num_block=4, + num_grow_ch=16, + scale=2, + weight_file_format="custom_{scale}x.bin", + ) + original.save_pretrained(str(tmp_path)) + loaded = ESRGANConfig.from_pretrained(str(tmp_path)) + + assert loaded.model_type == original.model_type + assert loaded.num_in_ch == original.num_in_ch + assert loaded.num_out_ch == original.num_out_ch + assert loaded.num_feat == original.num_feat + assert loaded.num_block == original.num_block + assert loaded.num_grow_ch == original.num_grow_ch + assert loaded.scale == original.scale + assert loaded.weight_file_format == original.weight_file_format + + +# ============================================================================= +# TestESRGANModelConstruction +# ============================================================================= + + +class TestESRGANModelConstruction: + """Tests for ESRGANForImageSuperResolution construction and forward.""" + + def test_model_creates_from_config(self, fast_config: ESRGANConfig) -> None: + """Model instantiates without error from config.""" + model = ESRGANForImageSuperResolution(fast_config) + assert model is not None + assert model.config.num_block == FAST_NUM_BLOCK + + @pytest.mark.parametrize("scale", [2, 4, 8]) + def test_output_shape_matches_scale(self, scale: int) -> None: + """Output spatial dims = input spatial dims * scale.""" + cfg = ESRGANConfig(num_block=FAST_NUM_BLOCK, scale=scale) + model = ESRGANForImageSuperResolution(cfg) + model.eval() + + h, w = 16, 16 + x = torch.randn(1, 3, h, w) + with torch.no_grad(): + out = model(pixel_values=x) + + assert out.reconstruction.shape == (1, 3, h * scale, w * scale) + + def test_output_is_image_super_resolution_output(self, fast_config: ESRGANConfig) -> None: + """Forward returns ImageSuperResolutionOutput.""" + model = ESRGANForImageSuperResolution(fast_config) + model.eval() + + x = torch.randn(1, 3, 16, 16) + with torch.no_grad(): + out = model(pixel_values=x) + + assert isinstance(out, ImageSuperResolutionOutput) + assert out.reconstruction is not None + + def test_return_dict_false(self, fast_config: ESRGANConfig) -> None: + """return_dict=False returns a tuple.""" + model = ESRGANForImageSuperResolution(fast_config) + model.eval() + + x = torch.randn(1, 3, 16, 16) + with torch.no_grad(): + out = model(pixel_values=x, return_dict=False) + + assert isinstance(out, tuple) + assert len(out) == 1 + assert out[0].shape[1] == 3 # channels + + def test_save_and_load_pretrained(self, fast_config: ESRGANConfig, tmp_path) -> None: + """save_pretrained -> from_pretrained roundtrip on a local dir. + + Verifies the overridden ``from_pretrained`` still delegates to + :class:`PreTrainedModel` for local directories with a ``config.json``. + """ + model = ESRGANForImageSuperResolution(fast_config) + model.eval() + + x = torch.randn(1, 3, 16, 16) + with torch.no_grad(): + original_out = model(pixel_values=x) + + save_dir = str(tmp_path / "esrgan_model") + model.save_pretrained(save_dir) + loaded = ESRGANForImageSuperResolution.from_pretrained(save_dir) + loaded.eval() + + with torch.no_grad(): + loaded_out = loaded(pixel_values=x) + + assert torch.allclose( + original_out.reconstruction, + loaded_out.reconstruction, + atol=1e-6, + ) diff --git a/tests/unit/models/esrgan/test_onnx_config.py b/tests/unit/models/esrgan/test_onnx_config.py new file mode 100644 index 000000000..c0a55ae03 --- /dev/null +++ b/tests/unit/models/esrgan/test_onnx_config.py @@ -0,0 +1,52 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for ESRGAN ONNX export config registration.""" + +from __future__ import annotations + +from optimum.exporters.tasks import TasksManager + +# Trigger registration via import side effects +import winml.modelkit.models.hf as _hf # noqa: F401 +from winml.modelkit.models.hf.esrgan import ESRGANConfig, ESRGANIOConfig + + +class TestESRGANOnnxConfigRegistration: + """Verify ESRGAN OnnxConfig is reachable through Optimum's TasksManager.""" + + def test_config_registered_for_image_to_image(self) -> None: + config_constructor = TasksManager.get_exporter_config_constructor( + exporter="onnx", + model_type="ESRGAN", + task="image-to-image", + library_name="transformers", + ) + actual_class_name = config_constructor.func.__name__ + assert actual_class_name == "ESRGANIOConfig" + + def test_inputs_have_pixel_values(self) -> None: + config = ESRGANConfig() + io_config = ESRGANIOConfig(config) + inputs = io_config.inputs + assert "pixel_values" in inputs + # Dynamic axes: batch (0), height (2), width (3) + assert 0 in inputs["pixel_values"] + assert 2 in inputs["pixel_values"] + assert 3 in inputs["pixel_values"] + + def test_outputs_have_reconstruction(self) -> None: + config = ESRGANConfig() + io_config = ESRGANIOConfig(config) + outputs = io_config.outputs + assert "reconstruction" in outputs + assert 0 in outputs["reconstruction"] + + def test_image_to_image_in_supported_tasks(self) -> None: + """`get_supported_tasks_for_model_type` lists image-to-image for ESRGAN.""" + supported = TasksManager.get_supported_tasks_for_model_type( + "ESRGAN", exporter="onnx", library_name="transformers", + ) + tasks = list(supported.keys()) if isinstance(supported, dict) else list(supported) + assert "image-to-image" in tasks diff --git a/tests/unit/models/winml/test_winml_esrgan.py b/tests/unit/models/winml/test_winml_esrgan.py new file mode 100644 index 000000000..2c54bd1ec --- /dev/null +++ b/tests/unit/models/winml/test_winml_esrgan.py @@ -0,0 +1,218 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for ``WinMLESRGANForImageToImage`` and its patch helpers. + +The class adds a Real-ESRGAN-style ``predict(lr_image) -> PIL.Image`` +method on top of the generic image-to-image runtime model. We test: + +* The task-class resolver picks the specialised subclass for + ``("esrgan", "image-to-image")``. +* The helper functions copied from upstream's ``utils.py`` round-trip on + shapes that are not patch-aligned. +* ``predict`` calls the underlying ``forward`` for every patch, batches + per ``batch_size``, threads ``self.config.scale`` correctly, and + returns a PIL image of size ``(W * scale, H * scale)``. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch +from PIL import Image + + +@pytest.fixture +def patches_size() -> int: + return 192 + + +@pytest.fixture +def padding() -> int: + return 24 + + +@pytest.fixture +def pad_size() -> int: + return 15 + + +# ============================================================================= +# Resolver / inheritance +# ============================================================================= + + +class TestResolver: + def test_specialised_class_registered(self): + from winml.modelkit.models.winml import ( + WINML_MODEL_CLASS_MAPPING, + WinMLESRGANForImageToImage, + get_winml_class, + ) + + assert WINML_MODEL_CLASS_MAPPING[("esrgan", "image-to-image")] == ( + "WinMLESRGANForImageToImage" + ) + assert get_winml_class("esrgan", "image-to-image") is WinMLESRGANForImageToImage + # Resolver normalises mixed case + assert get_winml_class("ESRGAN", "image-to-image") is WinMLESRGANForImageToImage + + def test_subclasses_generic_image_to_image(self): + from winml.modelkit.models.winml import ( + WinMLESRGANForImageToImage, + WinMLModelForImageToImage, + ) + + assert issubclass(WinMLESRGANForImageToImage, WinMLModelForImageToImage) + + +# ============================================================================= +# Patch helpers — these are copied verbatim from upstream, so we test the +# shape contracts we depend on rather than the implementation details. +# ============================================================================= + + +class TestPatchHelpers: + def test_pad_reflect_roundtrip(self, pad_size): + from winml.modelkit.models.winml.esrgan import pad_reflect, unpad_image + + rng = np.random.default_rng(0) + img = rng.integers(0, 255, size=(40, 50, 3), dtype=np.uint8) + padded = pad_reflect(img, pad_size) + assert padded.shape == (40 + 2 * pad_size, 50 + 2 * pad_size, 3) + # Centre region is the original image, untouched + np.testing.assert_array_equal(padded[pad_size:-pad_size, pad_size:-pad_size, :], img) + # unpad_image inverts pad_reflect exactly + np.testing.assert_array_equal(unpad_image(padded, pad_size), img) + + def test_split_and_stitch_roundtrip_with_extension(self, patches_size, padding): + """Image whose H/W are not multiples of patch_size is edge-extended, + split, then perfectly stitched back to its original H/W.""" + from winml.modelkit.models.winml.esrgan import ( + split_image_into_overlapping_patches, + stich_together, + ) + + rng = np.random.default_rng(1) + # 369 = 192 + 177, 250 = 192 + 58 — both require x_extend/y_extend > 0 + img = rng.integers(0, 255, size=(369, 250, 3), dtype=np.uint8) + patches, p_shape = split_image_into_overlapping_patches( + img, patch_size=patches_size, padding_size=padding + ) + # Each patch is patches_size + 2 * padding on a side + assert patches.shape[1] == patches_size + 2 * padding + assert patches.shape[2] == patches_size + 2 * padding + # scale=1 stitch: same shapes, no upscale → must recover the original + reconstructed = stich_together( + patches.astype(np.float64) / 255.0, + padded_image_shape=p_shape, + target_shape=img.shape, + padding_size=padding, + ) + # Convert back to uint8 for exact comparison + recovered = (reconstructed * 255).round().astype(np.uint8) + np.testing.assert_array_equal(recovered, img) + + +# ============================================================================= +# predict() — mock the inherited forward() so we don't need an ONNX session. +# ============================================================================= + + +def _make_model_with_fake_forward(scale: int, batch: int | None = 1): + """Construct a WinMLESRGANForImageToImage that skips real init. + + Forward returns zeros sized for the scale; lets us assert shape / + invocation count without a real ORT session. ``batch`` controls the + ONNX input batch dim that ``predict`` reads from ``self.io_config`` + (use ``None`` to simulate a dynamic-batch export). + """ + from winml.modelkit.models.winml import WinMLESRGANForImageToImage + from winml.modelkit.models.winml.image_to_image import ImageReconstructionOutput + + instance = WinMLESRGANForImageToImage.__new__(WinMLESRGANForImageToImage) + instance.config = SimpleNamespace(scale=scale) + # ``io_config`` is a read-only property on the base class that delegates + # to ``_session.io_config``; stub the session so the property returns + # the shape we control here. + instance._session = SimpleNamespace(io_config={"input_shapes": [[batch, 3, 240, 240]]}) + + def fake_forward(pixel_values, **_kw): + n, c, h, w = pixel_values.shape + return ImageReconstructionOutput( + reconstruction=torch.zeros(n, c, h * scale, w * scale, dtype=torch.float32) + ) + + instance.forward = MagicMock(side_effect=fake_forward) + return instance + + +class TestPredict: + @pytest.mark.parametrize("scale", [2, 4, 8]) + def test_predict_returns_pil_image_with_upscaled_size(self, scale: int): + model = _make_model_with_fake_forward(scale=scale) + lr = Image.new("RGB", (50, 40), color=(127, 127, 127)) # PIL size = (W, H) + + sr = model.predict(lr) + + assert isinstance(sr, Image.Image) + assert sr.size == (50 * scale, 40 * scale) + assert sr.mode == "RGB" + + def test_predict_reads_scale_from_config(self): + # If predict didn't read self.config.scale, it would default to + # the model's class-level scale (no such default exists) and the + # output shape would not match. + model = _make_model_with_fake_forward(scale=4) + lr = Image.new("RGB", (32, 32)) + assert model.predict(lr).size == (128, 128) + + def test_predict_invokes_forward_at_least_once_with_correct_patch_shape( + self, + patches_size, + padding, + ): + model = _make_model_with_fake_forward(scale=2) + lr = Image.new("RGB", (32, 32)) + model.predict(lr) + + # First call's pixel_values is the first batch of patches; each patch + # is shape-specialised to patches_size + 2 * padding on each spatial side. + assert model.forward.call_count >= 1 + first_call_pv = model.forward.call_args_list[0].kwargs["pixel_values"] + assert isinstance(first_call_pv, torch.Tensor) + assert first_call_pv.shape[1] == 3 + assert first_call_pv.shape[2] == patches_size + 2 * padding + assert first_call_pv.shape[3] == patches_size + 2 * padding + + def test_predict_infers_batch_size_from_session_input_shape(self): + """batch_size is no longer a kwarg — it's read from io_config[input_shapes].""" + # Static batch=4 export + model = _make_model_with_fake_forward(scale=2, batch=4) + # 600x400 LR with patches_size=192, padding=24, pad_size=15: + # post-reflect-pad → 630x430, after extending to multiples of 192: + # 768x576 → 4 * 3 = 12 patches. With inferred batch_size=4 → 3 forward calls. + lr = Image.new("RGB", (600, 400)) + model.predict(lr) + assert model.forward.call_count == 3 + + def test_predict_defaults_to_batch_1_when_dynamic(self): + """When the ONNX export left the batch dim dynamic, fall back to 1.""" + model = _make_model_with_fake_forward(scale=2, batch=None) + lr = Image.new("RGB", (600, 400)) + model.predict(lr) + # 12 patches, batch=1 -> 12 forward calls + assert model.forward.call_count == 12 + + def test_predict_supports_arbitrary_input_sizes(self): + """Non-square, non-patch-multiple sizes must work end-to-end.""" + model = _make_model_with_fake_forward(scale=2) + for w, h in [(100, 80), (333, 217), (1, 1)]: + lr = Image.new("RGB", (w, h)) + sr = model.predict(lr) + assert sr.size == (w * 2, h * 2)