diff --git a/pyPRMS/control/Control.py b/pyPRMS/control/Control.py index 051889e..a2eda65 100644 --- a/pyPRMS/control/Control.py +++ b/pyPRMS/control/Control.py @@ -1,19 +1,12 @@ #!/usr/bin/env python3 -import io import numpy as np import operator import pandas as pd # type: ignore -import pkgutil import re -import xml.etree.ElementTree as xmlET from typing import Dict, List, Optional, Sequence, Union # OrderedDict as OrderedDictType, -from networkx.utils.misc import check_create_using - -# from rich import pretty -# from rich.console import Console from rich.table import Table from .ControlVariable import ControlVariable @@ -21,13 +14,10 @@ from ..constants import (ctl_order, ctl_implicit_modules, internal_module_map, MetaDataType, VAR_DELIM, PTYPE_TO_PRMS_TYPE) from ..base.console import get_console_instance +from pyPRMS.prms_helpers import cond_check con = None -cond_check = {'=': operator.eq, - '>': operator.gt, - '<': operator.lt} - class Control(object): """ Class object for a collection of control variables. @@ -97,7 +87,7 @@ def cbh_files(self) -> List[str]: """ # List of control variables that specify possible CBH files - ctl_cbh_files = ['albebo_day', 'cloud_cover_day', 'humidity_day', 'potet_day', 'precip_day', + ctl_cbh_files = ['albedo_day', 'cloud_cover_day', 'humidity_day', 'potet_day', 'precip_day', 'swrad_day', 'tmax_day', 'tmin_day', 'transp_day', 'windspeed_day'] cbh_files = [] @@ -199,7 +189,6 @@ def add(self, name: str): # , meta=None): """Add a control variable by name. :param name: Name of the control variable - :param datatype: The datatype of the control variable :raises ControlError: if control variable already exists """ @@ -296,59 +285,55 @@ def write(self, filename: str): :param filename: Name of control file to create """ - outfile = open(filename, 'w') - - if self.__header is not None: - for hh in self.__header: - outfile.write(f'{hh}\n') - - order = ['datatype', 'values'] - - # Get set of variables in ctl_order that are missing from control_vars - setdiff = set(self.__control_vars.keys()).difference(set(ctl_order)) - - # Add missing control variables (setdiff) in ctl_order to the end of the list - ctl_order.extend(list(setdiff)) - - for kk in ctl_order: - if self.exists(kk): - cvar = self.get(kk) - - outfile.write(f'{VAR_DELIM}\n') - outfile.write(f'{kk}\n') - - for item in order: - if cvar.meta['datatype'] == 'datetime': - date_tmp = [int(xx) for xx in re.split(r'[-T:.]+', str(cvar.values))[0:6]] - - if item == 'datatype': - outfile.write(f'{len(date_tmp)}\n') - outfile.write(f'{PTYPE_TO_PRMS_TYPE[cvar.meta["datatype"]]}\n') - if item == 'values': - for cval in date_tmp: - outfile.write(f'{cval}\n') - else: - if item == 'datatype': - outfile.write(f'{cvar.size}\n') - outfile.write(f'{PTYPE_TO_PRMS_TYPE[cvar.meta["datatype"]]}\n') - if item == 'values': - if cvar.meta['context'] == 'scalar': - # Single-values (e.g. int, float, str) - # print(type(cvar.values)) - if isinstance(cvar.values, np.bytes_): - print("BYTES") - outfile.write(f'{cvar.values.decode()}\n') + with open(filename, 'w') as outfile: + if self.__header is not None: + for hh in self.__header: + outfile.write(f'{hh}\n') + + order = ['datatype', 'values'] + + # Get set of variables in ctl_order that are missing from control_vars + setdiff = set(self.__control_vars.keys()).difference(set(ctl_order)) + + # Add missing control variables (setdiff) in ctl_order to the end of the list + ctl_order.extend(list(setdiff)) + + for kk in ctl_order: + if self.exists(kk): + cvar = self.get(kk) + + outfile.write(f'{VAR_DELIM}\n') + outfile.write(f'{kk}\n') + + for item in order: + if cvar.meta['datatype'] == 'datetime': + date_tmp = [int(xx) for xx in re.split(r'[-T:.]+', str(cvar.values))[0:6]] + + if item == 'datatype': + outfile.write(f'{len(date_tmp)}\n') + outfile.write(f'{PTYPE_TO_PRMS_TYPE[cvar.meta["datatype"]]}\n') + if item == 'values': + for cval in date_tmp: + outfile.write(f'{cval}\n') + else: + if item == 'datatype': + outfile.write(f'{cvar.size}\n') + outfile.write(f'{PTYPE_TO_PRMS_TYPE[cvar.meta["datatype"]]}\n') + if item == 'values': + if cvar.meta['context'] == 'scalar': + # Single-values (e.g. int, float, str) + if isinstance(cvar.values, np.bytes_): + print("BYTES") + outfile.write(f'{cvar.values.decode()}\n') + else: + outfile.write(f'{cvar.values}\n') else: - outfile.write(f'{cvar.values}\n') - else: - # Multiple-values - if isinstance(cvar.values, np.ndarray): - for cval in cvar.values: - outfile.write(f'{cval}\n') - else: - outfile.write(f'{cvar.values}\n') - - outfile.close() + # Multiple-values + if isinstance(cvar.values, np.ndarray): + for cval in cvar.values: + outfile.write(f'{cval}\n') + else: + outfile.write(f'{cvar.values}\n') def write_metadata_csv(self, filename: str, sep: str = '\t'): """Writes the control metadata to a CSV file""" @@ -411,7 +396,7 @@ def _check_condition(self, cstr: str) -> bool: def _read(self): """Abstract function for reading. """ - assert False, 'Control._read() must be defined by child class' + raise NotImplementedError('Control._read() must be defined by child class') def _preload_metadata(self): # Create an entry for each variable in the control section of diff --git a/pyPRMS/control/ControlFile.py b/pyPRMS/control/ControlFile.py index 670ef5a..b614f70 100644 --- a/pyPRMS/control/ControlFile.py +++ b/pyPRMS/control/ControlFile.py @@ -34,12 +34,10 @@ def __init__(self, filename: Union[str, Path], self.__isloaded = False self.__include_missing = include_missing - if isinstance(filename, str): - filename = Path(filename) self.filename = filename @property - def filename(self) -> Union[str, Path]: + def filename(self) -> Path: """Get control filename. :returns: Name of control file @@ -54,7 +52,7 @@ def filename(self, filename: Union[str, Path]): """ self.__isloaded = False - self.__filename = filename + self.__filename = filename if isinstance(filename, Path) else Path(filename) self._read() def _read(self): diff --git a/pyPRMS/control/ControlVariable.py b/pyPRMS/control/ControlVariable.py index 9b2cb50..a313bd9 100644 --- a/pyPRMS/control/ControlVariable.py +++ b/pyPRMS/control/ControlVariable.py @@ -2,7 +2,7 @@ import datetime import numpy as np -import re +import numpy.typing as npt from typing import Callable, Dict, List, Optional, Sequence, Union from ..constants import NEW_PTYPE_TO_DTYPE @@ -18,13 +18,15 @@ class ControlVariable(object): # Create date: 2019-04-18 def __init__(self, name: str, - value = None, + value: Optional[Union[npt.NDArray, np.int32, np.float32, np.float64, np.str_]] = None, meta: Optional[Dict] = None, strict: Optional[bool] = True): """Initialize a control variable object. :param name: Name of control variable + :param value: Value(s) of control variable :param meta: Metadata of the control variable + :param strict: Enforce use of valid control variable metadata """ self.__name = name @@ -172,37 +174,22 @@ def value_meaning(self) -> Union[str, None]: # return meaning.get(self.values, meaning.get(str(self.values), None)) def _value_meaning_test(self, key, src_dict): - try: + """Look up the meaning for a value in a dictionary that may use + direct keys or conditional keys (e.g. '>0', '<5'). + """ + # Direct lookup by value + if key in src_dict: return src_dict[key] - except KeyError: - # Maybe the key is a string - try: - return src_dict[str(key)] - except KeyError: - # Maybe one of the keys is a conditional? - patterns = ['[><]'] - regex = [re.compile('^' + pat).match for pat in patterns] - - tt = {kk: vv for kk, vv in src_dict.items() - if any (reg(kk) for reg in regex)} - - if len(tt) > 0: - # So there is a conditional - for mm in tt: - # print(mm.split()) - if cond_check[mm[0]](key, int(mm[1:])): - return src_dict[mm] - # print(f'{mm}: {src_dict[mm]}') - raise ValueError('Invalid control value') - - - # try: - # if 'valid_values' in self.meta: - # # We want a KeyError here if the key is missing - # return self.meta['valid_values'][self.values] - # - # return None - # except KeyError: - # # Try again but return None if the key is still missing - # return self.meta['valid_values'].get(str(self.values), None) - # return None \ No newline at end of file + + # Try string representation of the key + str_key = str(key) + if str_key in src_dict: + return src_dict[str_key] + + # Check conditional keys (e.g. ">0", "<100") + for cond_key, meaning in src_dict.items(): + if cond_key and cond_key[0] in '><': + if cond_check[cond_key[0]](key, int(cond_key[1:])): + return meaning + + raise ValueError('Invalid control value') \ No newline at end of file diff --git a/tests/func/test_Control.py b/tests/func/test_Control.py index f98bcb4..33043d9 100644 --- a/tests/func/test_Control.py +++ b/tests/func/test_Control.py @@ -199,7 +199,7 @@ def test_to_dict(self, control_object): def test_control_read_method_is_abstract(self, control_object): """The Control class _read() method is abstract""" - with pytest.raises(AssertionError): + with pytest.raises(NotImplementedError): control_object._read() def test_default_header(self, control_object): @@ -226,7 +226,8 @@ def test_set_header_with_none(self, control_object): def test_cbh_files(self, control_object): """Check the default set of CBH files""" - expected = ['cloudcover.day', + expected = ['albedo.day', + 'cloudcover.day', 'humidity.day', 'potet.day', 'precip.day',