Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions src/imas_streams/ids_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import numpy as np
from imas.ids_toplevel import IDSToplevel

from imas_streams.imas_utils import get_dynamic_aos_ancestor, get_path_from_aos
from imas_streams.imas_utils import (
get_dynamic_aos_ancestor,
get_path_from_aos,
resize_and_return_dynamic_aos_ancestor,
)
from imas_streams.metadata import StreamingIMASMetadata


Expand Down Expand Up @@ -226,14 +230,8 @@ def __init__(
# Verify that IMAS-Python keeps the view of our buffer
assert ids_node.value is dataview
else:
# This is a dynamic variable inside a time-dependent AoS: find that aos
aos = get_dynamic_aos_ancestor(ids_node)
# First ensure there's an entry for every batch_size time slices:
if len(aos) != batch_size:
assert len(aos) == 1
aos.resize(batch_size, keep=True)
for i in range(1, batch_size):
aos[i] = copy.deepcopy(aos[0])
# Dynamic variable inside a time-dependent AoS, find and resize the AoS:
aos = resize_and_return_dynamic_aos_ancestor(ids_node, batch_size)
path_from_aos = get_path_from_aos(dyndata.path, aos)
if ids_node.metadata.ndim == 0:
# This is a scalar node
Expand All @@ -242,6 +240,7 @@ def __init__(
# Loop over all time slices and create views:
for i in range(batch_size):
dataview = self._array_view[i, idx : idx + n]
dataview = dataview.reshape(dyndata.shape)
aos[i][path_from_aos].value = dataview
# Verify that IMAS-Python keeps the view of our buffer
assert aos[i][path_from_aos].value is dataview
Expand Down
21 changes: 21 additions & 0 deletions src/imas_streams/imas_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import imas # noqa: F401 -- module required in doctests
from imas.ids_primitive import IDSPrimitive
from imas.ids_struct_array import IDSStructArray
Expand Down Expand Up @@ -32,6 +34,25 @@ def get_dynamic_aos_ancestor(ids_node: IDSPrimitive) -> IDSStructArray:
return node


def resize_and_return_dynamic_aos_ancestor(
ids_node: IDSPrimitive, batch_size: int
) -> IDSStructArray:
"""Resize the dynamic Array of Structures ancestor to 'batch_size' elements and
return that AoS.
"""
# First find AoS ancestor
aos = get_dynamic_aos_ancestor(ids_node)
# Ensure there's an entry for all batch_size time slices:
if len(aos) != batch_size:
# We expect that the input IDS has a single time slice initially:
assert len(aos) == 1
aos.resize(batch_size, keep=True)
# Copy any static data to all batched time slices:
for i in range(1, batch_size):
aos[i] = copy.deepcopy(aos[0])
return aos


def get_path_from_aos(path: str, aos: IDSStructArray) -> str:
"""Get the component of path relative to the provided Arrays of Structures ancestor.

Expand Down
96 changes: 80 additions & 16 deletions src/imas_streams/xarray_consumers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import itertools
import re
from collections.abc import Iterable
Expand All @@ -7,6 +8,10 @@
from imas.util import to_xarray

from imas_streams import StreamingIMASMetadata
from imas_streams.imas_utils import (
get_path_from_aos,
resize_and_return_dynamic_aos_ancestor,
)

if TYPE_CHECKING:
import xarray
Expand Down Expand Up @@ -71,21 +76,19 @@ class StreamingXArrayConsumer:
...
"""

def __init__(self, metadata: StreamingIMASMetadata) -> None:
def __init__(self, metadata: StreamingIMASMetadata, *, batch_size: int = 1) -> None:
"""Consumer of streaming IMAS data which outputs xarray.Datasets.

Args:
metadata: Metadata of the IMAS data stream.
"""
if batch_size < 1:
raise ValueError(f"Invalid batch size: {batch_size}")

self._metadata = metadata
ids = metadata.static_data
# Add entries for dynamic data in the IDS, so the IMAS-Python to_xarray will
# create the corresponding xarray.DataArrays for us
for dyndata in metadata.dynamic_data:
ids[dyndata.path].value = np.zeros(dyndata.shape)
self._dataset = to_xarray(ids)
# pandas is optional (through IMAS-Python), so import locally
from pandas import Index
self._batch_size = batch_size
# Setup dataset with batched time dimensions
self._dataset = self._prepare_dataset()

# Setup array view buffer
buffersize = 0
Expand All @@ -100,6 +103,9 @@ def __init__(self, metadata: StreamingIMASMetadata) -> None:
self._tensor_buffer = np.ndarray(buffersize, dtype=dtype)
readonly_view = memoryview(self._tensor_buffer).toreadonly()

# pandas is optional (through IMAS-Python), so import locally
from pandas import Index

# Setup array views
tensor_idx = 0
to_update = {}
Expand All @@ -125,24 +131,63 @@ def __init__(self, metadata: StreamingIMASMetadata) -> None:

# Set up the index array for writing received messages into the tensor buffer:
self._index_array = np.zeros(metadata.nbytes // 8, dtype=int)
self._time_offsets = np.zeros(metadata.nbytes // 8, dtype=int)
idx = 0
for dyndata in metadata.dynamic_data:
path, indices = path_to_xarray_name_and_indices(dyndata.path)
# First check if this works before attempting to speed up
array = self._dataset[path].data
time_dim = self._dataset[path].dims.index("time")
time_offset = array.strides[time_dim]
base_address = np_address_of(array) + offset_in_array(array, indices)
subarray = array[indices]
for index in itertools.product(*[range(i) for i in dyndata.shape]):
self._index_array[idx] = base_address + offset_in_array(subarray, index)
self._time_offsets[idx] = time_offset
idx += 1
self._index_array -= np_address_of(self._tensor_buffer)
self._index_array //= 8 # go from bytes to indices in the numpy array
# Convert memory offsets in bytes to array offsets:
self._index_array //= 8
self._time_offsets //= 8

# Message buffer and non-tensorized array view
self._msg_buffer = memoryview(bytearray(metadata.nbytes))
self._array_view = np.frombuffer(self._msg_buffer, dtype=dtype)
# Current index in the batch
self._cur_idx = 0
self._finished = False

def process_message(self, data: bytes | bytearray) -> "xarray.Dataset":
def _prepare_dataset(self) -> "xarray.Dataset":
"""Prepare the IDS by setting all time-dependent quantities to the correct size.

This takes the batch_size into account and resizes time-dependent quantities as
appropriate.
"""
ids = copy.deepcopy(self._metadata.static_data)

# Add entries for dynamic data in the IDS, so the IMAS-Python to_xarray will
# create the corresponding xarray.DataArrays for us
for dyndata in self._metadata.dynamic_data:
assert dyndata.data_type == "f64"
ids_node = ids[dyndata.path]
assert ids_node.metadata.type.is_dynamic
if dyndata.path == "time" or (
ids_node.metadata.ndim
and ids_node.metadata.coordinates[0].is_time_coordinate
):
# Node has explicit time axis
batched_shape = (self._batch_size,) + dyndata.shape[1:]
ids_node.value = np.zeros(batched_shape)
else:
# Dynamic variable inside a time-dependent AoS, find and resize the AoS:
aos = resize_and_return_dynamic_aos_ancestor(ids_node, self._batch_size)
path_from_aos = get_path_from_aos(dyndata.path, aos)
for item in aos:
item[path_from_aos].value = np.zeros(dyndata.shape)

return to_xarray(ids)

def process_message(self, data: bytes | bytearray) -> "xarray.Dataset | None":
"""Process a dynamic data message and return the resulting xarray Dataset.

Note that for efficiency we return the same dataset with each call. You should
Expand All @@ -152,16 +197,35 @@ def process_message(self, data: bytes | bytearray) -> "xarray.Dataset":
Args:
data: Binary data corresponding to one time slice of dynamic data.
"""
if self._finished:
raise RuntimeError("")
if len(data) != len(self._msg_buffer):
raise ValueError(
f"Unexpected size of data: {len(data)}. "
f"Was expecting {len(self._msg_buffer)}."
)
# Copy data to internal buffer, then write into the tensor view:
self._msg_buffer[:] = data
self._tensor_buffer[self._index_array] = self._array_view
return self._dataset

def finalize(self) -> None:
buffer_indices = self._index_array
if self._cur_idx > 0:
buffer_indices = buffer_indices + self._cur_idx * self._time_offsets
self._tensor_buffer[buffer_indices] = self._array_view

# Bookkeeping
self._cur_idx += 1
if self._cur_idx == self._batch_size:
# Completed a batch
self._cur_idx = 0
return self._dataset
# Batch is not finished yet
return None

def finalize(self) -> "xarray.Dataset | None":
"""Indicate that the final message is received and return any remaining data."""
return None # No data remaining
self._finished = True
n_time = self._cur_idx
if n_time == 0:
return None # No data remaining, easy!

# Let xarray handle resizing of the time dimension
return self._dataset.isel(time=slice(None, n_time))
33 changes: 32 additions & 1 deletion tests/test_ids_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@
from imas_streams.xarray_consumers import StreamingXArrayConsumer


def get_training_db_entry():
try:
# convert parameter added in IMAS-Python 2.3
return imas.training.get_training_db_entry(convert=True)
except TypeError:
return imas.training.get_training_db_entry()


@pytest.fixture(scope="module")
def testdb():
with imas.training.get_training_db_entry() as entry:
with get_training_db_entry() as entry:
yield entry


Expand Down Expand Up @@ -81,3 +89,26 @@ def test_stream_core_profiles_xarray(testdb):
xrds_deserialized = consumer.process_message(data)
# Check that both datasets are identical
assert xrds_orig.equals(xrds_deserialized)


def test_stream_core_profiles_xarray_batched(testdb):
ids_name = "core_profiles"
times = testdb.get(ids_name, lazy=True).time.value
first_slice = testdb.get_slice(ids_name, times[0], CLOSEST_INTERP)
producer = StreamingIDSProducer(first_slice, static_paths=cp_static_paths)
consumer = StreamingXArrayConsumer(producer.metadata, batch_size=len(times))

for i, t in enumerate(times):
time_slice = testdb.get_slice(ids_name, t, CLOSEST_INTERP)
data = producer.create_message(time_slice)

xrds_deserialized = consumer.process_message(data)
if i != len(times) - 1:
assert xrds_deserialized is None

# Compare against full IDS
ids = testdb.get(ids_name)
xrds_orig = imas.util.to_xarray(ids)
# Check that the data is identical
assert xrds_deserialized is not None
assert xrds_orig.equals(xrds_deserialized)
40 changes: 40 additions & 0 deletions tests/test_xarray_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_xarray_consumer(magnetics_metadata):

data = np.arange(7, dtype="<f8")
dataset = consumer.process_message(data.tobytes())
assert dataset is not None

assert np.array_equal(dataset.time, [0.0])
assert np.array_equal(
Expand Down Expand Up @@ -82,6 +83,7 @@ def test_xarray_consumer_shuffled_aos(magnetics_metadata):

data = np.arange(8, dtype="<f8")
dataset = consumer.process_message(data.tobytes())
assert dataset is not None

assert np.array_equal(dataset.time, [0.0])
assert np.array_equal(
Expand All @@ -93,3 +95,41 @@ def test_xarray_consumer_shuffled_aos(magnetics_metadata):
[[np.nan], [1.0], [np.nan], [5.0], [np.nan]],
equal_nan=True,
)


@pytest.mark.parametrize("batch_size", [1, 2, 5, 7, 10, 13, 20])
def test_xarray_batched(magnetics_metadata, batch_size):
reader = StreamingXArrayConsumer(magnetics_metadata, batch_size=batch_size)

def check_data(dataset, expected_time):
assert np.array_equal(dataset.time, expected_time)
assert np.array_equal(
dataset["flux_loop.flux.data"],
[[1.0], [2.0], [3.0], [4.0], [5.0]] + expected_time[None, :],
)
assert np.array_equal(
dataset["flux_loop.voltage.data"],
[[6.0]] + [[np.nan]] * 4 + expected_time[None, :],
equal_nan=True,
)

# Pretend sending 20 messages
for i in range(20):
test_data = np.arange(len(magnetics_metadata.dynamic_data), dtype="<f8") + i
dataset = reader.process_message(test_data.tobytes())
# Only expect a result after batch_size items are processed
if dataset is None:
assert (i + 1) % batch_size != 0
continue
expected_time = np.arange(batch_size, dtype=float) + (i + 1 - batch_size)
check_data(dataset, expected_time)

# Check that any remainders are handled
msg_remaining = 20 % batch_size
dataset = reader.finalize()
if dataset is None:
assert msg_remaining == 0
else:
assert len(dataset.time) == msg_remaining != 0
expected_time = np.arange(msg_remaining, dtype=float) + 20 - msg_remaining
check_data(dataset, expected_time)
Loading