Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 117 additions & 54 deletions pyaml/common/element.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,90 @@
import warnings
from typing import TYPE_CHECKING

from pydantic import BaseModel, ConfigDict

from ..validation.validator import add_schema
from .exception import PyAMLException

if TYPE_CHECKING:
from ..common.element_holder import ElementHolder


def __pyaml_repr__(obj):
"""
Returns a string representation of a pyaml object
"""
if hasattr(obj, "_cfg"):
if isinstance(obj, Element):
return repr(obj._cfg).replace(
"ConfigModel(",
obj.__class__.__name__ + "(peer='" + obj.attached_to() + "', ",
)
else:
# no peer
return repr(obj._cfg).replace("ConfigModel", obj.__class__.__name__)
else:
# Object is not yet fully constructed
if isinstance(obj, Element):
return f"{obj.__class__.__name__}: {obj.get_name()}"
else:
return f"{obj.__class__.__name__}"
# def __pyaml_repr__(obj):
# """
# Returns a string representation of a pyaml object
# """
# if hasattr(obj, "_cfg"):
# if isinstance(obj, Element):
# return repr(obj._cfg).replace(
# "ConfigModel(",
# obj.__class__.__name__ + "(peer='" + obj.attached_to() + "', ",
# )
# else:
# # no peer
# return repr(obj._cfg).replace("ConfigModel", obj.__class__.__name__)
# else:
# # Object is not yet fully constructed
# if isinstance(obj, Element):
# return f"{obj.__class__.__name__}: {obj.get_name()}"
# else:
# return f"{obj.__class__.__name__}"


class ElementConfigModel(BaseModel):
def __pyaml_repr__(obj):
"""
Base class for element configuration.

Parameters
----------
name : str
The name of the PyAML element.
description : str, optional
Description of the element.
lattice_names : str or None, optional
The name(s) of the associated element(s) in the lattice. By default,
the PyAML element name is used. lattice_name accept the following
syntax:
- list(name,[name]) : Element names
- [name]@idx[,idx] : Element indices in the subset formed by name.
- [name]#start_idx..end_idx : Element range in the subset formed by name.
In the above syntax, if the name is not specficied, the whole set
of lattice element is used for indexing.
Returns a string representation of a pyaml object
"""

model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")

name: str
description: str | None = None
lattice_names: str | None = None


attrs = {}

# Instance attributes
for k, v in obj.__dict__.items():
# Exclude private attributes
if not k.startswith("_"):
attrs[k] = v

# Properties
for name, attr in vars(type(obj)).items():
if isinstance(attr, property):
try:
attrs[name] = getattr(obj, name)
except Exception as e:
attrs[name] = f"<error: {e}>"

parts = ", ".join(f"{k}={v!r}" for k, v in attrs.items())
return f"{obj.__class__.__name__}({parts})"


# class ElementConfigModel(BaseModel):
# """
# Base class for element configuration.

# Parameters
# ----------
# name : str
# The name of the PyAML element.
# description : str, optional
# Description of the element.
# lattice_names : str or None, optional
# The name(s) of the associated element(s) in the lattice. By default,
# the PyAML element name is used. lattice_name accept the following
# syntax:
# - list(name,[name]) : Element names
# - [name]@idx[,idx] : Element indices in the subset formed by name.
# - [name]#start_idx..end_idx : Element range in the subset formed by name.
# In the above syntax, if the name is not specficied, the whole set
# of lattice element is used for indexing.
# """

# model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")

# name: str
# description: str | None = None
# lattice_names: str | None = None


@add_schema
class Element(object):
"""
Class providing access to one element of a physical or simulated lattice
Expand All @@ -66,30 +94,65 @@ class Element(object):
The unique name identifying the element in the configuration file
"""

def __init__(self, name: str):
def __init__(
self,
name: str,
lattice_names: str | None = None,
description: str | None = None,
):
self._name: str = name

# If no lattice names are given put it to the name of the element
if lattice_names:
self._lattice_names = lattice_names
else:
self._lattice_names = self._name

self.description = description

self._peer: "ElementHolder" = None # Peer: ControlSystem, Simulator

def get_name(self) -> str:
"""
Returns the name of the element
"""
@property
def name(self) -> str:
return self._name

# TODO: implement name setter -> this requires checking so the name is unique

def get_name(self) -> str:
warnings.warn(
"get_name() is deprecated; use .name instead",
DeprecationWarning,
stacklevel=2,
)
return self.name

@property
def lattice_names(self) -> str:
return self._lattice_names

# TODO: implement lattice_names setter -> this requires validation of the forma

def get_lattice_names(self) -> str:
"""
Returns the name of associated lattice element(s)
"""
if not hasattr(self, "_cfg"):
return self._name
else:
return self._cfg.lattice_names
warnings.warn(
"get_lattice_names() is deprecated; use .lattice_names instead",
DeprecationWarning,
stacklevel=2,
)
return self.lattice_names

def get_description(self) -> str:
"""
Returns the description of the element
"""
return self._cfg.description
warnings.warn(
"get_description() is deprecated; use .description instead",
DeprecationWarning,
stacklevel=2,
)
return self.description

def set_energy(self, E: float):
"""
Expand Down
63 changes: 63 additions & 0 deletions pyaml/validation/validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import inspect
from typing import Any, get_type_hints

from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema


def add_schema(cls):
# Get the attributes
sig = inspect.signature(cls.__init__)

# Filter out self, *args and **kwargs
params = [
p
for p in sig.parameters.values()
if p.name != "self"
and p.kind
not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
)
]

# Get the type annotations
hints = get_type_hints(cls.__init__, include_extras=True)

for p in params:
if p.name not in hints:
raise TypeError(f"{cls.__name__}.__init__ parameter '{p.name}' must be annotated.")

@classmethod
def __get_pydantic_core_schema__(target_cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
fields: dict[str, core_schema.TypedDictField] = {}

for p in params:
annotation = hints.get(p.name)
field_schema = handler.generate_schema(annotation)

fields[p.name] = core_schema.typed_dict_field(
field_schema,
required=(p.default is inspect._empty),
)

typed_dict = core_schema.typed_dict_schema(fields)

# Validate and create an object of the class
# This is required to handle nested objects
def validate(value, inner_validator):
# Allow already-created instances to pass through unchanged
if isinstance(value, target_cls):
return value

# Validate the data
data = inner_validator(value)

# Create an object of the class
return target_cls(**data)

return core_schema.no_info_wrap_validator_function(validate, typed_dict)

# Add the method on the class
cls.__get_pydantic_core_schema__ = __get_pydantic_core_schema__
return cls
Loading