diff --git a/src/imas_streams/__init__.py b/src/imas_streams/__init__.py index c8d2929..ff6a508 100644 --- a/src/imas_streams/__init__.py +++ b/src/imas_streams/__init__.py @@ -1,9 +1,10 @@ -from imas_streams.ids_consumers import StreamingIDSConsumer +from imas_streams.ids_consumers import BatchedIDSConsumer, StreamingIDSConsumer from imas_streams.metadata import DynamicData, StreamingIMASMetadata from imas_streams.producer import StreamingIDSProducer __all__ = [ "DynamicData", + "BatchedIDSConsumer", "StreamingIDSConsumer", "StreamingIDSProducer", "StreamingIMASMetadata", diff --git a/src/imas_streams/ids_consumers.py b/src/imas_streams/ids_consumers.py index 4830abe..bfcb9cf 100644 --- a/src/imas_streams/ids_consumers.py +++ b/src/imas_streams/ids_consumers.py @@ -3,13 +3,14 @@ import numpy as np from imas.ids_toplevel import IDSToplevel +from imas_streams.imas_utils import get_dynamic_aos_ancestor, get_path_from_aos from imas_streams.metadata import StreamingIMASMetadata class MessageProcessor: """Logic for building data arrays from streaming IMAS messages""" - def __init__(self, metadata: "StreamingIMASMetadata"): + def __init__(self, metadata: StreamingIMASMetadata): self._metadata = metadata self._msg_buffer = memoryview(bytearray(metadata.nbytes)) readonly_view = self._msg_buffer.toreadonly() @@ -146,3 +147,168 @@ def process_message(self, data: bytes | bytearray) -> IDSToplevel: if self._return_copy: return copy.deepcopy(self._ids) return self._ids + + def finalize(self) -> None: + """Indicate that the final message is received and return any remaining data.""" + return None # No data remaining + + +class BatchedIDSConsumer: + """Consumer of streaming IMAS data which outputs IDSs. + + This streaming IMAS data consumer produces an IDS for every N time slices. + + Example: + .. code-block:: python + + # Create metadata (from JSON): + metadata = StreamingIMASMetadata.model_validate_json(json_metadata) + # Create reader + reader = BatchedIDSConsumer(metadata, batch_size=100) + + # Consume dynamic data + for dynamic_data in dynamic_data_stream: + # process_message returns an IDS after every 100 (=batch_size) messages + # and None otherwise: + ids = reader.process_message(dynamic_data) + if ids is not None: + # Use IDS + ... + """ + + def __init__( + self, metadata: StreamingIMASMetadata, batch_size: int, *, return_copy=True + ) -> None: + """Consumer of streaming IMAS data which outputs IDSs. + + Args: + metadata: Metadata of the IMAS data stream. + batch_size: Number of time slices to batch in each returned IDS. + + Keyword Args: + return_copy: See the description in StreamingIDSConsumer + """ + if batch_size < 1: + raise ValueError(f"Invalid batch size: {batch_size}") + + self._metadata = metadata + self._batch_size = batch_size + self._return_copy = return_copy + self._ids = copy.deepcopy(metadata.static_data) + self._cur_idx = 0 + self._finished = False + + self._msg_bytes = metadata.nbytes + self._buffer = memoryview(bytearray(self._msg_bytes * batch_size)) + readonly_view = self._buffer.toreadonly() + dtype = " IDSToplevel | None: + """Process a single streaming IMAS message. + + This method returns None until a full batch is completed. Once ``batch_size`` + messages are processed a single IDSToplevel is returned, which contains all data + from the ``batch_size`` messages. + """ + if self._finished: + raise RuntimeError("") + nbytes = self._msg_bytes + if len(data) != nbytes: + raise ValueError( + f"Unexpected size of data: {len(data)}. Was expecting {nbytes}." + ) + # Update buffer + self._buffer[self._cur_idx * nbytes : (self._cur_idx + 1) * nbytes] = data + # Set scalar values + for aos, path_from_aos, idx in self._scalars: + aos[self._cur_idx][path_from_aos] = self._array_view[self._cur_idx, idx] + + # Bookkeeping + self._cur_idx += 1 + if self._cur_idx == self._batch_size: + # Completed a batch: + self._cur_idx = 0 + if self._return_copy: + return copy.deepcopy(self._ids) + return self._ids + # Batch is not finished yet + return None + + def finalize(self) -> IDSToplevel | None: + """Indicate that the final message is received and return any remaining data. + + Returns: + IDS with as many time slices as were remaining, or None in case no data was + remaining. + """ + self._finished = True + n_time = self._cur_idx + if n_time == 0: + return None # No data remaining, easy! + + # Resize dynamic quantities in the IDS: + for dyndata in self._metadata.dynamic_data: + ids_node = self._ids[dyndata.path] + if dyndata.path == "time" or ( + ids_node.metadata.ndim + and ids_node.metadata.coordinates[0].is_time_coordinate + ): + # Great! This IDS node is time-dependent by itself, and we just need to + # create a smaller view of the data: + ids_node.value = ids_node.value[:n_time] + else: + # This is a dynamic variable inside a time-dependent AoS: find that aos + aos = get_dynamic_aos_ancestor(ids_node) + # And resize it: + if len(aos) != n_time: + assert len(aos) == self._batch_size + aos.resize(n_time, keep=True) + + if self._return_copy: + return copy.deepcopy(self._ids) + return self._ids diff --git a/src/imas_streams/imas_utils.py b/src/imas_streams/imas_utils.py new file mode 100644 index 0000000..a8bd7e1 --- /dev/null +++ b/src/imas_streams/imas_utils.py @@ -0,0 +1,45 @@ +import imas # noqa: F401 -- module required in doctests +from imas.ids_primitive import IDSPrimitive +from imas.ids_struct_array import IDSStructArray +from imas.util import get_parent + + +def get_dynamic_aos_ancestor(ids_node: IDSPrimitive) -> IDSStructArray: + """Returns the dynamic Arrays of Structures ancestor for the provided node. + + Examples: + >>> cp = imas.IDSFactory("4.0.0").core_profiles() + >>> cp.profiles_1d.resize(1) + >>> get_dynamic_aos_ancestor(cp.profiles_1d[0].zeff) is cp.profiles_1d + True + >>> eq = imas.IDSFactory("4.0.0").equilibrium() + >>> eq.time_slice.resize(1) + >>> eq.time_slice[0].profiles_2d.resize(1) + >>> aos_ancestor = get_dynamic_aos_ancestor(eq.time_slice[0].profiles_2d[0].psi) + >>> aos_ancestor is eq.time_slice + True + """ + node = get_parent(ids_node) + while node is not None and ( + not isinstance(node, IDSStructArray) + or not node.metadata.coordinates[0].is_time_coordinate + ): + node = get_parent(node) + if node is None: + raise RuntimeError( + f"IDS node {ids_node} is not part of a time-dependent Array of Structures." + ) + return node + + +def get_path_from_aos(path: str, aos: IDSStructArray) -> str: + """Get the component of path relative to the provided Arrays of Structures ancestor. + + Examples: + >>> cp = imas.IDSFactory("4.0.0").core_profiles() + >>> get_path_from_aos("profiles_1d[0]/ion[1]/temperature", cp.profiles_1d) + 'ion[1]/temperature' + """ + path_parts = path.split("/") + aos_parts = aos.metadata.path.parts + return "/".join(path_parts[len(aos_parts) :]) diff --git a/src/imas_streams/kafka.py b/src/imas_streams/kafka.py index 19da882..78defd0 100644 --- a/src/imas_streams/kafka.py +++ b/src/imas_streams/kafka.py @@ -225,6 +225,9 @@ def stream(self, *, timeout=DEFAULT_KAFKA_CONSUMER_TIMEOUT) -> Iterator[Any]: Yields: Data produced by the StreamConsumer, e.g. an IDS for the StreamingIDSConsumer. + + For batching consumers (such as the BatchingIDSConsumer) the last yielded + value may contain fewer time slices than the batch size. """ try: while True: @@ -235,6 +238,10 @@ def stream(self, *, timeout=DEFAULT_KAFKA_CONSUMER_TIMEOUT) -> Iterator[Any]: self._settings.topic_name, timeout, ) + # Yield any remaining data + result = self._stream_consumer.finalize() + if result is not None: + yield result break if msg.error(): raise msg.error() diff --git a/src/imas_streams/protocols.py b/src/imas_streams/protocols.py index 197dadf..e732aae 100644 --- a/src/imas_streams/protocols.py +++ b/src/imas_streams/protocols.py @@ -8,3 +8,4 @@ class StreamConsumer(Protocol): def __init__(self, metadata: StreamingIMASMetadata, **kwargs) -> None: ... def process_message(self, data: bytes | bytearray) -> Any: ... + def finalize(self) -> Any: ... diff --git a/src/imas_streams/xarray_consumers.py b/src/imas_streams/xarray_consumers.py index e981f5d..efae1ee 100644 --- a/src/imas_streams/xarray_consumers.py +++ b/src/imas_streams/xarray_consumers.py @@ -161,3 +161,7 @@ def process_message(self, data: bytes | bytearray) -> "xarray.Dataset": self._msg_buffer[:] = data self._tensor_buffer[self._index_array] = self._array_view return self._dataset + + def finalize(self) -> None: + """Indicate that the final message is received and return any remaining data.""" + return None # No data remaining diff --git a/tests/test_ids_consumer.py b/tests/test_ids_consumer.py index 463f3fa..b003868 100644 --- a/tests/test_ids_consumer.py +++ b/tests/test_ids_consumer.py @@ -3,8 +3,9 @@ import imas import numpy as np import pytest +from imas.ids_defs import IDS_TIME_MODE_HOMOGENEOUS -from imas_streams import StreamingIDSConsumer, StreamingIMASMetadata +from imas_streams import BatchedIDSConsumer, StreamingIDSConsumer, StreamingIMASMetadata from imas_streams.metadata import DynamicData DD_VERSION = os.getenv("IMAS_VERSION", "4.0.0") @@ -14,7 +15,7 @@ def magnetics_metadata(): ids = imas.IDSFactory(DD_VERSION).new("magnetics") - ids.ids_properties.homogeneous_time = imas.ids_defs.IDS_TIME_MODE_HOMOGENEOUS + ids.ids_properties.homogeneous_time = IDS_TIME_MODE_HOMOGENEOUS ids.time = [0.0] ids.flux_loop.resize(5) @@ -107,3 +108,98 @@ def test_streaming_reader_nocopy(magnetics_metadata): assert ids is ids2 assert ids2.time[0] == 1.0 assert len(ids2.flux_loop) == 0 + + +@pytest.mark.parametrize("batch_size", [1, 2, 5, 7, 10, 13, 20]) +def test_batched_reader(magnetics_metadata, batch_size): + reader = BatchedIDSConsumer(magnetics_metadata, batch_size) + + def check_data(ids, expected_time): + assert np.array_equal(ids.time, expected_time) + assert np.array_equal(ids.flux_loop[0].flux.data, expected_time + 1) + assert np.array_equal(ids.flux_loop[1].flux.data, expected_time + 2) + assert np.array_equal(ids.flux_loop[2].flux.data, expected_time + 3) + assert np.array_equal(ids.flux_loop[3].flux.data, expected_time + 4) + assert np.array_equal(ids.flux_loop[4].flux.data, expected_time + 5) + assert np.array_equal(ids.flux_loop[0].voltage.data, expected_time + 6) + + # Pretend sending 20 messages + for i in range(20): + test_data = np.arange(len(magnetics_metadata.dynamic_data), dtype="