diff --git a/pyaml/lattice/attribute_linker.py b/pyaml/lattice/attribute_linker.py new file mode 100644 index 00000000..a9fa36cb --- /dev/null +++ b/pyaml/lattice/attribute_linker.py @@ -0,0 +1,68 @@ +import at +from pydantic import ConfigDict + +from pyaml.lattice.element import Element +from pyaml.lattice.lattice_elements_linker import LinkerIdentifier, LinkerConfigModel, LatticeElementsLinker + +PYAMLCLASS = "PyAtAttributeElementsLinker" + + +class ConfigModel(LinkerConfigModel): + """Base configuration model for linker definitions. + + This class defines the configuration structure used to instantiate + a specific linking strategy. Each concrete implementation of a + `LatticeElementsLinker` may define its own subclass extending this model + to include additional configuration parameters. + + Attributes + ---------- + model_config : ConfigDict + Pydantic configuration allowing arbitrary field types and forbidding + unexpected extra keys. + """ + model_config = ConfigDict(arbitrary_types_allowed=True,extra="forbid") + attribute_name: str + + +class PyAtAttributeIdentifier(LinkerIdentifier): + """Abstract base class for identifiers used to match PyAML and PyAT elements. + + The identifier acts as an intermediate representation between the PyAML + configuration and the PyAT lattice. Its exact structure depends on the + linking strategy (e.g., family name, element index, or user-defined tag). + + Subclasses should define the fields and logic necessary to represent + a unique reference to one or more PyAT elements. + """ + + def __init__(self, attribute_name:str, identifier): + self.attribute_name = attribute_name + self.identifier = identifier + + def __repr__(self): + return f"{self.attribute_name}={self.identifier}" + + +class PyAtAttributeElementsLinker(LatticeElementsLinker): + """Abstract base class defining the interface for PyAT–PyAML element linking. + + Implementations of this class define how PyAML elements are matched + to PyAT elements based on a given linking strategy (e.g., by family name, + by index, or by a custom attribute). + + Parameters + ---------- + config_model : ConfigModel + The configuration model for the linking strategy. + """ + + def __init__(self, config_model:ConfigModel): + super().__init__(config_model) + + def get_element_identifier(self, element: Element) -> LinkerIdentifier: + return PyAtAttributeIdentifier(self.linker_config_model.attribute_name, element.name) + + def _test_at_element(self, identifier: PyAtAttributeIdentifier, element: at.Element) -> bool: + attr_value = getattr(element, identifier.attribute_name, None) + return attr_value == identifier.identifier diff --git a/pyaml/lattice/lattice_elements_linker.py b/pyaml/lattice/lattice_elements_linker.py new file mode 100644 index 00000000..103000af --- /dev/null +++ b/pyaml/lattice/lattice_elements_linker.py @@ -0,0 +1,140 @@ +from abc import ABCMeta, abstractmethod +from typing import Iterable + +import at +from at import Lattice +from pydantic import BaseModel, ConfigDict + +from pyaml import PyAMLException +from pyaml.lattice.element import Element + + +class LinkerConfigModel(BaseModel): + """Base configuration model for linker definitions. + + This class defines the configuration structure used to instantiate + a specific linking strategy. Each concrete implementation of a + `LatticeElementsLinker` may define its own subclass extending this model + to include additional configuration parameters. + + Attributes + ---------- + model_config : ConfigDict + Pydantic configuration allowing arbitrary field types and forbidding + unexpected extra keys. + """ + model_config = ConfigDict(arbitrary_types_allowed=True,extra="forbid") + + +class LinkerIdentifier(metaclass=ABCMeta): + """Abstract base class for identifiers used to match PyAML and PyAT elements. + + The identifier acts as an intermediate representation between the PyAML + configuration and the PyAT lattice. Its exact structure depends on the + linking strategy (e.g., family name, element index, or user-defined tag). + + Subclasses should define the fields and logic necessary to represent + a unique reference to one or more PyAT elements. + """ + pass + + +class LatticeElementsLinker(metaclass=ABCMeta): + """Abstract base class defining the interface for PyAT–PyAML element linking. + + Implementations of this class define how PyAML elements are matched + to PyAT elements based on a given linking strategy (e.g., by family name, + by index, or by a custom attribute). + + Parameters + ---------- + linker_config_model : LinkerConfigModel + The configuration model for the linking strategy. + + Attributes + ---------- + lattice : at.Lattice + Reference to the PyAT lattice handled by this linker. + """ + + def __init__(self, linker_config_model:LinkerConfigModel): + self.linker_config_model = linker_config_model + self.lattice:Lattice = None + + def set_lattice(self, lattice:Lattice): + self.lattice = lattice + + @abstractmethod + def _test_at_element(self, identifier: LinkerIdentifier, element:at.Element) -> bool: + pass + + @abstractmethod + def get_element_identifier(self, element:Element) -> LinkerIdentifier: + pass + + def _iter_matches(self, identifier: LinkerIdentifier) -> Iterable[at.Element]: + """Yield all elements in the lattice whose matches the identifier.""" + for elem in self.lattice: + if self._test_at_element(identifier, elem): + yield elem + + def get_at_elements(self,element_id:LinkerIdentifier|list[LinkerIdentifier]) -> list[at.Element]: + """Return a list of PyAT elements matching the given identifiers. + + This method should resolve one or multiple PyAML identifiers + into their corresponding PyAT elements according to the specific + linking strategy implemented. + + Parameters + ---------- + element_id : LinkerIdentifier or list of LinkerIdentifier + One or several identifiers describing which PyAT elements + to retrieve. + + Returns + ------- + list of at.Element + The list of matching PyAT elements found in the lattice. + + Raises + ------ + PyAMLException + If no element matches the given identifier(s). + """ + if isinstance(element_id, LinkerIdentifier): + identifiers = [element_id] + else: + identifiers = element_id + + results: list[at.Element] = [] + for ident in identifiers: + results.extend(self._iter_matches(ident)) + + if not results: + raise PyAMLException( + f"No PyAT elements found for identifier(s): " + f"{', '.join(i.__repr__() for i in identifiers)}" + ) + return results + + def get_at_element(self, element_id:LinkerIdentifier) -> at.Element: + """Return a single PyAT element matching the given identifier. + + Parameters + ---------- + element_id : LinkerIdentifier + Identifier describing the PyAT element to retrieve. + + Returns + ------- + at.Element + The PyAT element matching the identifier. + + Raises + ------ + PyAMLException + If no element matches the identifier. + """ + for elem in self._iter_matches(element_id): + return elem + raise PyAMLException(f"No PyAT element found for FamName: {element_id.__repr__()}") \ No newline at end of file diff --git a/pyaml/lattice/simulator.py b/pyaml/lattice/simulator.py index 33bad69c..450ce745 100644 --- a/pyaml/lattice/simulator.py +++ b/pyaml/lattice/simulator.py @@ -1,5 +1,8 @@ from pydantic import BaseModel,ConfigDict import at + +from .attribute_linker import PyAtAttributeElementsLinker, ConfigModel as PyAtAttrLinkerConfigModel +from .lattice_elements_linker import LatticeElementsLinker from ..configuration import get_root_folder from .element import Element from pathlib import Path @@ -22,6 +25,8 @@ class ConfigModel(BaseModel): """AT lattice file""" mat_key: str = None """AT lattice ring name""" + linker: LatticeElementsLinker = None + """The linker configuration model""" class Simulator(ElementHolder): """ @@ -31,6 +36,7 @@ class Simulator(ElementHolder): def __init__(self, cfg: ConfigModel): super().__init__() self._cfg = cfg + self._linker = cfg.linker if cfg.linker else PyAtAttributeElementsLinker(PyAtAttrLinkerConfigModel(attribute_name="FamName")) path:Path = get_root_folder() / cfg.lattice if(self._cfg.mat_key is None): @@ -38,6 +44,8 @@ def __init__(self, cfg: ConfigModel): else: self.ring = at.load_lattice(path,mat_key=f"{self._cfg.mat_key}") + self._linker.set_lattice(self.ring) + def name(self) -> str: return self._cfg.name @@ -54,23 +62,24 @@ def fill_device(self,elements:list[Element]): for e in elements: # Need conversion to physics unit to work with simulator if isinstance(e,Magnet): - current = RWHardwareScalar(self.get_at_elems(e.name),e.polynom,e.model) if e.model.has_physics() else None - strength = RWStrengthScalar(self.get_at_elems(e.name),e.polynom,e.model) if e.model.has_physics() else None + current = RWHardwareScalar(self.get_at_elems(e),e.polynom,e.model) if e.model.has_physics() else None + strength = RWStrengthScalar(self.get_at_elems(e),e.polynom,e.model) if e.model.has_physics() else None # Create a unique ref for this simulator m = e.attach(strength,current) self.add_magnet(str(m),m) elif isinstance(e,CombinedFunctionMagnet): self.add_magnet(str(e),e) - currents = RWHardwareArray(self.get_at_elems(e.name),e.polynoms,e.model) if e.model.has_physics() else None - strengths = RWStrengthArray(self.get_at_elems(e.name),e.polynoms,e.model) if e.model.has_physics() else None + currents = RWHardwareArray(self.get_at_elems(e),e.polynoms,e.model) if e.model.has_physics() else None + strengths = RWStrengthArray(self.get_at_elems(e),e.polynoms,e.model) if e.model.has_physics() else None # Create unique refs of each function for this simulator ms = e.attach(strengths,currents) for m in ms: self.add_magnet(str(m),m) self.add_magnet(str(m),m) - def get_at_elems(self,elementName:str) -> list[at.Element]: - elementList = [e for e in self.ring if e.FamName == elementName] - if not elementList: - raise Exception(f"{elementName} not found in lattice:{self._cfg.lattice}") - return elementList + def get_at_elems(self,element:Element) -> list[at.Element]: + identifier = self._linker.get_element_identifier(element) + element_list = self._linker.get_at_elements(identifier) + if not element_list: + raise Exception(f"{identifier} not found in lattice:{self._cfg.lattice}") + return element_list diff --git a/tests/config/sr-attribute-linker.yaml b/tests/config/sr-attribute-linker.yaml new file mode 100644 index 00000000..e43a7f38 --- /dev/null +++ b/tests/config/sr-attribute-linker.yaml @@ -0,0 +1,28 @@ +type: pyaml.pyaml +instruments: + - type: pyaml.instrument + name: sr + energy: 6e9 + simulators: + - type: pyaml.lattice.simulator + lattice: sr/lattices/ebs.mat + name: design + linker: + type: pyaml.lattice.attribute_linker + attribute_name: FamName # equivalent to the default linker + data_folder: /data/store + arrays: + - type: pyaml.arrays.hcorrector + name: HCORR + elements: + - SH1A-C01-H + - SH1A-C02-H + - type: pyaml.arrays.vcorrector + name: VCORR + elements: + - SH1A-C01-V + - SH1A-C02-V + devices: + - sr/quadrupoles/QF1AC01.yaml + - sr/correctors/SH1AC01.yaml + - sr/correctors/SH1AC02.yaml \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 6265f020..be3675eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,6 @@ import types + +import at import pytest import subprocess import sys @@ -165,3 +167,27 @@ def register_mock_strategy(): Factory.register_strategy(strategy) yield Factory.remove_strategy(strategy) + + +# ----------------------- +# Linkers fixtures +# ----------------------- + + +@pytest.fixture +def lattice_with_famnames() -> at.Lattice: + """Lattice with duplicate FamName to test multi-match and first-element behavior.""" + qf1 = at.elements.Quadrupole('QF_1', 0.2); qf1.FamName = 'QF' + qf2 = at.elements.Quadrupole('QF_2', 0.25); qf2.FamName = 'QF' + qd1 = at.elements.Quadrupole('QD_1', 0.3); qd1.FamName = 'QD' + return at.Lattice([qf1, qf2, qd1], energy=3e9) + + +@pytest.fixture +def lattice_with_custom_attr() -> at.Lattice: + """Lattice where a custom attribute (e.g., 'Tag') is set on elements.""" + d1 = at.elements.Drift('D1', 1.0); setattr(d1, "Tag", "D1") + qf = at.elements.Quadrupole('QF', 0.2); setattr(qf, "Tag", "QF") + qf2 = at.elements.Quadrupole('QF2', 0.2); setattr(qf2, "Tag", "QF") + qd = at.elements.Quadrupole('QD', 0.3); setattr(qd, "Tag", "QD") + return at.Lattice([d1, qf, qf2, qd], energy=3e9) diff --git a/tests/test_linkers.py b/tests/test_linkers.py new file mode 100644 index 00000000..007d66d6 --- /dev/null +++ b/tests/test_linkers.py @@ -0,0 +1,88 @@ +import pytest + +from pyaml import PyAMLException +from pyaml.instrument import Instrument +from pyaml.lattice.element_holder import MagnetType + +from pyaml.lattice.attribute_linker import ( + PyAtAttributeElementsLinker, + PyAtAttributeIdentifier, + ConfigModel as AttrConfigModel, +) +from pyaml.pyaml import PyAML, pyaml + + +# ----------------------- +# Dummy PyAML Element +# ----------------------- + +class DummyPyAMLElement: + """Minimal stand-in for a PyAML Element: only provides .name.""" + def __init__(self, name: str): + self.name = name + + +def test_conf_with_linker(): + ml:PyAML = pyaml("tests/config/sr-attribute-linker.yaml") + sr:Instrument = ml.get('sr') + assert sr is not None + magnet = sr.design.get_magnet(MagnetType.HCORRECTOR,"SH1A-C01-H") + assert magnet is not None + + +# ----------------------- +# PyAtAttributeElementsLinker tests +# ----------------------- + +def test_attribute_identifier_from_pyaml_name(lattice_with_custom_attr): + # We bind to AT element attribute 'Tag'; identifier value comes from PyAML element .name + linker = PyAtAttributeElementsLinker(AttrConfigModel(attribute_name="Tag")) + linker.set_lattice(lattice_with_custom_attr) + pyaml_elem = DummyPyAMLElement(name="QF") # identifier="QF" + ident = linker.get_element_identifier(pyaml_elem) + assert isinstance(ident, PyAtAttributeIdentifier) + assert ident.attribute_name == "Tag" + assert ident.identifier == "QF" + + +def test_attribute_get_at_elements_all_matches(lattice_with_custom_attr): + linker = PyAtAttributeElementsLinker(AttrConfigModel(attribute_name="Tag")) + linker.set_lattice(lattice_with_custom_attr) + ident = PyAtAttributeIdentifier("Tag", "QF") + matches = linker.get_at_elements(ident) + # There are two elements with Tag == "QF" + assert len(matches) == 2 + assert all(getattr(e, "Tag", None) == "QF" for e in matches) + + +def test_attribute_get_at_element_first_match(lattice_with_custom_attr): + linker = PyAtAttributeElementsLinker(AttrConfigModel(attribute_name="Tag")) + linker.set_lattice(lattice_with_custom_attr) + ident = PyAtAttributeIdentifier("Tag", "QD") + first = linker.get_at_element(ident) + assert getattr(first, "Tag", None) == "QD" + # Ensure it's the first with Tag == QD in lattice order + for e in lattice_with_custom_attr: + if getattr(e, "Tag", None) == "QD": + assert first is e + break + + +def test_attribute_no_match_raises(lattice_with_custom_attr): + linker = PyAtAttributeElementsLinker(AttrConfigModel(attribute_name="Tag")) + linker.set_lattice(lattice_with_custom_attr) + ident = PyAtAttributeIdentifier("Tag", "ZZ") + with pytest.raises(PyAMLException): + _ = linker.get_at_elements(ident) + with pytest.raises(PyAMLException): + _ = linker.get_at_element(ident) + + +def test_attribute_multiple_identifiers_accumulate(lattice_with_custom_attr): + linker = PyAtAttributeElementsLinker(AttrConfigModel(attribute_name="Tag")) + linker.set_lattice(lattice_with_custom_attr) + ids = [PyAtAttributeIdentifier("Tag", "QF"), PyAtAttributeIdentifier("Tag", "QD")] + res = linker.get_at_elements(ids) + tags = [getattr(e, "Tag", None) for e in res] + assert tags.count("QF") == 2 and tags.count("QD") == 1 + assert len(res) == 3