Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ jobs:
pip install accelerator-toolbox
pip install pyaml
pip install flake8 pytest
pip install ruamel.yaml
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
Expand Down
78 changes: 53 additions & 25 deletions pyaml/configuration/fileloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,69 @@

logger = logging.getLogger(__name__)

#TODO
#Implement cycle detection in case of wrong yaml/json link that creates a cycle
accepted_suffixes = [".yaml", ".yml", ".json"]

def load(filename:str) -> Union[dict,list]:
class PyAMLConfigCyclingException(PyAMLException):

def __init__(self, error_filename:str, path_stack:list[Path]):
self.error_filename = error_filename
parent_file_stack = [parent_path.name for parent_path in path_stack]
super().__init__(f"Circular file inclusion of {error_filename}. File list before reaching it: {parent_file_stack}")
pass

def load(filename:str, paths_stack:list=None) -> Union[dict,list]:
"""Load recursively a configuration setup"""
if filename.endswith(".yaml") or filename.endswith(".yml"):
l = YAMLLoader(filename)
l = YAMLLoader(filename, paths_stack)
elif filename.endswith(".json"):
l = JSONLoader(filename)
l = JSONLoader(filename, paths_stack)
else:
raise PyAMLException(f"{filename} File format not supported (only .yaml .yml or .json)")
return l.load(filename)
return l.load()

# Expand condition
def hasToExpand(value):
return isinstance(value,str) and any(value.endswith(suffix) for suffix in accepted_suffixes)


# Loader base class (nested files expansion)
class Loader:

def __init__(self, filename:str):
self.suffixes = []
def __init__(self, filename:str, parent_path_stack:list[Path]):
self.path:Path = get_root_folder() / filename
self.files_stack:list[Path] = []
if parent_path_stack:
if any(self.path.samefile(parent_path) for parent_path in parent_path_stack):
raise PyAMLConfigCyclingException(filename, parent_path_stack)
self.files_stack.extend(parent_path_stack)
self.files_stack.append(self.path)

# Expand condition
def hasToExpand(self,value):
return isinstance(value,str) and any(value.endswith(suffix) for suffix in self.suffixes)

# Recursively expand a dict
def expand_dict(self,d:dict):
for key, value in d.items():
if self.hasToExpand(value):
d[key] = load(value)
else:
self.expand(value)
try:
if hasToExpand(value):
d[key] = load(value, self.files_stack)
else:
self.expand(value)
except PyAMLConfigCyclingException as pyaml_ex:
location = d.pop('__location__', None)
field_locations = d.pop('__fieldlocations__', None)
location_str = ""
if location:
file, line, col = location
if field_locations and key in field_locations:
location = field_locations[key]
file, line, col = location
location_str = f" in {file} at line {line}, column {col}"
raise PyAMLException(f"Circular file inclusion of {pyaml_ex.error_filename}{location_str}") from pyaml_ex

# Recursively expand a list
def expand_list(self,l:list):
for idx,value in enumerate(l):
if self.hasToExpand(value):
l[idx] = load(value)
if hasToExpand(value):
l[idx] = load(value, self.files_stack)
else:
self.expand(value)

Expand All @@ -65,7 +91,7 @@ def expand(self,obj: Union[dict,list]):

# Load a file
def load(self) -> Union[dict,list]:
raise Exception(str(self.path) + ": load() method not implemented")
raise PyAMLException(str(self.path) + ": load() method not implemented")

class SafeLineLoader(SafeLoader):

Expand Down Expand Up @@ -93,22 +119,24 @@ def construct_mapping(self, node, deep=False):

# YAML loader
class YAMLLoader(Loader):

def load(self,fileName:str) -> Union[dict,list]:
self.path:Path = get_root_folder() / fileName
self.suffixes = [".yaml",".yml"]
def __init__(self, filename: str, parent_paths_stack:list):
super().__init__(filename, parent_paths_stack)

def load(self) -> Union[dict,list]:
logger.log(logging.DEBUG, f"Loading YAML file '{self.path}'")
with open(self.path) as file:
try:
return self.expand(yaml.load(file,Loader=SafeLineLoader))
except yaml.YAMLError as e:
raise Exception(self.path + ": " + str(e))
raise PyAMLException(str(self.path) + ": " + str(e)) from e

# JSON loader
class JSONLoader(Loader):
def __init__(self, filename: str, parent_paths_stack:list):
super().__init__(filename, parent_paths_stack)

def load(self,fileName:str) -> Union[dict,list]:
def load(self) -> Union[dict,list]:
logger.log(logging.DEBUG, f"Loading JSON file '{self.path}'")
self.suffixes = [".json"]
with open(self.path) as file:
try:
return self.expand(json.load(file))
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ dependencies = [
"scipy>=1.7.3",
"accelerator-toolbox>=0.6.1",
"PyYAML>=6.0.2",
"pydantic>=2.11.7",
"ruamel.yaml"
"pydantic>=2.11.7"
]

[project.optional-dependencies]
Expand Down
32 changes: 32 additions & 0 deletions tests/config/bad_conf_cycles.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"type": "pyaml.pyaml",
"instruments":[
{
"type": "pyaml.instrument",
"name": "sr",
"energy": 6e9,
"simulators": ["../config/bad_conf_cycles.json"],
"data_folder": "/data/store",
"arrays": [
{
"type": "pyaml.arrays.hcorrector",
"name": "HCORR",
"elements": [
"SH1A-C01-H",
"SH1A-C02-H"
]
},
{
"type": "pyaml.arrays.vcorrector",
"name": "VCORR",
"elements": ["SH1A-C01-V", "SH1A-C02-V"]
}
],
"devices": [
"sr/quadrupoles/QF1AC01.yaml",
"sr/correctors/SH1AC01.yaml",
"sr/correctors/SH1AC02.yaml"
]
}
]
}
26 changes: 26 additions & 0 deletions tests/config/bad_conf_cycles.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
type: pyaml.pyaml
instruments:
- type: pyaml.instrument
name: sr
energy: 6e9
simulators:
- type: pyaml.lattice.simulator
lattice: sr/lattices/ebs.mat
name: design
data_folder: /data/store
arrays:
- type: pyaml.arrays.hcorrector
name: HCORR
elements:
- SH1A-C01-H
- SH1A-C02-H
- type: pyaml.arrays.vcorrector
name: VCORR
elements:
- SH1A-C01-V
- SH1A-C02-V
devices:
- ../config/bad_conf_cycles.yml # Cycle here
- sr/quadrupoles/QF1AC01.yaml
- sr/correctors/SH1AC01.yaml
- sr/correctors/SH1AC02.yaml
28 changes: 22 additions & 6 deletions tests/test_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from pyaml import PyAMLConfigException
from pyaml import PyAMLConfigException, PyAMLException
from pyaml.configuration.factory import Factory
from pyaml.pyaml import PyAML, pyaml
from tests.conftest import MockElement
Expand Down Expand Up @@ -28,9 +28,25 @@ def test_factory_with_custom_strategy():
assert obj.name == "custom_injected"


def test_error_location():

@pytest.mark.parametrize("test_file", [
"tests/config/bad_conf.yml",
])
def test_error_location(test_file):
with pytest.raises(PyAMLConfigException) as exc:
ml: PyAML = pyaml("tests/config/bad_conf.yml")

assert "at line 7, column 9" in str(exc.value)
ml: PyAML = pyaml(test_file)
print(str(exc.value))
test_file_names = test_file.split("/")
test_file_name = test_file_names[len(test_file_names)-1]
assert f"{test_file_name} at line 7, column 9" in str(exc.value)

@pytest.mark.parametrize("test_file", [
"tests/config/bad_conf_cycles.yml",
"tests/config/bad_conf_cycles.json",
])
def test_error_cycles(test_file):
with pytest.raises(PyAMLException) as exc:
ml: PyAML = pyaml(test_file)

assert "Circular file inclusion of " in str(exc.value)
if not test_file.endswith(".json"):
assert "at line 23" in str(exc.value)
Loading