diff --git a/imas/backends/netcdf/ids_tensorizer.py b/imas/backends/netcdf/ids_tensorizer.py index 7e9e33e..f907414 100644 --- a/imas/backends/netcdf/ids_tensorizer.py +++ b/imas/backends/netcdf/ids_tensorizer.py @@ -3,7 +3,7 @@ """Tensorization logic to convert IDSs to netCDF files and/or xarray Datasets.""" from collections import deque -from typing import List +from typing import List, Tuple import numpy @@ -203,3 +203,45 @@ def tensorize(self, path, fillvalue): tmp_var[aos_coords + tuple(map(slice, node.shape))] = node.value return tmp_var + + def recursively_convert_to_list( + self, path: str, inactive_index: Tuple, shape: Tuple, i_dim: int + ): + entry = [] + for index in range(shape[i_dim]): + new_index = inactive_index + (index,) + if i_dim == len(shape) - 1: + entry.append(self.filled_data[path][new_index].value) + else: + entry.append( + self.recursively_convert_to_list(path, new_index, shape, i_dim + 1) + ) + return entry + + def awkward_tensorize(self, path: str): + """ + Tensorizes the data at the given path with the specified fill value. + + Args: + path: The path to the data in the IDS. + + Returns: + A tensor filled with the data from the specified path. + """ + if not self.filled_data[path]: + return [] + hdf5_dim = len(next(iter(self.filled_data[path]))) + + if hdf5_dim == 0: + return self.filled_data[path][()].value + + if path in self.shapes: + shape = self.shapes[path].shape[:hdf5_dim] + else: + dimensions = self.ncmeta.get_dimensions(path, self.homogeneous_time) + full_shape = tuple(self.dimension_size[dim] for dim in dimensions) + # Get the split between HDF5 indices and stored matrices + # i.e. equilibrium.time_slice.profiles_2d <-> psi + shape = full_shape[:hdf5_dim] + + return self.recursively_convert_to_list(path, tuple(), shape, 0) diff --git a/imas/test/test_wrangle.py b/imas/test/test_wrangle.py new file mode 100644 index 0000000..e7152b4 --- /dev/null +++ b/imas/test/test_wrangle.py @@ -0,0 +1,260 @@ +import pytest +import numpy as np + +try: + import awkward as ak +except ImportError as exc: + raise ImportError( + "awkward-array is required" "Install it with: pip install imas-python[awkward]" + ) from exc + +from imas.wrangler import wrangle, unwrangle +from imas.ids_factory import IDSFactory +from imas.util import idsdiffgen + + +@pytest.fixture +def test_data(): + data = {"equilibrium": {}} + data["equilibrium"]["N_time"] = 100 + data["equilibrium"]["N_radial"] = 100 + data["equilibrium"]["N_grid"] = 1 + data["equilibrium"]["time"] = np.linspace(0.0, 5.0, data["equilibrium"]["N_time"]) + data["equilibrium"]["psi_1d"] = np.linspace( + 0.0, 1.0, data["equilibrium"]["N_radial"] + ) + data["equilibrium"]["r"] = np.linspace(1.0, 2.0, data["equilibrium"]["N_radial"]) + data["equilibrium"]["z"] = np.linspace(-1.0, 1.0, data["equilibrium"]["N_radial"]) + r_grid, z_grid = np.meshgrid( + data["equilibrium"]["r"], data["equilibrium"]["z"], indexing="ij" + ) + data["equilibrium"]["psi_2d"] = (r_grid - 1.5) ** 2 + z_grid**2 + + data["thomson_scattering"] = {} + data["thomson_scattering"]["N_ch"] = (20, 10) + N = data["thomson_scattering"]["N_ch"][0] + data["thomson_scattering"]["N_ch"][1] + data["thomson_scattering"]["identifier"] = np.asarray( + "channel_" + np.asarray(np.linspace(1, N + 1, N, dtype=int), dtype="|U2"), + dtype="|U10", + ) + data["thomson_scattering"]["N_time"] = (100, 300) + data["thomson_scattering"]["r"] = np.concatenate( + [ + np.ones(data["thomson_scattering"]["N_ch"][0]) * 1.6, + np.ones(data["thomson_scattering"]["N_ch"][1]) * 1.7, + ] + ) + data["thomson_scattering"]["z"] = np.concatenate( + [ + np.linspace(-1.0, 1.0, data["thomson_scattering"]["N_ch"][0]), + np.linspace(-1.0, 1.0, data["thomson_scattering"]["N_ch"][1]), + ] + ) + data["thomson_scattering"]["t_e"] = data["thomson_scattering"]["z"] ** 2 * 5.0e3 + data["thomson_scattering"]["n_e"] = data["thomson_scattering"]["z"] ** 2 * 5.0e19 + data["thomson_scattering"]["time"] = ( + np.linspace(0, 5.0, data["thomson_scattering"]["N_time"][0]), + np.linspace(0, 5.0, data["thomson_scattering"]["N_time"][1]), + ) + return data + + +@pytest.fixture +def flat(test_data): + flat = {} + # Equilibrium test data + flat["equilibrium.time"] = test_data["equilibrium"]["time"] + flat["equilibrium.time_slice.time"] = test_data["equilibrium"]["time"] + flat["equilibrium.ids_properties.homogeneous_time"] = 1 + flat["equilibrium.time_slice.profiles_1d.psi"] = np.zeros( + (test_data["equilibrium"]["N_time"], test_data["equilibrium"]["N_radial"]) + ) + flat["equilibrium.time_slice.profiles_1d.psi"][:] = test_data["equilibrium"][ + "psi_1d" + ] + flat["equilibrium.time_slice.profiles_2d.grid.dim1"] = np.zeros( + ( + test_data["equilibrium"]["N_time"], + test_data["equilibrium"]["N_grid"], + test_data["equilibrium"]["N_radial"], + ) + ) + flat["equilibrium.time_slice.profiles_2d.grid.dim1"][:] = test_data["equilibrium"][ + "r" + ][None, :] + flat["equilibrium.time_slice.profiles_2d.grid.dim2"] = np.zeros( + ( + test_data["equilibrium"]["N_time"], + test_data["equilibrium"]["N_grid"], + test_data["equilibrium"]["N_radial"], + ) + ) + flat["equilibrium.time_slice.profiles_2d.grid.dim2"][:] = test_data["equilibrium"][ + "z" + ][None, :] + flat["equilibrium.time_slice.profiles_2d.psi"] = np.zeros( + ( + test_data["equilibrium"]["N_time"], + test_data["equilibrium"]["N_grid"], + test_data["equilibrium"]["N_radial"], + test_data["equilibrium"]["N_radial"], + ) + ) + flat["equilibrium.time_slice.profiles_2d.psi"][:] = test_data["equilibrium"][ + "psi_2d" + ][None, ...] + # Thomson scattering test data (ragged) + flat["thomson_scattering.channel.identifier"] = test_data["thomson_scattering"][ + "identifier" + ] + flat["thomson_scattering.ids_properties.homogeneous_time"] = 0 + flat["thomson_scattering.channel.t_e.time"] = ak.concatenate( + [ + np.tile( + test_data["thomson_scattering"]["time"][0], + (test_data["thomson_scattering"]["N_ch"][0], 1), + ), + np.tile( + test_data["thomson_scattering"]["time"][1], + (test_data["thomson_scattering"]["N_ch"][1], 1), + ), + ] + ) + flat["thomson_scattering.channel.t_e.data"] = ak.concatenate( + [ + np.repeat( + test_data["thomson_scattering"]["t_e"][ + : test_data["thomson_scattering"]["N_ch"][0], None + ], + test_data["thomson_scattering"]["N_time"][0], + axis=1, + ), + np.repeat( + test_data["thomson_scattering"]["t_e"][ + test_data["thomson_scattering"]["N_ch"][0] :, None + ], + test_data["thomson_scattering"]["N_time"][1], + axis=1, + ), + ] + ) + flat["thomson_scattering.channel.n_e.time"] = ak.concatenate( + [ + np.tile( + test_data["thomson_scattering"]["time"][0], + (test_data["thomson_scattering"]["N_ch"][0], 1), + ), + np.tile( + test_data["thomson_scattering"]["time"][1], + (test_data["thomson_scattering"]["N_ch"][1], 1), + ), + ] + ) + flat["thomson_scattering.channel.n_e.data"] = ak.concatenate( + [ + np.repeat( + test_data["thomson_scattering"]["n_e"][ + : test_data["thomson_scattering"]["N_ch"][0], None + ], + test_data["thomson_scattering"]["N_time"][0], + axis=1, + ), + np.repeat( + test_data["thomson_scattering"]["n_e"][ + test_data["thomson_scattering"]["N_ch"][0] :, None + ], + test_data["thomson_scattering"]["N_time"][1], + axis=1, + ), + ] + ) + flat["thomson_scattering.channel.position.r"] = test_data["thomson_scattering"]["r"] + flat["thomson_scattering.channel.position.z"] = test_data["thomson_scattering"]["z"] + return flat + + +@pytest.fixture +def test_ids_dict(test_data): + factory = IDSFactory("3.41.0") + equilibrium = factory.equilibrium() + equilibrium.time = test_data["equilibrium"]["time"] + equilibrium.time_slice.resize(test_data["equilibrium"]["N_time"]) + equilibrium.ids_properties.homogeneous_time = 1 + for i in range(test_data["equilibrium"]["N_time"]): + equilibrium.time_slice[i].time = test_data["equilibrium"]["time"][i] + equilibrium.time_slice[i].profiles_1d.psi = test_data["equilibrium"]["psi_1d"] + equilibrium.time_slice[i].profiles_2d.resize(1) + equilibrium.time_slice[i].profiles_2d[0].grid.dim1 = test_data["equilibrium"][ + "r" + ] + equilibrium.time_slice[i].profiles_2d[0].grid.dim2 = test_data["equilibrium"][ + "z" + ] + equilibrium.time_slice[i].profiles_2d[0].psi = test_data["equilibrium"][ + "psi_2d" + ] + + thomson_scattering = factory.thomson_scattering() + thomson_scattering.ids_properties.homogeneous_time = 0 + N = ( + test_data["thomson_scattering"]["N_ch"][0] + + test_data["thomson_scattering"]["N_ch"][1] + ) + thomson_scattering.channel.resize(N) + index = 0 + for i in range(N): + if i == test_data["thomson_scattering"]["N_ch"][0]: + index = 1 + thomson_scattering.channel[i].identifier = test_data["thomson_scattering"][ + "identifier" + ][i] + thomson_scattering.channel[i].t_e.time = test_data["thomson_scattering"][ + "time" + ][index] + thomson_scattering.channel[i].t_e.data = np.tile( + test_data["thomson_scattering"]["t_e"][i], + test_data["thomson_scattering"]["N_time"][index], + ) + thomson_scattering.channel[i].n_e.time = test_data["thomson_scattering"][ + "time" + ][index] + thomson_scattering.channel[i].n_e.data = np.tile( + test_data["thomson_scattering"]["n_e"][i], + test_data["thomson_scattering"]["N_time"][index], + ) + thomson_scattering.channel[i].position.r = test_data["thomson_scattering"]["r"][ + i + ] + thomson_scattering.channel[i].position.z = test_data["thomson_scattering"]["z"][ + i + ] + + return {"equilibrium": equilibrium, "thomson_scattering": thomson_scattering} + + +def test_wrangle(test_ids_dict, flat): + wrangled = wrangle(flat) + for key in test_ids_dict: + diff = idsdiffgen(wrangled[key], test_ids_dict[key]) + assert len(list(diff)) == 0, diff + + +def get_dtype(arr): + """Get dtype from either numpy or awkward array.""" + if isinstance(arr, ak.Array): + # Easiest way to extract the numpy dtype from an awkward array + return eval("np." + arr.typestr.split("*")[-1]) + if hasattr(arr, "dtype"): + return arr.dtype + else: + return type(arr) + + +def test_unwrangle(test_ids_dict, flat): + result, failed = unwrangle(list(flat.keys()), test_ids_dict) + assert len(failed) == 0, f"The following fields failed to load {failed}" + for key in flat.keys(): + if np.issubdtype(get_dtype(result[key]), np.floating): + assert ak.almost_equal(result[key], flat[key]) + else: + assert ak.array_equal(result[key], flat[key]) diff --git a/imas/wrangler.py b/imas/wrangler.py new file mode 100644 index 0000000..f44563c --- /dev/null +++ b/imas/wrangler.py @@ -0,0 +1,90 @@ +from typing import Dict, List, Tuple +import awkward as ak +import numpy as np +from . import IDSFactory +from .ids_convert import convert_ids +from .ids_toplevel import IDSToplevel +from .backends.netcdf.ids_tensorizer import IDSTensorizer + + +def recursively_put(location, value, ids): + # time_slice.profiles_1d.psi + if "." in location: + position, sub_location = location.split(".", 1) + sub_ids = getattr(ids, position) + if hasattr(sub_ids, "size"): + N = len(value) + if sub_ids.size == 0: + sub_ids.resize(N) + elif sub_ids.size != N: + raise ValueError( + f"Inconsistent size across flat entries {location}, " + f"{N} (flat) vs. ids {sub_ids.size}!" + ) + # Need to iterate over indices (e.g. equilibrium.time_slice[:].) + for index in range(N): + recursively_put(sub_location, value[index], sub_ids[index]) + else: + # Need to set an attribute + # Now get the new substring, e.g. time_slice + position, sub_location = location.split(".", 1) + recursively_put(sub_location, value, sub_ids) + else: + setattr(ids, location, value) + return ids + + +def wrangle(flat: Dict, source_version: str) -> Dict[str, IDSToplevel]: + wrangled = {} + factory = IDSFactory(source_version) + for key in flat: + ids, location = key.split(".", 1) + if ids not in wrangled: + wrangled[ids] = getattr(factory, ids)() + wrangled[ids] = recursively_put(location, flat[key], wrangled[ids]) + return wrangled + + +def split_location_across_ids(locations: List[str]) -> Dict[str, List[str]]: + ids_locations = {} + for location in locations: + ids, path = location.split(".", 1) + if ids not in ids_locations: + ids_locations[ids] = [] + ids_locations[ids].append(path.replace(".", "/")) + return ids_locations + + +def unwrangle( + locations: List[str], + ids_dict: Dict[str, IDSToplevel], + target_version: str | None = None, +) -> Tuple[Dict[str, ak.Array | np.ndarray], List[str]]: + flat = {} + ids_locations = split_location_across_ids(locations) + failed_locations = [] + for key in ids_locations: + ids = ids_dict[key] + if target_version is not None: + ids = convert_ids(ids, target_version) + tensorizer = IDSTensorizer(ids, ids_locations[key]) + tensorizer.include_coordinate_paths() + tensorizer.collect_filled_data() + tensorizer.determine_data_shapes() + # Add IDS conversion + for ids_location in ids_locations[key]: + location = key + "." + ids_location.replace("/", ".") + try: + values = tensorizer.awkward_tensorize(ids_location) + except KeyError: + failed_locations.append(location) + continue + if hasattr(values, "__getitem__"): + # Not a scalar, e.g. homogenous_time + try: + flat[location] = np.asarray(values) + except ValueError: + flat[location] = ak.Array(values) + else: + flat[location] = values + return flat, failed_locations