diff --git a/pyaml/common/element.py b/pyaml/common/element.py index 11ab1f02e..15e566faa 100644 --- a/pyaml/common/element.py +++ b/pyaml/common/element.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict @@ -12,21 +12,45 @@ def __pyaml_repr__(obj): """ Returns a string representation of a pyaml object """ - if hasattr(obj, "_cfg"): + + cls_name = obj.__class__.__name__ + + # Keep the old behavior when _cfg exists + cfg = getattr(obj, "_cfg", None) + if cfg is not None: if isinstance(obj, Element): - return repr(obj._cfg).replace( + return repr(cfg).replace( "ConfigModel(", - obj.__class__.__name__ + "(peer='" + obj.attached_to() + "', ", + f"{cls_name}(peer={obj.attached_to()!r}, ", + 1, ) - else: - # no peer - return repr(obj._cfg).replace("ConfigModel", obj.__class__.__name__) - else: - # Object is not yet fully constructed - if isinstance(obj, Element): - return f"{obj.__class__.__name__}: {obj.get_name()}" - else: - return f"{obj.__class__.__name__}" + return repr(cfg).replace("ConfigModel", cls_name, 1) + + # Generic fallback when there is no _cfg + attrs = {} + + # Instance attributes + for k, v in obj.__dict__.items(): + # Exclude private attributes + if not k.startswith("_"): + attrs[k] = v + + # Properties + for name, attr in vars(type(obj)).items(): + if isinstance(attr, property): + try: + attrs[name] = getattr(obj, name) + except Exception as e: + attrs[name] = f"" + + if isinstance(obj, Element) and "name" not in attrs: + try: + attrs["name"] = obj.get_name() + except Exception as e: + attrs["name"] = f"" + + parts = ", ".join(f"{k}={v!r}" for k, v in attrs.items()) + return f"{cls_name}({parts})" if parts else cls_name class ElementConfigModel(BaseModel): @@ -57,39 +81,64 @@ class ElementConfigModel(BaseModel): lattice_names: str | None = None -class Element(object): +class Element: """ Class providing access to one element of a physical or simulated lattice - - Attributes: - name: str - The unique name identifying the element in the configuration file """ - def __init__(self, name: str): - self._name: str = name - self._peer: "ElementHolder" = None # Peer: ControlSystem, Simulator + def __init__( + self, + name: str, + lattice_names: str | None = None, + description: str | None = None, + ): + self._name = name + self._lattice_names = lattice_names + self._description = description + self._peer: ElementHolder | None = None - def get_name(self) -> str: + def _cfg_value(self, attr: str, fallback: Any) -> Any: """ - Returns the name of the element + Return an attribute from _cfg if available, otherwise fallback. """ - return self._name + cfg = getattr(self, "_cfg", None) + if cfg is not None: + value = getattr(cfg, attr, None) + if value is not None: + return value + return fallback - def get_lattice_names(self) -> str: - """ - Returns the name of associated lattice element(s) - """ - if not hasattr(self, "_cfg"): - return self._name - else: - return self._cfg.lattice_names + @property + def name(self) -> str: + return self._cfg_value("name", self._name) + + @property + def lattice_names(self) -> str: + cfg = getattr(self, "_cfg", None) + + if cfg is not None and cfg.lattice_names is not None: + return cfg.lattice_names + + if self._lattice_names is not None: + return self._lattice_names + + return self.name - def get_description(self) -> str: + @property + def description(self) -> str | None: + return self._cfg_value("description", self._description) + + def get_name(self) -> str: """ - Returns the description of the element + Returns the name of the element """ - return self._cfg.description + return self.name + + def get_lattice_names(self) -> str | None: + return self.lattice_names + + def get_description(self) -> str | None: + return self.description def set_energy(self, E: float): """ diff --git a/pyaml/common/element_holder.py b/pyaml/common/element_holder.py index 22df8bc13..a06a06935 100644 --- a/pyaml/common/element_holder.py +++ b/pyaml/common/element_holder.py @@ -272,7 +272,7 @@ def get_rf_plant(self, name: str) -> RFPlant: def add_rf_plant(self, rf: RFPlant): self.__add(self.__RFPLANT, rf) - def add_rf_transnmitter(self, rf: RFTransmitter): + def add_rf_transmitter(self, rf: RFTransmitter): self.__add(self.__RFTRANSMITTER, rf) def get_rf_trasnmitter(self, name: str) -> RFTransmitter: diff --git a/pyaml/configuration/factory.py b/pyaml/configuration/factory.py index df63fbe78..e41d28345 100644 --- a/pyaml/configuration/factory.py +++ b/pyaml/configuration/factory.py @@ -161,7 +161,7 @@ class BuildInfo: ---------- module : ModuleType Imported module containing the object class and validation model. - config_cls : type[BaseModel] + config_cls : type[BaseModel], optional Pydantic model used to validate the configuration. class_str : str Name of the class to instantiate. @@ -174,7 +174,7 @@ class BuildInfo: """ module: ModuleType - config_cls: type[BaseModel] + config_cls: type[BaseModel] | None class_str: str field_locations: dict | None location_str: str @@ -248,8 +248,6 @@ def resolve_build_info(data: dict, ignore_external: bool) -> BuildInfo | None: # Get the validation class config_cls = getattr(module, validation_class_str, None) - if config_cls is None: - raise PyAMLConfigException(f"No validation class for '{module.__name__}.{class_str}' {location_str}") return BuildInfo( module=module, @@ -456,6 +454,8 @@ def _construct_element( try: if control_modes is None: + if isinstance(cfg, dict): + return elem_cls(**cfg) return elem_cls(cfg) return UnboundElement(elem_cls, module_name, control_modes, cfg) @@ -495,9 +495,11 @@ def build_object(self, data: dict, ignore_external: bool = False): cleaned_data, control_modes = self._strip_build_metadata(data) - # Validate the model try: - cfg = config_cls.model_validate(cleaned_data) + if config_cls is not None: + cfg = config_cls.model_validate(cleaned_data) + else: + cfg = cleaned_data except ValidationError as e: handle_validation_error(e, module.__name__, location_str, field_locations) diff --git a/pyaml/control/controlsystem.py b/pyaml/control/controlsystem.py index a622ebd8c..fdde41718 100644 --- a/pyaml/control/controlsystem.py +++ b/pyaml/control/controlsystem.py @@ -191,19 +191,19 @@ def fill_device(self, elements: list[Element]): elif isinstance(e, RFPlant): attachedTrans: list[RFTransmitter] = [] - if e._cfg.transmitters: - for t in e._cfg.transmitters: - vDev = self.get_device_access(t._cfg.voltage) - pDev = self.get_device_access(t._cfg.phase) + if e.transmitters: + for t in e.transmitters: + vDev = self.get_device_access(t.voltage_name) + pDev = self.get_device_access(t.phase_name) voltage = RWRFVoltageScalar(t, vDev) phase = RWRFPhaseScalar(t, pDev) nt = t.attach(self, voltage, phase) - self.add_rf_transnmitter(nt) + self.add_rf_transmitter(nt) attachedTrans.append(nt) - fDev = self.get_device_access(e._cfg.masterclock) + fDev = self.get_device_access(e.masterclock) frequency = RWRFFrequencyScalar(e, fDev) - voltage = RWTotalVoltage(attachedTrans) if e._cfg.transmitters else None + voltage = RWTotalVoltage(attachedTrans) if e.transmitters else None ne = e.attach(self, frequency, voltage) self.add_rf_plant(ne) diff --git a/pyaml/lattice/simulator.py b/pyaml/lattice/simulator.py index a877588aa..44a92df92 100644 --- a/pyaml/lattice/simulator.py +++ b/pyaml/lattice/simulator.py @@ -200,13 +200,13 @@ def fill_device(self, elements: list[Element]): self.add_bpm(e) elif isinstance(e, RFPlant): - if e._cfg.transmitters: + if e.transmitters: cavs: list[at.Element] = [] harmonics: list[float] = [] attachedTrans: list[RFTransmitter] = [] - for t in e._cfg.transmitters: + for t in e.transmitters: cavsPerTrans: list[at.Element] = [] - for c in t._cfg.cavities: + for c in t.cavities: # Expect unique name for cavities cav = self.get_at_elems(Element(c)) if len(cav) > 1: @@ -214,11 +214,11 @@ def fill_device(self, elements: list[Element]): if len(cav) == 0: raise PyAMLException(f"RF transmitter {t.get_name()}, No cavity found") cavsPerTrans.append(cav[0]) - harmonics.append(t._cfg.harmonic) + harmonics.append(t.harmonic) voltage = RWRFVoltageScalar(cavsPerTrans) phase = RWRFPhaseScalar(cavsPerTrans) nt = t.attach(self, voltage, phase) - self.add_rf_transnmitter(nt) + self.add_rf_transmitter(nt) cavs.extend(cavsPerTrans) attachedTrans.append(nt) diff --git a/pyaml/rf/rf_plant.py b/pyaml/rf/rf_plant.py index 87e8706a7..a9ad488b9 100644 --- a/pyaml/rf/rf_plant.py +++ b/pyaml/rf/rf_plant.py @@ -1,49 +1,46 @@ -import numpy as np -from pydantic import BaseModel, ConfigDict - -try: - from typing import Self # Python 3.11+ -except ImportError: - from typing_extensions import Self # Python 3.10 and earlier +import copy +from typing import Self from .. import PyAMLException from ..common import abstract -from ..common.element import Element, ElementConfigModel -from ..control.deviceaccess import DeviceAccess +from ..common.element import Element +from ..validation import DynamicValidation from .rf_transmitter import RFTransmitter # Define the main class name for this module PYAMLCLASS = "RFPlant" -class ConfigModel(ElementConfigModel): - masterclock: str | None = None - """Device to apply main RF frequency""" - transmitters: list[RFTransmitter] | None = None - """List of RF trasnmitters""" - - -class RFPlant(Element): +class RFPlant(Element, DynamicValidation): """ Main RF object """ - def __init__(self, cfg: ConfigModel): - super().__init__(cfg.name) - self._cfg = cfg + def __init__( + self, + name: str, + masterclock: str | None = None, + transmitters: list[RFTransmitter] | None = None, + lattice_names: str | None = None, + description: str | None = None, + ): + super().__init__(name, lattice_names, description) + + self.masterclock = masterclock + self.transmitters = transmitters self.__frequency = None self.__voltage = None @property def frequency(self) -> abstract.ReadWriteFloatScalar: if self.__frequency is None: - raise PyAMLException(f"{str(self)} has no masterclock device defined") + raise PyAMLException(f"{str(self.name)} has no masterclock device defined") return self.__frequency @property def voltage(self) -> abstract.ReadWriteFloatScalar: if self.__voltage is None: - raise PyAMLException(f"{str(self)} has no trasmitter device defined") + raise PyAMLException(f"{str(self.name)} has no transmitter device defined") return self.__voltage def attach( @@ -53,7 +50,7 @@ def attach( voltage: abstract.ReadWriteFloatScalar, ) -> Self: # Attach frequency attribute and returns a new reference - obj = self.__class__(self._cfg) + obj = copy.copy(self) obj.__frequency = frequency obj.__voltage = voltage obj._peer = peer @@ -76,19 +73,19 @@ def get(self) -> float: sum = 0 # Count only fundamental harmonic for t in self.__trans: - if t._cfg.harmonic == 1.0: + if t.harmonic == 1.0: sum += t.voltage.get() return sum def set(self, value: float): # Assume that sum of transmitter (fundamental harmonic) distribution is 1 for t in self.__trans: - if t._cfg.harmonic == 1.0: - v = value * t._cfg.distribution + if t.harmonic == 1.0: + v = value * t.distribution t.voltage.set(v) def set_and_wait(self, value: float): raise NotImplementedError("Not implemented yet.") def unit(self) -> str: - return self.__trans[0]._cfg.phase.unit() + return self.__trans[0].phase_device_access.unit() diff --git a/pyaml/rf/rf_transmitter.py b/pyaml/rf/rf_transmitter.py index 8eb942a6b..2dcdd5098 100644 --- a/pyaml/rf/rf_transmitter.py +++ b/pyaml/rf/rf_transmitter.py @@ -1,56 +1,38 @@ -import numpy as np -from pydantic import BaseModel, ConfigDict - -try: - from typing import Self # Python 3.11+ -except ImportError: - from typing_extensions import Self # Python 3.10 and earlier +import copy +from typing import Self from .. import PyAMLException from ..common import abstract -from ..common.element import Element, ElementConfigModel -from ..control.deviceaccess import DeviceAccess +from ..common.element import Element +from ..validation import DynamicValidation # Define the main class name for this module PYAMLCLASS = "RFTransmitter" -class ConfigModel(ElementConfigModel): - """ - Configuration model for RF Transmitter. - - Attributes - ---------- - voltage : str or None, optional - Device to apply cavity voltage - phase : str or None, optional - Device to apply cavity phase - cavities : list[str] - List of cavity names connected to this transmitter - harmonic : float, optional - Harmonic frequency ratio, 1.0 for main frequency, by default 1.0 - distribution : float, optional - RF distribution (Part of the total RF voltage powered by this transmitter), - by default 1.0 - """ - - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") - - voltage: str | None = None - phase: str | None = None - cavities: list[str] - harmonic: float = 1.0 - distribution: float = 1.0 - - -class RFTransmitter(Element): +class RFTransmitter(Element, DynamicValidation): """ Class that handle a RF transmitter """ - def __init__(self, cfg: ConfigModel): - super().__init__(cfg.name) - self._cfg = cfg + def __init__( + self, + name: str, + cavities: list[str], + voltage: str | None = None, + phase: str | None = None, + harmonic: float = 1.0, + distribution: float = 1.0, + lattice_names: str | None = None, + description: str | None = None, + ): + super().__init__(name, lattice_names, description) + self.voltage_name = voltage + self.phase_name = phase + self.cavities = cavities + self.harmonic = harmonic + self.distribution = distribution + self.__voltage = None self.__phase = None @@ -70,7 +52,7 @@ def voltage(self) -> abstract.ReadWriteFloatScalar: If transmitter is unattached or has no voltage device defined """ if self.__voltage is None: - raise PyAMLException(f"{str(self)} is unattached or has no voltage device defined") + raise PyAMLException(f"{str(self.name)} is unattached or has no voltage device defined") return self.__voltage @property @@ -89,7 +71,7 @@ def phase(self) -> abstract.ReadWriteFloatScalar: If transmitter is unattached or has no phase device defined """ if self.__phase is None: - raise PyAMLException(f"{str(self)} is unattached or has no phase device defined") + raise PyAMLException(f"{str(self.name)} is unattached or has no phase device defined") return self.__phase def attach( @@ -116,7 +98,7 @@ def attach( A new attached instance of RFTransmitter """ # Attach voltage and phase attribute and returns a new reference - obj = self.__class__(self._cfg) + obj = copy.copy(self) obj.__voltage = voltage obj.__phase = phase obj._peer = peer diff --git a/pyaml/validation/__init__.py b/pyaml/validation/__init__.py new file mode 100644 index 000000000..1d90a08b6 --- /dev/null +++ b/pyaml/validation/__init__.py @@ -0,0 +1,18 @@ +""" +PyAML validation subpackage. +""" + +from .generator import SchemaGenerator +from .models import ConfigurationSchema, DynamicValidation, StaticValidation +from .registry import SchemaRegistry, register_schema +from .validator import SchemaValidator + +__all__ = [ + "ConfigurationSchema", + "DynamicValidation", + "SchemaRegistry", + "SchemaValidator", + "SchemaGenerator", + "StaticValidation", + "register_schema", +] diff --git a/pyaml/validation/generator.py b/pyaml/validation/generator.py new file mode 100644 index 000000000..874ee4afa --- /dev/null +++ b/pyaml/validation/generator.py @@ -0,0 +1,262 @@ +"""Module for generating JSON Schema from registered configuration schemas.""" + +import json +import logging +from copy import deepcopy +from pathlib import Path +from typing import Any + +from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue +from pydantic_core import core_schema + +from .models import ConfigurationSchema +from .registry import SchemaRegistry + +logger = logging.getLogger(__name__) + + +METADATA_KEYS = ( + "title", + "description", + "examples", + "deprecated", + "readOnly", + "writeOnly", +) + +CLASS_ALIAS = "class" + + +class SchemaGenerator: + """ + Generate JSON Schemas for registered configuration models. + + This class provides convenience methods for generating and exporting + JSON Schemas from models registered in the ``SchemaRegistry``. Schema + generation is delegated to a custom Pydantic JSON Schema generator + that adds support for registry-aware polymorphism. + + Configuration base classes with registered subclasses are represented + as ``oneOf`` unions over their concrete implementations, allowing + generated schemas to describe all valid registered configuration types. + + Primitive unions such as ``str | None`` are emitted using compact + ``type: [...]`` representations when supported by Pydantic. + """ + + _registry = SchemaRegistry() + + @classmethod + def generate(cls, class_path: str) -> dict[str, Any]: + """ + Generate a JSON Schema for a registered configuration schema. + + The schema is generated using a custom Pydantic JSON Schema generator + that expands registered configuration subclasses into ``oneOf`` unions + and preserves compact representations for primitive unions such as + ``str | None``. + + Parameters + ---------- + class_path : str + Registry key identifying the configuration schema class. + + Returns + ------- + dict[str, Any] + Generated JSON Schema for the requested configuration schema. + + Raises + ------ + KeyError + If no schema is registered for the given class path. + """ + + schema_cls = cls._registry.get(class_path) + + logger.debug("Generating schema for %s.", schema_cls) + + if schema_cls is None: + raise KeyError(f"No schema registered for '{class_path}'") + + return schema_cls.model_json_schema( + by_alias=True, + union_format="primitive_type_array", + schema_generator=RegistryJsonSchema, + ) + + @classmethod + def save( + cls, + class_path: str, + filename: str | Path, + *, + indent: int = 2, + ) -> Path: + """ + Generate JSON Schema and save it to a file. + + Parameters + ---------- + class_path : str + Registered class path to generate schema for. + filename : str or Path + Output filename. + indent : int, optional + JSON indentation level. Default: 2. + + Returns + ------- + Path + Path to the written file. + """ + schema = cls.generate(class_path) + + path = Path(filename) + + with path.open("w", encoding="utf-8") as file: + json.dump(schema, file, indent=indent) + + return path + + +class RegistryJsonSchema(GenerateJsonSchema): + """ + Custom Pydantic JSON Schema generator for configuration schemas. + + This generator extends the default Pydantic schema generation to support + registry-aware polymorphism for ``ConfigurationSchema`` subclasses. + + For configuration base classes with registered subclasses, the generated + schema is replaced by a ``oneOf`` union over all registered concrete + subclasses. Human-facing schema metadata such as titles and descriptions + are preserved from the original schema. + + In addition, all generated schema unions are normalized to use ``oneOf`` + instead of ``anyOf`` for improved compatibility with downstream tooling. + Primitive unions such as ``str | None`` continue to use compact + ``type: [...]`` representations when supported by Pydantic. + """ + + _registry = SchemaRegistry() + + def model_schema(self, schema: core_schema.ModelSchema) -> dict[str, Any]: + """ + Generate a JSON Schema for a Pydantic model. + + For ``ConfigurationSchema`` subclasses, the generated schema may be + transformed into a polymorphic schema based on the registered schema + registry: + + - If the model defines a ``class`` field, all registered aliases + corresponding to the model are added as allowed literal values. + - If registered subclasses exist, the schema is replaced by an + ``anyOf`` union containing the schemas of all registered subclasses. + + Metadata fields from the original schema, such as titles and + descriptions, are preserved in the merged schema. + + Parameters + ---------- + schema : core_schema.ModelSchema + Pydantic core schema describing the model. + + Returns + ------- + dict[str, Any] + Generated JSON Schema for the model or polymorphic union schema. + + Notes + ----- + The generated polymorphic schema uses ``anyOf`` instead of ``oneOf`` + because nested ``oneOf`` unions may lead to ambiguous validation in + downstream JSON Schema tooling when subclass schemas contain nullable + or overlapping branches. + """ + + base_schema = super().model_schema(schema) + model_cls = schema.get("cls") + logging.debug(f"Base schema is extracted from {model_cls}.") + + if not isinstance(model_cls, type) or not issubclass(model_cls, ConfigurationSchema): + return base_schema + + # If the baseschema has a class field, add literal for all keys. + properties = base_schema.get("properties") + if isinstance(properties, dict) and CLASS_ALIAS in properties and isinstance(properties[CLASS_ALIAS], dict): + logging.debug(f"Adding list of classes to: {model_cls}.") + + # Find keys that correspond to the same schema + base_keys = sorted(key for key, schema_cls in self._registry.items() if schema_cls is model_cls) + + base_schema = deepcopy(base_schema) + properties = base_schema["properties"] + self._add_literals_to_class_path(properties[CLASS_ALIAS], base_keys) + + # Get subclasses in registry sorted by module name + subclasses = sorted( + { + schema_cls + for _, schema_cls in self._registry.items() + if isinstance(schema_cls, type) and issubclass(schema_cls, model_cls) and schema_cls is not model_cls + }, + key=lambda cls: f"{cls.__module__}.{cls.__name__}", + ) + logging.debug(f"Subclasses found in registry: {subclasses}.") + + if not subclasses: + return base_schema + + # Generate schemas of subclasses + subschemas = [self.generate_inner(item.__pydantic_core_schema__) for item in subclasses] + + # TODO: get the schemas to work when using oneOf instead + merged: dict[str, Any] = {"anyOf": subschemas} + + for key in METADATA_KEYS: + if key in base_schema and key not in merged: + merged[key] = deepcopy(base_schema[key]) + + return merged + + @staticmethod + def _add_literals_to_class_path(schema: dict[str, Any], literals: list[str]) -> None: + """ + Add allowed literal values to a JSON Schema string field. + + The provided literals are merged with any existing ``enum`` values in + the schema while preserving insertion order and removing duplicates. + + If the resulting set contains only a single value, the schema is + simplified by replacing ``enum`` with ``const``. + + The schema is modified in place. + + Parameters + ---------- + schema : dict[str, Any] + JSON Schema fragment representing a string-like field. + literals : list[str] + Literal values to add to the schema. + + Notes + ----- + Only schemas representing string values or existing enumerations are + modified. Empty literal lists are ignored. + """ + + if not literals: + return + + # Add registry keys as literals + if schema.get("type") == "string" or "enum" in schema: + existing = schema.get("enum", []) + merged = list(dict.fromkeys([*existing, *literals])) + schema["enum"] = merged + + # If only one value exists use const + if len(merged) == 1: + schema["const"] = merged[0] + schema.pop("enum", None) + + return diff --git a/pyaml/validation/models.py b/pyaml/validation/models.py new file mode 100644 index 000000000..43cb70695 --- /dev/null +++ b/pyaml/validation/models.py @@ -0,0 +1,219 @@ +"""Base datamodels for configuration.""" + +import inspect +import logging +from typing import Any, get_type_hints + +from pydantic import BaseModel, ConfigDict, Field, create_model + +logger = logging.getLogger(__name__) + + +class PyAMLBaseModel(BaseModel): + """ + Base model for pyAML. + + Overrides ``model_dump()`` and ``model_dump_json()`` to enable + ``serialize_as_any=True`` by default. This ensures that fields are + serialized according to their runtime type rather than their declared + annotation type. + """ + + def model_dump(self, **kwargs): + kwargs.setdefault("serialize_as_any", True) + return super().model_dump(**kwargs) + + def model_dump_json(self, **kwargs): + kwargs.setdefault("serialize_as_any", True) + return super().model_dump_json(**kwargs) + + +class ConfigurationSchema(PyAMLBaseModel): + """ + Base model for configuration schemas. + + Provides common fields and functionality for schemas which are to be registered in the :class:`SchemaRegistry`. + """ + + model_config = ConfigDict(validate_by_name=True, validate_by_alias=True, arbitrary_types_allowed=False, extra="forbid") + + class_path: str = Field( + description="Fully qualified class path.", + alias="class", + ) + + +class ValidationSchema(PyAMLBaseModel): + """ + Base model for validation schemas. + + Provides common fields and functionality for schemas used to validate arguments during object creation. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + +class ValidationMeta(type): + """ + Metaclass that validates constructor arguments using a Pydantic model. + + Classes using this metaclass must define a ``validation_model`` + attribute containing a subclass of :class:`pydantic.BaseModel`. + Before an instance is created, the supplied arguments are bound to + the ``__init__`` signature and validated against the model. + + Both positional and keyword arguments are validated before + ``__init__`` is executed. + """ + + def __call__(cls, *args: Any, **kwargs: Any): + """ + Create an instance after validating constructor arguments. + + The supplied arguments are bound to the class ``__init__`` signature, + default values are applied, and the resulting argument mapping is + validated using ``validation_model``. The validated values are then + passed to the constructor. + + Raises + ------ + TypeError + If the class does not define ``validation_model``. + + ValidationError + If the supplied arguments do not conform to the validation + model. + """ + + validation_model = getattr(cls, "validation_model", None) + + if validation_model is None: + raise TypeError(f"{cls.__name__} must define validation_model.") + + # Inspect the signature of the class + signature = inspect.signature(cls.__init__) + + # Map arguments to parameters + bound = signature.bind(None, *args, **kwargs) + + # Include default arguments + bound.apply_defaults() + + # Remove self from list + bound.arguments.pop("self", None) + arguments = dict(bound.arguments) + + # Validate the model + logger.debug("Validating input against schema: %s", validation_model.model_fields) + validated = validation_model.model_validate(arguments) + + # Return the object + return super().__call__(**validated.model_dump()) + + +class DynamicValidation(metaclass=ValidationMeta): + """ + Base class that generates a validation schema from the constructor + signature. + + When a subclass is defined, a schema derived from + :class:`ValidationSchema` is generated automatically and assigned to + ``validation_model``. The generated schema is used by + :class:`ValidationMeta` to validate constructor arguments before + instance creation. + + Subclasses must not define ``validation_model`` manually. + """ + + validation_model: type[ValidationSchema] | None = None + + def __init_subclass__(cls, **kwargs): + """ + Generate and attach a validation schema for the subclass. + + A schema derived from :class:`ValidationSchema` is created from the + subclass's ``__init__`` signature and assigned to + ``validation_model``. Defining ``validation_model`` explicitly is + not permitted and results in a :class:`TypeError`. + """ + + super().__init_subclass__(**kwargs) + + if getattr(cls, "validation_model", None) is not None: + raise TypeError(f"{cls.__name__} may not define validation_model manually.") + + cls.validation_model = cls._build_validation_model() + + @classmethod + def _build_validation_model(cls) -> type[ValidationSchema]: + """ + Build a validation schema from the constructor signature. + + The generated schema contains one field for each parameter in the + subclass's ``__init__`` method, excluding ``self``, ``*args`` and + ``**kwargs``. Field types are obtained from the constructor's type + annotations and default values are preserved. + + Returns + ------- + type[ValidationSchema] + A dynamically generated subclass of :class:`ValidationSchema` + representing the constructor arguments accepted by the subclass. + """ + + logger.debug("Building validation schema for %s.", f"{cls.__module__}.{cls.__name__}") + + signature = inspect.signature(cls.__init__) + type_hints = get_type_hints(cls.__init__) + + fields: dict[str, tuple[Any, Any]] = {} + + # Skip *args and **kwargs + for name, param in signature.parameters.items(): + if name == "self": + continue + + if param.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + continue + + annotation = type_hints.get(name, Any) + default = param.default if param.default is not inspect._empty else ... + fields[name] = (annotation, default) + + model = create_model(f"{cls.__name__}ValidationSchema", **fields, __base__=ValidationSchema) + + logger.debug("Created model: %s", model.model_fields) + + return model + + +class StaticValidation(metaclass=ValidationMeta): + """ + Base class for explicit constructor validation. + + Subclasses must define a ``validation_model`` attribute containing a + subclass of :class:`pydantic.BaseModel`. The model is used by + :class:`ValidationMeta` to validate constructor arguments before + instance creation. + """ + + validation_model: type[BaseModel] + + def __init_subclass__(cls, **kwargs): + """ + Verify that the subclass defines a validation model. + + Raises + ------ + TypeError + If the subclass does not define a ``validation_model`` + attribute. + """ + + super().__init_subclass__(**kwargs) + + if getattr(cls, "validation_model", None) is None: + raise TypeError(f"{cls.__name__} must define validation_model.") diff --git a/pyaml/validation/registry.py b/pyaml/validation/registry.py new file mode 100644 index 000000000..ea1239f75 --- /dev/null +++ b/pyaml/validation/registry.py @@ -0,0 +1,350 @@ +"""Registry for schemas.""" + +import importlib +import logging +import pkgutil +from collections.abc import ItemsView, Iterator, KeysView, ValuesView +from typing import Callable, Type, TypeVar + +from .models import ConfigurationSchema + +logger = logging.getLogger(__name__) + + +class SchemaRegistry: + """ + Singleton registry for dynamically registered schemas. + + The registry is used to validate data and produce + jsonschemas for dynamic nested models. + """ + + _instance: "SchemaRegistry | None" = None + _schemas: dict[str, Type[ConfigurationSchema]] + + def __new__(cls) -> "SchemaRegistry": + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._schemas = {} + return cls._instance + + # ========================================================== + # Registration + # ========================================================== + + def register( + self, + class_path: str, + schema: type[ConfigurationSchema], + ) -> None: + """Register a schema for a class path. + + Parameters + ---------- + class_path : str + Fully qualified class path. + schema : type[ConfigurationSchema] + Schema class used for validation. Must inherit from + :class:`ConfigurationSchema`. + + Raises + ------ + TypeError + If ``schema`` is not a subclass of + :class:`ConfigurationSchema`. + ValueError + If a different schema has already been registered for + ``class_path``. + """ + existing = self._schemas.get(class_path) + if existing is not None and existing is not schema: + raise ValueError(f"{class_path} already registered with a different schema.") + + if not isinstance(schema, type) or not issubclass(schema, ConfigurationSchema): + raise TypeError(f"{schema!r} must inherit from ConfigurationSchema.") + + self._schemas[class_path] = schema + + def discover(self) -> None: + """Discover and register schemas. + + This imports modules in the package so classes decorated with + :func:`register_schema` are registered, then registers legacy + schemas from ``pyproject.toml``. + """ + + # Import package modules so schema registration runs. + root_package = __package__.split(".")[0] + package = importlib.import_module(root_package) + for _, module_name, _ in pkgutil.walk_packages( + package.__path__, + package.__name__ + ".", + ): + importlib.import_module(module_name) + + def unregister( + self, + class_path: str, + ) -> None: + """Unregister a schema. + + Parameters + ---------- + class_path : str + Fully qualified class path. + + Raises + ------ + KeyError + If no schema has been registered for ``class_path``. + """ + try: + del self._schemas[class_path] + + except KeyError: + raise KeyError(f"No schema registered for '{class_path}'") from None + + def clear(self) -> None: + """Remove all registered schemas. + + This clears the registry in place. + """ + self._schemas.clear() + + def __repr__( + self, + ) -> str: + """Return a string representation of the registry.""" + if not self._schemas: + return f"{self.__class__.__name__}({{}})" + + lines = [f"{self.__class__.__name__}("] + + for class_path, schema in sorted(self._schemas.items()): + lines.append(f" {class_path!r}: {schema.__module__}.{schema.__name__},") + + lines.append(")") + + return "\n".join(lines) + + # ========================================================== + # Lookup + # ========================================================== + + def __getitem__( + self, + class_path: str, + ) -> Type[ConfigurationSchema]: + """Return the registered schema for a class path. + + Parameters + ---------- + class_path : str + Fully qualified class path. + + Returns + ------- + Type[ConfigurationSchema] + Registered schema class. + + Raises + ------ + KeyError + If no schema has been registered for ``class_path``. + """ + + try: + return self._schemas[class_path] + + except KeyError: + raise KeyError(f"No schema registered for '{class_path}.'") from None + + def get( + self, + class_path: str, + ) -> type[ConfigurationSchema] | None: + """Return the registered schema for a class path. + + Parameters + ---------- + class_path : str + Fully qualified class path. + + Returns + ------- + type[ConfigurationSchema] | None + Registered schema class, or ``None`` if no schema is + registered for ``class_path``. + """ + return self._schemas.get(class_path) + + # ========================================================== + # Contents + # ========================================================== + + def __contains__( + self, + class_path: str, + ) -> bool: + """Return whether a schema is registered for a class path. + + Parameters + ---------- + class_path : str + Fully qualified class path. + + Returns + ------- + bool + ``True`` if a schema is registered for ``class_path``, + otherwise ``False``. + """ + return class_path in self._schemas + + def items( + self, + ) -> ItemsView[str, Type[ConfigurationSchema]]: + """Return a view of registered schema items. + + Returns + ------- + ItemsView[str, Type[ConfigurationSchema]] + View of registered ``(class_path, schema)`` pairs. + """ + return self._schemas.items() + + def keys( + self, + ) -> KeysView[str]: + """Return a view of registered class paths. + + Returns + ------- + KeysView[str] + View of registered class paths. + """ + return self._schemas.keys() + + def values( + self, + ) -> ValuesView[Type[ConfigurationSchema]]: + """Return a view of registered schemas. + + Returns + ------- + ValuesView[Type[ConfigurationSchema]] + View of registered schema classes. + """ + return self._schemas.values() + + def __len__( + self, + ) -> int: + """Return the number of registered schemas. + + Returns + ------- + int + Number of registered schemas. + """ + return len(self._schemas) + + def __iter__( + self, + ) -> Iterator[str]: + """Iterate over registered class paths. + + Returns + ------- + Iterator[str] + Iterator over registered class paths. + """ + return iter(self._schemas) + + # ========================================================== + # Updating + # ========================================================== + + def update( + self, + class_path: str, + schema: type[ConfigurationSchema], + ) -> None: + """Replace the schema registered for a class path. + + Parameters + ---------- + class_path : str + Fully qualified class path. + schema : type[ConfigurationSchema] + Schema class used for validation. Must inherit from + :class:`ConfigurationSchema`. + + Raises + ------ + TypeError + If ``schema`` is not a subclass of + :class:`ConfigurationSchema`. + KeyError + If no schema has been registered for ``class_path``. + """ + if not isinstance(schema, type) or not issubclass(schema, ConfigurationSchema): + raise TypeError(f"{schema!r} must inherit from ConfigurationSchema.") + + if class_path not in self._schemas: + raise KeyError(f"{class_path} is not registered.") + + self._schemas[class_path] = schema + + +# ========================================================== +# Decorator to register schemas +# ========================================================== + +ModelT = TypeVar("ModelT", bound=ConfigurationSchema) +ClassT = TypeVar("ClassT") + + +def register_schema( + schema: Type[ModelT], +) -> Callable[[Type[ClassT]], Type[ClassT]]: + """Register a runtime class with a Pydantic schema. + + Parameters + ---------- + schema : Type[ModelT] + Schema class to register. Must inherit from + :class:`ConfigurationSchema`. + + Returns + ------- + Callable[[Type[ClassT]], Type[ClassT]] + Decorator that registers the decorated class with ``schema``. + + Examples + -------- + >>> @register_schema(MySchema) + ... class MyClass: + ... pass + """ + + if not (isinstance(schema, type) and issubclass(schema, ConfigurationSchema)): + raise TypeError("register_schema must be called with a schema class, e.g. @register_schema(MySchema)") + + registry = SchemaRegistry() + + def decorator( + cls: Type[ClassT], + ) -> Type[ClassT]: + class_path = f"{cls.__module__}.{cls.__name__}" + + logger.debug("Register schema for %s.", class_path) + + registry.register( + class_path=class_path, + schema=schema, + ) + + return cls + + return decorator diff --git a/pyaml/validation/validator.py b/pyaml/validation/validator.py new file mode 100644 index 000000000..da5ed5500 --- /dev/null +++ b/pyaml/validation/validator.py @@ -0,0 +1,144 @@ +"""Module for schema validation.""" + +import logging +import warnings +from typing import Any + +from pydantic import ValidationError + +from .models import ConfigurationSchema +from .registry import SchemaRegistry + +logger = logging.getLogger(__name__) + + +class SchemaValidator: + """Recursive validator for configuration dictionaries. + + The validator traverses nested configuration data structures and + converts dictionaries representing configuration objects into + validated Pydantic schema models. + + Validation is performed recursively: + + - Lists are traversed element-by-element + - Dictionaries are recursively validated + - Dictionaries matching configuration schemas are converted into + validated schema models + - Dictionaries with unknown schemas are left unchanged + + Schema lookup is performed through the :class:`SchemaRegistry`. + """ + + _registry = SchemaRegistry() + + @classmethod + def validate( + cls, + data: dict[str, Any], + ) -> ConfigurationSchema: + """Validate configuration data recursively. + + Parameters + ---------- + data : dict[str, Any] + Configuration dictionary to validate. + + Returns + ------- + ConfigurationSchema + Fully validated top-level configuration model. + + Raises + ------ + TypeError + If the validated top-level object is not a + :class:`ConfigurationSchema`. + """ + validated = cls._recursive_validate(data) + + if not isinstance(validated, ConfigurationSchema): + raise TypeError("Top-level configuration did not validate to a ConfigurationSchema.") + + return validated + + @classmethod + def _recursive_validate(cls, obj: Any) -> Any: + """Recursively validate nested configuration objects. + + Lists are traversed recursively element-by-element. Dictionaries + are recursively traversed and then interpreted as configuration + objects when possible. + + If a dictionary corresponds to a registered configuration schema, + it is converted into a validated schema model. Otherwise, the + dictionary is returned unchanged. + + Parameters + ---------- + obj : Any + Object to validate recursively. + + Returns + ------- + Any + Validated object. This may be: + + - A validated configuration model + - A recursively validated list + - A recursively validated dictionary + - The original object if no validation applies + """ + if isinstance(obj, list): + return [cls._recursive_validate(item) for item in obj] + + if not isinstance(obj, dict): + return obj + + logger.debug("Validating dict with keys: %s", list(obj)) + validated_dict = {key: cls._recursive_validate(value) for key, value in obj.items()} + + # Check if the dict is a configuration object + config = cls._parse_configuration(validated_dict) + if config is None: + return validated_dict + + class_path = config.class_path + schema = cls._registry.get(class_path) + + if schema is None: + warnings.warn( + f"Unknown schema for '{class_path}' so cannot validate. Leaving data as raw dict.", + stacklevel=2, + ) + return validated_dict + + return schema.model_validate(validated_dict) + + @classmethod + def _parse_configuration( + cls, + validated_dict: dict[str, Any], + ) -> ConfigurationSchema | None: + """Parse a dictionary as configuration metadata. + + Parameters + ---------- + validated_dict : dict[str, Any] + Dictionary to interpret as configuration metadata. + + Returns + ------- + ConfigurationSchema | None + Parsed configuration model if validation succeeds, + otherwise ``None``. + """ + try: + return ConfigurationSchema.model_validate( + validated_dict, + extra="allow", + ) + except ValidationError: + logger.debug("Could not validate against ConfigurationSchema.") + + return None diff --git a/tests/rf/test_rf.py b/tests/rf/test_rf.py index 5752d4d07..964e25e7b 100644 --- a/tests/rf/test_rf.py +++ b/tests/rf/test_rf.py @@ -107,7 +107,7 @@ def test_rf_multi_notrans(install_test_package): RF.frequency.set(3.523e8) with pytest.raises(PyAMLException) as exc: RF.voltage.set(10e6) - assert "has no trasmitter device defined" in str(exc) + assert "has no transmitter device defined" in str(exc) # Check that frequency and voltage has been applied on the masterclock device assert np.isclose(RF.frequency.get(), 3.523e8) diff --git a/tests/validation/test_generator.py b/tests/validation/test_generator.py new file mode 100644 index 000000000..e32521f96 --- /dev/null +++ b/tests/validation/test_generator.py @@ -0,0 +1,204 @@ +"""Tests of the schema generator.""" + +import json +import re +from collections.abc import Generator +from pathlib import Path + +import pytest +from pydantic import Field + +from pyaml.validation import ( + ConfigurationSchema, + SchemaGenerator, + SchemaRegistry, +) +from pyaml.validation.generator import RegistryJsonSchema + +# ========================================================== +# Dummy schemas +# ========================================================== + + +class DummySchema(ConfigurationSchema): + pass + + +class OtherSchema(ConfigurationSchema): + pass + + +class ParentSchema(ConfigurationSchema): + """Parent schema used to test inheritance.""" + + pass + + +class ChildSchemaA(ParentSchema): + a: int = 1 + + +class ChildSchemaB(ParentSchema): + b: str = "x" + + +class ContainerSchema(ConfigurationSchema): + model: ChildSchemaA | None = Field( + default=None, + description="Container schema used for testing.", + ) + + +# ========================================================== +# Fixtures +# ========================================================== + + +@pytest.fixture(autouse=True) +def clear_registry() -> Generator[None, None, None]: + """Ensure the registry is empty for each test.""" + + registry = SchemaRegistry() + registry.clear() + yield + registry.clear() + + +@pytest.fixture +def registry() -> SchemaRegistry: + return SchemaRegistry() + + +# ========================================================== +# Generate +# ========================================================== + + +def test_generate_raises_clean_keyerror_for_missing_schema(registry: SchemaRegistry): + class_path = "pkg.module.Class" + + with pytest.raises( + KeyError, + match=re.escape(f"No schema registered for '{class_path}'"), + ): + SchemaGenerator.generate(class_path) + + +def test_generate_returns_schema_for_registered_class(registry: SchemaRegistry): + class_path = "pkg.module.Class" + + registry.register(class_path, DummySchema) + + schema = SchemaGenerator.generate(class_path) + + assert schema["title"] == "DummySchema" + + +# ========================================================== +# Save +# ========================================================== + + +def test_save_writes_schema_to_file(registry: SchemaRegistry, tmp_path: Path): + class_path = "pkg.module.Class" + registry.register(class_path, DummySchema) + + filename = tmp_path / "schema.json" + + result = SchemaGenerator.save(class_path, filename, indent=2) + + assert result == filename + assert json.loads(filename.read_text(encoding="utf-8")) == SchemaGenerator.generate(class_path) + + +# ========================================================== +# Registry-aware polymorphism +# ========================================================== + + +def test_generate_replaces_parent_schema_with_registered_subclasses( + registry: SchemaRegistry, +): + registry.register("pkg.module.Parent", ParentSchema) + registry.register("pkg.module.ChildA", ChildSchemaA) + registry.register("pkg.module.ChildB", ChildSchemaB) + + schema = SchemaGenerator.generate("pkg.module.Parent") + + child_refs = {item["$ref"] for item in schema.get("anyOf", [])} + + assert "#/$defs/ChildSchemaA" in child_refs + assert "#/$defs/ChildSchemaB" in child_refs + + +def test_model_schema_preserves_metadata_from_parent_schema( + registry: SchemaRegistry, +): + registry.register("pkg.module.Parent", ParentSchema) + registry.register("pkg.module.ChildA", ChildSchemaA) + registry.register("pkg.module.ChildB", ChildSchemaB) + + generator = RegistryJsonSchema() + base_schema = ParentSchema.__pydantic_core_schema__ + + schema = generator.model_schema(base_schema) + + assert schema["title"] == "ParentSchema" + + +# ========================================================== +# Literals +# ========================================================== + + +def test_add_literals_to_class_path_ignores_empty_literal_list(): + schema = {"type": "string"} + + RegistryJsonSchema._add_literals_to_class_path(schema, []) + + assert schema == {"type": "string"} + + +def test_add_literals_to_class_path_replaces_single_value_with_const(): + schema = {"type": "string"} + + RegistryJsonSchema._add_literals_to_class_path( + schema, + ["pkg.module.Parent"], + ) + + assert schema["const"] == "pkg.module.Parent" + assert "enum" not in schema + + +def test_add_literals_to_class_path_merges_existing_enum_and_removes_duplicates(): + schema = {"type": "string", "enum": ["pkg.module.Parent", "pkg.module.ChildA"]} + + RegistryJsonSchema._add_literals_to_class_path( + schema, + ["pkg.module.ChildA", "pkg.module.ChildB", "pkg.module.Parent"], + ) + + assert schema["enum"] == [ + "pkg.module.Parent", + "pkg.module.ChildA", + "pkg.module.ChildB", + ] + assert "const" not in schema + + +def test_add_literals_to_class_path_does_not_modify_non_string_schema_without_enum(): + schema = {"type": "integer"} + + RegistryJsonSchema._add_literals_to_class_path(schema, ["pkg.module.Parent"]) + + assert schema == {"type": "integer"} + + +def test_add_literals_to_class_path_updates_existing_enum_even_without_string_type(): + schema = {"enum": ["pkg.module.Parent"]} + + RegistryJsonSchema._add_literals_to_class_path(schema, ["pkg.module.ChildA"]) + + assert schema["enum"] == ["pkg.module.Parent", "pkg.module.ChildA"] + assert "const" not in schema diff --git a/tests/validation/test_models.py b/tests/validation/test_models.py new file mode 100644 index 000000000..b2ad86db2 --- /dev/null +++ b/tests/validation/test_models.py @@ -0,0 +1,269 @@ +"""Tests of the configuration models.""" + +import json + +import pytest +from pydantic import BaseModel, ValidationError +from pydantic.errors import PydanticSchemaGenerationError + +from pyaml.validation import ConfigurationSchema, DynamicValidation, StaticValidation +from pyaml.validation.models import PyAMLBaseModel, ValidationSchema + +# ========================================================== +# PyAMLBaseModel +# ========================================================== + + +def test_model_dump_serializes_subclass_fields(): + class Device(PyAMLBaseModel): + name: str + + class Magnet(Device): + type: str + + class Accelerator(PyAMLBaseModel): + device: Device + + accelerator = Accelerator(device=Magnet(name="QF", type="Quadrupole")) + + dumped = accelerator.model_dump() + + assert dumped == {"device": {"name": "QF", "type": "Quadrupole"}} + + +def test_model_dump_json_serializes_subclass_fields(): + class Device(PyAMLBaseModel): + name: str + + class Magnet(Device): + type: str + + class Accelerator(PyAMLBaseModel): + device: Device + + accelerator = Accelerator(device=Magnet(name="QF", type="Quadrupole")) + + dumped_json = accelerator.model_dump_json() + dumped = json.loads(dumped_json) + + assert dumped == {"device": {"name": "QF", "type": "Quadrupole"}} + + +# ========================================================== +# ConfigurationSchema +# ========================================================== + + +def test_configuration_schema_accepts_alias_class(): + schema = ConfigurationSchema.model_validate({"class": "pkg.module.Class"}) + + assert schema.class_path == "pkg.module.Class" + + +def test_configuration_schema_accepts_field_name_class_path(): + schema = ConfigurationSchema.model_validate({"class_path": "pkg.module.Class"}) + + assert schema.class_path == "pkg.module.Class" + + +def test_configuration_schema_forbids_extra_fields(): + with pytest.raises(ValidationError) as exc_info: + ConfigurationSchema.model_validate( + { + "class": "pkg.module.Class", + "unexpected": "value", + } + ) + + assert "extra_forbidden" in str(exc_info.value) + + +def test_configuration_schema_do_not_allow_arbitrary_types(): + class ArbitraryType: + pass + + with pytest.raises(PydanticSchemaGenerationError): + + class TestModel(ConfigurationSchema): + value: ArbitraryType + + +def test_configuration_schema_dump_uses_alias_when_requested(): + schema = ConfigurationSchema.model_validate({"class": "pkg.module.Class"}) + + dumped = schema.model_dump(by_alias=True) + + assert dumped == {"class": "pkg.module.Class"} + + +# ========================================================== +# ValidationSchema +# ========================================================== + + +class DummyDevice: + pass + + +class DummySchema(ValidationSchema): + device: DummyDevice + + +def test_validation_schema_allows_arbitrary_types(): + device = DummyDevice() + + schema = DummySchema.model_validate({"device": device}) + + assert schema.device is device + + +def test_validation_schema_forbids_extra_fields(): + with pytest.raises(ValidationError): + DummySchema.model_validate({"device": DummyDevice(), "extra_field": 123}) + + +# ========================================================== +# DynamicValidation +# ========================================================== + + +def test_dynamic_validation_builds_schema_from_init_signature(): + class MyClass(DynamicValidation): + def __init__(self, name: str, count: int = 0): + self.name = name + self.count = count + + assert issubclass(MyClass.validation_model, ValidationSchema) + assert list(MyClass.validation_model.model_fields) == ["name", "count"] + + name_field = MyClass.validation_model.model_fields["name"] + count_field = MyClass.validation_model.model_fields["count"] + + assert name_field.annotation is str + assert name_field.is_required() + + assert count_field.annotation is int + assert count_field.default == 0 + + +def test_dynamic_validation_accepts_positional_and_keyword_arguments(): + class MyClass(DynamicValidation): + def __init__(self, name: str, count: int = 0): + self.name = name + self.count = count + + obj1 = MyClass("test", 1) + obj2 = MyClass(name="test", count=1) + obj3 = MyClass("test") + + assert obj1.name == "test" + assert obj1.count == 1 + + assert obj2.name == "test" + assert obj2.count == 1 + + assert obj3.name == "test" + assert obj3.count == 0 + + +def test_dynamic_validation_coerces_and_rejects_invalid_input(): + class MyClass(DynamicValidation): + def __init__(self, name: str, count: int): + self.name = name + self.count = count + + obj = MyClass(name="test", count="12") + assert obj.count == 12 + + with pytest.raises(ValidationError): + MyClass(name="test", count="not-an-int") + + +def test_dynamic_validation_rejects_manual_validation_model(): + class ManualModel(BaseModel): + name: str + + with pytest.raises(TypeError, match="may not define validation_model manually"): + + class Broken(DynamicValidation): + validation_model = ManualModel + + def __init__(self, name: str): + self.name = name + + +# ========================================================== +# StaticValidation +# ========================================================== + + +def test_static_validation_accepts_explicit_basemodel(): + class ExampleSchema(BaseModel): + name: str + count: int = 0 + + class Example(StaticValidation): + validation_model = ExampleSchema + + def __init__(self, name: str, count: int = 0): + self.name = name + self.count = count + + obj1 = Example("test", 1) + obj2 = Example(name="test", count=1) + obj3 = Example("test") + + assert obj1.name == "test" + assert obj1.count == 1 + + assert obj2.name == "test" + assert obj2.count == 1 + + assert obj3.name == "test" + assert obj3.count == 0 + + +def test_static_validation_validates_and_coerces_input(): + class ExampleSchema(BaseModel): + name: str + count: int + + class Example(StaticValidation): + validation_model = ExampleSchema + + def __init__(self, name: str, count: int): + self.name = name + self.count = count + + obj = Example(name="test", count="12") + assert obj.count == 12 + + with pytest.raises(ValidationError): + Example(name="test", count="not-an-int") + + +def test_static_validation_inherits_validation_model(): + class ParentSchema(BaseModel): + name: str + + class Parent(StaticValidation): + validation_model = ParentSchema + + def __init__(self, name: str): + self.name = name + + class Child(Parent): + def __init__(self, name: str): + super().__init__(name) + + obj = Child("test") + assert obj.name == "test" + assert Child.validation_model is ParentSchema + + +def test_static_validation_requires_a_validation_model(): + with pytest.raises(TypeError, match="must define validation_model"): + + class Broken(StaticValidation): + def __init__(self, name: str): + self.name = name diff --git a/tests/validation/test_registry.py b/tests/validation/test_registry.py new file mode 100644 index 000000000..ce68aa822 --- /dev/null +++ b/tests/validation/test_registry.py @@ -0,0 +1,347 @@ +"""Tests of the schema registry.""" + +import re +from collections.abc import Generator + +import pytest + +from pyaml.validation import ConfigurationSchema, SchemaRegistry, register_schema + +# ========================================================== +# Dummy schemas +# ========================================================== + + +class DummySchema(ConfigurationSchema): + pass + + +class OtherSchema(ConfigurationSchema): + pass + + +# ========================================================== +# Fixtures +# ========================================================== + + +@pytest.fixture(autouse=True) +def clear_registry() -> Generator[None, None, None]: + """Ensure the registry is empty for each test.""" + + registry = SchemaRegistry() + registry.clear() + yield + registry.clear() + + +@pytest.fixture +def registry() -> SchemaRegistry: + return SchemaRegistry() + + +# ========================================================== +# Singleton behaviour +# ========================================================== + + +def test_singleton_returns_same_instance(): + assert SchemaRegistry() is SchemaRegistry() + + +# ========================================================== +# Registration +# ========================================================== + + +def test_register_stores_schema(registry: SchemaRegistry): + registry.register("pkg.module.Class", DummySchema) + + assert registry["pkg.module.Class"] is DummySchema + + +def test_register_allows_same_schema_for_existing_class_path(registry: SchemaRegistry): + class_path = "pkg.module.Class" + + registry.register(class_path, DummySchema) + registry.register(class_path, DummySchema) + + assert registry[class_path] is DummySchema + assert len(registry) == 1 + + +def test_register_raises_valueerror_for_different_schema(registry: SchemaRegistry): + class_path = "pkg.module.Class" + + registry.register(class_path, DummySchema) + + with pytest.raises( + ValueError, + match=re.escape(f"{class_path} already registered with a different schema."), + ): + registry.register(class_path, OtherSchema) + + +def test_register_raises_typeerror_for_invalid_schema(registry: SchemaRegistry): + with pytest.raises( + TypeError, + match=re.escape("must inherit from ConfigurationSchema"), + ): + registry.register( + "pkg.module.Class", + object, # type: ignore[arg-type] + ) + + +# ========================================================== +# Unregistering +# ========================================================== + + +def test_unregister_removes_registered_schema(registry: SchemaRegistry): + class_path = "pkg.module.Class" + + registry.register(class_path, DummySchema) + + registry.unregister(class_path) + + assert class_path not in registry + + +def test_unregister_raises_clean_keyerror_for_missing_schema( + registry: SchemaRegistry, +): + class_path = "pkg.module.Class" + + with pytest.raises( + KeyError, + match=re.escape(f"No schema registered for '{class_path}'"), + ): + registry.unregister(class_path) + + +def test_unregister_removes_only_requested_schema(registry: SchemaRegistry): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + registry.unregister("pkg.module.ClassA") + + assert "pkg.module.ClassA" not in registry + assert registry["pkg.module.ClassB"] is OtherSchema + assert len(registry) == 1 + + +# ========================================================== +# Clearing +# ========================================================== + + +def test_clear_removes_all_registered_schemas(registry: SchemaRegistry): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + registry.clear() + + assert len(registry) == 0 + assert "pkg.module.ClassA" not in registry + assert "pkg.module.ClassB" not in registry + + +def test_clear_on_empty_registry_keeps_registry_empty( + registry: SchemaRegistry, +): + registry.clear() + + assert len(registry) == 0 + + +def test_clear_allows_new_registrations_afterwards( + registry: SchemaRegistry, +): + registry.register("pkg.module.Class", DummySchema) + + registry.clear() + + registry.register("pkg.module.OtherClass", OtherSchema) + + assert len(registry) == 1 + assert registry["pkg.module.OtherClass"] is OtherSchema + + +# ========================================================== +# Lookup +# ========================================================== + + +def test_getitem_raises_clean_keyerror_for_missing_schema(registry: SchemaRegistry): + with pytest.raises(KeyError, match=r"No schema registered for 'pkg\.module\.Class.'"): + _ = registry["pkg.module.Class"] + + +def test_get_returns_registered_schema(registry: SchemaRegistry): + registry.register("pkg.module.Class", DummySchema) + + assert registry.get("pkg.module.Class") is DummySchema + + +def test_get_returns_none_for_missing_schema(registry: SchemaRegistry): + assert registry.get("pkg.module.Class") is None + + +# ========================================================== +# Contents +# ========================================================== + + +def test_items_returns_registered_items(registry: SchemaRegistry): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + items = registry.items() + + assert ("pkg.module.ClassA", DummySchema) in items + assert ("pkg.module.ClassB", OtherSchema) in items + assert len(items) == 2 + + +def test_keys_returns_registered_class_paths(registry: SchemaRegistry): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + keys = registry.keys() + + assert "pkg.module.ClassA" in keys + assert "pkg.module.ClassB" in keys + assert len(keys) == 2 + + +def test_values_returns_registered_schemas(registry: SchemaRegistry): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + values = registry.values() + + assert DummySchema in values + assert OtherSchema in values + assert len(values) == 2 + + +def test_iter_returns_registered_class_paths(registry: SchemaRegistry): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + class_paths = list(iter(registry)) + + assert "pkg.module.ClassA" in class_paths + assert "pkg.module.ClassB" in class_paths + assert len(class_paths) == 2 + + +# ========================================================== +# Updating +# ========================================================== + + +def test_update_replaces_registered_schema(registry: SchemaRegistry): + registry.register("pkg.module.Class", DummySchema) + + registry.update("pkg.module.Class", OtherSchema) + + assert registry["pkg.module.Class"] is OtherSchema + + +def test_update_raises_keyerror_for_missing_schema(registry: SchemaRegistry): + class_path = "pkg.module.Class" + + with pytest.raises( + KeyError, + match=re.escape(f"{class_path} is not registered."), + ): + registry.update(class_path, DummySchema) + + +def test_update_raises_typeerror_for_invalid_schema(registry: SchemaRegistry): + with pytest.raises( + TypeError, + match=r"must inherit from ConfigurationSchema", + ): + registry.update( + "pkg.module.Class", + object, # type: ignore[arg-type] + ) + + +# ========================================================== +# Representation +# ========================================================== + + +def test_repr_returns_empty_registry_representation( + registry: SchemaRegistry, +): + assert repr(registry) == "SchemaRegistry({})" + + +def test_repr_returns_registered_schemas( + registry: SchemaRegistry, +): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + result = repr(registry) + + assert result.startswith("SchemaRegistry(") + assert "'pkg.module.ClassA'" in result + assert "'pkg.module.ClassB'" in result + + assert f"{DummySchema.__module__}.{DummySchema.__name__}" in result + assert f"{OtherSchema.__module__}.{OtherSchema.__name__}" in result + + assert result.endswith(")") + + +def test_repr_sorts_registered_class_paths( + registry: SchemaRegistry, +): + registry.register("pkg.module.ZClass", DummySchema) + registry.register("pkg.module.AClass", OtherSchema) + + result = repr(registry) + + assert result.index("'pkg.module.AClass'") < result.index("'pkg.module.ZClass'") + + +# ========================================================== +# Register schema decorator +# ========================================================== + + +def test_register_schema_registers_the_decorated_class( + registry: SchemaRegistry, +): + @register_schema(DummySchema) + class DecoratedClass: + pass + + class_path = f"{DecoratedClass.__module__}.{DecoratedClass.__name__}" + + assert registry[class_path] is DummySchema + + +def test_register_schema_can_register_multiple_classes_with_same_schema( + registry: SchemaRegistry, +): + @register_schema(DummySchema) + class FirstClass: + pass + + @register_schema(DummySchema) + class SecondClass: + pass + + first_path = f"{FirstClass.__module__}.{FirstClass.__name__}" + second_path = f"{SecondClass.__module__}.{SecondClass.__name__}" + + assert registry[first_path] is DummySchema + assert registry[second_path] is DummySchema + assert len(registry) == 2 diff --git a/tests/validation/test_validator.py b/tests/validation/test_validator.py new file mode 100644 index 000000000..e7eec6e95 --- /dev/null +++ b/tests/validation/test_validator.py @@ -0,0 +1,192 @@ +"""Tests of the schema validator.""" + +from collections.abc import Generator + +import pytest + +from pyaml.validation import ( + ConfigurationSchema, + SchemaRegistry, + SchemaValidator, +) + +# ========================================================== +# Dummy schemas +# ========================================================== + + +class DummySchema(ConfigurationSchema): + value: int | None = None + + +class OtherSchema(ConfigurationSchema): + name: str | None = None + children: list[DummySchema] | None = None + + +# ========================================================== +# Fixtures +# ========================================================== + + +@pytest.fixture(autouse=True) +def clear_registry() -> Generator[None, None, None]: + """Ensure the registry is empty for each test.""" + + registry = SchemaRegistry() + registry.clear() + yield + registry.clear() + + +@pytest.fixture +def registry() -> SchemaRegistry: + return SchemaRegistry() + + +# ========================================================== +# Recursive validation +# ========================================================== + + +def test_recursive_validate_returns_validated_schema( + registry: SchemaRegistry, +): + registry.register("pkg.module.Class", DummySchema) + + data = { + "class_path": "pkg.module.Class", + "value": 42, + } + + result = SchemaValidator._recursive_validate(data) + + assert isinstance(result, DummySchema) + assert result.class_path == "pkg.module.Class" + assert result.value == 42 + + +def test_recursive_validate_recurses_through_nested_lists_and_dicts( + registry: SchemaRegistry, +): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + data = { + "class_path": "pkg.module.ClassB", + "name": "dummy", + "children": [ + { + "class_path": "pkg.module.ClassA", + "value": "42", + }, + { + "class_path": "pkg.module.ClassA", + "value": "73", + }, + ], + } + + result = SchemaValidator._recursive_validate(data) + + assert isinstance(result, OtherSchema) + + assert isinstance(result.children[0], DummySchema) + assert result.children[0].value == 42 + + assert isinstance(result.children[1], DummySchema) + assert result.children[1].value == 73 + + +def test_recursive_validate_leaves_plain_dicts_unchanged(): + data = { + "plain": "dict", + } + + result = SchemaValidator._recursive_validate(data) + + assert result == data + + +def test_recursive_validate_leaves_non_container_values_unchanged(): + assert SchemaValidator._recursive_validate("text") == "text" + assert SchemaValidator._recursive_validate(123) == 123 + assert SchemaValidator._recursive_validate(True) is True + assert SchemaValidator._recursive_validate(None) is None + + +def test_recursive_validate_warns_for_unknown_schema( + registry: SchemaRegistry, +): + data = { + "class_path": "pkg.module.Unknown", + "value": 42, + } + + with pytest.warns( + UserWarning, + match=r"Unknown schema for 'pkg\.module\.Unknown' so cannot validate\. Leaving data as raw dict\.", + ): + result = SchemaValidator._recursive_validate(data) + + assert result == data + + +# ========================================================== +# Configuration parsing +# ========================================================== + + +def test_parse_configuration_returns_configuration_schema(): + data = { + "class_path": "pkg.module.Class", + } + + result = SchemaValidator._parse_configuration(data) + + assert isinstance(result, ConfigurationSchema) + assert result.class_path == "pkg.module.Class" + + +def test_parse_configuration_returns_none_for_non_configuration_dict(): + data = { + "plain": "dict", + } + + result = SchemaValidator._parse_configuration(data) + + assert result is None + + +# ========================================================== +# Top-level validation +# ========================================================== + + +def test_validate_returns_validated_configuration_schema( + registry: SchemaRegistry, +): + registry.register("pkg.module.Class", DummySchema) + + data = { + "class_path": "pkg.module.Class", + "value": 42, + } + + result = SchemaValidator.validate(data) + + assert isinstance(result, DummySchema) + assert result.class_path == "pkg.module.Class" + assert result.value == 42 + + +def test_validate_raises_typeerror_for_non_configuration_dict(): + data = { + "plain": "dict", + } + + with pytest.raises( + TypeError, + match=r"Top-level configuration did not validate to a ConfigurationSchema\.", + ): + SchemaValidator.validate(data)