From a3088f0b5ab9a323f03b090bae7e7264ba7411a0 Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Tue, 10 Mar 2026 16:18:00 +0100 Subject: [PATCH] First changes to add a validator decorator. So far only modified for accelerator. --- pyaml/accelerator.py | 115 +++++++++++++++++++----------- pyaml/configuration/factory.py | 74 +++++++------------ pyaml/configuration/validation.py | 53 ++++++++++++++ 3 files changed, 151 insertions(+), 91 deletions(-) create mode 100644 pyaml/configuration/validation.py diff --git a/pyaml/accelerator.py b/pyaml/accelerator.py index 09bf35f9..1965fa68 100644 --- a/pyaml/accelerator.py +++ b/pyaml/accelerator.py @@ -3,6 +3,7 @@ """ import os +from typing import Optional from pydantic import BaseModel, ConfigDict, Field @@ -11,6 +12,7 @@ from .common.exception import PyAMLConfigException from .configuration.factory import Factory from .configuration.fileloader import load, set_root_folder +from .configuration.validation import validator from .control.controlsystem import ControlSystem from .lattice.simulator import Simulator @@ -36,13 +38,13 @@ class ConfigModel(BaseModel): can access several control systems simulators : list[Simulator], optional Simulator list - data_folder : str + data_folder : str, optional Data folder arrays : list[ArrayConfig], optional Element family description : str , optional Acceleration description - devices : list[.common.element.Element] + devices : list[Element], optional Element list """ @@ -51,51 +53,72 @@ class ConfigModel(BaseModel): facility: str machine: str energy: float - controls: list[ControlSystem] = None - simulators: list[Simulator] = None - data_folder: str - description: str | None = None - arrays: list[ArrayConfig] = Field(default=None, repr=False) - devices: list[Element] = Field(repr=False) + controls: Optional[list[ControlSystem]] = None + simulators: Optional[list[Simulator]] = None + data_folder: Optional[str] = None + description: Optional[str | None] = None + arrays: Optional[list[ArrayConfig]] = Field(default=None, repr=False) + devices: Optional[list[Element]] = Field(default=None, repr=False) +@validator(ConfigModel) class Accelerator(object): - """PyAML top level class""" - - def __init__(self, cfg: ConfigModel): - self._cfg = cfg + """Accelerator class""" + + def __init__( + self, + facility: str, + machine: str, + energy: float, + controls: Optional[list[ControlSystem]] = None, + simulators: Optional[list[Simulator]] = None, + data_folder: Optional[str] = None, + description: Optional[str | None] = None, + arrays: Optional[list[ArrayConfig]] = None, + devices: Optional[list[Element]] = None, + ): __design = None __live = None - if cfg.controls is not None: - for c in cfg.controls: + self._facility = facility + self._machine = machine + self._energy = energy + self._controls = controls + self._simulators = simulators + self._data_folder = data_folder + self._description = description + self._arrays = arrays + self._devices = devices + + if self._controls is not None: + for c in self._controls: if c.name() == "live": self.__live = c else: # Add as dynacmic attribute setattr(self, c.name(), c) - c.fill_device(cfg.devices) + c.fill_device(self._devices) - if cfg.simulators is not None: - for s in cfg.simulators: + if self._simulators is not None: + for s in self._simulators: if s.name() == "design": self.__design = s else: # Add as dynacmic attribute setattr(self, s.name(), s) - s.fill_device(cfg.devices) + s.fill_device(self._devices) - if cfg.arrays is not None: - for a in cfg.arrays: - if cfg.simulators is not None: - for s in cfg.simulators: + if self._arrays is not None: + for a in self._arrays: + if self._simulators is not None: + for s in self._simulators: a.fill_array(s) - if cfg.controls is not None: - for c in cfg.controls: + if self._controls is not None: + for c in self._controls: a.fill_array(c) - if cfg.energy is not None: - self.set_energy(cfg.energy) + if self._energy is not None: + self.set_energy(self._energy) self.post_init() @@ -108,29 +131,29 @@ def set_energy(self, E: float): E : float Energy value to set in eV """ - if self._cfg.simulators is not None: - for s in self._cfg.simulators: + if self._simulators is not None: + for s in self._simulators: s.set_energy(E) - if self._cfg.controls is not None: - for c in self._cfg.controls: + if self._controls is not None: + for c in self._controls: c.set_energy(E) def post_init(self): """ Method triggered after all initialisations are done """ - if self._cfg.simulators is not None: - for s in self._cfg.simulators: + if self._simulators is not None: + for s in self._simulators: s.post_init() - if self._cfg.controls is not None: - for c in self._cfg.controls: + if self._controls is not None: + for c in self._controls: c.post_init() def get_description(self) -> str: """ Returns the description of the accelerator """ - return self._cfg.description + return self._description @property def live(self) -> ControlSystem: @@ -156,11 +179,14 @@ def design(self) -> Simulator: """ return self.__design - def __repr__(self): - return repr(self._cfg).replace("ConfigModel", self.__class__.__name__) + # TODO: make this generic when no config model might exist + # def __repr__(self): + # return repr(self._cfg).replace("ConfigModel", self.__class__.__name__) @staticmethod - def from_dict(config_dict: dict, ignore_external=False) -> "Accelerator": + def from_dict( + config_dict: dict, validate=True, ignore_external=False + ) -> "Accelerator": """ Construct an accelerator from a dictionary. @@ -168,6 +194,8 @@ def from_dict(config_dict: dict, ignore_external=False) -> "Accelerator": ---------- config_dict : str Dictionary containing accelerator config + validate : bool + Validate the input ignore_external: bool Ignore external modules and return None for object that cannot be created. pydantic schema that support that an @@ -179,11 +207,14 @@ def from_dict(config_dict: dict, ignore_external=False) -> "Accelerator": config_dict.pop("controls", None) # Ensure factory is clean before building a new accelerator Factory.clear() - return Factory.depth_first_build(config_dict, ignore_external) + return Factory.depth_first_build(config_dict, validate, ignore_external) @staticmethod def load( - filename: str, use_fast_loader: bool = False, ignore_external=False + filename: str, + validate: bool = True, + use_fast_loader: bool = False, + ignore_external=False, ) -> "Accelerator": """ Load an accelerator from a config file. @@ -192,6 +223,8 @@ def load( ---------- filename : str Configuration file name, yaml or json. + validate : bool + Validate the input. use_fast_loader : bool Use fast yaml loader. When specified, no line number are reported in case of error, @@ -209,4 +242,4 @@ def load( rootfolder = os.path.abspath(os.path.dirname(filename)) set_root_folder(rootfolder) config_dict = load(os.path.basename(filename), None, use_fast_loader) - return Accelerator.from_dict(config_dict) + return Accelerator.from_dict(config_dict, validate, ignore_external) diff --git a/pyaml/configuration/factory.py b/pyaml/configuration/factory.py index cc79b89c..efe0a6d6 100644 --- a/pyaml/configuration/factory.py +++ b/pyaml/configuration/factory.py @@ -45,32 +45,7 @@ def remove_strategy(self, strategy: BuildStrategy): """Register a plugin-based strategy for object creation.""" self._strategies.remove(strategy) - def handle_validation_error( - self, e, type_str: str, location_str: str, field_locations: dict - ): - # Handle pydantic errors - globalMessage = "" - for err in e.errors(): - msg = err["msg"] - field = "" - if len(err["loc"]) == 2: - field, fieldIdx = err["loc"] - message = f"'{field}.{fieldIdx}': {msg}" - else: - field = err["loc"][0] - message = f"'{field}': {msg}" - if field_locations and field in field_locations: - file, line, col = field_locations[field] - loc = f"{file} at line {line}, colum {col}" - message += f" {loc}" - globalMessage += message - globalMessage += ", " - # Discard pydantic stack trace - raise PyAMLConfigException( - f"{globalMessage} for object: '{type_str}' {location_str}" - ) from None - - def build_object(self, d: dict, ignore_external: bool = False): + def build_object(self, d: dict, validate=True, ignore_external: bool = False): """Build an object from the dict""" location = d.pop("__location__", None) field_locations = d.pop("__fieldlocations__", None) @@ -112,12 +87,6 @@ def build_object(self, d: dict, ignore_external: bool = False): ) from e # Default loading strategy - # Get the config object - config_cls = getattr(module, "ConfigModel", None) - if config_cls is None: - raise PyAMLConfigException( - f"ConfigModel class '{type_str}.ConfigModel' not found {location_str}" - ) # Get the class name cls_name = getattr(module, "PYAMLCLASS", None) @@ -126,28 +95,33 @@ def build_object(self, d: dict, ignore_external: bool = False): f"PYAMLCLASS definition not found in '{type_str}' {location_str}" ) - try: - # Validate the model - cfg = config_cls.model_validate(d) - except ValidationError as e: - self.handle_validation_error(e, type_str, location_str, field_locations) - - # Construct and return the object - elem_cls = getattr(module, cls_name, None) - if elem_cls is None: + # Get the class + cls = getattr(module, cls_name, None) + if cls is None: raise PyAMLConfigException(f"Unknown element class '{type_str}.{cls_name}'") - try: - obj = elem_cls(cfg) - self.register_element(obj) - except Exception as e: - raise PyAMLConfigException( - f"{str(e)} when creating '{type_str}.{cls_name}' {location_str}" - ) from e + # Valide/not validate and create the object + if validate: + try: + obj = cls.from_validated(d) + except Exception as e: + raise PyAMLConfigException( + f"Error creating {type_str}.{cls_name} at {location_str}: {e}" + ) from e + else: + try: + obj = cls(**d) + except Exception as e: + raise PyAMLConfigException( + f"{str(e)} when creating '{type_str}.{cls_name}' {location_str}" + ) from e + + # Register the element + self.register_element(obj) return obj - def depth_first_build(self, d, ignore_external: bool): + def depth_first_build(self, d, validate: bool, ignore_external: bool): """ Main factory function (Depth-first factory) @@ -177,7 +151,7 @@ def depth_first_build(self, d, ignore_external: bool): d[key] = obj # We are now on leaf (no nested object), we can construct - return self.build_object(d, ignore_external) + return self.build_object(d, validate, ignore_external) raise PyAMLConfigException( "Unexpected element found. 'dict' or 'list' expected " diff --git a/pyaml/configuration/validation.py b/pyaml/configuration/validation.py new file mode 100644 index 00000000..eb189210 --- /dev/null +++ b/pyaml/configuration/validation.py @@ -0,0 +1,53 @@ +import functools +import inspect +import warnings +from typing import Type + +from pydantic import BaseModel, ValidationError + +from ..common.exception import PyAMLConfigException + + +def validator(model: Type[BaseModel]): + def decorator(cls): + # TODO: add check so input model is of right type + + # Add validation model + cls._validation_model = model + + @classmethod + def validate(cls, data: dict) -> BaseModel: + if cls._validation_model is None: + raise PyAMLConfigException( + f"No validation model has been specified for " + f"{cls.__module__}.{cls.__name__} so validation is not possible." + ) + + try: + validated = cls._validation_model(**data) + except ValidationError as e: + errors = handle_validation_error(e) + raise PyAMLConfigException(errors) from None + + return validated + + @classmethod + def from_validated(cls, data: dict): + validated = cls.validate(data) + return cls(**validated.model_dump()) + + cls.validate = validate + cls.from_validated = from_validated + + return cls + + return decorator + + +def handle_validation_error(e): + errors = [] + for err in e.errors(): + field = ".".join(str(x) for x in err["loc"]) + errors.append(f"{field}: {err['msg']}") + + return errors