From df4b96f4e63f0ab59b92483232356440b5294b18 Mon Sep 17 00:00:00 2001 From: Maarten Sebregts Date: Mon, 9 Feb 2026 17:03:45 +0100 Subject: [PATCH 1/4] Add consumer class that produces an xarray Dataset --- pyproject.toml | 5 + src/imas_streams/xarray_consumers.py | 155 +++++++++++++++++++++++++++ tests/test_ids_stream.py | 55 ++++++---- tests/test_xarray_consumer.py | 95 ++++++++++++++++ 4 files changed, 292 insertions(+), 18 deletions(-) create mode 100644 src/imas_streams/xarray_consumers.py create mode 100644 tests/test_xarray_consumer.py diff --git a/pyproject.toml b/pyproject.toml index bae68ec..bac18b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ test = [ "pytest-cov", "pytest-xdist", "pytest-randomly", + "imaspy[xarray]", ] [project.urls] @@ -48,6 +49,10 @@ local_scheme = "no-local-version" [tool.pytest.ini_options] testpaths = [ "tests", + "src", +] +addopts = [ + "--doctest-modules", ] [tool.ruff] diff --git a/src/imas_streams/xarray_consumers.py b/src/imas_streams/xarray_consumers.py new file mode 100644 index 0000000..6ef1363 --- /dev/null +++ b/src/imas_streams/xarray_consumers.py @@ -0,0 +1,155 @@ +import itertools +import re +from collections.abc import Iterable +from typing import TYPE_CHECKING + +import numpy as np +from imas.util import to_xarray + +from imas_streams import StreamingIMASMetadata + +if TYPE_CHECKING: + import xarray + + +_index_pattern = re.compile(r"\[(\d+)\]") + + +def path_to_xarray_name_and_indices(path: str) -> tuple[str, tuple[int, ...]]: + """Convert the IMAS DD path to its tensorized xarray name and corresponding indices. + + Examples: + >>> path_to_xarray_name_and_indices("time") + ('time', ()) + >>> path_to_xarray_name_and_indices("profiles_1d[0]/grid/rho_tor_norm") + ('profiles_1d.grid.rho_tor_norm', (0,)) + >>> path_to_xarray_name_and_indices("profiles_1d[0]/ion[2]/temperature") + ('profiles_1d.ion.temperature', (0, 2)) + """ + indices = tuple(int(match.group(1)) for match in _index_pattern.finditer(path)) + path = _index_pattern.sub("", path).replace("/", ".") + return path, indices + + +def np_address_of(arr: np.ndarray) -> int: + """Return the memory address of the first item in the provided numpy array.""" + return arr.__array_interface__["data"][0] + + +def offset_in_array(array: np.ndarray, index: Iterable[int]) -> int: + """Calculate the offset (in bytes) of the provided index in the array. + + Examples: + >>> array = np.arange(15, dtype=float).reshape(3, 5) + >>> offset_in_array(array, (0, 0)) # First item is stored at offset 0 + 0 + >>> offset_in_array(array, (0, 1)) # Second item is offset by 8 bytes + 8 + >>> offset_in_array(array, (1, 0)) # Second row is offset by 5*8 bytes + 40 + """ + return sum(i * stride for i, stride in zip(index, array.strides, strict=False)) + + +class StreamingXArrayConsumer: + """Consumer of streaming IMAS data which outputs xarray.Datasets. + + This streaming IMAS data consumer updates an xarray.Dataset for each time slice. + + Example: + .. code-block:: python + + # Create metadata (from JSON) + metadata = StreamingIMASMetadata.model_validate_json(json_metadata) + # Create reader + reader = StreamingXArrayConsumer(metadata) + + # Consume dynamic data + for dynamic_data in dynamic_data_stream: + ds = reader.process_message(dynamic_data) + # Use Dataset + ... + """ + + def __init__(self, metadata: StreamingIMASMetadata) -> None: + """Consumer of streaming IMAS data which outputs xarray.Datasets. + + Args: + metadata: Metadata of the IMAS data stream. + """ + self._metadata = metadata + ids = metadata.static_data + # Add entries for dynamic data in the IDS, so the IMAS-Python to_xarray will + # create the corresponding xarray.DataArrays for us + for dyndata in metadata.dynamic_data: + ids[dyndata.path].value = np.zeros(dyndata.shape) + self._dataset = to_xarray(ids) + + # Setup array view buffer + buffersize = 0 + tensorized_paths = [] + for dyndata in metadata.dynamic_data: + assert dyndata.data_type == "f64" + path = path_to_xarray_name_and_indices(dyndata.path)[0] + if path not in tensorized_paths: + tensorized_paths.append(path) + buffersize += self._dataset[path].size + dtype = " "xarray.Dataset": + """Process a dynamic data message and return the resulting xarray Dataset. + + Note that for efficiency we return the same dataset with each call. You should + not modify the dataset in-place, or future calls to this method may not work + correctly. + + Args: + data: Binary data corresponding to one time slice of dynamic data. + """ + if len(data) != len(self._msg_buffer): + raise ValueError( + f"Unexpected size of data: {len(data)}. " + f"Was expecting {len(self._msg_buffer)}." + ) + # Copy data to internal buffer, then write into the tensor view: + self._msg_buffer[:] = data + self._tensor_buffer[self._index_array] = self._array_view + return self._dataset diff --git a/tests/test_ids_stream.py b/tests/test_ids_stream.py index 1ff57c8..9df2ec9 100644 --- a/tests/test_ids_stream.py +++ b/tests/test_ids_stream.py @@ -1,8 +1,10 @@ # Integration tests for streaming IDSs import imas.training import pytest +from imas.ids_defs import CLOSEST_INTERP from imas_streams import StreamingIDSConsumer, StreamingIDSProducer +from imas_streams.xarray_consumers import StreamingXArrayConsumer @pytest.fixture(scope="module") @@ -11,32 +13,49 @@ def testdb(): yield entry +# INT and STR are not supported for streaming, but it is actually static data: +cp_static_paths = [ + "profiles_1d/ion/element/z_n", + "profiles_1d/ion/element/atoms_n", + "profiles_1d/ion/name", + "profiles_1d/ion/neutral_index", + "profiles_1d/ion/multiple_states_flag", + "profiles_1d/neutral/element/z_n", + "profiles_1d/neutral/element/atoms_n", + "profiles_1d/neutral/name", + "profiles_1d/neutral/ion_index", + "profiles_1d/neutral/multiple_states_flag", +] + + def test_stream_core_profiles(testdb): ids_name = "core_profiles" - # INT and STR are not supported for streaming, but it is actually static data: - static_paths = [ - "profiles_1d/ion/element/z_n", - "profiles_1d/ion/element/atoms_n", - "profiles_1d/ion/name", - "profiles_1d/ion/neutral_index", - "profiles_1d/ion/multiple_states_flag", - "profiles_1d/neutral/element/z_n", - "profiles_1d/neutral/element/atoms_n", - "profiles_1d/neutral/name", - "profiles_1d/neutral/ion_index", - "profiles_1d/neutral/multiple_states_flag", - ] - times = testdb.get(ids_name, lazy=True).time.value - - first_slice = testdb.get_slice(ids_name, times[0], imas.ids_defs.CLOSEST_INTERP) - producer = StreamingIDSProducer(first_slice, static_paths=static_paths) + first_slice = testdb.get_slice(ids_name, times[0], CLOSEST_INTERP) + producer = StreamingIDSProducer(first_slice, static_paths=cp_static_paths) consumer = StreamingIDSConsumer(producer.metadata, return_copy=False) for t in times: - time_slice = testdb.get_slice(ids_name, t, imas.ids_defs.CLOSEST_INTERP) + time_slice = testdb.get_slice(ids_name, t, CLOSEST_INTERP) data = producer.create_message(time_slice) deserialized = consumer.process_message(data) # Check that the data is identical assert list(imas.util.idsdiffgen(time_slice, deserialized)) == [] + + +def test_stream_core_profiles_xarray(testdb): + ids_name = "core_profiles" + times = testdb.get(ids_name, lazy=True).time.value + first_slice = testdb.get_slice(ids_name, times[0], CLOSEST_INTERP) + producer = StreamingIDSProducer(first_slice, static_paths=cp_static_paths) + consumer = StreamingXArrayConsumer(producer.metadata) + + for t in times: + time_slice = testdb.get_slice(ids_name, t, CLOSEST_INTERP) + data = producer.create_message(time_slice) + + xrds_orig = imas.util.to_xarray(time_slice) + xrds_deserialized = consumer.process_message(data) + # Check that both datasets are identical + assert xrds_orig.equals(xrds_deserialized) diff --git a/tests/test_xarray_consumer.py b/tests/test_xarray_consumer.py new file mode 100644 index 0000000..7846ac8 --- /dev/null +++ b/tests/test_xarray_consumer.py @@ -0,0 +1,95 @@ +import os + +import imas +import numpy as np +import pytest + +from imas_streams import StreamingIMASMetadata +from imas_streams.metadata import DynamicData +from imas_streams.xarray_consumers import StreamingXArrayConsumer + +DD_VERSION = os.getenv("IMAS_VERSION", "4.0.0") + + +@pytest.fixture +def magnetics_metadata(): + ids = imas.IDSFactory(DD_VERSION).new("magnetics") + + ids.ids_properties.homogeneous_time = imas.ids_defs.IDS_TIME_MODE_HOMOGENEOUS + ids.time = [0.0] + + ids.flux_loop.resize(5) + for i, loop in enumerate(ids.flux_loop): + loop.name = f"flux_loop_{i}" + loop.position.resize(1) + loop.position[0].r = i / 2 + loop.position[0].z = i / 2 + + return StreamingIMASMetadata( + data_dictionary_version=DD_VERSION, + ids_name="magnetics", + static_data=ids, + dynamic_data=[ + DynamicData(path="time", shape=(1,), data_type="f64"), + DynamicData(path="flux_loop[0]/flux/data", shape=(1,), data_type="f64"), + DynamicData(path="flux_loop[1]/flux/data", shape=(1,), data_type="f64"), + DynamicData(path="flux_loop[2]/flux/data", shape=(1,), data_type="f64"), + DynamicData(path="flux_loop[3]/flux/data", shape=(1,), data_type="f64"), + DynamicData(path="flux_loop[4]/flux/data", shape=(1,), data_type="f64"), + DynamicData(path="flux_loop[0]/voltage/data", shape=(1,), data_type="f64"), + ], + ) + + +def test_xarray_consumer(magnetics_metadata): + consumer = StreamingXArrayConsumer(magnetics_metadata) + + data = np.arange(7, dtype=" Date: Mon, 9 Feb 2026 17:12:45 +0100 Subject: [PATCH 2/4] Use importlib import mode to fix doctest failures --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index bac18b4..352f542 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ testpaths = [ ] addopts = [ "--doctest-modules", + "--import-mode=importlib", ] [tool.ruff] From b215f430fdfa457d302ec5344ac3dd412c454c79 Mon Sep 17 00:00:00 2001 From: Maarten Sebregts Date: Mon, 9 Feb 2026 17:16:15 +0100 Subject: [PATCH 3/4] Fix incorrect imas-python package name --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 352f542..250caaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ test = [ "pytest-cov", "pytest-xdist", "pytest-randomly", - "imaspy[xarray]", + "imas-python[xarray]", ] [project.urls] From 5a29d5b572c8b11454952c213e8411afb5bbd0e0 Mon Sep 17 00:00:00 2001 From: Maarten Sebregts Date: Tue, 10 Feb 2026 11:35:56 +0100 Subject: [PATCH 4/4] Ensure pandas Index doesn't make a copy of the data --- src/imas_streams/xarray_consumers.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/imas_streams/xarray_consumers.py b/src/imas_streams/xarray_consumers.py index 6ef1363..e981f5d 100644 --- a/src/imas_streams/xarray_consumers.py +++ b/src/imas_streams/xarray_consumers.py @@ -84,6 +84,8 @@ def __init__(self, metadata: StreamingIMASMetadata) -> None: for dyndata in metadata.dynamic_data: ids[dyndata.path].value = np.zeros(dyndata.shape) self._dataset = to_xarray(ids) + # pandas is optional (through IMAS-Python), so import locally + from pandas import Index # Setup array view buffer buffersize = 0 @@ -101,6 +103,7 @@ def __init__(self, metadata: StreamingIMASMetadata) -> None: # Setup array views tensor_idx = 0 to_update = {} + tensorviews = {} for path in tensorized_paths: xrda = self._dataset[path] # Fill tensor buffer with initial values of data array @@ -109,11 +112,16 @@ def __init__(self, metadata: StreamingIMASMetadata) -> None: # And put a readonly view of the tensor buffer back buffer = readonly_view[tensor_idx : tensor_idx + size] tensorview = np.frombuffer(buffer, dtype=dtype).reshape(xrda.shape) + tensorviews[path] = tensorview + if path in self._dataset.indexes: + # Prevent xarray from creating a copy of the data: + tensorview = Index(tensorview, copy=False) to_update[path] = (xrda.dims, tensorview) tensor_idx += size self._dataset = self._dataset.assign(to_update) - for path, (_, tensorview) in to_update.items(): - assert self._dataset[path].data is tensorview + # Check that all data arrays are indeed views of our tensor buffer: + for path, tensorview in tensorviews.items(): + assert np_address_of(self._dataset[path].data) == np_address_of(tensorview) # Set up the index array for writing received messages into the tensor buffer: self._index_array = np.zeros(metadata.nbytes // 8, dtype=int)