diff --git a/CITATION.cff b/CITATION.cff index 717cc8c..73a322d 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -2,7 +2,7 @@ cff-version: 1.2.0 message: If you use this software, please cite it using the metadata from this file. type: software title: 'pycoupler: dynamic model coupling of LPJmL' -version: 1.6.5 +version: 1.7.0 date-released: '2025-09-22' abstract: An LPJmL-Python interface for operating LPJmL in a Python environment and coupling it with Python models, programmes or simple programming scripts. diff --git a/README.md b/README.md index ea7e9f3..ba6d34e 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ a Python environment and coupling it with Python models, programmes or simple programming scripts. *pycoupler* was written to establish a coupled World-Earth modeling framework, [copan:LPJmL](https://github.com/pik-copan/pycopanlpjml), -based on [copan:CORE](https://github.com/pik-copan/pycopancore/) and LPJmL. +based on [copan:CORE](https://github.com/pik-copan/pycopancore) and LPJmL. Coupling with LPJmL is possible on an annual basis, i.e. for each year in which LPJmL is in coupling mode, the desired inputs must be passed via *pycoupler*. diff --git a/pycoupler/__init__.py b/pycoupler/__init__.py index df59fcb..a963650 100644 --- a/pycoupler/__init__.py +++ b/pycoupler/__init__.py @@ -11,9 +11,15 @@ read_yaml, ) -from .coupler import LPJmLCoupler +from .coupler import LPJmLCoupler, kill_process_on_port -from .run import run_lpjml, submit_lpjml, check_lpjml +from .run import ( + run_lpjml, + submit_lpjml, + check_lpjml, + start_lpjml, + kill_stale_lpjml_processes, +) from .data import ( LPJmLData, @@ -32,7 +38,6 @@ detect_io_type, ) - __all__ = [ "LpjmlConfig", "CoupledConfig", @@ -40,7 +45,7 @@ "read_yaml", ] -__all__ += [LPJmLCoupler] +__all__ += ["LPJmLCoupler", "kill_process_on_port"] __all__ += [ "LPJmLData", @@ -52,7 +57,13 @@ "get_headersize", ] -__all__ += ["run_lpjml", "submit_lpjml", "check_lpjml"] +__all__ += [ + "run_lpjml", + "start_lpjml", + "submit_lpjml", + "check_lpjml", + "kill_stale_lpjml_processes", +] # noqa: E501 __all__ += [ "get_countries", diff --git a/pycoupler/config.py b/pycoupler/config.py index d005366..de20d4e 100644 --- a/pycoupler/config.py +++ b/pycoupler/config.py @@ -1,4 +1,5 @@ -"""Classes and functions to handle LPJmL configurations and related operations""" +"""Classes and functions to handle LPJmL configurations and related +operations""" import os import sys @@ -19,8 +20,8 @@ class SubConfig: Parameters ---------- config_dict : dict - Dictionary (ideally an LPJmL config dictionary) used to build up a nested - LpjmLConfig class with corresponding fields. + Dictionary (ideally an LPJmL config dictionary) used to build up a + nested LpjmLConfig class with corresponding fields. """ def __init__(self, config_dict): @@ -239,14 +240,16 @@ def set_transient( dependency : str, optional Name of simulation to depend on (e.g., spinup run). temporal_resolution : str or dict, default "annual" - Temporal resolution for outputs. Can be a dict of temporal resolutions - corresponding to outputs or a str to set the same resolution for - all outputs. Choose between "annual", "monthly", "daily". + Temporal resolution for outputs. Can be a dict of temporal + resolutions corresponding to outputs or a str to set the same + resolution for all outputs. Choose between "annual", "monthly", + "daily". write_output : list, default [] Output IDs to be written by LPJmL. Check available outputs with get_output_avail(). write_file_format : str, default "cdf" - File format of output files. Choose between "raw", "clm", and "cdf". + File format of output files. Choose between "raw", "clm", and + "cdf". append_output : bool, default True If True, defined output entries are appended. If False, existing outputs are overwritten. @@ -279,13 +282,14 @@ def set_coupled( coupled_input, coupled_output, sim_name="coupled", + *, + coupled_config=None, dependency=None, coupled_year=None, temporal_resolution="annual", write_output=[], write_file_format="cdf", append_output=True, - model_name="copan:CORE", ): """ Set configuration required for coupled model runs. @@ -306,6 +310,8 @@ def set_coupled( Provide output ID as identifier. sim_name : str, default "coupled" Name of the simulation. + coupled_config : str, optional + Path to coupled config file. dependency : str, optional Name of simulation to depend on (e.g., transient run). coupled_year : int, optional @@ -319,12 +325,11 @@ def set_coupled( Output IDs to be written by LPJmL. Check available outputs with get_output_avail(). write_file_format : str, default "cdf" - File format of output files. Choose between "raw", "clm", and "cdf". + File format of output files. Choose between "raw", "clm", and + "cdf". append_output : bool, default True If True, defined output entries are appended. If False, existing outputs are overwritten. - model_name : str, default "copan:CORE" - Name of the coupled model. """ self.sim_name = sim_name self.sim_path = create_subdirs(sim_path, self.sim_name) @@ -350,14 +355,19 @@ def set_coupled( ) # set coupling parameters self._set_coupling( - inputs=coupled_input, - outputs=coupled_output, - start_year=coupled_year, - model_name=model_name, + inputs=coupled_input, outputs=coupled_output, start_year=coupled_year ) # set start from directory to start from historic run self._set_startfrom(path=f"{sim_path}/restart", dependency=dependency) + # add coupled config if provided + if coupled_config: + self.add_config(coupled_config) + if hasattr(self.coupled_config, "model") and isinstance( + self.coupled_config.model, str + ): + self.coupled_model = self.coupled_config.model + def _set_output( self, output_path, @@ -401,7 +411,9 @@ def _set_output( outputs.append("grid") # create dict of outputvar names with indexes for iteration - outputvar_names = {ov.name: pos for pos, ov in enumerate(self.outputvar)} + outputvar_names = { + ov.name: pos for pos, ov in enumerate(self.outputvar) + } # noqa: E501 # extract dict of outputvar for manipulation outputvars = self.to_dict()["outputvar"] @@ -423,7 +435,9 @@ def _set_output( isinstance(temporal_resolution, dict) and out.id in temporal_resolution.keys() ): - self.output[pos].file.timestep = temporal_resolution[out.id] + self.output[pos].file.timestep = temporal_resolution[ + out.id + ] # noqa: E501 if out.id not in nonvariable_outputs: self.output[pos].file.fmt = file_format @@ -539,10 +553,9 @@ def _set_grid_explicitly(self, only_all=True): - 1 ) - def _set_coupling( - self, inputs, outputs, start_year=None, model_name="copan:CORE" - ): # noqa - """Coupled settings - no spinup, not write restart file and set sockets""" + def _set_coupling(self, inputs, outputs, start_year=None, model_name="copan"): + """Coupled settings - no spinup, not write restart file and set + sockets""" self.write_restart = False self.nspinup = 0 self.float_grid = True @@ -558,6 +571,8 @@ def _set_input_sockets(self, inputs=[]): """Set sockets for inputs and outputs (via corresponding ids)""" for inp in inputs: sock_input = getattr(self.input, inp) + if not hasattr(sock_input, "__dict__"): + continue # skip scalars (e.g. delta_year in LPJmL v6) if "id" not in sock_input.__dict__.keys(): raise ValueError("Please use a config with input ids.") sock_input.__dict__["socket"] = True @@ -568,7 +583,9 @@ def _set_outputsockets(self, outputs=[]): outputs.append("grid") # get names/ids only of outputs that are defined in outputvar - valid_outs = {out.name for out in self.outputvar if out.name in outputs} + valid_outs = { + out.name for out in self.outputvar if out.name in outputs + } # noqa: E501 # check if all outputs are valid nonvalid_outputs = list(set(outputs) - valid_outs) @@ -579,7 +596,9 @@ def _set_outputsockets(self, outputs=[]): ) # get position of valid outputs in config output list output_pos = [ - pos for pos, out in enumerate(self.output) if out.id in valid_outs + pos + for pos, out in enumerate(self.output) + if out.id in valid_outs # noqa: E501 ] # set socket to true for corresponding outputs @@ -601,13 +620,17 @@ def get_input_sockets(self, id_only=False): return [ inp for inp in inputs - if ("socket" in inputs[inp]) and inputs[inp]["socket"] + if isinstance(inputs[inp], dict) + and ("socket" in inputs[inp]) + and inputs[inp]["socket"] ] else: return { inp: inputs[inp] for inp in inputs - if ("socket" in inputs[inp]) and inputs[inp]["socket"] + if isinstance(inputs[inp], dict) + and ("socket" in inputs[inp]) + and inputs[inp]["socket"] } def get_output_sockets(self, id_only=False): @@ -645,7 +668,9 @@ def add_config(self, file_name): """ self.coupled_config = read_yaml(file_name, CoupledConfig) - def regrid(self, sim_path, model_path=None, country_code="BEL", overwrite=False): + def regrid( + self, sim_path, model_path=None, country_code="BEL", overwrite=False + ): # noqa: E501 """ Regrid LPJmL configuration file to a new country. @@ -701,11 +726,11 @@ def regrid(self, sim_path, model_path=None, country_code="BEL", overwrite=False) if country in self.input.coord.name: return - country_grid_file = ( - f"{sim_path}/input/{country}_{os.path.basename(self.input.coord.name)}" - ) + country_grid_file = f"{sim_path}/input/{country}_{os.path.basename(self.input.coord.name)}" # noqa: E501 # check if country specific input files already exist - if (not os.path.isfile(country_grid_file) or overwrite) and not hasattr( + if ( + not os.path.isfile(country_grid_file) or overwrite + ) and not hasattr( # noqa: E501 sys, "_called_from_test" ): @@ -742,12 +767,14 @@ def regrid(self, sim_path, model_path=None, country_code="BEL", overwrite=False) else f"{self.inpath}/{self.input.lakes.name}" ) # extract country specific lakes file from meta file - if self.input.lakes.fmt == "meta" and not hasattr(sys, "_called_from_test"): + if self.input.lakes.fmt == "meta" and not hasattr( + sys, "_called_from_test" + ): # noqa: E501 lakes_filename = read_json(lakes_fn_string)["filename"] lakes_file = lakes_fn_string lakes_file = ( - f"{lakes_file[:lakes_file.rfind('/')+1]}{lakes_filename}" # noqa + f"{lakes_file[:lakes_file.rfind('/')+1]}{lakes_filename}" # noqa: E501 ) else: lakes_file = lakes_fn_string @@ -757,7 +784,9 @@ def regrid(self, sim_path, model_path=None, country_code="BEL", overwrite=False) ) # check if country specific input files already exist - if (not os.path.isfile(country_lakes_file) or overwrite) and not hasattr( + if ( + not os.path.isfile(country_lakes_file) or overwrite + ) and not hasattr( # noqa: E501 sys, "_called_from_test" ): @@ -789,9 +818,9 @@ def regrid(self, sim_path, model_path=None, country_code="BEL", overwrite=False) for config_key, config_input in self.input: if ( - config_input.fmt != "clm" + getattr(config_input, "fmt", None) != "clm" or config_key in ["coord", "lakes"] - or (config_input.name == "DUMMYLOCATION") + or getattr(config_input, "name", None) == "DUMMYLOCATION" ): continue @@ -807,7 +836,9 @@ def regrid(self, sim_path, model_path=None, country_code="BEL", overwrite=False) ) # check if country specific input files already exist - if (not os.path.isfile(country_input_file) or overwrite) and not hasattr( + if ( + not os.path.isfile(country_input_file) or overwrite + ) and not hasattr( # noqa: E501 sys, "_called_from_test" ): @@ -833,12 +864,19 @@ def regrid(self, sim_path, model_path=None, country_code="BEL", overwrite=False) # regrid_cmd.insert(1, additional_arg) run(regrid_cmd, check=True, stdout=open(os.devnull, "wb")) - config_input.fmt = ( - detect_io_type(country_input_file) - if not hasattr(sys, "_called_from_test") - else "clm" - ) + # Keep original fmt ("clm"): regridclm always outputs clm format. + # Re-detecting via detect_io_type() can misclassify (e.g. + # slope files as "text" when first bytes are printable ASCII). config_input.name = country_input_file + if ( + os.path.isfile(country_input_file) + and os.path.getsize(country_input_file) == 0 + and not hasattr(sys, "_called_from_test") + ): + raise OSError( + f"Regridded file '{country_input_file}' is empty. " + "Delete it and run regrid(..., overwrite=True) to regenerate." # noqa: E501 + ) self._set_grid_explicitly(only_all=False) @@ -863,12 +901,16 @@ def convert_cdf_to_raw(self, output_id=None): if not os.path.isfile(f"{output_dir}/{grid_name}") and not hasattr( sys, "_called_from_test" ): - run(f"tail -c +44 {grid_file} > {output_dir}/{grid_name}", shell=True) + run( + f"tail -c +44 {grid_file} > {output_dir}/{grid_name}", shell=True + ) # noqa: E501 grid_file = f"{output_dir}/{grid_name}" outputs = [ - out for out in self.get_output(fmt="cdf", id_only=True) if out != "grid" + out + for out in self.get_output(fmt="cdf", id_only=True) + if out != "grid" # noqa: E501 ] if output_id: @@ -962,7 +1004,7 @@ def __repr__(self, sub_repr=0): summary_list = [summary, " (changed)"] summary_list.extend( [ - f" * {torepr}{(20-len(torepr))*' '} {getattr(self, torepr)}" + f" * {torepr}{(20-len(torepr))*' '} {getattr(self, torepr)}" # noqa: E501 for torepr in changed_repr ] ) @@ -1000,7 +1042,10 @@ def __setattr__(self, __name, __value): def parse_config( - file_name="./lpjml_config.json", spin_up=False, macros=None, config_class=None + file_name="./lpjml_config.json", + spin_up=False, + macros=None, + config_class=None, # noqa: E501 ): """ Precompile lpjml_config.json and return LpjmlConfig object or dict. @@ -1056,7 +1101,8 @@ def read_config( Parameters ---------- file_name : str - File name (including relative/absolute path) of the LPJmL configuration. + File name (including relative/absolute path) of the LPJmL + configuration. model_path : str, optional Path to model root directory. If provided, joined with file_name. spin_up : bool, default False @@ -1135,12 +1181,9 @@ def __repr__(self, sub_repr=1, order=1): for key, value in self.__dict__.items(): if isinstance(value, SubConfig): - summary += ( - f"""{' ' * sub_repr}* {key}: {value.__repr__( - sub_repr + 1, order + 1 - )}""".strip() - + spacing - ) + summary += f"""{' ' * sub_repr}* {key}: {value.__repr__( + sub_repr + 1, order + 1 + )}""".strip() + spacing else: summary += ( f"{' ' * sub_repr}* {key:<20} {value}".strip() + spacing diff --git a/pycoupler/coupler.py b/pycoupler/coupler.py index c968036..1e7eb83 100644 --- a/pycoupler/coupler.py +++ b/pycoupler/coupler.py @@ -5,12 +5,14 @@ import tempfile import copy import warnings +import subprocess +import atexit +from contextlib import contextmanager import numpy as np import pandas as pd import xarray as xr -from subprocess import run from enum import Enum from pycoupler.config import read_config @@ -22,9 +24,87 @@ read_meta, read_data, read_header, + DEFAULT_NETCDF_FILL_VALUE_INT, ) from pycoupler.utils import get_countries +# Port cleanup utilities ==================================================== # + + +def kill_process_on_port(port): + """Kill any process using the specified port.""" + try: + # Find processes using the port + result = subprocess.run( + ["lsof", "-ti", f":{port}"], capture_output=True, text=True, timeout=5 + ) + if result.returncode == 0 and result.stdout.strip(): + pids = result.stdout.strip().split("\n") + killed_count = 0 + for pid in pids: + if pid.strip(): + try: + kill_result = subprocess.run( + ["kill", "-9", pid.strip()], + timeout=5, + capture_output=True, + ) + if kill_result.returncode == 0: + killed_count += 1 + except subprocess.TimeoutExpired: + # Ignore timeout errors during best-effort port cleanup. + pass + return killed_count + return 0 + except ( + subprocess.TimeoutExpired, + subprocess.CalledProcessError, + FileNotFoundError, + ): + return -1 + + +def cleanup_port_on_exit(port): + """Register a cleanup function for the given port.""" + + def cleanup(): + kill_process_on_port(port) + + atexit.register(cleanup) + + +@contextmanager +def cleanup_port_context(port): + """Context manager that cleans up processes on the port on entry and exit. + + Does not bind or verify the port; use this to ensure stale processes are + killed before and after a block that uses the port. + """ + # Clean up any existing processes on the port first + kill_process_on_port(port) + + try: + yield port + finally: + # Clean up on exit + kill_process_on_port(port) + + +def safe_port_binding(host, port): + """Deprecated. Use :func:`cleanup_port_context` instead. + + The host parameter is ignored. Kept for backward compatibility. + """ + import warnings + + warnings.warn( + "safe_port_binding(host, port) is deprecated; host is ignored. " + "Use cleanup_port_context(port) instead.", + DeprecationWarning, + stacklevel=2, + ) + return cleanup_port_context(port) + # class for testing purposes class test_channel: @@ -206,6 +286,9 @@ def opentdt(host, port): if hasattr(sys, "_called_from_test"): channel = test_channel() else: + # Clean up any existing processes on the port first + kill_process_on_port(port) + # create an INET, STREAMing socket serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -286,13 +369,15 @@ def __init__(self, config_file, version=3, host="localhost", port=2224): if hasattr(self._config, "coupled_host") and hasattr( self._config, "coupled_port" ): - if host != "localhost" or port != 2224: + config_host = self._config.coupled_host + config_port = self._config.coupled_port + if (host, port) != (config_host, config_port): warnings.warn( "Host and port are set in configuration file. " "Provided host and port are ignored." ) - host = self._config.coupled_host - port = self._config.coupled_port + host = config_host + port = config_port # initiate socket connection to LPJmL self._init_channel(version, host, port) @@ -439,7 +524,7 @@ def get_historic_years(self, match_period=True): generator Generator for all historic years. """ - start_year = self._sim_year + start_year = self._config.firstyear end_year = self.config.start_coupling if match_period and start_year >= end_year: raise ValueError( @@ -519,7 +604,8 @@ def code_to_name(self, to_iso_alpha_3=False): if static_output == "country" and to_iso_alpha_3: country_dict = get_countries() name_dict = { - idx: country_dict[reg]["code"] for idx, reg in name_dict.items() + idx: country_dict[reg]["code"] + for idx, reg in name_dict.items() # noqa: E501 } getattr(self, f"{static_output}").attrs[ "long_name" @@ -549,18 +635,25 @@ def read_historic_output(self, to_xarray=True): Returns ------- dict or xarray.DataArray - Dictionary with output keys and corresponding output as numpy arrays + Dictionary with output keys and corresponding output as numpy + arrays """ # read all historic outputs - hist_years = list() + output_dict = {} + output_years = list() for year in self.get_historic_years(): - hist_years.append(year) - if year == self._config.outputyear: - output_dict = self.read_output(year=year, to_xarray=False) - elif year > self._config.outputyear: + if year >= self._config.outputyear: output_dict = append_to_dict( output_dict, self.read_output(year=year, to_xarray=False) ) + output_years.append(year) + + if not output_dict: + raise ValueError( + f"No historic output found. outputyear {self._config.outputyear} " + f"was not encountered in get_historic_years() " + f"(e.g., outputyear < firstyear)." + ) for key in output_dict: if key in output_dict: @@ -568,11 +661,11 @@ def read_historic_output(self, to_xarray=True): item[0] for item in self._output_ids.items() if item[1] == key ][0] lpjml_output = self._create_xarray_template( - index, time_length=len(hist_years) + index, time_length=len(output_years) ) lpjml_output.coords["time"] = pd.date_range( - start=str(hist_years[0]), end=str(hist_years[-1] + 1), freq="YE" + start=str(output_years[0]), end=str(output_years[-1] + 1), freq="YE" ) lpjml_output.data = output_dict[key] output_dict[key] = lpjml_output @@ -583,8 +676,20 @@ def read_historic_output(self, to_xarray=True): return output_dict def close(self): - """Close socket channel""" - self._channel.close() + """Close socket channel and clean up port""" + if hasattr(self, "_channel") and self._channel: + self._channel.close() + + # Clean up any processes still using the port + if hasattr(self, "_config") and hasattr(self._config, "coupled_port"): + kill_process_on_port(self._config.coupled_port) + + def __del__(self): + """Destructor to ensure cleanup on object deletion""" + try: + self.close() + except Exception: + pass # Ignore errors during cleanup def send_input(self, input_dict, year): """Send input data of iterated year as dictionary to LPJmL. @@ -592,8 +697,8 @@ def send_input(self, input_dict, year): Parameters ---------- input_dict : dict - Dictionary of input keys and corresponding input in the form of numpy - arrays with dimensions (ncell, nband) + Dictionary of input keys and corresponding input in the form of + numpy arrays with dimensions (ncell, nband) year : int Year for which input data is to be sent. @@ -618,7 +723,9 @@ def send_input(self, input_dict, year): # steps (analogous to years left) # Year check - if procvided year matches internal simulation year if year != self._sim_year: - raise ValueError(f"Year {year} not matches simulated year {self._sim_year}") + raise ValueError( + f"Year {year} not matches simulated year {self._sim_year}" + ) # noqa: E501 operations = self.operations_left # Check if read_output operation valid @@ -687,7 +794,9 @@ def read_output(self, year, to_xarray=True): if key not in self._static_ids } # noqa - output_iterations = sum(steps for steps in self._output_steps if steps > 0) + output_iterations = sum( + steps for steps in self._output_steps if steps > 0 + ) # noqa: E501 # Perform subannual (monthly, daily) read_output operation output = self._iterate_operation( @@ -748,7 +857,8 @@ def read_input(self, start_year=None, end_year=None, copy=True): key=key, ) # TODO: add support for input meta data file (if available) - # inputs[key].add_meta() to replace workaround with assign_band_names + # inputs[key].add_meta() to replace workaround with + # assign_band_names # if no start_year and end_year provided and only one year is supplied # ensure years are the same (although they are not - but to avoid @@ -758,7 +868,9 @@ def read_input(self, start_year=None, end_year=None, copy=True): inputs.values(), key=lambda inp: inp.time.item() ).time.item() # noqa for key in inputs.keys(): - inputs[key] = inputs[key].drop("time").assign_coords({"time": [year]}) + inputs[key] = ( + inputs[key].drop("time").assign_coords({"time": [year]}) + ) # noqa: E501 inputs = LPJmLDataSet(inputs) # define longitide and latitude DataArray (workaround to reduce dims to @@ -773,16 +885,23 @@ def read_input(self, start_year=None, end_year=None, copy=True): lats = self._cached_grid["lats"] if start_year and end_year: - kwargs = {"time": [year for year in range(start_year, end_year + 1)]} + kwargs = { + "time": [year for year in range(start_year, end_year + 1)] + } # noqa: E501 elif start_year and not end_year: kwargs = { "time": [ - year for year in range(start_year, max(inputs.time.values) + 1) + year + for year in range( + start_year, max(inputs.time.values) + 1 + ) # noqa: E501 ] } elif not start_year and end_year: kwargs = { - "time": [year for year in range(min(inputs.time.values), end_year + 1)] + "time": [ + year for year in range(min(inputs.time.values), end_year + 1) + ] # noqa: E501 } else: kwargs = {} @@ -828,7 +947,9 @@ def assign_config_band_names(self, x, key): if len(band_names) % len(self.config.cftmap) == 0: len_irr_systems = len(band_names) // len(self.config.cftmap) else: - len_irr_systems = len(band_names) // len(self.config.landusemap) + len_irr_systems = len(band_names) // len( + self.config.landusemap + ) # noqa: E501 if len_irr_systems == 2: irr_systems = irr_systems_short @@ -880,7 +1001,9 @@ def _copy_input(self, start_year, end_year): for key in sock_inputs: # check if working on the cluster (workaround by Ciaron) # (might be adjusted to the new cluster coming soon ...) - if self.config.inpath and (not sock_inputs[key]["name"].startswith("/")): + if self.config.inpath and ( + not sock_inputs[key]["name"].startswith("/") + ): # noqa: E501 sock_inputs[key][ "name" ] = f"{self.config.inpath}/{sock_inputs[key]['name']}" @@ -914,7 +1037,9 @@ def _copy_input(self, start_year, end_year): # data ends in simulation period cut_start_year = start_year cut_end_year = meta_data.lastyear - elif meta_data.firstyear <= start_year and meta_data.lastyear >= start_year: + elif ( + meta_data.firstyear <= start_year and meta_data.lastyear >= start_year + ): # noqa: E501 # data starts before simulation period, # but simulation is within data period cut_start_year = start_year @@ -927,7 +1052,7 @@ def _copy_input(self, start_year, end_year): f"{temp_dir}/1_{file_name_tmp}", ] if not hasattr(sys, "_called_from_test"): - run(cut_clm_start, stdout=open(os.devnull, "wb")) + subprocess.run(cut_clm_start, stdout=open(os.devnull, "wb")) # predefine cut clm command for reusage # cannot deal with overwriting a temp file with same name @@ -939,7 +1064,7 @@ def _copy_input(self, start_year, end_year): f"{temp_dir}/2_{file_name_tmp}", ] if not hasattr(sys, "_called_from_test"): - run(cut_clm_end, stdout=open(os.devnull, "wb")) + subprocess.run(cut_clm_end, stdout=open(os.devnull, "wb")) # a flag for multi (categorical) band input - if true, set # "-landuse" @@ -958,7 +1083,9 @@ def _copy_input(self, start_year, end_year): if self.config.input.coord.name.startswith("/"): grid_file = self.config.input.coord.name else: - grid_file = f"{self.config.inpath}/{self.config.input.coord.name}" + grid_file = ( + f"{self.config.inpath}/{self.config.input.coord.name}" # noqa: E501 + ) # convert clm input to netcdf files conversion_cmd = [ f"{self._config.model_path}/bin/clm2cdf", @@ -969,12 +1096,10 @@ def _copy_input(self, start_year, end_year): f"{temp_dir}/2_{file_name_tmp}", f"{input_path}/{key}.nc", ] - - if None in conversion_cmd: - conversion_cmd.remove(None) + conversion_cmd = [arg for arg in conversion_cmd if arg is not None] if not hasattr(sys, "_called_from_test"): - run(conversion_cmd) + subprocess.run(conversion_cmd) else: return "tested" # remove the temporary clm (binary) files, 1_* is not created in @@ -998,13 +1123,17 @@ def _init_channel(self, version, host, port): def _init_coupling(self): """Initialize coupling""" # initiate simulation time - self._sim_year = min(self._config.outputyear, self._config.start_coupling) + self._sim_year = min( + self._config.outputyear, self._config.start_coupling + ) # noqa: E501 self._year_read_output = None self._year_send_input = None # read amount of LPJml cells self._ncell = read_int(self._channel) - if self._ncell != len(range(self._config.startgrid, self._config.endgrid + 1)): + if self._ncell != len( + range(self._config.startgrid, self._config.endgrid + 1) + ): # noqa: E501 self.close() raise ValueError( f"Invalid number of cells received ({self._ncell}), must be" @@ -1104,7 +1233,8 @@ def _iterate_operation(self, length, fun, token, args=None): # check token self._check_token(token=token) - # execute method on channel and if supplied further method arguments + # execute method on channel and if supplied further method + # arguments output = fun(**args) if args else fun() if args is not None and "output" in args.keys(): args["output"] = output @@ -1123,7 +1253,7 @@ def _check_token(self, token): if received_token is not token: self.close() raise ValueError( - f"Received LPJmLToken {received_token.name} is not {token.name}" + f"Received LPJmLToken {received_token.name} is not {token.name}" # noqa: E501 ) # noqa def _check_index(self, index): @@ -1176,7 +1306,7 @@ def _get_config_input_sockets(self): # check if input is defined in LPJmLInputType (band size required) valid_inputs = { - LPJmLInputType(id=sock_id).id: LPJmLInputType(id=sock_id).nband # noqa + LPJmLInputType(id=sock_id).id: LPJmLInputType(id=sock_id).nband for sock_id in socket_ids if sock_id in input_ids } @@ -1230,7 +1360,7 @@ def _init_static_data(self): # Fill array with missing values if self._output_types[static_id].type == int: - tmp_static[:] = -9999 + tmp_static[:] = DEFAULT_NETCDF_FILL_VALUE_INT else: tmp_static[:] = np.nan @@ -1308,9 +1438,9 @@ def _create_xarray_template(self, index, time_length=1): dtype=self._output_types[index].type, ) - # Check if data array is of type integer, use -9999 for nan + # Check if data array is of type integer, use fill value for missing if self._output_types[index].type == int: - output_tmpl[:] = -9999 + output_tmpl[:] = DEFAULT_NETCDF_FILL_VALUE_INT else: output_tmpl[:] = np.nan @@ -1319,7 +1449,9 @@ def _create_xarray_template(self, index, time_length=1): data=output_tmpl, dims=("cell", "band", "time"), coords=dict( - cell=np.arange(self._config.startgrid, self._config.endgrid + 1), + cell=np.arange( + self._config.startgrid, self._config.endgrid + 1 + ), # noqa: E501 lon=(["cell"], self.grid.coords["lon"].values), lat=(["cell"], self.grid.coords["lat"].values), band=np.arange(bands), # [str(i) for i in range(bands)], @@ -1337,7 +1469,9 @@ def _create_xarray_template(self, index, time_length=1): # add meta information to output output_tmpl.add_meta(meta_output) - output_tmpl = output_tmpl.rename(band=f"band ({self._output_ids[index]})") + output_tmpl = output_tmpl.rename( + band=f"band ({self._output_ids[index]})" + ) # noqa: E501 # add meta data to output return output_tmpl @@ -1349,26 +1483,47 @@ def _send_input_data(self, data, year): index = read_int(self._channel) if isinstance(data, LPJmLDataSet): data = data.to_numpy() - elif not isinstance(data[self._input_ids[index]], np.ndarray): + + # Get the input data array + input_name = self._input_ids[index] + input_data = data[input_name] + + # Convert to numpy array if it's not already + if not isinstance(input_data, np.ndarray): + if hasattr(input_data, "values"): + input_data = input_data.values + elif hasattr(input_data, "to_numpy"): + input_data = input_data.to_numpy() + else: + input_data = np.array(input_data) + + # Ensure it's a numpy array + if not isinstance(input_data, np.ndarray): self.close() raise TypeError( - "Unsupported object type. Please supply a numpy " - + "array with the dimension of (ncells, nband)." + f"Input data for '{input_name}' could not be converted to numpy array. " # noqa: E501 + + f"Got type: {type(input_data)}" ) - # type check conversion + # type check and conversion if self._input_types[index].type == float: type_check = np.floating elif self._input_types[index].type == int: type_check = np.integer - if not np.issubdtype(data[self._input_ids[index]].dtype, type_check): - self.close() - raise TypeError( - f"Unsupported type: {data[self._input_ids[index]].dtype} " - + "Please supply a numpy array with data type: " - + f"{np.dtype(self._input_types[index].type)}." - ) + if not np.issubdtype(input_data.dtype, type_check): + # Auto-convert float to int when integer input is expected + if self._input_types[index].type == int and np.issubdtype( + input_data.dtype, np.floating + ): + input_data = np.around(input_data).astype(np.int64) + else: + self.close() + raise TypeError( + f"Unsupported type: {input_data.dtype} " + + "Please supply a numpy array with data type: " + + f"{np.dtype(self._input_types[index].type)}." + ) # check received year self._check_year(year) @@ -1376,25 +1531,24 @@ def _send_input_data(self, data, year): if index in self._input_ids.keys(): # get corresponding number of bands from LPJmLInputType class bands = LPJmLInputType(id=index).nband - if not np.shape(data[self._input_ids[index]]) == (self._ncell, bands): + if not np.shape(input_data) == (self._ncell, bands): if ( - bands == 1 - and not np.shape(data[self._input_ids[index]])[0] == self._ncell - ): # noqa + bands == 1 and not np.shape(input_data)[0] == self._ncell + ): # noqa: E501 self.close() raise ValueError( "The dimensions of the supplied data: " - + f"{(self._ncell, bands)} does not match the " - + f"needed dimensions for {self._input_ids[index]}" + + f"{np.shape(input_data)} does not match the " + + f"needed dimensions for {input_name}" + f": {(self._ncell, bands)}." ) # execute sending values method to actually send the input to # socket - self._send_input_values(data[self._input_ids[index]]) + self._send_input_values(input_data) def _send_input_values(self, data): - """Iterate over all values to be sent to the socket. Recursive iteration - with the correct order of bands and cells for inputs. + """Iterate over all values to be sent to the socket. Recursive + iteration with the correct order of bands and cells for inputs. """ dims = data.shape one_band = len(dims) == 1 @@ -1423,7 +1577,7 @@ def _read_output_data(self, output, year, to_xarray=True): steps_as_bands = True else: raise ValueError( - "Subannual output data with more than one band not supported." + "Subannual output data with more than one band not supported." # noqa: E501 ) if not to_xarray: # read and assign corresponding values from socket to numpy array @@ -1445,9 +1599,11 @@ def _read_output_data(self, output, year, to_xarray=True): # as list for appending/extending as list return output - def _read_output_values(self, output, index, dims=None, steps_as_bands=False): - """Iterate over all values to be read from the socket. Recursive iteration - with the correct order of cells and bands for outputs. + def _read_output_values( + self, output, index, dims=None, steps_as_bands=False + ): # noqa: E501 + """Iterate over all values to be read from the socket. Recursive + iteration with the correct order of cells and bands for outputs. """ cells, bands = dims[0], dims[1] # Determine if there is only one band @@ -1458,11 +1614,15 @@ def _read_output_values(self, output, index, dims=None, steps_as_bands=False): for cell in range(cells): # Read the value from the socket if one_band: - output[cell] = self._output_types[index].read_fun(self._channel) + output[cell] = self._output_types[index].read_fun( + self._channel + ) # noqa: E501 elif steps_as_bands: output[cell, self._output_count_steps[index]] = self._output_types[ index - ].read_fun(self._channel) + ].read_fun( + self._channel + ) # noqa: E501 else: output[cell, band] = self._output_types[index].read_fun( self._channel @@ -1481,7 +1641,9 @@ def _read_meta_output(self, index=None, output_id=None): output = self._config.output[output_id] elif index is not None: output = [ - out for out in self._config.output if out.id == self._output_ids[index] + out + for out in self._config.output + if out.id == self._output_ids[index] # noqa: E501 ][0] else: raise ValueError("Either index or output_id must be supplied") diff --git a/pycoupler/data.py b/pycoupler/data.py index bb68f1e..057d457 100644 --- a/pycoupler/data.py +++ b/pycoupler/data.py @@ -1,7 +1,9 @@ import os import struct import re +from pathlib import Path from collections.abc import Hashable +from typing import Dict import numpy as np import pandas as pd import xarray as xr @@ -11,6 +13,27 @@ from pycoupler.utils import read_json +DEFAULT_NETCDF_FILL_VALUE = np.float32(-999.0) +DEFAULT_NETCDF_FILL_VALUE_INT = -9999 + + +def _suppress_coordinate_fill(dataset: xr.Dataset) -> None: + """Remove _FillValue metadata from coordinate variables.""" + + for coord_name in dataset.coords: + coord = dataset[coord_name] + coord.attrs.pop("_FillValue", None) + coord.encoding["_FillValue"] = None + + +def _sanitize_prefix(value: str | None, default: str = "lpjml") -> str: + """Return a filesystem-friendly prefix.""" + candidate = (value or "").strip() + if not candidate: + candidate = default + safe = re.sub(r"[^\w\-]+", "_", candidate).strip("_") + return safe or default + class LPJmLInputType: """Available Input types loaded from config. @@ -65,6 +88,13 @@ def __init__(self, id=None, name=None): def load_config(cls, config): """Load input types from the provided config.""" cls.__input_types__ = config.input.to_dict() + # Filter to entries with "id" (proper input types); skip scalars like + # delta_year + cls.__input_types__ = { + k: v + for k, v in cls.__input_types__.items() + if isinstance(v, dict) and "id" in v + } cls.names = list(cls.__input_types__.keys()) cls.ids = [value["id"] for value in cls.__input_types__.values()] @@ -84,14 +114,15 @@ def nband(self): @property def type(self): """Get the data type for the specific input""" - if self.name in ["with_tillage", "sdate"]: + if self.name in ["with_tillage", "sdate", "cover_crop"]: return int else: return float @property def has_bands(self): - """Check if multiple bands exist (better check for categorical bands)""" + """Check if multiple bands exist (better check for categorical + bands)""" return bool(self.nband > 1) @@ -123,6 +154,56 @@ def append_to_dict(data_dict, data): return data_dict +def _ensure_cf_metadata( + data_array: xr.DataArray, +) -> Dict[str, str] | None: + """Attach CF-compliant coordinate metadata and extract globals.""" + + da = data_array + + def _ensure(coord_name: str, attrs: Dict[str, str]): + if coord_name in da.coords: + coord = da.coords[coord_name] + for key, value in attrs.items(): + coord.attrs.setdefault(key, value) + + _ensure( + "lat", + { + "standard_name": "latitude", + "long_name": "Latitude", + "units": "degrees_north", + "axis": "Y", + }, + ) + _ensure( + "lon", + { + "standard_name": "longitude", + "long_name": "Longitude", + "units": "degrees_east", + "axis": "X", + }, + ) + if "time" in da.coords: + coord = da.coords["time"] + coord.attrs.setdefault("standard_name", "time") + coord.attrs.setdefault("long_name", "Time") + coord.attrs.setdefault("axis", "T") + coord.attrs.setdefault("units", "years") + coord.attrs.setdefault("calendar", "noleap") + + _ensure( + "area_km2", + {"long_name": "grid cell area", "units": "km2"}, + ) + + if not da.attrs.get("units"): + da.attrs["units"] = "1" + + return da.attrs.pop("_global_attrs", None) + + class LPJmLData(xr.DataArray): """Class for single LPJmL data arrays (input, output, etc.) with meta data and defined dimensions (cell, band, time). @@ -165,12 +246,15 @@ def add_meta(self, meta_data): if meta_data.cellsize_lat != meta_data.cellsize_lon: raise ValueError( - "Cell sizes in latitude and longitude direction must be " "equal!" + "Cell sizes in latitude and longitude direction must be " + "equal!" # noqa: E501 ) else: self.attrs["cellsize"] = meta_data.cellsize_lon - band_dim = next((dim for dim in self.dims if dim.startswith("band")), None) + band_dim = next( + (dim for dim in self.dims if dim.startswith("band")), None + ) # noqa: E501 # TODO assign lat lon to grid object if band_dim is not None and len(self.coords[band_dim]) > 1: if meta_data.variable == "grid": @@ -178,10 +262,14 @@ def add_meta(self, meta_data): elif len(self.coords[band_dim]) == len(meta_data.band_names): self.coords[band_dim] = meta_data.band_names else: - self.coords[band_dim] = np.arange(1, len(self.coords[band_dim]) + 1) + self.coords[band_dim] = np.arange( + 1, len(self.coords[band_dim]) + 1 + ) # noqa: E501 if hasattr(meta_data, "global_attrs"): - self.attrs["institution"] = meta_data.global_attrs["institution"] + self.attrs["institution"] = meta_data.global_attrs[ + "institution" + ] # noqa: E501 self.attrs["contact"] = meta_data.global_attrs["contact"] self.attrs["comment"] = meta_data.global_attrs["comment"] else: @@ -239,7 +327,9 @@ def get_neighbourhood(self, id=True, cellsize=0.5): current_neighbours ] elif len(current_neighbours) > 0 and not id: - neighbour_ids[i, : len(current_neighbours)] = current_neighbours + neighbour_ids[i, : len(current_neighbours)] = ( + current_neighbours # noqa: E501 + ) neighbours = LPJmLData( data=neighbour_ids, @@ -255,9 +345,174 @@ def get_neighbourhood(self, id=True, cellsize=0.5): return neighbours - def transform(self): - """TODO: implement function to convert cell into lon/lat format""" - pass + def _wrap(self, data_array): + """Return a LPJmLData instance from a generic DataArray.""" + if isinstance(data_array, LPJmLData): + return data_array + if isinstance(data_array, xr.DataArray): + return self.__class__(data_array) + raise TypeError(f"Expected xarray.DataArray, got {type(data_array)!r}") + + def transform(self, to="lon_lat"): + """Transform the spatial layout between cell and lon/lat grid. + + Parameters + ---------- + to : {"lon_lat", "cell"}, default "lon_lat" + Target representation. ``"lon_lat"`` reshapes the data onto a + two-dimensional latitude/longitude grid using the embedded + coordinates; ``"cell"`` collapses a lon/lat grid back to the + original one-dimensional ``cell`` dimension. + + Returns + ------- + LPJmLData + Transformed data array. + """ + + to = (to or "").lower() + if to not in {"lon_lat", "cell"}: + raise ValueError( + f"Unsupported transform target '{to}'. " + "Use either 'lon_lat' or 'cell'." + ) + + if to == "lon_lat": + if "cell" not in self.dims: + return self.copy() + if not {"lon", "lat"} <= set(self.coords): + raise ValueError( + "Cannot transform to lon/lat grid without 'lon' and 'lat' " + "coordinates attached to the 'cell' dimension." + ) + + # Ensure lat/lon combinations are unique before unstacking + index = pd.MultiIndex.from_arrays( + [self.coords["lat"].values, self.coords["lon"].values], + names=("lat", "lon"), + ) + if not index.is_unique: + raise ValueError( + "Duplicate lat/lon pairs detected; cannot build a unique grid." # noqa: E501 + ) + + lon_lat = self.set_index(cell=("lat", "lon")).unstack("cell") + # Sort coordinates for consistent orientation (lon ascending, + # lat descending) + if "lon" in lon_lat.dims: + lon_lat = lon_lat.sortby("lon") + if "lat" in lon_lat.dims: + lon_lat = lon_lat.sortby("lat", ascending=False) + + dims = list(lon_lat.dims) + ordered_dims = [dim for dim in ("lat", "lon") if dim in dims] + ordered_dims.extend(dim for dim in dims if dim not in ordered_dims) + lon_lat = lon_lat.transpose(*ordered_dims) + return self._wrap(lon_lat) + + # to == "cell" + if not {"lon", "lat"} <= set(self.dims): + return self.copy() + + cell_da = self.sortby("lon").sortby("lat", ascending=False) + cell_da = cell_da.stack(cell=("lat", "lon")).reset_index("cell") + # Ensure lon/lat remain attached to the cell dimension + if "lon" in cell_da.coords: + cell_da.coords["lon"] = ("cell", cell_da.coords["lon"].values) + if "lat" in cell_da.coords: + cell_da.coords["lat"] = ("cell", cell_da.coords["lat"].values) + + dims = list(cell_da.dims) + ordered_dims = ["cell"] + ordered_dims.extend(dim for dim in dims if dim != "cell") + cell_da = cell_da.transpose(*ordered_dims) + return self._wrap(cell_da) + + def to_netcdf( + self, + path: str | os.PathLike[str] | None = None, + *, + compression: bool = True, + complevel: int = 4, + fill_value: float | None = None, + engine: str = "netcdf4", + mode: str = "w", + compute: bool = True, + **kwargs, + ): + """Write this LPJmLData to NetCDF, analogous to + :meth:`xarray.DataArray.to_netcdf`, but gridding ``cell`` dimensions + before persisting. [#da_netcdf]_ + + .. [#da_netcdf] https://docs.xarray.dev/en/latest/generated/xarray.DataArray.to_netcdf.html # noqa: E501 + """ + + kwargs = dict(kwargs) + kwargs.pop("lpjml_style", None) + + lpjml = self + if lpjml.name is None and path is None: + raise ValueError( + "LPJmLData must have a name when no output path is provided." + ) + + if "cell" in lpjml.dims: + if not {"lon", "lat"} <= set(lpjml.coords): + raise ValueError( + "Cannot write LPJmLData with a 'cell' dimension that lacks " # noqa: E501 + "'lon' and 'lat' coordinates." + ) + lpjml = lpjml.transform("lon_lat") + + global_attrs = _ensure_cf_metadata(lpjml) + var_name = lpjml.name or "__lpjml_data__" + dataset = lpjml.to_dataset(name=var_name) + _suppress_coordinate_fill(dataset) + if global_attrs: + dataset.attrs.update(dict(global_attrs)) + dataset.attrs.setdefault("Conventions", "CF-1.8") + dtype = lpjml.dtype + encoding: dict[str, dict[str, object]] = {} + + enc: dict[str, object] = {} + target_fill = fill_value + if target_fill is None: + if np.issubdtype(dtype, np.floating): + target_fill = DEFAULT_NETCDF_FILL_VALUE + elif np.issubdtype(dtype, np.integer): + target_fill = DEFAULT_NETCDF_FILL_VALUE_INT + if target_fill is not None: + enc["_FillValue"] = target_fill + + if compression and np.issubdtype(dtype, np.number): + enc["zlib"] = True + enc["complevel"] = complevel + + encoding[var_name] = enc + + if path is None: + return dataset.to_netcdf( + path=None, + engine=engine, + mode=mode, + encoding=encoding, + compute=compute, + **kwargs, + ) + + target_path = Path(path) + if not target_path.suffix: + target_path = target_path.with_suffix(".nc4") + + dataset.to_netcdf( + target_path, + engine=engine, + mode=mode, + encoding=encoding, + compute=compute, + **kwargs, + ) + return str(target_path) class LPJmLDataSet(xr.Dataset): @@ -283,7 +538,9 @@ class LPJmLDataSet(xr.Dataset): def __init__(self, *args, **kwargs): super(LPJmLDataSet, self).__init__(*args, **kwargs) - if self.data_vars and ("cellsize" in self[list(self.data_vars)[0]].attrs): + if self.data_vars and ( + "cellsize" in self[list(self.data_vars)[0]].attrs + ): # noqa: E501 first_attrs = self[list(self.data_vars)[0]].attrs self.attrs["source"] = first_attrs["source"] self.attrs["history"] = first_attrs["history"] @@ -300,8 +557,8 @@ def to_numpy(self): Returns ------- dict - Dictionary with data variables as keys and corresponding numpy arrays - as values. + Dictionary with data variables as keys and corresponding numpy + arrays as values. """ return {key: value.to_numpy() for key, value in self.data_vars.items()} @@ -338,9 +595,14 @@ def _construct_dataarray(self, name: Hashable) -> LPJmLData: _, name, variable = xr.core.dataset._get_virtual_variable( self._variables, name, self.dims ) + else: + # Work on a shallow copy so we don't mutate dataset variables + variable = variable.copy(deep=False) needed_dims = set(variable.dims) - stripped_dims = {re.sub(r"\s*\(.*?\)", "", item) for item in needed_dims} + stripped_dims = { + re.sub(r"\s*\(.*?\)", "", item) for item in needed_dims + } # noqa: E501 coords: dict[Hashable, Variable] = {} # preserve ordering @@ -349,16 +611,25 @@ def _construct_dataarray(self, name: Hashable) -> LPJmLData: set(self.variables[k].dims) <= needed_dims or set(self.variables[k].dims) <= stripped_dims ): - coords[k] = self.variables[k] - - indexes = xr.core.indexes.filter_indexes_from_coords(self._indexes, set(coords)) + coords[k] = self.variables[k].copy(deep=False) + + indexes = xr.core.indexes.filter_indexes_from_coords( + self._indexes, set(coords) + ) # noqa: E501 + # Copy indexes to avoid mutating dataset-level state + indexes = { + key: (idx.copy() if hasattr(idx, "copy") else idx) + for key, idx in indexes.items() + } # TODO: this is a hack to get around the fact that we don't have # a proper way to represent band dimensions in xarray # get the corresponding band dimension band_dim = [ - dim for dim in variable._dims if dim.startswith("band") and dim != "band" + dim + for dim in variable._dims + if dim.startswith("band") and dim != "band" # noqa: E501 ] if band_dim: variable._dims = variable._parse_dimensions( @@ -369,7 +640,9 @@ def _construct_dataarray(self, name: Hashable) -> LPJmLData: ) # get the corresponding "band" index and delete all other band indexes - band_idx = [key for key in coords if key.startswith("band") and key != "band"] + band_idx = [ + key for key in coords if key.startswith("band") and key != "band" + ] # noqa: E501 if band_idx: for key in band_idx: if name not in key: @@ -398,11 +671,14 @@ def _construct_dataarray(self, name: Hashable) -> LPJmLData: if name.startswith("band") and name != "band": name = "band" - return LPJmLData(variable, coords, name=name, indexes=indexes, fastpath=True) + return LPJmLData( + variable, coords, name=name, indexes=indexes, fastpath=True + ) # noqa: E501 def to_dict(self, data="list", encoding=False): """ - Convert this dataset to a dictionary following xarray naming conventions. + Convert this dataset to a dictionary following xarray naming + conventions. Converts all variables and attributes to native Python objects. Useful for converting to JSON. To avoid datetime incompatibility, @@ -410,13 +686,14 @@ def to_dict(self, data="list", encoding=False): Parameters ---------- - data : bool or {"list", "array", "lpjmldata"}, optional, default: "list" + data : bool or {"list", "array", "lpjmldata"}, optional, default: + "list" Whether to include the actual data in the dictionary. - If set to ``False``, returns just the schema. - If set to ``"array"``, returns data as the underlying array type. - If set to ``"list"`` (or ``True`` for backwards compatibility), - returns data in lists of Python data types. For efficient "list" output, - use ``ds.compute().to_dict(data="list")``. + returns data in lists of Python data types. For efficient "list" + output, use ``ds.compute().to_dict(data="list")``. encoding : bool, optional, default: False Whether to include the Dataset's encoding in the dictionary. @@ -436,6 +713,108 @@ def to_dict(self, data="list", encoding=False): return super().to_dict(data=data, encoding=encoding) + def transform(self, to="lon_lat"): + """Transform all LPJmLData variables to the requested spatial + layout.""" + + transformed = {} + for name, data in self.data_vars.items(): + if isinstance(data, LPJmLData): + transformed[name] = data.transform(to=to) + else: + transformed[name] = data + + result = LPJmLDataSet(transformed) + result.attrs = self.attrs.copy() + return result + + def to_netcdf( + self, + path: str | os.PathLike[str], + *, + lpjml_style: bool = True, + per_variable: bool = True, + file_prefix: str | None = None, + suffix: str = ".nc4", + compression: bool = True, + complevel: int = 4, + fill_value: float | None = None, + engine: str | None = None, + mode: str = "w", + compute: bool = True, + **kwargs, + ) -> str | Dict[str, str]: + """Write the dataset to NetCDF. + + By default LPJmL conventions are applied: every variable is aligned on + the LPJmL grid and written to a dedicated ``_.nc4`` file + (``per_variable=True``). Set ``per_variable=False`` to emit a single + multi-variable file. Pass ``lpjml_style=False`` to defer entirely to + the underlying :meth:`xarray.Dataset.to_netcdf`. + """ + + kwargs = dict(kwargs) + per_variable = kwargs.pop("split_vars", per_variable) + kwargs.pop("file_prefix", None) + kwargs.pop("suffix", None) + + if not lpjml_style: + return super().to_netcdf( + path=path, + engine=engine, + mode=mode, + compute=compute, + **kwargs, + ) + + aligned = _grid_aligned_dataset(self) + _suppress_coordinate_fill(aligned) + if per_variable: + target_dir = Path(path) + target_dir.mkdir(parents=True, exist_ok=True) + prefix_default = ( + file_prefix + or aligned.attrs.get("sim_name") + or target_dir.name + or "lpjml" + ) + prefix = _sanitize_prefix(prefix_default) + written: Dict[str, str] = {} + for name, data in aligned.data_vars.items(): + target_file = target_dir / f"{prefix}_{name}{suffix}" + LPJmLData(data).to_netcdf( + target_file, + compression=compression, + complevel=complevel, + fill_value=fill_value, + engine=engine or "netcdf4", + mode=mode, + compute=compute, + **kwargs, + ) + written[name] = str(target_file) + return written + + encoding = { + name: _netcdf_encoding( + data, + fill_value=fill_value, + compression=compression, + complevel=complevel, + ) + for name, data in aligned.data_vars.items() + } + target_path = os.fspath(path) + aligned.to_netcdf( + target_path, + engine=engine or "netcdf4", + mode=mode, + encoding=encoding, + compute=compute, + **kwargs, + ) + return target_path + def read_data(file_name, var_name=None, multiple_bands=False): """Read netcdf file and return data as numpy array or xarray.DataArray. @@ -463,7 +842,9 @@ def read_data(file_name, var_name=None, multiple_bands=False): data.coords["time"].attrs["units"] = unit data.coords["time"] = date_time.year - other_dims = [dim for dim in data.dims if dim not in ["lat", "lon", "time"]] + other_dims = [ + dim for dim in data.dims if dim not in ["lat", "lon", "time"] + ] # noqa: E501 # handle multiple bands if var_name and multiple_bands: @@ -633,7 +1014,7 @@ def __repr__(self, sub_repr=False): summary_list = [summary] summary_list.extend( [ - f" * {torepr}{(13-len(torepr))*' '} {getattr(self, torepr)}" + f" * {torepr}{(13-len(torepr))*' '} {getattr(self, torepr)}" # noqa: E501 for torepr in other_attr ] ) @@ -657,6 +1038,53 @@ def read_meta(file_name): return LPJmLMetaData(read_json(file_name)) +def _netcdf_encoding( + data: xr.DataArray, + *, + fill_value: float | None, + compression: bool, + complevel: int, +) -> Dict[str, object]: + dtype = data.dtype + encoding: Dict[str, object] = {} + + if fill_value is not None: + encoding["_FillValue"] = fill_value + else: + if np.issubdtype(dtype, np.floating): + encoding["_FillValue"] = DEFAULT_NETCDF_FILL_VALUE + elif np.issubdtype(dtype, np.integer): + encoding["_FillValue"] = DEFAULT_NETCDF_FILL_VALUE_INT + + if compression and np.issubdtype(dtype, np.number): + encoding["zlib"] = True + encoding["complevel"] = complevel + + return encoding + + +def _grid_aligned_dataset(ds: xr.Dataset) -> xr.Dataset: + """Return dataset where all cell variables are converted to lon/lat + grids.""" + + data_vars = {} + for name, data in ds.data_vars.items(): + arr = data if isinstance(data, LPJmLData) else LPJmLData(data) + if "cell" in arr.dims: + if not {"lon", "lat"} <= set(arr.coords): + raise ValueError( + f"Variable '{name}' is indexed by 'cell' but lacks " + "corresponding 'lon'/'lat' coordinates." + ) + arr = arr.transform("lon_lat") + data_vars[name] = arr + + aligned = xr.Dataset(data_vars) + aligned.attrs.update(ds.attrs) + aligned.attrs.setdefault("Conventions", "CF-1.8") + return aligned + + # Function has been derived from the lpjmlkit R package # https://github.com/PIK-LPJmL/lpjmlkit # Author of original R function: Sebastian Ostberg @@ -888,7 +1316,9 @@ def get_headersize(filename): header = read_header(filename, to_dict=True) version = header["header"]["version"] if version < 1 or version > 4: - raise ValueError("Invalid header version. Expecting value between 1 and 4.") + raise ValueError( + "Invalid header version. Expecting value between 1 and 4." + ) # noqa: E501 headersize = len(header["name"]) + {1: 7, 2: 9, 3: 11, 4: 13}[version] * 4 return headersize diff --git a/pycoupler/release.py b/pycoupler/release.py index 58eb055..3cc3155 100644 --- a/pycoupler/release.py +++ b/pycoupler/release.py @@ -14,8 +14,11 @@ 2. Run tests with pytest 3. Run linting with flake8 4. Update CITATION.cff (if needed) -5. Commit changes +5. Commit CITATION.cff changes (if updated) 6. Create Git tag + +Note: You should commit your changes manually before running this script. +The script will only commit CITATION.cff updates automatically. """ import sys @@ -233,10 +236,13 @@ def main(): print(" - Run tests with pytest") print(" - Run linting with flake8") print(" - Update CITATION.cff (if needed)") - print(" - Commit changes") + print(" - Commit CITATION.cff changes (if updated)") print(" - Delete existing local tag (if it exists)") print(" - Create Git tag") print("") + print("Note: Commit your changes manually before running this script.") + print("The script will only commit CITATION.cff updates automatically.") + print("") print("Prerequisites: Install dev dependencies first:") print(" pip install -e .[dev]") sys.exit(0) diff --git a/pycoupler/run.py b/pycoupler/run.py index 4299842..9c74f5a 100644 --- a/pycoupler/run.py +++ b/pycoupler/run.py @@ -1,10 +1,112 @@ import os +import subprocess from datetime import datetime -from subprocess import run, Popen, PIPE, CalledProcessError +from pathlib import Path +from subprocess import CalledProcessError, PIPE, Popen, run + from pycoupler.config import read_config import multiprocessing as mp +from pycoupler.utils import warn_deprecated_alias + + +def kill_stale_lpjml_processes(port=None, verbose=False): + """Kill stale LPJmL processes that may be left over from previous runs. + + This function identifies and terminates LPJmL processes that are no longer + needed, which can happen when simulations crash or are interrupted. + + Parameters + ---------- + port : int, optional + If provided, also kill any process using this port. Defaults to None. + verbose : bool, optional + If True, print information about killed processes. Defaults to False. + + Returns + ------- + int + Number of processes killed, or -1 if an error occurred. + """ + killed_count = 0 + + # Kill processes by name (lpjml) + try: + # Find LPJmL processes by name + result = subprocess.run( + ["pgrep", "-f", "bin/lpjml"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0 and result.stdout.strip(): + pids = result.stdout.strip().split("\n") + for pid in pids: + if pid.strip(): + try: + kill_result = subprocess.run( + ["kill", "-9", pid.strip()], + timeout=5, + capture_output=True, + ) + if kill_result.returncode == 0: + killed_count += 1 + if verbose: + print( + f"Killed LPJmL process with PID {pid.strip()}" + ) # noqa: E501 + except subprocess.TimeoutExpired: + # Ignore timeout errors during best-effort port cleanup. + pass + except ( + subprocess.TimeoutExpired, + subprocess.CalledProcessError, + FileNotFoundError, + ): + pass + + # Also kill processes on the specified port + if port is not None: + try: + result = subprocess.run( + ["lsof", "-ti", f":{port}"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0 and result.stdout.strip(): + pids = result.stdout.strip().split("\n") + for pid in pids: + if pid.strip(): + try: + kill_result = subprocess.run( + ["kill", "-9", pid.strip()], + timeout=5, + capture_output=True, + ) + if kill_result.returncode == 0: + killed_count += 1 + if verbose: + print( + f"Killed process on port {port} " + f"with PID {pid.strip()}" + ) + except subprocess.TimeoutExpired: + # Ignore timeout errors during best-effort port cleanup. + pass + except ( + subprocess.TimeoutExpired, + subprocess.CalledProcessError, + FileNotFoundError, + ): + pass + + if verbose and killed_count > 0: + print(f"Total killed: {killed_count} process(es)") + + return killed_count + def operate_lpjml(config_file, std_to_file=False): """Run LPJmL using a generated (class LpjmlConfig) config file. @@ -22,7 +124,9 @@ def operate_lpjml(config_file, std_to_file=False): config = read_config(config_file) if not os.path.isdir(config.model_path): - raise ValueError(f"Folder of model_path '{config.model_path}' does not exist!") + raise ValueError( + f"Folder of model_path '{config.model_path}' does not exist!" + ) # noqa: E501 output_path = f"{config.sim_path}/output/{config.sim_name}" @@ -74,7 +178,7 @@ def operate_lpjml(config_file, std_to_file=False): raise CalledProcessError(p.returncode, p.args) -def run_lpjml(config_file, std_to_file=False): +def start_lpjml(config_file, std_to_file=False, cleanup_stale=True, port=2224): """Run LPJmL using a generated (class LpjmlConfig) config file. Similar to R function `lpjmlKit::run_lpjml`. @@ -85,11 +189,28 @@ def run_lpjml(config_file, std_to_file=False): std_to_file : bool, optional If True, stdout and stderr are written to files in the output folder. Defaults to False. + cleanup_stale : bool, optional + If True, kill any stale LPJmL processes before starting a new one. + This prevents "Another copan:LPJmL process is already running" errors. + Defaults to True. + port : int, optional + The coupling port to clean up. Only used if cleanup_stale is True. + Defaults to 2224. """ - run = mp.Process(target=operate_lpjml, args=(config_file, std_to_file)) - run.start() + if cleanup_stale: + kill_stale_lpjml_processes(port=port, verbose=False) + + process = mp.Process(target=operate_lpjml, args=(config_file, std_to_file)) + process.start() + + return process + - return run +def run_lpjml(*args, **kwargs): + """Backward-compatible alias for :func:`start_lpjml`.""" + + warn_deprecated_alias(start_lpjml, "run_lpjml", "start_lpjml") + return start_lpjml(*args, **kwargs) def submit_lpjml( @@ -103,6 +224,7 @@ def submit_lpjml( option=None, couple_to=None, venv_path=None, + slurm_jcf_dir=None, ): """Submit LPJmL run to Slurm using `lpjsubmit` and a generated (class LpjmlConfig) config file. @@ -130,8 +252,8 @@ def submit_lpjml( More information at and . dependency : int/str, optional - If there is a job that should be processed first (e.g. spinup) then pass - its job id here. + If there is a job that should be processed first (e.g. spinup) then + pass its job id here. blocking : int, optional Cores to be blocked. More information at and @@ -144,6 +266,10 @@ def submit_lpjml( venv_path : str, optional Path to a venv to run the coupled script in. This should be the path to the top folder of the venv. If not set, `python3` in PATH is used. + slurm_jcf_dir : str or Path, optional + Directory where slurm.jcf file should be written. If not set, uses + current working directory. Useful for tests to avoid polluting the + project root. Returns ------- @@ -153,7 +279,9 @@ def submit_lpjml( config = read_config(config_file) if not os.path.isdir(config.model_path): - raise ValueError(f"Folder of model_path '{config.model_path}' does not exist!") + raise ValueError( + f"Folder of model_path '{config.model_path}' does not exist!" + ) # noqa: E501 output_path = f"{config.sim_path}/output/{config.sim_name}" @@ -199,6 +327,13 @@ def submit_lpjml( cmd.extend(["-option", opt]) # run in coupled mode and pass coupling program/model + needs_coupler_wait = bool(couple_to) + if slurm_jcf_dir is None: + slurm_jcf_dir = os.getcwd() + slurm_jcf_path = Path(slurm_jcf_dir) / "slurm.jcf" + + couple_file = None + if couple_to: python_path = "python3" if venv_path: @@ -227,6 +362,9 @@ def submit_lpjml( cmd.extend(["-couple", couple_file]) + if needs_coupler_wait: + cmd.append("-norun") + cmd.extend([str(ntasks), config_file]) # Intialize submit_status in higher scope @@ -244,22 +382,122 @@ def submit_lpjml( else: del os.environ["LPJROOT"] - # print stdout and stderr if not successful if submit_status is None: raise Exception("Process was not submitted.") - elif submit_status.returncode == 0: - print(submit_status.stdout.decode("utf-8")) - else: - print(submit_status.stdout.decode("utf-8")) - print(submit_status.stderr.decode("utf-8")) + + submit_stdout = submit_status.stdout.decode("utf-8") + submit_stderr = submit_status.stderr.decode("utf-8") + + if submit_status.returncode != 0: + print(submit_stdout) + print(submit_stderr) raise CalledProcessError(submit_status.returncode, submit_status.args) - # return job id - return ( - submit_status.stdout.decode("utf-8") - .split("Submitted batch job ")[1] - .split("\n")[0] + + print(submit_stdout) + + job_submission_output = submit_stdout + + if needs_coupler_wait: + job_submission_output = _patch_slurm_and_submit( + slurm_jcf_path=slurm_jcf_path, + couple_file=couple_file, + dependency=dependency, + ) + + if "Submitted batch job" not in job_submission_output: + raise RuntimeError( + "Could not determine job id from submission output:\n" + f"{job_submission_output}" + ) + + return job_submission_output.split("Submitted batch job ")[1].split("\n")[ + 0 + ] # noqa: E501 + + +def _patch_slurm_and_submit( + slurm_jcf_path: Path, couple_file: str | None, dependency +): # noqa: E501 + """Ensure the coupling helper is waited on before submitting the job. + + Older LPJmL versions background the coupler without waiting for it. We + patch the generated `slurm.jcf` to add `couple_pid` handling if it is + missing and then submit the job ourselves via `sbatch`. + """ + + if couple_file is None: + raise RuntimeError( + "Coupling file path is required for coupled submissions." + ) # noqa: E501 + + if not slurm_jcf_path.exists(): + raise FileNotFoundError( + f"lpjsubmit did not create expected job file at '{slurm_jcf_path}'." # noqa: E501 + ) + + slurm_text = slurm_jcf_path.read_text() + if "couple_pid" not in slurm_text: + slurm_text = _inject_coupler_wait(slurm_text, couple_file) + slurm_jcf_path.write_text(slurm_text) + + sbatch_cmd = ["sbatch"] + if dependency: + sbatch_cmd.append(f"--dependency=afterok:{dependency}") + sbatch_cmd.append(str(slurm_jcf_path)) + sbatch_status = run(sbatch_cmd, capture_output=True) + + sbatch_stdout = sbatch_status.stdout.decode("utf-8") + sbatch_stderr = sbatch_status.stderr.decode("utf-8") + + if sbatch_status.returncode != 0: + print(sbatch_stdout) + print(sbatch_stderr) + raise CalledProcessError(sbatch_status.returncode, sbatch_status.args) + + print(sbatch_stdout) + return sbatch_stdout + + +def _inject_coupler_wait(slurm_text: str, couple_file: str) -> str: + """Patch the slurm script so it waits for the coupling helper.""" + + launch_variants = [ + f"{couple_file} &", + f"{couple_file} &", + ] + + target_snippet = None + for variant in launch_variants: + snippet = f"{variant}\n\n" + if snippet in slurm_text: + target_snippet = snippet + break + + if target_snippet is None: + raise RuntimeError( + "Could not find the coupling launch line in slurm.jcf to patch." + ) + + replacement = target_snippet.replace("&", "&\ncouple_pid=$!", 1) + slurm_text = slurm_text.replace(target_snippet, replacement, 1) + + exit_marker = "exit $rc # exit with return code" + if exit_marker not in slurm_text: + raise RuntimeError("Could not find exit marker in slurm.jcf.") + + wait_block = ( + 'if [ -n "${couple_pid:-}" ]; then\n' + " wait $couple_pid\n" + " couple_rc=$?\n" + " if [ $rc -eq 0 ]; then\n" + " rc=$couple_rc\n" + " fi\n" + "fi\n" + f"{exit_marker}" ) + return slurm_text.replace(exit_marker, wait_block, 1) + def check_lpjml(config_file): """Check if config file is set correctly. @@ -273,7 +511,9 @@ def check_lpjml(config_file): """ config = read_config(config_file) if not os.path.isdir(config.model_path): - raise ValueError(f"Folder of model_path '{config.model_path}' does not exist!") + raise ValueError( + f"Folder of model_path '{config.model_path}' does not exist!" + ) # noqa: E501 if os.path.isfile(f"{config.model_path}/bin/lpjcheck"): proc_status = run( ["./bin/lpjcheck", config_file], diff --git a/pycoupler/utils.py b/pycoupler/utils.py index 82a766d..72248c1 100644 --- a/pycoupler/utils.py +++ b/pycoupler/utils.py @@ -1,8 +1,28 @@ -import os import json +import os +import warnings + from fuzzywuzzy import fuzz, process +def warn_deprecated_alias(instance, old_name: str, new_name: str) -> None: + """Emit a standardized DeprecationWarning for attribute aliases.""" + + if callable(instance): + subject = old_name + replacement = new_name + else: + cls_name = instance.__class__.__name__ + subject = f"{cls_name}.{old_name}" + replacement = f"{cls_name}.{new_name}" + warnings.warn( + f"{subject} is deprecated and will be removed in a future " + f"release; use {replacement} instead.", + DeprecationWarning, + stacklevel=3, + ) + + def get_countries(): """Current workaround to get countries defined in LPJmL. diff --git a/tests/test_config.py b/tests/test_config.py index 9b204f9..9d4eb24 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -136,6 +136,8 @@ def test_set_coupled_config(test_path): # align both config objects check_config_coupled.restart_filename = config_coupled.restart_filename check_config_coupled.sim_path = config_coupled.sim_path + # align coupled_model (set_coupled sets it to 'copan' by default) + check_config_coupled.coupled_model = config_coupled.coupled_model # delete tracking enty changed from dict for comparison config_coupled_dict = config_coupled.to_dict() @@ -145,7 +147,7 @@ def test_set_coupled_config(test_path): assert ( repr(config_coupled) - == f"\nSettings: lpjml v5.8\n (general)\n * sim_name coupled_test\n * firstyear 2001\n * lastyear 2050\n * startgrid 27410\n * endgrid 27411\n * landuse yes\n (changed)\n * model_path {test_path}/data\n * sim_path {test_path}/data\n * outputyear 2022\n * output_metafile True\n * grid_type float\n * write_restart False\n * nspinup 0\n * float_grid True\n * restart_filename {test_path}/data/restart/restart_historic_run.lpj\n * outputyear 2022\n * radiation cloudiness\n * fix_co2 True\n * fix_co2_year 2018\n * fix_climate True\n * fix_climate_cycle 11\n * fix_climate_year 2013\n * river_routing False\n * tillage_type read\n * residue_treatment fixed_residue_remove\n * double_harvest False\n * intercrop True\nCoupled model: copan:CORE\n * start_coupling 2023\n * input (coupled) ['with_tillage']\n * output (coupled) ['grid', 'pft_harvestc', 'cftfrac', 'soilc_agr_layer', 'hdate', 'country', 'region']\n" # noqa + == f"\nSettings: lpjml v5.8\n (general)\n * sim_name coupled_test\n * firstyear 2001\n * lastyear 2050\n * startgrid 27410\n * endgrid 27411\n * landuse yes\n (changed)\n * model_path {test_path}/data\n * sim_path {test_path}/data\n * outputyear 2022\n * output_metafile True\n * grid_type float\n * write_restart False\n * nspinup 0\n * float_grid True\n * restart_filename {test_path}/data/restart/restart_historic_run.lpj\n * outputyear 2022\n * radiation cloudiness\n * fix_co2 True\n * fix_co2_year 2018\n * fix_climate True\n * fix_climate_cycle 11\n * fix_climate_year 2013\n * river_routing False\n * tillage_type read\n * residue_treatment fixed_residue_remove\n * double_harvest False\n * intercrop True\nCoupled model: copan\n * start_coupling 2023\n * input (coupled) ['with_tillage']\n * output (coupled) ['grid', 'pft_harvestc', 'cftfrac', 'soilc_agr_layer', 'hdate', 'country', 'region']\n" # noqa ) # noqa assert config_coupled_dict == check_config_coupled_dict diff --git a/tests/test_couple.py b/tests/test_couple.py index 51e2f5c..7565d59 100644 --- a/tests/test_couple.py +++ b/tests/test_couple.py @@ -11,12 +11,15 @@ def test_lpjml_coupler(model_path, sim_path, lpjml_coupler): hist_outputs = outputs.copy(deep=True) for year in lpjml_coupler.get_sim_years(): - inputs.time.values[0] = np.datetime64(f"{year}-12-31") + # Use assign_coords to set time (avoids read-only array issues) + # Use datetime64[ns] to match xarray's internal representation + new_time = np.datetime64(f"{year}-12-31", "ns") + inputs = inputs.assign_coords(time=[new_time]) # send input data to lpjml lpjml_coupler.send_input(inputs, year) # read output data from lpjml - outputs.time.values[0] = np.datetime64(f"{year}-12-31") + outputs = outputs.assign_coords(time=[new_time]) for name, output in lpjml_coupler.read_output(year).items(): outputs[name][:] = output[:] @@ -39,13 +42,12 @@ def test_lpjml_coupler(model_path, sim_path, lpjml_coupler): assert lpjml_coupler.ncell == 2 assert [year for year in lpjml_coupler.get_cells()] == [27410, 27411] - assert lpjml_coupler.historic_years == [] + # historic_years is computed from firstyear (2001) to start_coupling (2023) + assert lpjml_coupler.historic_years == list(range(2001, 2023)) assert lpjml_coupler.sim_years == [] assert lpjml_coupler.coupled_years == [] assert [year for year in lpjml_coupler.get_coupled_years()] == [] - assert ( - repr(lpjml_coupler) - == f""" + assert repr(lpjml_coupler) == f""" Simulation: (version: 3, localhost:) * sim_year 2050 * ncell 2 @@ -85,7 +87,6 @@ def test_lpjml_coupler(model_path, sim_path, lpjml_coupler): * input (coupled) ['with_tillage'] * output (coupled) ['grid', 'pft_harvestc', 'cftfrac', 'soilc_agr_layer', 'hdate', 'country', 'region'] """ # noqa - ) def test_lpjml_coupler_codes_name(lpjml_coupler): diff --git a/tests/test_coupler_utils.py b/tests/test_coupler_utils.py new file mode 100644 index 0000000..08feb62 --- /dev/null +++ b/tests/test_coupler_utils.py @@ -0,0 +1,227 @@ +"""Test port utility functions from coupler.py.""" + +import subprocess +from unittest.mock import MagicMock, patch + +import pytest + +from pycoupler.coupler import ( + kill_process_on_port, + cleanup_port_on_exit, + cleanup_port_context, + safe_port_binding, +) + + +class TestKillProcessOnPort: + """Test kill_process_on_port function.""" + + @patch("pycoupler.coupler.subprocess.run") + def test_kill_process_on_port_success(self, mock_run): + """Test successfully killing processes on a port.""" + # Mock lsof finding two PIDs + mock_lsof = MagicMock() + mock_lsof.returncode = 0 + mock_lsof.stdout = "12345\n67890\n" + + # Mock kill commands + mock_kill = MagicMock() + mock_kill.returncode = 0 + + mock_run.side_effect = [mock_lsof, mock_kill, mock_kill] + + result = kill_process_on_port(8080) + + assert result == 2 + assert mock_run.call_count == 3 + # Check lsof was called correctly + mock_run.assert_any_call( + ["lsof", "-ti", ":8080"], capture_output=True, text=True, timeout=5 + ) + # Check kill was called for each PID + mock_run.assert_any_call( + ["kill", "-9", "12345"], timeout=5, capture_output=True + ) + mock_run.assert_any_call( + ["kill", "-9", "67890"], timeout=5, capture_output=True + ) + + @patch("pycoupler.coupler.subprocess.run") + def test_kill_process_on_port_no_processes(self, mock_run): + """Test when no processes are using the port.""" + mock_lsof = MagicMock() + mock_lsof.returncode = 0 + mock_lsof.stdout = "" # No PIDs found + + mock_run.return_value = mock_lsof + + result = kill_process_on_port(8080) + + assert result == 0 + assert mock_run.call_count == 1 + mock_run.assert_called_once_with( + ["lsof", "-ti", ":8080"], capture_output=True, text=True, timeout=5 + ) + + @patch("pycoupler.coupler.subprocess.run") + def test_kill_process_on_port_lsof_fails(self, mock_run): + """Test when lsof command fails.""" + mock_lsof = MagicMock() + mock_lsof.returncode = 1 # lsof failed + + mock_run.return_value = mock_lsof + + result = kill_process_on_port(8080) + + assert result == 0 + assert mock_run.call_count == 1 + + @patch("pycoupler.coupler.subprocess.run") + def test_kill_process_on_port_timeout(self, mock_run): + """Test when lsof times out.""" + mock_run.side_effect = subprocess.TimeoutExpired("lsof", 5) + + result = kill_process_on_port(8080) + + assert result == -1 + assert mock_run.call_count == 1 + + @patch("pycoupler.coupler.subprocess.run") + def test_kill_process_on_port_file_not_found(self, mock_run): + """Test when lsof command is not found.""" + mock_run.side_effect = FileNotFoundError("lsof not found") + + result = kill_process_on_port(8080) + + assert result == -1 + assert mock_run.call_count == 1 + + @patch("pycoupler.coupler.subprocess.run") + def test_kill_process_on_port_kill_timeout(self, mock_run): + """Test when kill command times out.""" + mock_lsof = MagicMock() + mock_lsof.returncode = 0 + mock_lsof.stdout = "12345\n" + + mock_kill_timeout = subprocess.TimeoutExpired("kill", 5) + + mock_run.side_effect = [mock_lsof, mock_kill_timeout] + + result = kill_process_on_port(8080) + + # Should return 0 because no processes were successfully killed + assert result == 0 + assert mock_run.call_count == 2 + + @patch("pycoupler.coupler.subprocess.run") + def test_kill_process_on_port_whitespace_handling(self, mock_run): + """Test handling of whitespace in PID output.""" + mock_lsof = MagicMock() + mock_lsof.returncode = 0 + mock_lsof.stdout = " 12345 \n 67890 \n " # Extra whitespace + + mock_kill = MagicMock() + mock_kill.returncode = 0 + + mock_run.side_effect = [mock_lsof, mock_kill, mock_kill] + + result = kill_process_on_port(8080) + + assert result == 2 + # Check that strip() was applied + mock_run.assert_any_call( + ["kill", "-9", "12345"], timeout=5, capture_output=True + ) + mock_run.assert_any_call( + ["kill", "-9", "67890"], timeout=5, capture_output=True + ) + + +class TestCleanupPortOnExit: + """Test cleanup_port_on_exit function.""" + + @patch("pycoupler.coupler.atexit.register") + @patch("pycoupler.coupler.kill_process_on_port") + def test_cleanup_port_on_exit_registers(self, mock_kill, mock_atexit): + """Test that cleanup_port_on_exit registers an atexit handler.""" + cleanup_port_on_exit(8080) + + # Verify atexit.register was called + assert mock_atexit.call_count == 1 + # Get the registered function + registered_func = mock_atexit.call_args[0][0] + # Call it to verify it calls kill_process_on_port + registered_func() + mock_kill.assert_called_once_with(8080) + + +class TestCleanupPortContext: + """Test cleanup_port_context context manager.""" + + @patch("pycoupler.coupler.kill_process_on_port") + def test_cleanup_port_context_success(self, mock_kill): + """Test successful port cleanup.""" + mock_kill.return_value = 1 # Killed 1 process initially + + with cleanup_port_context(8080) as port: + assert port == 8080 + # Verify cleanup was called at start + assert mock_kill.call_count == 1 + mock_kill.assert_called_with(8080) + + # Verify cleanup was called again on exit + assert mock_kill.call_count == 2 + + @patch("pycoupler.coupler.kill_process_on_port") + def test_cleanup_port_context_exception(self, mock_kill): + """Test that cleanup happens even when exception occurs.""" + mock_kill.return_value = 0 + + try: + with cleanup_port_context(8080) as port: + assert port == 8080 + raise ValueError("Test exception") + except ValueError: + pass # Expected; cleanup runs in context manager finally block + + # Verify cleanup was called twice (start and finally) + assert mock_kill.call_count == 2 + + @patch("pycoupler.coupler.kill_process_on_port") + def test_cleanup_port_context_no_existing_processes(self, mock_kill): + """Test when no processes are using the port.""" + mock_kill.return_value = 0 # No processes killed + + with cleanup_port_context(8080) as port: + assert port == 8080 + + # Cleanup should still be called + assert mock_kill.call_count == 2 + + @patch("pycoupler.coupler.kill_process_on_port") + def test_cleanup_port_context_multiple_ports(self, mock_kill): + """Test using multiple ports sequentially.""" + with cleanup_port_context(8080) as port1: + assert port1 == 8080 + + with cleanup_port_context(8081) as port2: + assert port2 == 8081 + + # Each port should have cleanup called twice + assert mock_kill.call_count == 4 + + +class TestSafePortBinding: + """Test deprecated safe_port_binding backward compatibility.""" + + @patch("pycoupler.coupler.kill_process_on_port") + def test_safe_port_binding_delegates_to_cleanup_port_context(self, mock_kill): + """Test that safe_port_binding delegates and ignores host.""" + mock_kill.return_value = 0 + + with pytest.warns(DeprecationWarning): + with safe_port_binding("localhost", 8080) as port: + assert port == 8080 + + assert mock_kill.call_count == 2 + mock_kill.assert_any_call(8080) diff --git a/tests/test_data.py b/tests/test_data.py index 0c2f5f5..fa6697a 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,17 +1,187 @@ """Test the LPJmLData class.""" +from pathlib import Path + import numpy as np +import xarray as xr +from netCDF4 import Dataset from pycoupler.data import ( - read_data, - read_meta, - read_header, + LPJmLData, + LPJmLDataSet, + append_to_dict, get_headersize, LPJmLInputType, - append_to_dict, + read_data, + read_header, + read_meta, ) +def _sample_lpjml_data(): + """Create a tiny LPJmLData object with cell+lat/lon information.""" + cell_ids = np.array([100, 101, 200, 201]) + lat = np.array([50.0, 50.0, 49.5, 49.5]) + lon = np.array([-1.0, -0.5, -1.0, -0.5]) + time = np.array([2000, 2001]) + + values = np.arange(cell_ids.size * time.size, dtype=float).reshape( + cell_ids.size, time.size + ) + return LPJmLData( + data=values, + dims=("cell", "time"), + coords=dict( + cell=cell_ids, + time=time, + lat=("cell", lat), + lon=("cell", lon), + ), + name="soilc", + ) + + +def test_lpjmldata_transform_roundtrip(): + """Validate cell <-> lon/lat transforms match lpjmlkit behaviour.""" + data = _sample_lpjml_data() + + lon_lat = data.transform("lon_lat") + assert set(lon_lat.dims) == {"lat", "lon", "time"} + np.testing.assert_allclose(lon_lat.lat.values, np.array([50.0, 49.5])) + np.testing.assert_allclose(lon_lat.lon.values, np.array([-1.0, -0.5])) + + roundtrip = lon_lat.transform("cell") + assert set(roundtrip.dims) == {"cell", "time"} + + roundtrip_sorted = roundtrip.sortby("lon").sortby("lat", ascending=False) + expected = data.sortby("lon").sortby("lat", ascending=False) + + roundtrip_sorted = roundtrip_sorted.assign_coords( + cell=("cell", np.arange(roundtrip_sorted.sizes["cell"])) + ) + expected = expected.assign_coords(cell=("cell", np.arange(expected.sizes["cell"]))) + + xr.testing.assert_allclose(roundtrip_sorted, expected) + + +def test_lpjmldataset_transform_and_netcdf(tmp_path): + """Ensure dataset transform enables writing gridded NetCDF output.""" + soilc = _sample_lpjml_data() + ds = LPJmLDataSet({"soilc": soilc}) + + lon_lat_ds = ds.transform("lon_lat") + lon_lat_var = lon_lat_ds["soilc"] + assert {"lat", "lon"} <= set(lon_lat_var.dims) + + nc_path = tmp_path / "soilc.nc4" + lon_lat_var.to_netcdf(nc_path) + with xr.open_dataset(nc_path) as reopened: + reopened_var = reopened["soilc"].transpose(*lon_lat_var.dims) + np.testing.assert_allclose(reopened_var.values, lon_lat_var.values) + + cell_ds = lon_lat_ds.transform("cell") + cell_sorted = cell_ds["soilc"].sortby("lon").sortby("lat", ascending=False) + expected = soilc.sortby("lon").sortby("lat", ascending=False) + cell_sorted = cell_sorted.assign_coords( + cell=("cell", np.arange(cell_sorted.sizes["cell"])) + ) + expected = expected.assign_coords(cell=("cell", np.arange(expected.sizes["cell"]))) + xr.testing.assert_allclose(cell_sorted, expected) + + +def test_write_lpjmldata_netcdf_helper(tmp_path): + """Ensure helper writes grid and non-grid variables.""" + soilc = _sample_lpjml_data() + target = tmp_path / "soilc.nc4" + soilc.to_netcdf(target) + assert target.exists() + with xr.open_dataset(target) as reopened: + assert reopened["soilc"].dims == soilc.transform("lon_lat").dims + np.testing.assert_allclose( + reopened["soilc"].values, + soilc.transform("lon_lat").values, + ) + + world = LPJmLData( + data=np.array([1.0, 2.0]), + dims=("time",), + coords={"time": [2000, 2001]}, + name="world_var", + ) + world_target = tmp_path / "world_var.nc4" + world.to_netcdf(world_target) + with xr.open_dataset(world_target) as reopened: + np.testing.assert_allclose(reopened["world_var"].values, world.values) + + +def test_lpjmldata_method_to_netcdf(tmp_path): + soilc = _sample_lpjml_data() + target = tmp_path / "method_soilc.nc4" + result_path = soilc.to_netcdf(target) + assert Path(result_path).exists() + + +def test_lpjmldata_global_attrs_passthrough(tmp_path): + soilc = _sample_lpjml_data() + soilc.attrs["_global_attrs"] = {"title": "Test Title", "institution": "PIK"} + target = tmp_path / "global.nc4" + soilc.to_netcdf(target) + with xr.open_dataset(target) as reopened: + assert reopened.attrs["title"] == "Test Title" + assert reopened.attrs["institution"] == "PIK" + + +def test_netcdf_fill_values_are_finite(tmp_path): + soilc = _sample_lpjml_data() + target = tmp_path / "finite_fill.nc4" + soilc.to_netcdf(target) + + with Dataset(target) as nc: + soilc_var = nc.variables["soilc"] + assert abs(float(soilc_var._FillValue)) < 1000 + for coord_name in ("time", "lat", "lon"): + coord_var = nc.variables[coord_name] + assert "_FillValue" not in coord_var.ncattrs() + + +def test_lpjmldataset_to_netcdf_separate(tmp_path): + soilc = _sample_lpjml_data() + world = LPJmLData( + data=np.array([1.0, 2.0]), + dims=("time",), + coords={"time": [2000, 2001]}, + name="world_var", + ) + ds = LPJmLDataSet({"soilc": soilc, "world_var": world}) + + out_dir = tmp_path / "nc_out" + files = ds.to_netcdf(out_dir, file_prefix="run") + assert set(files) == {"soilc", "world_var"} + for file_path in files.values(): + assert Path(file_path).exists() + + +def test_lpjmldataset_to_netcdf_combined(tmp_path): + soilc = _sample_lpjml_data() + world = LPJmLData( + data=np.array([1.0, 2.0]), + dims=("time",), + coords={"time": [2000, 2001]}, + name="world_var", + ) + ds = LPJmLDataSet({"soilc": soilc, "world_var": world}) + + target = tmp_path / "combined.nc4" + result = ds.to_netcdf(target, per_variable=False) + assert Path(result).exists() + with xr.open_dataset(result) as reopened: + assert {"lat", "lon"} <= set(reopened["soilc"].dims) + np.testing.assert_allclose( + reopened["soilc"].values, + soilc.transform("lon_lat").values, + ) + + def test_read_data(test_path): """Test the set_config method of the LPJmLCoupler class.""" # create config for coupled run diff --git a/tests/test_run.py b/tests/test_run.py index 53365ad..b29ec92 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,8 +1,12 @@ -from pycoupler.run import submit_lpjml -import pytest +import json +from pathlib import Path from subprocess import CalledProcessError import pytest_subprocess # noqa: F401 +import pytest + +from pycoupler.run import submit_lpjml + class TestLpjSubmit: group = "copan" @@ -10,25 +14,42 @@ class TestLpjSubmit: ntasks = 256 wtime = "00:16:10" couple_script = "/some/path/to/script.py" + sbatch_job_id = "4242" + + @pytest.fixture() + def slurm_wait_state(self, request): + return getattr(request, "param", "missing") @pytest.fixture(autouse=True) - def mock_lpjsubmit(self, fp, request): + def mock_lpjsubmit(self, fp, request, slurm_wait_state, config_coupled, tmp_path): # We expect chmod to actually modify permissions fp.pass_command([fp.program("chmod"), "+x", fp.any(min=1, max=1)]) - if hasattr(request, "param") and request.param == "no mocking": + fail_mode = getattr(request, "param", None) + if fail_mode == "no mocking": return - # Register a fake process for lpjsubmit - # (see https://pytest-subprocess.readthedocs.io/en/latest/usage.html#non-exact-command-matching) # noqa: E501 - return fp.register( + + slurm_jcf_path = tmp_path / "slurm.jcf" + + def _lpjsubmit_callback(_): + slurm_text = self._build_slurm_text( + config_coupled, + has_wait=(slurm_wait_state == "present"), + ) + slurm_jcf_path.write_text(slurm_text) + + fp.register( [fp.program("lpjsubmit"), fp.any()], - stdout="Mock lpjsubmit\nSubmitted batch job 42\nsome stuff", - returncode=( - 1 - if hasattr(request, "param") and request.param == "non-zero errorcode" - else 0 - ), + stdout="Mock lpjsubmit\nSubmitted batch job 41\nsome stuff", + returncode=(1 if fail_mode == "non-zero errorcode" else 0), + callback=_lpjsubmit_callback, ) + fp.register( + ["sbatch", str(slurm_jcf_path)], + stdout=f"Submitted batch job {self.sbatch_job_id}\n", + ) + return slurm_jcf_path + @pytest.fixture() def mock_venv(self, tmp_path_factory, request): if hasattr(request, "param") and request.param == "none": @@ -47,6 +68,7 @@ def submit( sim_path, config_coupled, request, + tmp_path, ): return submit_lpjml( config_coupled, @@ -56,10 +78,11 @@ def submit( wtime=self.wtime, couple_to=self.couple_script, venv_path=mock_venv, + slurm_jcf_dir=tmp_path, ) def test_job_id(self, submit): - assert submit == "42" + assert submit == self.sbatch_job_id @pytest.mark.parametrize( "mock_lpjsubmit", @@ -97,6 +120,7 @@ def test_command(self, sim_path, config_coupled, fp, submit): self.wtime, "-couple", str(run_script_path), + "-norun", str(self.ntasks), config_coupled, ] @@ -104,6 +128,12 @@ def test_command(self, sim_path, config_coupled, fp, submit): == 1 ), "lpjsubmit should be called exactly once with correct parameters" + def test_sbatch_called(self, fp, submit, tmp_path): + slurm_jcf_path = tmp_path / "slurm.jcf" + assert ( + fp.call_count(["sbatch", str(slurm_jcf_path)]) == 1 + ), "sbatch should be invoked once with generated slurm.jcf" + @pytest.mark.parametrize( "mock_venv", [ @@ -120,9 +150,7 @@ def test_run_script(self, sim_path, config_coupled, mock_venv, request, submit): run_script_path.stat().st_mode & 0o0100 ), "run script should be executable" with run_script_path.open("r") as f: - assert ( - f.read() - == f"""#!/bin/bash + assert f.read() == f"""#!/bin/bash # Define the path to the config file config_file="{config_coupled}" @@ -131,4 +159,51 @@ def test_run_script(self, sim_path, config_coupled, mock_venv, request, submit): {f"{mock_venv}/bin/python" if mock_venv else "python3"} {self.couple_script} \ $config_file """ + + def test_slurm_wait_block_injected(self, config_coupled, submit, tmp_path): + slurm_text = (tmp_path / "slurm.jcf").read_text() + assert "couple_pid=$!" in slurm_text + assert "wait $couple_pid" in slurm_text + + @pytest.mark.parametrize("slurm_wait_state", ["present"], indirect=True) + def test_slurm_wait_block_respected( + self, config_coupled, submit, slurm_wait_state, tmp_path + ): + expected = self._build_slurm_text(config_coupled, has_wait=True) + assert (tmp_path / "slurm.jcf").read_text() == expected + + def _build_slurm_text(self, config_path: str, has_wait: bool) -> str: + couple_file = self._couple_file(config_path) + base = ( + "#!/bin/bash\n\n" + f"{couple_file} &\n\n" + "mpirun $LPJROOT/bin/lpjml args\n\n" + "rc=$?\n" + "exit $rc # exit with return code\n" + ) + if has_wait: + base = base.replace( + f"{couple_file} &\n\n", + f"{couple_file} &\ncouple_pid=$!\n\n", + 1, ) + base = base.replace( + "exit $rc # exit with return code", + 'if [ -n "${couple_pid:-}" ]; then\n' + " wait $couple_pid\n" + " couple_rc=$?\n" + " if [ $rc -eq 0 ]; then\n" + " rc=$couple_rc\n" + " fi\n" + "fi\n" + "exit $rc # exit with return code", + 1, + ) + return base + + def _couple_file(self, config_path: str) -> str: + with open(config_path) as fh: + cfg = json.load(fh) + return str( + Path(cfg["sim_path"]) / "output" / cfg["sim_name"] / "copan_lpjml.sh" + ) diff --git a/tests/test_run_additional.py b/tests/test_run_additional.py new file mode 100644 index 0000000..803ef56 --- /dev/null +++ b/tests/test_run_additional.py @@ -0,0 +1,416 @@ +"""Additional tests for run.py functions that need more coverage.""" + +import os +import subprocess +from unittest.mock import MagicMock, patch, mock_open +from subprocess import CalledProcessError, PIPE + +import pytest + +from pycoupler.run import ( + operate_lpjml, + start_lpjml, + run_lpjml, + kill_stale_lpjml_processes, +) +from pycoupler.utils import warn_deprecated_alias + + +class TestKillStaleLpjmlProcesses: + """Test kill_stale_lpjml_processes function.""" + + @patch("pycoupler.run.subprocess.run") + def test_kill_stale_lpjml_processes_by_name(self, mock_run): + """Test killing LPJmL processes by pgrep.""" + mock_pgrep = MagicMock() + mock_pgrep.returncode = 0 + mock_pgrep.stdout = "12345\n67890\n" + + mock_kill = MagicMock() + mock_kill.returncode = 0 + + mock_run.side_effect = [mock_pgrep, mock_kill, mock_kill] + + result = kill_stale_lpjml_processes(port=None, verbose=False) + + assert result == 2 + assert mock_run.call_count == 3 + mock_run.assert_any_call( + ["pgrep", "-f", "bin/lpjml"], + capture_output=True, + text=True, + timeout=5, + ) + + @patch("pycoupler.run.subprocess.run") + def test_kill_stale_lpjml_processes_with_port(self, mock_run): + """Test killing processes on a specific port.""" + mock_pgrep = MagicMock() + mock_pgrep.returncode = 0 + mock_pgrep.stdout = "" + + mock_lsof = MagicMock() + mock_lsof.returncode = 0 + mock_lsof.stdout = "11111\n" + + mock_kill = MagicMock() + mock_kill.returncode = 0 + + mock_run.side_effect = [mock_pgrep, mock_lsof, mock_kill] + + result = kill_stale_lpjml_processes(port=2224, verbose=False) + + assert result == 1 + mock_run.assert_any_call( + ["lsof", "-ti", ":2224"], + capture_output=True, + text=True, + timeout=5, + ) + + @patch("pycoupler.run.subprocess.run") + def test_kill_stale_lpjml_processes_verbose(self, mock_run): + """Test verbose output when processes are killed.""" + mock_pgrep = MagicMock() + mock_pgrep.returncode = 0 + mock_pgrep.stdout = "12345\n" + + mock_kill = MagicMock() + mock_kill.returncode = 0 + + mock_run.side_effect = [mock_pgrep, mock_kill] + + with patch("builtins.print") as mock_print: + result = kill_stale_lpjml_processes(port=None, verbose=True) + + assert result == 1 + mock_print.assert_any_call("Killed LPJmL process with PID 12345") + mock_print.assert_any_call("Total killed: 1 process(es)") + + @patch("pycoupler.run.subprocess.run") + def test_kill_stale_lpjml_processes_no_processes(self, mock_run): + """Test when no processes are found.""" + mock_pgrep = MagicMock() + mock_pgrep.returncode = 0 + mock_pgrep.stdout = "" + + mock_run.return_value = mock_pgrep + + result = kill_stale_lpjml_processes(port=None, verbose=False) + + assert result == 0 + assert mock_run.call_count == 1 + + @patch("pycoupler.run.subprocess.run") + def test_kill_stale_lpjml_processes_pgrep_fails(self, mock_run): + """Test when pgrep fails (e.g. no pgrep on Windows).""" + mock_run.side_effect = FileNotFoundError("pgrep not found") + + result = kill_stale_lpjml_processes(port=None, verbose=False) + + assert result == 0 + + @patch("pycoupler.run.subprocess.run") + def test_kill_stale_lpjml_processes_kill_fails_no_count(self, mock_run): + """Test that failed kill does not increment count.""" + mock_pgrep = MagicMock() + mock_pgrep.returncode = 0 + mock_pgrep.stdout = "12345\n" + + mock_kill = MagicMock() + mock_kill.returncode = 1 # kill failed (e.g. process already gone) + + mock_run.side_effect = [mock_pgrep, mock_kill] + + result = kill_stale_lpjml_processes(port=None, verbose=False) + + assert result == 0 + + @patch("pycoupler.run.subprocess.run") + def test_kill_stale_lpjml_processes_kill_timeout(self, mock_run): + """Test that TimeoutExpired during kill is ignored.""" + mock_pgrep = MagicMock() + mock_pgrep.returncode = 0 + mock_pgrep.stdout = "12345\n" + + mock_run.side_effect = [mock_pgrep, subprocess.TimeoutExpired("kill", 5)] + + result = kill_stale_lpjml_processes(port=None, verbose=False) + + assert result == 0 + + @patch("pycoupler.run.subprocess.run") + def test_kill_stale_lpjml_processes_port_verbose(self, mock_run): + """Test verbose output when killing process on port.""" + mock_pgrep = MagicMock() + mock_pgrep.returncode = 0 + mock_pgrep.stdout = "" + + mock_lsof = MagicMock() + mock_lsof.returncode = 0 + mock_lsof.stdout = "99999\n" + + mock_kill = MagicMock() + mock_kill.returncode = 0 + + mock_run.side_effect = [mock_pgrep, mock_lsof, mock_kill] + + with patch("builtins.print") as mock_print: + result = kill_stale_lpjml_processes(port=2224, verbose=True) + + assert result == 1 + mock_print.assert_any_call("Killed process on port 2224 with PID 99999") + mock_print.assert_any_call("Total killed: 1 process(es)") + + +class TestOperateLpjml: + """Test operate_lpjml function.""" + + @patch("pycoupler.run.Popen") + @patch("pycoupler.run.read_config") + @patch("os.path.isdir") + @patch("os.makedirs") + def test_operate_lpjml_std_to_file( + self, mock_makedirs, mock_isdir, mock_read_config, mock_popen + ): + """Test operate_lpjml with std_to_file=True.""" + # Setup initial environment + initial_env = {"I_MPI_DAPL_UD": "enable", "I_MPI_FABRICS": "shm:dapl"} + with patch.dict(os.environ, initial_env, clear=False): + # Setup mocks + mock_config = MagicMock() + mock_config.model_path = "/fake/model/path" + mock_config.sim_path = "/fake/sim/path" + mock_config.sim_name = "test_sim" + mock_read_config.return_value = mock_config + mock_isdir.return_value = True + + # Mock Popen + mock_process = MagicMock() + mock_process.returncode = 0 + mock_process.__enter__ = MagicMock(return_value=mock_process) + mock_process.__exit__ = MagicMock(return_value=False) + mock_popen.return_value = mock_process + + # Mock file opening + with patch("builtins.open", mock_open()): + operate_lpjml("/fake/config.json", std_to_file=True) + + # Verify Popen was called correctly + assert mock_popen.called + call_args = mock_popen.call_args + assert call_args[1]["cwd"] == "/fake/model/path" + assert call_args[1]["bufsize"] == 1 + assert call_args[1]["universal_newlines"] is True + + # Verify environment was reset after function completes + assert os.environ["I_MPI_DAPL_UD"] == "enable" + assert os.environ["I_MPI_FABRICS"] == "shm:dapl" + assert "I_MPI_DAPL_FABRIC" not in os.environ + + @patch("pycoupler.run.Popen") + @patch("pycoupler.run.read_config") + @patch("os.path.isdir") + @patch("os.makedirs") + @patch("os.environ", {"I_MPI_DAPL_UD": "enable", "I_MPI_FABRICS": "shm:dapl"}) + def test_operate_lpjml_std_to_console( + self, mock_makedirs, mock_isdir, mock_read_config, mock_popen + ): + """Test operate_lpjml with std_to_file=False.""" + # Setup mocks + mock_config = MagicMock() + mock_config.model_path = "/fake/model/path" + mock_config.sim_path = "/fake/sim/path" + mock_config.sim_name = "test_sim" + mock_read_config.return_value = mock_config + mock_isdir.return_value = True + + # Mock Popen with stdout/stderr + mock_process = MagicMock() + mock_process.returncode = 0 + mock_process.stdout = ["line1\n", "line2\n"] + mock_process.stderr = ["error1\n"] + mock_process.__enter__ = MagicMock(return_value=mock_process) + mock_process.__exit__ = MagicMock(return_value=False) + mock_popen.return_value = mock_process + + with patch("builtins.print"): + operate_lpjml("/fake/config.json", std_to_file=False) + + # Verify Popen was called with PIPE + call_args = mock_popen.call_args + assert call_args[1]["stdout"] == PIPE + assert call_args[1]["stderr"] == PIPE + + @patch("pycoupler.run.Popen") + @patch("pycoupler.run.read_config") + @patch("os.path.isdir") + @patch("os.makedirs") + def test_operate_lpjml_model_path_not_exists( + self, mock_makedirs, mock_isdir, mock_read_config, mock_popen + ): + """Test operate_lpjml when model_path doesn't exist.""" + mock_config = MagicMock() + mock_config.model_path = "/fake/model/path" + mock_read_config.return_value = mock_config + mock_isdir.return_value = False + + with pytest.raises(ValueError, match="Folder of model_path"): + operate_lpjml("/fake/config.json") + + @patch("pycoupler.run.Popen") + @patch("pycoupler.run.read_config") + @patch("os.path.isdir") + @patch("os.makedirs") + def test_operate_lpjml_creates_output_path( + self, mock_makedirs, mock_isdir, mock_read_config, mock_popen + ): + """Test that operate_lpjml creates output path if it doesn't exist.""" + mock_config = MagicMock() + mock_config.model_path = "/fake/model/path" + mock_config.sim_path = "/fake/sim/path" + mock_config.sim_name = "test_sim" + mock_read_config.return_value = mock_config + mock_isdir.side_effect = ( + lambda p: p == "/fake/model/path" + ) # Only model_path exists + + mock_process = MagicMock() + mock_process.returncode = 0 + mock_process.__enter__ = MagicMock(return_value=mock_process) + mock_process.__exit__ = MagicMock(return_value=False) + mock_popen.return_value = mock_process + + with patch("builtins.open", mock_open()), patch("builtins.print"): + operate_lpjml("/fake/config.json", std_to_file=True) + + # Verify output path was created + mock_makedirs.assert_called() + + @patch("pycoupler.run.Popen") + @patch("pycoupler.run.read_config") + @patch("os.path.isdir") + @patch("os.makedirs") + def test_operate_lpjml_process_error( + self, mock_makedirs, mock_isdir, mock_read_config, mock_popen + ): + """Test operate_lpjml when process returns non-zero exit code.""" + mock_config = MagicMock() + mock_config.model_path = "/fake/model/path" + mock_config.sim_path = "/fake/sim/path" + mock_config.sim_name = "test_sim" + mock_read_config.return_value = mock_config + mock_isdir.return_value = True + + mock_process = MagicMock() + mock_process.returncode = 1 + mock_process.args = ["lpjml", "/fake/config.json"] + mock_process.__enter__ = MagicMock(return_value=mock_process) + mock_process.__exit__ = MagicMock(return_value=False) + mock_popen.return_value = mock_process + + with patch("builtins.open", mock_open()), patch("os.environ", {}): + with pytest.raises(CalledProcessError): + operate_lpjml("/fake/config.json", std_to_file=True) + + +class TestStartLpjml: + """Test start_lpjml function.""" + + @patch("multiprocessing.Process") + def test_start_lpjml(self, mock_process_class): + """Test start_lpjml creates and starts a process.""" + from pycoupler.run import operate_lpjml + + mock_process = MagicMock() + mock_process_class.return_value = mock_process + + result = start_lpjml("/fake/config.json", std_to_file=True) + + # Verify Process was created with correct target and args + mock_process_class.assert_called_once_with( + target=operate_lpjml, args=("/fake/config.json", True) + ) + # Verify process was started + mock_process.start.assert_called_once() + # Verify correct process was returned + assert result == mock_process + + @patch("multiprocessing.Process") + @patch("pycoupler.run.kill_stale_lpjml_processes") + def test_start_lpjml_with_cleanup_stale(self, mock_kill, mock_process_class): + """Test start_lpjml calls kill_stale_lpjml_processes when cleanup_stale=True.""" + + mock_process = MagicMock() + mock_process_class.return_value = mock_process + + result = start_lpjml( + "/fake/config.json", + std_to_file=False, + cleanup_stale=True, + port=2224, + ) + + mock_kill.assert_called_once_with(port=2224, verbose=False) + mock_process.start.assert_called_once() + assert result == mock_process + + @patch("multiprocessing.Process") + @patch("pycoupler.run.kill_stale_lpjml_processes") + def test_start_lpjml_without_cleanup_stale(self, mock_kill, mock_process_class): + """Test start_lpjml skips cleanup when cleanup_stale=False.""" + + mock_process = MagicMock() + mock_process_class.return_value = mock_process + + start_lpjml("/fake/config.json", cleanup_stale=False) + + mock_kill.assert_not_called() + + +class TestRunLpjml: + """Test run_lpjml deprecated alias.""" + + @patch("pycoupler.run.start_lpjml") + @patch("pycoupler.run.warn_deprecated_alias") + def test_run_lpjml_calls_start_lpjml(self, mock_warn, mock_start): + """Test that run_lpjml calls start_lpjml and emits deprecation warning.""" + from pycoupler.run import start_lpjml + + mock_start.return_value = MagicMock() + + result = run_lpjml("/fake/config.json", std_to_file=False) + + # Verify deprecation warning was emitted + mock_warn.assert_called_once_with(start_lpjml, "run_lpjml", "start_lpjml") + # Verify start_lpjml was called with correct args + mock_start.assert_called_once_with("/fake/config.json", std_to_file=False) + # Verify result is returned + assert result == mock_start.return_value + + +class TestWarnDeprecatedAlias: + """Test warn_deprecated_alias function.""" + + def test_warn_deprecated_alias_callable(self): + """Test warning for callable (function).""" + + def test_func(): + pass + + with pytest.warns(DeprecationWarning, match="run_lpjml is deprecated"): + warn_deprecated_alias(test_func, "run_lpjml", "start_lpjml") + + def test_warn_deprecated_alias_instance(self): + """Test warning for instance (class method).""" + + class TestClass: + def method(self): + pass + + instance = TestClass() + + with pytest.warns( + DeprecationWarning, match="TestClass.old_method is deprecated" + ): + warn_deprecated_alias(instance, "old_method", "new_method") diff --git a/tests/test_utils.py b/tests/test_utils.py index b55ac54..02433c9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -16,7 +16,7 @@ def test_search_country(): assert netherlands == "NLD" -def test_detect_io_type(test_path): +def test_detect_io_type(test_path, tmp_path): # Test meta file detection meta_file = detect_io_type(f"{test_path}/data/output/coupled_test/grid.nc4.json") assert meta_file == "meta" @@ -40,3 +40,11 @@ def test_detect_io_type(test_path): # Test invalid file (should raise FileNotFoundError) with pytest.raises(FileNotFoundError): detect_io_type(f"{test_path}/data/non_existent_file.txt") + + # Test binary file that can't be decoded as UTF-8 (triggers UnicodeDecodeError) + # Create a binary file that's not valid UTF-8 (use tmp_path to avoid writing + # into the checked-in tests/data directory) + binary_file = tmp_path / "invalid_utf8.bin" + binary_file.write_bytes(b"\xff\xfe\x00\x01") # Invalid UTF-8 sequence + result = detect_io_type(str(binary_file)) + assert result == "raw" # Should default to 'raw' when decode fails diff --git a/tests/test_utils_additional.py b/tests/test_utils_additional.py new file mode 100644 index 0000000..a746f1a --- /dev/null +++ b/tests/test_utils_additional.py @@ -0,0 +1,105 @@ +"""Additional tests for utils.py functions that need more coverage.""" + +import json +from unittest.mock import patch + +import pytest + +from pycoupler.utils import create_subdirs, read_json + + +class TestCreateSubdirs: + """Test create_subdirs function.""" + + def test_create_subdirs_all_exist(self, tmp_path): + """Test when all subdirectories already exist.""" + base_path = tmp_path / "base" + base_path.mkdir() + (base_path / "input").mkdir() + (base_path / "output" / "test_sim").mkdir(parents=True) + (base_path / "restart").mkdir() + + result = create_subdirs(str(base_path), "test_sim") + + assert result == str(base_path) + + def test_create_subdirs_none_exist(self, tmp_path): + """Test when no subdirectories exist.""" + base_path = tmp_path / "base" + base_path.mkdir() + + with patch("builtins.print"): + result = create_subdirs(str(base_path), "test_sim") + + assert result == str(base_path) + assert (base_path / "input").exists() + assert (base_path / "output" / "test_sim").exists() + assert (base_path / "restart").exists() + + def test_create_subdirs_partial_exist(self, tmp_path): + """Test when some subdirectories exist.""" + base_path = tmp_path / "base" + base_path.mkdir() + (base_path / "input").mkdir() + # output and restart don't exist + + with patch("builtins.print"): + result = create_subdirs(str(base_path), "test_sim") + + assert result == str(base_path) + assert (base_path / "input").exists() + assert (base_path / "output" / "test_sim").exists() + assert (base_path / "restart").exists() + + def test_create_subdirs_base_path_not_exists(self, tmp_path): + """Test when base_path doesn't exist.""" + base_path = tmp_path / "nonexistent" + + with pytest.raises(OSError, match="Path.*does not exist"): + create_subdirs(str(base_path), "test_sim") + + +class TestReadJson: + """Test read_json function.""" + + def test_read_json_simple(self, tmp_path): + """Test reading a simple JSON file.""" + json_file = tmp_path / "test.json" + data = {"key": "value", "number": 42} + json_file.write_text(json.dumps(data)) + + result = read_json(str(json_file)) + + assert result == data + + def test_read_json_with_object_hook(self, tmp_path): + """Test reading JSON with object_hook.""" + json_file = tmp_path / "test.json" + data = {"key": "value"} + json_file.write_text(json.dumps(data)) + + def custom_hook(dct): + return {k.upper(): v for k, v in dct.items()} + + result = read_json(str(json_file), object_hook=custom_hook) + + assert result == {"KEY": "value"} + + def test_read_json_nested(self, tmp_path): + """Test reading nested JSON.""" + json_file = tmp_path / "test.json" + data = {"level1": {"level2": {"level3": "value"}}, "array": [1, 2, 3]} + json_file.write_text(json.dumps(data)) + + result = read_json(str(json_file)) + + assert result == data + assert result["level1"]["level2"]["level3"] == "value" + assert result["array"] == [1, 2, 3] + + def test_read_json_file_not_found(self, tmp_path): + """Test reading non-existent JSON file.""" + json_file = tmp_path / "nonexistent.json" + + with pytest.raises(FileNotFoundError): + read_json(str(json_file))