diff --git a/src/imas_streams/ids_consumers.py b/src/imas_streams/ids_consumers.py index bfcb9cf..b645069 100644 --- a/src/imas_streams/ids_consumers.py +++ b/src/imas_streams/ids_consumers.py @@ -3,7 +3,11 @@ import numpy as np from imas.ids_toplevel import IDSToplevel -from imas_streams.imas_utils import get_dynamic_aos_ancestor, get_path_from_aos +from imas_streams.imas_utils import ( + get_dynamic_aos_ancestor, + get_path_from_aos, + resize_and_return_dynamic_aos_ancestor, +) from imas_streams.metadata import StreamingIMASMetadata @@ -226,14 +230,8 @@ def __init__( # Verify that IMAS-Python keeps the view of our buffer assert ids_node.value is dataview else: - # This is a dynamic variable inside a time-dependent AoS: find that aos - aos = get_dynamic_aos_ancestor(ids_node) - # First ensure there's an entry for every batch_size time slices: - if len(aos) != batch_size: - assert len(aos) == 1 - aos.resize(batch_size, keep=True) - for i in range(1, batch_size): - aos[i] = copy.deepcopy(aos[0]) + # Dynamic variable inside a time-dependent AoS, find and resize the AoS: + aos = resize_and_return_dynamic_aos_ancestor(ids_node, batch_size) path_from_aos = get_path_from_aos(dyndata.path, aos) if ids_node.metadata.ndim == 0: # This is a scalar node @@ -242,6 +240,7 @@ def __init__( # Loop over all time slices and create views: for i in range(batch_size): dataview = self._array_view[i, idx : idx + n] + dataview = dataview.reshape(dyndata.shape) aos[i][path_from_aos].value = dataview # Verify that IMAS-Python keeps the view of our buffer assert aos[i][path_from_aos].value is dataview diff --git a/src/imas_streams/imas_utils.py b/src/imas_streams/imas_utils.py index a8bd7e1..5d58991 100644 --- a/src/imas_streams/imas_utils.py +++ b/src/imas_streams/imas_utils.py @@ -1,3 +1,5 @@ +import copy + import imas # noqa: F401 -- module required in doctests from imas.ids_primitive import IDSPrimitive from imas.ids_struct_array import IDSStructArray @@ -32,6 +34,25 @@ def get_dynamic_aos_ancestor(ids_node: IDSPrimitive) -> IDSStructArray: return node +def resize_and_return_dynamic_aos_ancestor( + ids_node: IDSPrimitive, batch_size: int +) -> IDSStructArray: + """Resize the dynamic Array of Structures ancestor to 'batch_size' elements and + return that AoS. + """ + # First find AoS ancestor + aos = get_dynamic_aos_ancestor(ids_node) + # Ensure there's an entry for all batch_size time slices: + if len(aos) != batch_size: + # We expect that the input IDS has a single time slice initially: + assert len(aos) == 1 + aos.resize(batch_size, keep=True) + # Copy any static data to all batched time slices: + for i in range(1, batch_size): + aos[i] = copy.deepcopy(aos[0]) + return aos + + def get_path_from_aos(path: str, aos: IDSStructArray) -> str: """Get the component of path relative to the provided Arrays of Structures ancestor. diff --git a/src/imas_streams/xarray_consumers.py b/src/imas_streams/xarray_consumers.py index efae1ee..b5e2a91 100644 --- a/src/imas_streams/xarray_consumers.py +++ b/src/imas_streams/xarray_consumers.py @@ -1,3 +1,4 @@ +import copy import itertools import re from collections.abc import Iterable @@ -7,6 +8,10 @@ from imas.util import to_xarray from imas_streams import StreamingIMASMetadata +from imas_streams.imas_utils import ( + get_path_from_aos, + resize_and_return_dynamic_aos_ancestor, +) if TYPE_CHECKING: import xarray @@ -71,21 +76,19 @@ class StreamingXArrayConsumer: ... """ - def __init__(self, metadata: StreamingIMASMetadata) -> None: + def __init__(self, metadata: StreamingIMASMetadata, *, batch_size: int = 1) -> None: """Consumer of streaming IMAS data which outputs xarray.Datasets. Args: metadata: Metadata of the IMAS data stream. """ + if batch_size < 1: + raise ValueError(f"Invalid batch size: {batch_size}") + self._metadata = metadata - ids = metadata.static_data - # Add entries for dynamic data in the IDS, so the IMAS-Python to_xarray will - # create the corresponding xarray.DataArrays for us - for dyndata in metadata.dynamic_data: - ids[dyndata.path].value = np.zeros(dyndata.shape) - self._dataset = to_xarray(ids) - # pandas is optional (through IMAS-Python), so import locally - from pandas import Index + self._batch_size = batch_size + # Setup dataset with batched time dimensions + self._dataset = self._prepare_dataset() # Setup array view buffer buffersize = 0 @@ -100,6 +103,9 @@ def __init__(self, metadata: StreamingIMASMetadata) -> None: self._tensor_buffer = np.ndarray(buffersize, dtype=dtype) readonly_view = memoryview(self._tensor_buffer).toreadonly() + # pandas is optional (through IMAS-Python), so import locally + from pandas import Index + # Setup array views tensor_idx = 0 to_update = {} @@ -125,24 +131,63 @@ def __init__(self, metadata: StreamingIMASMetadata) -> None: # Set up the index array for writing received messages into the tensor buffer: self._index_array = np.zeros(metadata.nbytes // 8, dtype=int) + self._time_offsets = np.zeros(metadata.nbytes // 8, dtype=int) idx = 0 for dyndata in metadata.dynamic_data: path, indices = path_to_xarray_name_and_indices(dyndata.path) # First check if this works before attempting to speed up array = self._dataset[path].data + time_dim = self._dataset[path].dims.index("time") + time_offset = array.strides[time_dim] base_address = np_address_of(array) + offset_in_array(array, indices) subarray = array[indices] for index in itertools.product(*[range(i) for i in dyndata.shape]): self._index_array[idx] = base_address + offset_in_array(subarray, index) + self._time_offsets[idx] = time_offset idx += 1 self._index_array -= np_address_of(self._tensor_buffer) - self._index_array //= 8 # go from bytes to indices in the numpy array + # Convert memory offsets in bytes to array offsets: + self._index_array //= 8 + self._time_offsets //= 8 # Message buffer and non-tensorized array view self._msg_buffer = memoryview(bytearray(metadata.nbytes)) self._array_view = np.frombuffer(self._msg_buffer, dtype=dtype) + # Current index in the batch + self._cur_idx = 0 + self._finished = False - def process_message(self, data: bytes | bytearray) -> "xarray.Dataset": + def _prepare_dataset(self) -> "xarray.Dataset": + """Prepare the IDS by setting all time-dependent quantities to the correct size. + + This takes the batch_size into account and resizes time-dependent quantities as + appropriate. + """ + ids = copy.deepcopy(self._metadata.static_data) + + # Add entries for dynamic data in the IDS, so the IMAS-Python to_xarray will + # create the corresponding xarray.DataArrays for us + for dyndata in self._metadata.dynamic_data: + assert dyndata.data_type == "f64" + ids_node = ids[dyndata.path] + assert ids_node.metadata.type.is_dynamic + if dyndata.path == "time" or ( + ids_node.metadata.ndim + and ids_node.metadata.coordinates[0].is_time_coordinate + ): + # Node has explicit time axis + batched_shape = (self._batch_size,) + dyndata.shape[1:] + ids_node.value = np.zeros(batched_shape) + else: + # Dynamic variable inside a time-dependent AoS, find and resize the AoS: + aos = resize_and_return_dynamic_aos_ancestor(ids_node, self._batch_size) + path_from_aos = get_path_from_aos(dyndata.path, aos) + for item in aos: + item[path_from_aos].value = np.zeros(dyndata.shape) + + return to_xarray(ids) + + def process_message(self, data: bytes | bytearray) -> "xarray.Dataset | None": """Process a dynamic data message and return the resulting xarray Dataset. Note that for efficiency we return the same dataset with each call. You should @@ -152,6 +197,8 @@ def process_message(self, data: bytes | bytearray) -> "xarray.Dataset": Args: data: Binary data corresponding to one time slice of dynamic data. """ + if self._finished: + raise RuntimeError("") if len(data) != len(self._msg_buffer): raise ValueError( f"Unexpected size of data: {len(data)}. " @@ -159,9 +206,26 @@ def process_message(self, data: bytes | bytearray) -> "xarray.Dataset": ) # Copy data to internal buffer, then write into the tensor view: self._msg_buffer[:] = data - self._tensor_buffer[self._index_array] = self._array_view - return self._dataset - - def finalize(self) -> None: + buffer_indices = self._index_array + if self._cur_idx > 0: + buffer_indices = buffer_indices + self._cur_idx * self._time_offsets + self._tensor_buffer[buffer_indices] = self._array_view + + # Bookkeeping + self._cur_idx += 1 + if self._cur_idx == self._batch_size: + # Completed a batch + self._cur_idx = 0 + return self._dataset + # Batch is not finished yet + return None + + def finalize(self) -> "xarray.Dataset | None": """Indicate that the final message is received and return any remaining data.""" - return None # No data remaining + self._finished = True + n_time = self._cur_idx + if n_time == 0: + return None # No data remaining, easy! + + # Let xarray handle resizing of the time dimension + return self._dataset.isel(time=slice(None, n_time)) diff --git a/tests/test_ids_stream.py b/tests/test_ids_stream.py index ac2a88e..a06d654 100644 --- a/tests/test_ids_stream.py +++ b/tests/test_ids_stream.py @@ -7,9 +7,17 @@ from imas_streams.xarray_consumers import StreamingXArrayConsumer +def get_training_db_entry(): + try: + # convert parameter added in IMAS-Python 2.3 + return imas.training.get_training_db_entry(convert=True) + except TypeError: + return imas.training.get_training_db_entry() + + @pytest.fixture(scope="module") def testdb(): - with imas.training.get_training_db_entry() as entry: + with get_training_db_entry() as entry: yield entry @@ -81,3 +89,26 @@ def test_stream_core_profiles_xarray(testdb): xrds_deserialized = consumer.process_message(data) # Check that both datasets are identical assert xrds_orig.equals(xrds_deserialized) + + +def test_stream_core_profiles_xarray_batched(testdb): + ids_name = "core_profiles" + times = testdb.get(ids_name, lazy=True).time.value + first_slice = testdb.get_slice(ids_name, times[0], CLOSEST_INTERP) + producer = StreamingIDSProducer(first_slice, static_paths=cp_static_paths) + consumer = StreamingXArrayConsumer(producer.metadata, batch_size=len(times)) + + for i, t in enumerate(times): + time_slice = testdb.get_slice(ids_name, t, CLOSEST_INTERP) + data = producer.create_message(time_slice) + + xrds_deserialized = consumer.process_message(data) + if i != len(times) - 1: + assert xrds_deserialized is None + + # Compare against full IDS + ids = testdb.get(ids_name) + xrds_orig = imas.util.to_xarray(ids) + # Check that the data is identical + assert xrds_deserialized is not None + assert xrds_orig.equals(xrds_deserialized) diff --git a/tests/test_xarray_consumer.py b/tests/test_xarray_consumer.py index 7846ac8..4ed07e2 100644 --- a/tests/test_xarray_consumer.py +++ b/tests/test_xarray_consumer.py @@ -46,6 +46,7 @@ def test_xarray_consumer(magnetics_metadata): data = np.arange(7, dtype="