From 0fb0149dd0eec507fdbc16566a1a78f79af637f5 Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Mon, 8 Jun 2026 14:31:21 +0200 Subject: [PATCH] Add __get_pydantic_core_schema__ to Element. --- pyaml/common/element.py | 171 +++++++++++++++++++++++----------- pyaml/validation/validator.py | 63 +++++++++++++ 2 files changed, 180 insertions(+), 54 deletions(-) create mode 100644 pyaml/validation/validator.py diff --git a/pyaml/common/element.py b/pyaml/common/element.py index 11ab1f02e..c422c71ff 100644 --- a/pyaml/common/element.py +++ b/pyaml/common/element.py @@ -1,62 +1,90 @@ +import warnings from typing import TYPE_CHECKING from pydantic import BaseModel, ConfigDict +from ..validation.validator import add_schema from .exception import PyAMLException if TYPE_CHECKING: from ..common.element_holder import ElementHolder -def __pyaml_repr__(obj): - """ - Returns a string representation of a pyaml object - """ - if hasattr(obj, "_cfg"): - if isinstance(obj, Element): - return repr(obj._cfg).replace( - "ConfigModel(", - obj.__class__.__name__ + "(peer='" + obj.attached_to() + "', ", - ) - else: - # no peer - return repr(obj._cfg).replace("ConfigModel", obj.__class__.__name__) - else: - # Object is not yet fully constructed - if isinstance(obj, Element): - return f"{obj.__class__.__name__}: {obj.get_name()}" - else: - return f"{obj.__class__.__name__}" +# def __pyaml_repr__(obj): +# """ +# Returns a string representation of a pyaml object +# """ +# if hasattr(obj, "_cfg"): +# if isinstance(obj, Element): +# return repr(obj._cfg).replace( +# "ConfigModel(", +# obj.__class__.__name__ + "(peer='" + obj.attached_to() + "', ", +# ) +# else: +# # no peer +# return repr(obj._cfg).replace("ConfigModel", obj.__class__.__name__) +# else: +# # Object is not yet fully constructed +# if isinstance(obj, Element): +# return f"{obj.__class__.__name__}: {obj.get_name()}" +# else: +# return f"{obj.__class__.__name__}" -class ElementConfigModel(BaseModel): +def __pyaml_repr__(obj): """ - Base class for element configuration. - - Parameters - ---------- - name : str - The name of the PyAML element. - description : str, optional - Description of the element. - lattice_names : str or None, optional - The name(s) of the associated element(s) in the lattice. By default, - the PyAML element name is used. lattice_name accept the following - syntax: - - list(name,[name]) : Element names - - [name]@idx[,idx] : Element indices in the subset formed by name. - - [name]#start_idx..end_idx : Element range in the subset formed by name. - In the above syntax, if the name is not specficied, the whole set - of lattice element is used for indexing. + Returns a string representation of a pyaml object """ - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") - - name: str - description: str | None = None - lattice_names: str | None = None - - + attrs = {} + + # Instance attributes + for k, v in obj.__dict__.items(): + # Exclude private attributes + if not k.startswith("_"): + attrs[k] = v + + # Properties + for name, attr in vars(type(obj)).items(): + if isinstance(attr, property): + try: + attrs[name] = getattr(obj, name) + except Exception as e: + attrs[name] = f"" + + parts = ", ".join(f"{k}={v!r}" for k, v in attrs.items()) + return f"{obj.__class__.__name__}({parts})" + + +# class ElementConfigModel(BaseModel): +# """ +# Base class for element configuration. + +# Parameters +# ---------- +# name : str +# The name of the PyAML element. +# description : str, optional +# Description of the element. +# lattice_names : str or None, optional +# The name(s) of the associated element(s) in the lattice. By default, +# the PyAML element name is used. lattice_name accept the following +# syntax: +# - list(name,[name]) : Element names +# - [name]@idx[,idx] : Element indices in the subset formed by name. +# - [name]#start_idx..end_idx : Element range in the subset formed by name. +# In the above syntax, if the name is not specficied, the whole set +# of lattice element is used for indexing. +# """ + +# model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + +# name: str +# description: str | None = None +# lattice_names: str | None = None + + +@add_schema class Element(object): """ Class providing access to one element of a physical or simulated lattice @@ -66,30 +94,65 @@ class Element(object): The unique name identifying the element in the configuration file """ - def __init__(self, name: str): + def __init__( + self, + name: str, + lattice_names: str | None = None, + description: str | None = None, + ): self._name: str = name + + # If no lattice names are given put it to the name of the element + if lattice_names: + self._lattice_names = lattice_names + else: + self._lattice_names = self._name + + self.description = description + self._peer: "ElementHolder" = None # Peer: ControlSystem, Simulator - def get_name(self) -> str: - """ - Returns the name of the element - """ + @property + def name(self) -> str: return self._name + # TODO: implement name setter -> this requires checking so the name is unique + + def get_name(self) -> str: + warnings.warn( + "get_name() is deprecated; use .name instead", + DeprecationWarning, + stacklevel=2, + ) + return self.name + + @property + def lattice_names(self) -> str: + return self._lattice_names + + # TODO: implement lattice_names setter -> this requires validation of the forma + def get_lattice_names(self) -> str: """ Returns the name of associated lattice element(s) """ - if not hasattr(self, "_cfg"): - return self._name - else: - return self._cfg.lattice_names + warnings.warn( + "get_lattice_names() is deprecated; use .lattice_names instead", + DeprecationWarning, + stacklevel=2, + ) + return self.lattice_names def get_description(self) -> str: """ Returns the description of the element """ - return self._cfg.description + warnings.warn( + "get_description() is deprecated; use .description instead", + DeprecationWarning, + stacklevel=2, + ) + return self.description def set_energy(self, E: float): """ diff --git a/pyaml/validation/validator.py b/pyaml/validation/validator.py new file mode 100644 index 000000000..c52a6d0c8 --- /dev/null +++ b/pyaml/validation/validator.py @@ -0,0 +1,63 @@ +import inspect +from typing import Any, get_type_hints + +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema + + +def add_schema(cls): + # Get the attributes + sig = inspect.signature(cls.__init__) + + # Filter out self, *args and **kwargs + params = [ + p + for p in sig.parameters.values() + if p.name != "self" + and p.kind + not in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ) + ] + + # Get the type annotations + hints = get_type_hints(cls.__init__, include_extras=True) + + for p in params: + if p.name not in hints: + raise TypeError(f"{cls.__name__}.__init__ parameter '{p.name}' must be annotated.") + + @classmethod + def __get_pydantic_core_schema__(target_cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + fields: dict[str, core_schema.TypedDictField] = {} + + for p in params: + annotation = hints.get(p.name) + field_schema = handler.generate_schema(annotation) + + fields[p.name] = core_schema.typed_dict_field( + field_schema, + required=(p.default is inspect._empty), + ) + + typed_dict = core_schema.typed_dict_schema(fields) + + # Validate and create an object of the class + # This is required to handle nested objects + def validate(value, inner_validator): + # Allow already-created instances to pass through unchanged + if isinstance(value, target_cls): + return value + + # Validate the data + data = inner_validator(value) + + # Create an object of the class + return target_cls(**data) + + return core_schema.no_info_wrap_validator_function(validate, typed_dict) + + # Add the method on the class + cls.__get_pydantic_core_schema__ = __get_pydantic_core_schema__ + return cls