Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 74 additions & 41 deletions pyaml/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import os
from typing import Optional

from pydantic import BaseModel, ConfigDict, Field

Expand All @@ -11,6 +12,7 @@
from .common.exception import PyAMLConfigException
from .configuration.factory import Factory
from .configuration.fileloader import load, set_root_folder
from .configuration.validation import validator
from .control.controlsystem import ControlSystem
from .lattice.simulator import Simulator

Expand All @@ -36,13 +38,13 @@ class ConfigModel(BaseModel):
can access several control systems
simulators : list[Simulator], optional
Simulator list
data_folder : str
data_folder : str, optional
Data folder
arrays : list[ArrayConfig], optional
Element family
description : str , optional
Acceleration description
devices : list[.common.element.Element]
devices : list[Element], optional
Element list
"""

Expand All @@ -51,51 +53,72 @@ class ConfigModel(BaseModel):
facility: str
machine: str
energy: float
controls: list[ControlSystem] = None
simulators: list[Simulator] = None
data_folder: str
description: str | None = None
arrays: list[ArrayConfig] = Field(default=None, repr=False)
devices: list[Element] = Field(repr=False)
controls: Optional[list[ControlSystem]] = None
simulators: Optional[list[Simulator]] = None
data_folder: Optional[str] = None
description: Optional[str | None] = None
arrays: Optional[list[ArrayConfig]] = Field(default=None, repr=False)
devices: Optional[list[Element]] = Field(default=None, repr=False)


@validator(ConfigModel)
class Accelerator(object):
"""PyAML top level class"""

def __init__(self, cfg: ConfigModel):
self._cfg = cfg
"""Accelerator class"""

def __init__(
self,
facility: str,
machine: str,
energy: float,
controls: Optional[list[ControlSystem]] = None,
simulators: Optional[list[Simulator]] = None,
data_folder: Optional[str] = None,
description: Optional[str | None] = None,
arrays: Optional[list[ArrayConfig]] = None,
devices: Optional[list[Element]] = None,
):
__design = None
__live = None

if cfg.controls is not None:
for c in cfg.controls:
self._facility = facility
self._machine = machine
self._energy = energy
self._controls = controls
self._simulators = simulators
self._data_folder = data_folder
self._description = description
self._arrays = arrays
self._devices = devices

if self._controls is not None:
for c in self._controls:
if c.name() == "live":
self.__live = c
else:
# Add as dynacmic attribute
setattr(self, c.name(), c)
c.fill_device(cfg.devices)
c.fill_device(self._devices)

if cfg.simulators is not None:
for s in cfg.simulators:
if self._simulators is not None:
for s in self._simulators:
if s.name() == "design":
self.__design = s
else:
# Add as dynacmic attribute
setattr(self, s.name(), s)
s.fill_device(cfg.devices)
s.fill_device(self._devices)

if cfg.arrays is not None:
for a in cfg.arrays:
if cfg.simulators is not None:
for s in cfg.simulators:
if self._arrays is not None:
for a in self._arrays:
if self._simulators is not None:
for s in self._simulators:
a.fill_array(s)
if cfg.controls is not None:
for c in cfg.controls:
if self._controls is not None:
for c in self._controls:
a.fill_array(c)

if cfg.energy is not None:
self.set_energy(cfg.energy)
if self._energy is not None:
self.set_energy(self._energy)

self.post_init()

Expand All @@ -108,29 +131,29 @@ def set_energy(self, E: float):
E : float
Energy value to set in eV
"""
if self._cfg.simulators is not None:
for s in self._cfg.simulators:
if self._simulators is not None:
for s in self._simulators:
s.set_energy(E)
if self._cfg.controls is not None:
for c in self._cfg.controls:
if self._controls is not None:
for c in self._controls:
c.set_energy(E)

def post_init(self):
"""
Method triggered after all initialisations are done
"""
if self._cfg.simulators is not None:
for s in self._cfg.simulators:
if self._simulators is not None:
for s in self._simulators:
s.post_init()
if self._cfg.controls is not None:
for c in self._cfg.controls:
if self._controls is not None:
for c in self._controls:
c.post_init()

def get_description(self) -> str:
"""
Returns the description of the accelerator
"""
return self._cfg.description
return self._description

@property
def live(self) -> ControlSystem:
Expand All @@ -156,18 +179,23 @@ def design(self) -> Simulator:
"""
return self.__design

def __repr__(self):
return repr(self._cfg).replace("ConfigModel", self.__class__.__name__)
# TODO: make this generic when no config model might exist
# def __repr__(self):
# return repr(self._cfg).replace("ConfigModel", self.__class__.__name__)

@staticmethod
def from_dict(config_dict: dict, ignore_external=False) -> "Accelerator":
def from_dict(
config_dict: dict, validate=True, ignore_external=False
) -> "Accelerator":
"""
Construct an accelerator from a dictionary.

Parameters
----------
config_dict : str
Dictionary containing accelerator config
validate : bool
Validate the input
ignore_external: bool
Ignore external modules and return None for object that
cannot be created. pydantic schema that support that an
Expand All @@ -179,11 +207,14 @@ def from_dict(config_dict: dict, ignore_external=False) -> "Accelerator":
config_dict.pop("controls", None)
# Ensure factory is clean before building a new accelerator
Factory.clear()
return Factory.depth_first_build(config_dict, ignore_external)
return Factory.depth_first_build(config_dict, validate, ignore_external)

@staticmethod
def load(
filename: str, use_fast_loader: bool = False, ignore_external=False
filename: str,
validate: bool = True,
use_fast_loader: bool = False,
ignore_external=False,
) -> "Accelerator":
"""
Load an accelerator from a config file.
Expand All @@ -192,6 +223,8 @@ def load(
----------
filename : str
Configuration file name, yaml or json.
validate : bool
Validate the input.
use_fast_loader : bool
Use fast yaml loader. When specified,
no line number are reported in case of error,
Expand All @@ -209,4 +242,4 @@ def load(
rootfolder = os.path.abspath(os.path.dirname(filename))
set_root_folder(rootfolder)
config_dict = load(os.path.basename(filename), None, use_fast_loader)
return Accelerator.from_dict(config_dict)
return Accelerator.from_dict(config_dict, validate, ignore_external)
74 changes: 24 additions & 50 deletions pyaml/configuration/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,7 @@ def remove_strategy(self, strategy: BuildStrategy):
"""Register a plugin-based strategy for object creation."""
self._strategies.remove(strategy)

def handle_validation_error(
self, e, type_str: str, location_str: str, field_locations: dict
):
# Handle pydantic errors
globalMessage = ""
for err in e.errors():
msg = err["msg"]
field = ""
if len(err["loc"]) == 2:
field, fieldIdx = err["loc"]
message = f"'{field}.{fieldIdx}': {msg}"
else:
field = err["loc"][0]
message = f"'{field}': {msg}"
if field_locations and field in field_locations:
file, line, col = field_locations[field]
loc = f"{file} at line {line}, colum {col}"
message += f" {loc}"
globalMessage += message
globalMessage += ", "
# Discard pydantic stack trace
raise PyAMLConfigException(
f"{globalMessage} for object: '{type_str}' {location_str}"
) from None

def build_object(self, d: dict, ignore_external: bool = False):
def build_object(self, d: dict, validate=True, ignore_external: bool = False):
"""Build an object from the dict"""
location = d.pop("__location__", None)
field_locations = d.pop("__fieldlocations__", None)
Expand Down Expand Up @@ -112,12 +87,6 @@ def build_object(self, d: dict, ignore_external: bool = False):
) from e

# Default loading strategy
# Get the config object
config_cls = getattr(module, "ConfigModel", None)
if config_cls is None:
raise PyAMLConfigException(
f"ConfigModel class '{type_str}.ConfigModel' not found {location_str}"
)

# Get the class name
cls_name = getattr(module, "PYAMLCLASS", None)
Expand All @@ -126,28 +95,33 @@ def build_object(self, d: dict, ignore_external: bool = False):
f"PYAMLCLASS definition not found in '{type_str}' {location_str}"
)

try:
# Validate the model
cfg = config_cls.model_validate(d)
except ValidationError as e:
self.handle_validation_error(e, type_str, location_str, field_locations)

# Construct and return the object
elem_cls = getattr(module, cls_name, None)
if elem_cls is None:
# Get the class
cls = getattr(module, cls_name, None)
if cls is None:
raise PyAMLConfigException(f"Unknown element class '{type_str}.{cls_name}'")

try:
obj = elem_cls(cfg)
self.register_element(obj)
except Exception as e:
raise PyAMLConfigException(
f"{str(e)} when creating '{type_str}.{cls_name}' {location_str}"
) from e
# Valide/not validate and create the object
if validate:
try:
obj = cls.from_validated(d)
except Exception as e:
raise PyAMLConfigException(
f"Error creating {type_str}.{cls_name} at {location_str}: {e}"
) from e
else:
try:
obj = cls(**d)
except Exception as e:
raise PyAMLConfigException(
f"{str(e)} when creating '{type_str}.{cls_name}' {location_str}"
) from e

# Register the element
self.register_element(obj)

return obj

def depth_first_build(self, d, ignore_external: bool):
def depth_first_build(self, d, validate: bool, ignore_external: bool):
"""
Main factory function (Depth-first factory)

Expand Down Expand Up @@ -177,7 +151,7 @@ def depth_first_build(self, d, ignore_external: bool):
d[key] = obj

# We are now on leaf (no nested object), we can construct
return self.build_object(d, ignore_external)
return self.build_object(d, validate, ignore_external)

raise PyAMLConfigException(
"Unexpected element found. 'dict' or 'list' expected "
Expand Down
53 changes: 53 additions & 0 deletions pyaml/configuration/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import functools
import inspect
import warnings
from typing import Type

from pydantic import BaseModel, ValidationError

from ..common.exception import PyAMLConfigException


def validator(model: Type[BaseModel]):
def decorator(cls):
# TODO: add check so input model is of right type

# Add validation model
cls._validation_model = model

@classmethod
def validate(cls, data: dict) -> BaseModel:
if cls._validation_model is None:
raise PyAMLConfigException(
f"No validation model has been specified for "
f"{cls.__module__}.{cls.__name__} so validation is not possible."
)

try:
validated = cls._validation_model(**data)
except ValidationError as e:
errors = handle_validation_error(e)
raise PyAMLConfigException(errors) from None

return validated

@classmethod
def from_validated(cls, data: dict):
validated = cls.validate(data)
return cls(**validated.model_dump())

cls.validate = validate
cls.from_validated = from_validated

return cls

return decorator


def handle_validation_error(e):
errors = []
for err in e.errors():
field = ".".join(str(x) for x in err["loc"])
errors.append(f"{field}: {err['msg']}")

return errors
Loading