From 0fc6a594de43aa36a8db3ba95736fbfc28068e5f Mon Sep 17 00:00:00 2001 From: Maarten Sebregts Date: Tue, 10 Feb 2026 17:36:13 +0100 Subject: [PATCH 1/5] [WIP] Implement batched IDS consumer --- src/imas_streams/ids_consumers.py | 142 +++++++++++++++++++++++++++++- tests/test_ids_stream.py | 23 +++++ 2 files changed, 164 insertions(+), 1 deletion(-) diff --git a/src/imas_streams/ids_consumers.py b/src/imas_streams/ids_consumers.py index 4830abe..3ad38a1 100644 --- a/src/imas_streams/ids_consumers.py +++ b/src/imas_streams/ids_consumers.py @@ -1,7 +1,10 @@ import copy import numpy as np +from imas.ids_primitive import IDSPrimitive +from imas.ids_struct_array import IDSStructArray from imas.ids_toplevel import IDSToplevel +from imas.util import get_parent from imas_streams.metadata import StreamingIMASMetadata @@ -9,7 +12,7 @@ class MessageProcessor: """Logic for building data arrays from streaming IMAS messages""" - def __init__(self, metadata: "StreamingIMASMetadata"): + def __init__(self, metadata: StreamingIMASMetadata): self._metadata = metadata self._msg_buffer = memoryview(bytearray(metadata.nbytes)) readonly_view = self._msg_buffer.toreadonly() @@ -146,3 +149,140 @@ def process_message(self, data: bytes | bytearray) -> IDSToplevel: if self._return_copy: return copy.deepcopy(self._ids) return self._ids + + +# TODO: MOVE? +def get_dynamic_aos_parent(ids_node: IDSPrimitive) -> IDSStructArray: + node = get_parent(ids_node) + while node is not None and ( + not isinstance(node, IDSStructArray) + or not node.metadata.coordinates[0].is_time_coordinate + ): + node = get_parent(node) + if node is None: + raise RuntimeError( + f"IDS node {ids_node} is not part of a time-dependent Array of Structures." + ) + return node + + +def get_path_from_aos(path: str, aos: IDSStructArray) -> str: + path_parts = path.split("/") + aos_parts = aos.metadata.path.parts + return "/".join(path_parts[len(aos_parts) :]) + + +class BatchedIDSConsumser: + """Consumer of streaming IMAS data which outputs IDSs. + + This streaming IMAS data consumer produces an IDS for every N time slices. + + Example: + .. code-block:: python + + # Create metadata (from JSON): + metadata = StreamingIMASMetadata.model_validate_json(json_metadata) + # Create reader + reader = StreamingIDSConsumer(metadata) + + # Consume dynamic data + for dynamic_data in dynamic_data_stream: + ids = reader.process_message(dynamic_data) + if ids is not None: + # Use IDS + ... + """ + + def __init__( + self, metadata: StreamingIMASMetadata, batch_size: int, *, return_copy=True + ) -> None: + """Consumer of streaming IMAS data which outputs IDSs. + + Args: + metadata: Metadata of the IMAS data stream. + batch_size: Number of time slices to batch in each returned IDS. + + Keyword Args: + return_copy: See the description in StreamingIDSConsumer + """ + if batch_size < 1: + raise ValueError(f"Invalid batch size: {batch_size}") + + self._metadata = metadata + self._batch_size = batch_size + self._return_copy = return_copy + self._ids = copy.deepcopy(metadata.static_data) + self._cur_idx = 0 + + self._buffer = memoryview(bytearray(metadata.nbytes * batch_size)) + readonly_view = self._buffer.toreadonly() + dtype = " IDSToplevel | None: + nbytes = self._metadata.nbytes + if len(data) != nbytes: + raise ValueError( + f"Unexpected size of data: {len(data)}. Was expecting {nbytes}." + ) + # Update buffer + self._buffer[self._cur_idx * nbytes : (self._cur_idx + 1) * nbytes] = data + # Set scalar values + for aos, path_from_aos, idx in self._scalars: + aos[self._cur_idx][path_from_aos] = self._array_view[self._cur_idx, idx] + + # Bookkeeping + self._cur_idx += 1 + if self._cur_idx == self._batch_size: + # Completed a batch: + self._cur_idx = 0 + if self._return_copy: + return copy.deepcopy(self._ids) + return self._ids + # Batch is not finished yet + return None diff --git a/tests/test_ids_stream.py b/tests/test_ids_stream.py index 9df2ec9..3bc6881 100644 --- a/tests/test_ids_stream.py +++ b/tests/test_ids_stream.py @@ -4,6 +4,7 @@ from imas.ids_defs import CLOSEST_INTERP from imas_streams import StreamingIDSConsumer, StreamingIDSProducer +from imas_streams.ids_consumers import BatchedIDSConsumser from imas_streams.xarray_consumers import StreamingXArrayConsumer @@ -44,6 +45,28 @@ def test_stream_core_profiles(testdb): assert list(imas.util.idsdiffgen(time_slice, deserialized)) == [] +def test_stream_core_profiles_batched(testdb): + ids_name = "core_profiles" + times = testdb.get(ids_name, lazy=True).time.value + first_slice = testdb.get_slice(ids_name, times[0], CLOSEST_INTERP) + producer = StreamingIDSProducer(first_slice, static_paths=cp_static_paths) + consumer = BatchedIDSConsumser(producer.metadata, len(times), return_copy=False) + + for i, t in enumerate(times): + time_slice = testdb.get_slice(ids_name, t, CLOSEST_INTERP) + data = producer.create_message(time_slice) + + deserialized = consumer.process_message(data) + if i != len(times) - 1: + assert deserialized is None + + # Compare against full IDS + ids = testdb.get(ids_name) + # Check that the data is identical + assert deserialized is not None + assert list(imas.util.idsdiffgen(ids, deserialized)) == [] + + def test_stream_core_profiles_xarray(testdb): ids_name = "core_profiles" times = testdb.get(ids_name, lazy=True).time.value From f10f1f8ef2ebd3462e68ed69bef8d3243249ef60 Mon Sep 17 00:00:00 2001 From: Maarten Sebregts Date: Wed, 11 Feb 2026 15:17:15 +0100 Subject: [PATCH 2/5] Additional tests and documentation for batched IDS consumer --- src/imas_streams/__init__.py | 3 ++- src/imas_streams/ids_consumers.py | 39 +++++++++------------------ src/imas_streams/imas_utils.py | 45 +++++++++++++++++++++++++++++++ tests/test_ids_consumer.py | 31 +++++++++++++++++++-- tests/test_ids_stream.py | 5 ++-- 5 files changed, 90 insertions(+), 33 deletions(-) create mode 100644 src/imas_streams/imas_utils.py diff --git a/src/imas_streams/__init__.py b/src/imas_streams/__init__.py index c8d2929..ff6a508 100644 --- a/src/imas_streams/__init__.py +++ b/src/imas_streams/__init__.py @@ -1,9 +1,10 @@ -from imas_streams.ids_consumers import StreamingIDSConsumer +from imas_streams.ids_consumers import BatchedIDSConsumer, StreamingIDSConsumer from imas_streams.metadata import DynamicData, StreamingIMASMetadata from imas_streams.producer import StreamingIDSProducer __all__ = [ "DynamicData", + "BatchedIDSConsumer", "StreamingIDSConsumer", "StreamingIDSProducer", "StreamingIMASMetadata", diff --git a/src/imas_streams/ids_consumers.py b/src/imas_streams/ids_consumers.py index 3ad38a1..392b682 100644 --- a/src/imas_streams/ids_consumers.py +++ b/src/imas_streams/ids_consumers.py @@ -1,11 +1,9 @@ import copy import numpy as np -from imas.ids_primitive import IDSPrimitive -from imas.ids_struct_array import IDSStructArray from imas.ids_toplevel import IDSToplevel -from imas.util import get_parent +from imas_streams.imas_utils import get_dynamic_aos_ancestor, get_path_from_aos from imas_streams.metadata import StreamingIMASMetadata @@ -151,28 +149,7 @@ def process_message(self, data: bytes | bytearray) -> IDSToplevel: return self._ids -# TODO: MOVE? -def get_dynamic_aos_parent(ids_node: IDSPrimitive) -> IDSStructArray: - node = get_parent(ids_node) - while node is not None and ( - not isinstance(node, IDSStructArray) - or not node.metadata.coordinates[0].is_time_coordinate - ): - node = get_parent(node) - if node is None: - raise RuntimeError( - f"IDS node {ids_node} is not part of a time-dependent Array of Structures." - ) - return node - - -def get_path_from_aos(path: str, aos: IDSStructArray) -> str: - path_parts = path.split("/") - aos_parts = aos.metadata.path.parts - return "/".join(path_parts[len(aos_parts) :]) - - -class BatchedIDSConsumser: +class BatchedIDSConsumer: """Consumer of streaming IMAS data which outputs IDSs. This streaming IMAS data consumer produces an IDS for every N time slices. @@ -183,10 +160,12 @@ class BatchedIDSConsumser: # Create metadata (from JSON): metadata = StreamingIMASMetadata.model_validate_json(json_metadata) # Create reader - reader = StreamingIDSConsumer(metadata) + reader = BatchedIDSConsumer(metadata, batch_size=100) # Consume dynamic data for dynamic_data in dynamic_data_stream: + # process_message returns an IDS after every 100 (=batch_size) messages + # and None otherwise: ids = reader.process_message(dynamic_data) if ids is not None: # Use IDS @@ -243,7 +222,7 @@ def __init__( assert ids_node.value is dataview else: # This is a dynamic variable inside a time-dependent AoS: find that aos - aos = get_dynamic_aos_parent(ids_node) + aos = get_dynamic_aos_ancestor(ids_node) # First ensure there's an entry for every batch_size time slices: if len(aos) != batch_size: assert len(aos) == 1 @@ -265,6 +244,12 @@ def __init__( idx += n def process_message(self, data: bytes | bytearray) -> IDSToplevel | None: + """Process a single streaming IMAS message. + + This method returns None until a full batch is completed. Once ``batch_size`` + messages are processed a single IDSToplevel is returned, which contains all data + from the ``batch_size`` messages. + """ nbytes = self._metadata.nbytes if len(data) != nbytes: raise ValueError( diff --git a/src/imas_streams/imas_utils.py b/src/imas_streams/imas_utils.py new file mode 100644 index 0000000..a8bd7e1 --- /dev/null +++ b/src/imas_streams/imas_utils.py @@ -0,0 +1,45 @@ +import imas # noqa: F401 -- module required in doctests +from imas.ids_primitive import IDSPrimitive +from imas.ids_struct_array import IDSStructArray +from imas.util import get_parent + + +def get_dynamic_aos_ancestor(ids_node: IDSPrimitive) -> IDSStructArray: + """Returns the dynamic Arrays of Structures ancestor for the provided node. + + Examples: + >>> cp = imas.IDSFactory("4.0.0").core_profiles() + >>> cp.profiles_1d.resize(1) + >>> get_dynamic_aos_ancestor(cp.profiles_1d[0].zeff) is cp.profiles_1d + True + >>> eq = imas.IDSFactory("4.0.0").equilibrium() + >>> eq.time_slice.resize(1) + >>> eq.time_slice[0].profiles_2d.resize(1) + >>> aos_ancestor = get_dynamic_aos_ancestor(eq.time_slice[0].profiles_2d[0].psi) + >>> aos_ancestor is eq.time_slice + True + """ + node = get_parent(ids_node) + while node is not None and ( + not isinstance(node, IDSStructArray) + or not node.metadata.coordinates[0].is_time_coordinate + ): + node = get_parent(node) + if node is None: + raise RuntimeError( + f"IDS node {ids_node} is not part of a time-dependent Array of Structures." + ) + return node + + +def get_path_from_aos(path: str, aos: IDSStructArray) -> str: + """Get the component of path relative to the provided Arrays of Structures ancestor. + + Examples: + >>> cp = imas.IDSFactory("4.0.0").core_profiles() + >>> get_path_from_aos("profiles_1d[0]/ion[1]/temperature", cp.profiles_1d) + 'ion[1]/temperature' + """ + path_parts = path.split("/") + aos_parts = aos.metadata.path.parts + return "/".join(path_parts[len(aos_parts) :]) diff --git a/tests/test_ids_consumer.py b/tests/test_ids_consumer.py index 463f3fa..a4f300e 100644 --- a/tests/test_ids_consumer.py +++ b/tests/test_ids_consumer.py @@ -3,8 +3,9 @@ import imas import numpy as np import pytest +from imas.ids_defs import IDS_TIME_MODE_HOMOGENEOUS -from imas_streams import StreamingIDSConsumer, StreamingIMASMetadata +from imas_streams import BatchedIDSConsumer, StreamingIDSConsumer, StreamingIMASMetadata from imas_streams.metadata import DynamicData DD_VERSION = os.getenv("IMAS_VERSION", "4.0.0") @@ -14,7 +15,7 @@ def magnetics_metadata(): ids = imas.IDSFactory(DD_VERSION).new("magnetics") - ids.ids_properties.homogeneous_time = imas.ids_defs.IDS_TIME_MODE_HOMOGENEOUS + ids.ids_properties.homogeneous_time = IDS_TIME_MODE_HOMOGENEOUS ids.time = [0.0] ids.flux_loop.resize(5) @@ -107,3 +108,29 @@ def test_streaming_reader_nocopy(magnetics_metadata): assert ids is ids2 assert ids2.time[0] == 1.0 assert len(ids2.flux_loop) == 0 + + +@pytest.mark.parametrize("batch_size", [1, 2, 5, 10, 20]) +def test_batched_reader(magnetics_metadata, batch_size): + reader = BatchedIDSConsumer(magnetics_metadata, batch_size) + + # Pretend sending 20 messages + for i in range(20): + test_data = np.arange(len(magnetics_metadata.dynamic_data), dtype=" Date: Wed, 11 Feb 2026 17:01:56 +0100 Subject: [PATCH 3/5] Cache metadata.nbytes instead of recalculating every time --- src/imas_streams/ids_consumers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/imas_streams/ids_consumers.py b/src/imas_streams/ids_consumers.py index 392b682..88baaed 100644 --- a/src/imas_streams/ids_consumers.py +++ b/src/imas_streams/ids_consumers.py @@ -193,7 +193,8 @@ def __init__( self._ids = copy.deepcopy(metadata.static_data) self._cur_idx = 0 - self._buffer = memoryview(bytearray(metadata.nbytes * batch_size)) + self._msg_bytes = metadata.nbytes + self._buffer = memoryview(bytearray(self._msg_bytes * batch_size)) readonly_view = self._buffer.toreadonly() dtype = " IDSToplevel | None: messages are processed a single IDSToplevel is returned, which contains all data from the ``batch_size`` messages. """ - nbytes = self._metadata.nbytes + nbytes = self._msg_bytes if len(data) != nbytes: raise ValueError( f"Unexpected size of data: {len(data)}. Was expecting {nbytes}." From 99adfb96cd41c8f365bfafa1dc42d3a8190e7dfb Mon Sep 17 00:00:00 2001 From: Maarten Sebregts Date: Thu, 12 Feb 2026 16:59:24 +0100 Subject: [PATCH 4/5] BatchedIDSConsumer: handle remaining data at stream end --- src/imas_streams/ids_consumers.py | 44 +++++++++++++- src/imas_streams/kafka.py | 4 ++ src/imas_streams/protocols.py | 1 + src/imas_streams/xarray_consumers.py | 4 ++ tests/test_ids_consumer.py | 89 ++++++++++++++++++++++++---- 5 files changed, 129 insertions(+), 13 deletions(-) diff --git a/src/imas_streams/ids_consumers.py b/src/imas_streams/ids_consumers.py index 88baaed..fee56af 100644 --- a/src/imas_streams/ids_consumers.py +++ b/src/imas_streams/ids_consumers.py @@ -148,6 +148,10 @@ def process_message(self, data: bytes | bytearray) -> IDSToplevel: return copy.deepcopy(self._ids) return self._ids + def finalize(self) -> None: + """Indicate that the final message is received and return any remaining data.""" + return None # No data remaining + class BatchedIDSConsumer: """Consumer of streaming IMAS data which outputs IDSs. @@ -192,6 +196,7 @@ def __init__( self._return_copy = return_copy self._ids = copy.deepcopy(metadata.static_data) self._cur_idx = 0 + self._finished = False self._msg_bytes = metadata.nbytes self._buffer = memoryview(bytearray(self._msg_bytes * batch_size)) @@ -209,9 +214,8 @@ def __init__( ids_node = self._ids[dyndata.path] assert ids_node.metadata.type.is_dynamic n = np.prod(dyndata.shape, dtype=int) - if ( - dyndata.path == "time" - or ids_node.metadata.ndim + if dyndata.path == "time" or ( + ids_node.metadata.ndim and ids_node.metadata.coordinates[0].is_time_coordinate ): # Great! This IDS node is time-dependent by itself, and we can create a @@ -251,6 +255,8 @@ def process_message(self, data: bytes | bytearray) -> IDSToplevel | None: messages are processed a single IDSToplevel is returned, which contains all data from the ``batch_size`` messages. """ + if self._finished: + raise RuntimeError("") nbytes = self._msg_bytes if len(data) != nbytes: raise ValueError( @@ -272,3 +278,35 @@ def process_message(self, data: bytes | bytearray) -> IDSToplevel | None: return self._ids # Batch is not finished yet return None + + def finalize(self) -> IDSToplevel | None: + """Indicate that the final message is received and return any remaining data. + + Returns: + IDS with as many time slices as were remaining, or None in case no data was + remaining. + """ + self._finished = True + n_time = self._cur_idx + if n_time == 0: + return None # No data remaining, easy! + + # Resize dynamic quantities in the IDS: + for dyndata in self._metadata.dynamic_data: + ids_node = self._ids[dyndata.path] + if dyndata.path == "time" or ( + ids_node.metadata.ndim + and ids_node.metadata.coordinates[0].is_time_coordinate + ): + # Great! This IDS node is time-dependent by itself, and we just need to + # create a smaller view of the data: + ids_node.value = ids_node.value[:n_time] + else: + # This is a dynamic variable inside a time-dependent AoS: find that aos + aos = get_dynamic_aos_ancestor(ids_node) + # And resize it: + if len(aos) != n_time: + assert len(aos) == self._batch_size + aos.resize(n_time, keep=True) + + return self._ids diff --git a/src/imas_streams/kafka.py b/src/imas_streams/kafka.py index 19da882..580803c 100644 --- a/src/imas_streams/kafka.py +++ b/src/imas_streams/kafka.py @@ -235,6 +235,10 @@ def stream(self, *, timeout=DEFAULT_KAFKA_CONSUMER_TIMEOUT) -> Iterator[Any]: self._settings.topic_name, timeout, ) + # Yield any remaining data + result = self._stream_consumer.finalize() + if result is not None: + yield result break if msg.error(): raise msg.error() diff --git a/src/imas_streams/protocols.py b/src/imas_streams/protocols.py index 197dadf..e732aae 100644 --- a/src/imas_streams/protocols.py +++ b/src/imas_streams/protocols.py @@ -8,3 +8,4 @@ class StreamConsumer(Protocol): def __init__(self, metadata: StreamingIMASMetadata, **kwargs) -> None: ... def process_message(self, data: bytes | bytearray) -> Any: ... + def finalize(self) -> Any: ... diff --git a/src/imas_streams/xarray_consumers.py b/src/imas_streams/xarray_consumers.py index e981f5d..efae1ee 100644 --- a/src/imas_streams/xarray_consumers.py +++ b/src/imas_streams/xarray_consumers.py @@ -161,3 +161,7 @@ def process_message(self, data: bytes | bytearray) -> "xarray.Dataset": self._msg_buffer[:] = data self._tensor_buffer[self._index_array] = self._array_view return self._dataset + + def finalize(self) -> None: + """Indicate that the final message is received and return any remaining data.""" + return None # No data remaining diff --git a/tests/test_ids_consumer.py b/tests/test_ids_consumer.py index a4f300e..b003868 100644 --- a/tests/test_ids_consumer.py +++ b/tests/test_ids_consumer.py @@ -110,27 +110,96 @@ def test_streaming_reader_nocopy(magnetics_metadata): assert len(ids2.flux_loop) == 0 -@pytest.mark.parametrize("batch_size", [1, 2, 5, 10, 20]) +@pytest.mark.parametrize("batch_size", [1, 2, 5, 7, 10, 13, 20]) def test_batched_reader(magnetics_metadata, batch_size): reader = BatchedIDSConsumer(magnetics_metadata, batch_size) + def check_data(ids, expected_time): + assert np.array_equal(ids.time, expected_time) + assert np.array_equal(ids.flux_loop[0].flux.data, expected_time + 1) + assert np.array_equal(ids.flux_loop[1].flux.data, expected_time + 2) + assert np.array_equal(ids.flux_loop[2].flux.data, expected_time + 3) + assert np.array_equal(ids.flux_loop[3].flux.data, expected_time + 4) + assert np.array_equal(ids.flux_loop[4].flux.data, expected_time + 5) + assert np.array_equal(ids.flux_loop[0].voltage.data, expected_time + 6) + # Pretend sending 20 messages for i in range(20): test_data = np.arange(len(magnetics_metadata.dynamic_data), dtype=" Date: Fri, 13 Feb 2026 10:09:04 +0100 Subject: [PATCH 5/5] Take `return_copy` into account when finalizing batched IDS consumer --- src/imas_streams/ids_consumers.py | 2 ++ src/imas_streams/kafka.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/src/imas_streams/ids_consumers.py b/src/imas_streams/ids_consumers.py index fee56af..bfcb9cf 100644 --- a/src/imas_streams/ids_consumers.py +++ b/src/imas_streams/ids_consumers.py @@ -309,4 +309,6 @@ def finalize(self) -> IDSToplevel | None: assert len(aos) == self._batch_size aos.resize(n_time, keep=True) + if self._return_copy: + return copy.deepcopy(self._ids) return self._ids diff --git a/src/imas_streams/kafka.py b/src/imas_streams/kafka.py index 580803c..78defd0 100644 --- a/src/imas_streams/kafka.py +++ b/src/imas_streams/kafka.py @@ -225,6 +225,9 @@ def stream(self, *, timeout=DEFAULT_KAFKA_CONSUMER_TIMEOUT) -> Iterator[Any]: Yields: Data produced by the StreamConsumer, e.g. an IDS for the StreamingIDSConsumer. + + For batching consumers (such as the BatchingIDSConsumer) the last yielded + value may contain fewer time slices than the batch size. """ try: while True: