Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/imas_streams/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from imas_streams.ids_consumers import StreamingIDSConsumer
from imas_streams.ids_consumers import BatchedIDSConsumer, StreamingIDSConsumer
from imas_streams.metadata import DynamicData, StreamingIMASMetadata
from imas_streams.producer import StreamingIDSProducer

__all__ = [
"DynamicData",
"BatchedIDSConsumer",
"StreamingIDSConsumer",
"StreamingIDSProducer",
"StreamingIMASMetadata",
Expand Down
168 changes: 167 additions & 1 deletion src/imas_streams/ids_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import numpy as np
from imas.ids_toplevel import IDSToplevel

from imas_streams.imas_utils import get_dynamic_aos_ancestor, get_path_from_aos
from imas_streams.metadata import StreamingIMASMetadata


class MessageProcessor:
"""Logic for building data arrays from streaming IMAS messages"""

def __init__(self, metadata: "StreamingIMASMetadata"):
def __init__(self, metadata: StreamingIMASMetadata):
self._metadata = metadata
self._msg_buffer = memoryview(bytearray(metadata.nbytes))
readonly_view = self._msg_buffer.toreadonly()
Expand Down Expand Up @@ -146,3 +147,168 @@ def process_message(self, data: bytes | bytearray) -> IDSToplevel:
if self._return_copy:
return copy.deepcopy(self._ids)
return self._ids

def finalize(self) -> None:
"""Indicate that the final message is received and return any remaining data."""
return None # No data remaining


class BatchedIDSConsumer:
"""Consumer of streaming IMAS data which outputs IDSs.

This streaming IMAS data consumer produces an IDS for every N time slices.

Example:
.. code-block:: python

# Create metadata (from JSON):
metadata = StreamingIMASMetadata.model_validate_json(json_metadata)
# Create reader
reader = BatchedIDSConsumer(metadata, batch_size=100)

# Consume dynamic data
for dynamic_data in dynamic_data_stream:
# process_message returns an IDS after every 100 (=batch_size) messages
# and None otherwise:
ids = reader.process_message(dynamic_data)
if ids is not None:
# Use IDS
...
"""

def __init__(
self, metadata: StreamingIMASMetadata, batch_size: int, *, return_copy=True
) -> None:
"""Consumer of streaming IMAS data which outputs IDSs.

Args:
metadata: Metadata of the IMAS data stream.
batch_size: Number of time slices to batch in each returned IDS.

Keyword Args:
return_copy: See the description in StreamingIDSConsumer
"""
if batch_size < 1:
raise ValueError(f"Invalid batch size: {batch_size}")

self._metadata = metadata
self._batch_size = batch_size
self._return_copy = return_copy
self._ids = copy.deepcopy(metadata.static_data)
self._cur_idx = 0
self._finished = False

self._msg_bytes = metadata.nbytes
self._buffer = memoryview(bytearray(self._msg_bytes * batch_size))
readonly_view = self._buffer.toreadonly()
dtype = "<f8" # little-endian IEEE-754 64-bits floating point number
self._array_view = np.frombuffer(readonly_view, dtype=dtype).reshape(
(batch_size, -1)
)

# Setup array views for batched data
self._scalars = []
idx = 0
for dyndata in self._metadata.dynamic_data:
assert dyndata.data_type == "f64"
ids_node = self._ids[dyndata.path]
assert ids_node.metadata.type.is_dynamic
n = np.prod(dyndata.shape, dtype=int)
if dyndata.path == "time" or (
ids_node.metadata.ndim
and ids_node.metadata.coordinates[0].is_time_coordinate
):
# Great! This IDS node is time-dependent by itself, and we can create a
# single view for it:
new_shape = (batch_size,) + dyndata.shape[1:]
dataview = self._array_view[:, idx : idx + n].reshape(new_shape)
ids_node.value = dataview
# Verify that IMAS-Python keeps the view of our buffer
assert ids_node.value is dataview
else:
# This is a dynamic variable inside a time-dependent AoS: find that aos
aos = get_dynamic_aos_ancestor(ids_node)
# First ensure there's an entry for every batch_size time slices:
if len(aos) != batch_size:
assert len(aos) == 1
aos.resize(batch_size, keep=True)
for i in range(1, batch_size):
aos[i] = copy.deepcopy(aos[0])
path_from_aos = get_path_from_aos(dyndata.path, aos)
if ids_node.metadata.ndim == 0:
# This is a scalar node
self._scalars.append((aos, path_from_aos, idx))
else:
# Loop over all time slices and create views:
for i in range(batch_size):
dataview = self._array_view[i, idx : idx + n]
aos[i][path_from_aos].value = dataview
# Verify that IMAS-Python keeps the view of our buffer
assert aos[i][path_from_aos].value is dataview

idx += n

def process_message(self, data: bytes | bytearray) -> IDSToplevel | None:
"""Process a single streaming IMAS message.

This method returns None until a full batch is completed. Once ``batch_size``
messages are processed a single IDSToplevel is returned, which contains all data
from the ``batch_size`` messages.
"""
if self._finished:
raise RuntimeError("")
nbytes = self._msg_bytes
if len(data) != nbytes:
raise ValueError(
f"Unexpected size of data: {len(data)}. Was expecting {nbytes}."
)
# Update buffer
self._buffer[self._cur_idx * nbytes : (self._cur_idx + 1) * nbytes] = data
# Set scalar values
for aos, path_from_aos, idx in self._scalars:
aos[self._cur_idx][path_from_aos] = self._array_view[self._cur_idx, idx]

# Bookkeeping
self._cur_idx += 1
if self._cur_idx == self._batch_size:
# Completed a batch:
self._cur_idx = 0
if self._return_copy:
return copy.deepcopy(self._ids)
return self._ids
# Batch is not finished yet
return None

def finalize(self) -> IDSToplevel | None:
"""Indicate that the final message is received and return any remaining data.

Returns:
IDS with as many time slices as were remaining, or None in case no data was
remaining.
"""
self._finished = True
n_time = self._cur_idx
if n_time == 0:
return None # No data remaining, easy!

# Resize dynamic quantities in the IDS:
for dyndata in self._metadata.dynamic_data:
ids_node = self._ids[dyndata.path]
if dyndata.path == "time" or (
ids_node.metadata.ndim
and ids_node.metadata.coordinates[0].is_time_coordinate
):
# Great! This IDS node is time-dependent by itself, and we just need to
# create a smaller view of the data:
ids_node.value = ids_node.value[:n_time]
else:
# This is a dynamic variable inside a time-dependent AoS: find that aos
aos = get_dynamic_aos_ancestor(ids_node)
# And resize it:
if len(aos) != n_time:
assert len(aos) == self._batch_size
aos.resize(n_time, keep=True)

if self._return_copy:
return copy.deepcopy(self._ids)
return self._ids
45 changes: 45 additions & 0 deletions src/imas_streams/imas_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import imas # noqa: F401 -- module required in doctests
from imas.ids_primitive import IDSPrimitive
from imas.ids_struct_array import IDSStructArray
from imas.util import get_parent


def get_dynamic_aos_ancestor(ids_node: IDSPrimitive) -> IDSStructArray:
"""Returns the dynamic Arrays of Structures ancestor for the provided node.

Examples:
>>> cp = imas.IDSFactory("4.0.0").core_profiles()
>>> cp.profiles_1d.resize(1)
>>> get_dynamic_aos_ancestor(cp.profiles_1d[0].zeff) is cp.profiles_1d
True
>>> eq = imas.IDSFactory("4.0.0").equilibrium()
>>> eq.time_slice.resize(1)
>>> eq.time_slice[0].profiles_2d.resize(1)
>>> aos_ancestor = get_dynamic_aos_ancestor(eq.time_slice[0].profiles_2d[0].psi)
>>> aos_ancestor is eq.time_slice
True
"""
node = get_parent(ids_node)
while node is not None and (
not isinstance(node, IDSStructArray)
or not node.metadata.coordinates[0].is_time_coordinate
):
node = get_parent(node)
if node is None:
raise RuntimeError(
f"IDS node {ids_node} is not part of a time-dependent Array of Structures."
)
return node


def get_path_from_aos(path: str, aos: IDSStructArray) -> str:
"""Get the component of path relative to the provided Arrays of Structures ancestor.

Examples:
>>> cp = imas.IDSFactory("4.0.0").core_profiles()
>>> get_path_from_aos("profiles_1d[0]/ion[1]/temperature", cp.profiles_1d)
'ion[1]/temperature'
"""
path_parts = path.split("/")
aos_parts = aos.metadata.path.parts
return "/".join(path_parts[len(aos_parts) :])
7 changes: 7 additions & 0 deletions src/imas_streams/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ def stream(self, *, timeout=DEFAULT_KAFKA_CONSUMER_TIMEOUT) -> Iterator[Any]:
Yields:
Data produced by the StreamConsumer, e.g. an IDS for the
StreamingIDSConsumer.

For batching consumers (such as the BatchingIDSConsumer) the last yielded
value may contain fewer time slices than the batch size.
"""
try:
while True:
Expand All @@ -235,6 +238,10 @@ def stream(self, *, timeout=DEFAULT_KAFKA_CONSUMER_TIMEOUT) -> Iterator[Any]:
self._settings.topic_name,
timeout,
)
# Yield any remaining data
result = self._stream_consumer.finalize()
if result is not None:
yield result
break
if msg.error():
raise msg.error()
Expand Down
1 change: 1 addition & 0 deletions src/imas_streams/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
class StreamConsumer(Protocol):
def __init__(self, metadata: StreamingIMASMetadata, **kwargs) -> None: ...
def process_message(self, data: bytes | bytearray) -> Any: ...
def finalize(self) -> Any: ...
4 changes: 4 additions & 0 deletions src/imas_streams/xarray_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,7 @@ def process_message(self, data: bytes | bytearray) -> "xarray.Dataset":
self._msg_buffer[:] = data
self._tensor_buffer[self._index_array] = self._array_view
return self._dataset

def finalize(self) -> None:
"""Indicate that the final message is received and return any remaining data."""
return None # No data remaining
100 changes: 98 additions & 2 deletions tests/test_ids_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import imas
import numpy as np
import pytest
from imas.ids_defs import IDS_TIME_MODE_HOMOGENEOUS

from imas_streams import StreamingIDSConsumer, StreamingIMASMetadata
from imas_streams import BatchedIDSConsumer, StreamingIDSConsumer, StreamingIMASMetadata
from imas_streams.metadata import DynamicData

DD_VERSION = os.getenv("IMAS_VERSION", "4.0.0")
Expand All @@ -14,7 +15,7 @@
def magnetics_metadata():
ids = imas.IDSFactory(DD_VERSION).new("magnetics")

ids.ids_properties.homogeneous_time = imas.ids_defs.IDS_TIME_MODE_HOMOGENEOUS
ids.ids_properties.homogeneous_time = IDS_TIME_MODE_HOMOGENEOUS
ids.time = [0.0]

ids.flux_loop.resize(5)
Expand Down Expand Up @@ -107,3 +108,98 @@ def test_streaming_reader_nocopy(magnetics_metadata):
assert ids is ids2
assert ids2.time[0] == 1.0
assert len(ids2.flux_loop) == 0


@pytest.mark.parametrize("batch_size", [1, 2, 5, 7, 10, 13, 20])
def test_batched_reader(magnetics_metadata, batch_size):
reader = BatchedIDSConsumer(magnetics_metadata, batch_size)

def check_data(ids, expected_time):
assert np.array_equal(ids.time, expected_time)
assert np.array_equal(ids.flux_loop[0].flux.data, expected_time + 1)
assert np.array_equal(ids.flux_loop[1].flux.data, expected_time + 2)
assert np.array_equal(ids.flux_loop[2].flux.data, expected_time + 3)
assert np.array_equal(ids.flux_loop[3].flux.data, expected_time + 4)
assert np.array_equal(ids.flux_loop[4].flux.data, expected_time + 5)
assert np.array_equal(ids.flux_loop[0].voltage.data, expected_time + 6)

# Pretend sending 20 messages
for i in range(20):
test_data = np.arange(len(magnetics_metadata.dynamic_data), dtype="<f8") + i
ids = reader.process_message(test_data.tobytes())
# Only expect a result after batch_size items are processed
if ids is None:
assert (i + 1) % batch_size != 0
continue
expected_time = np.arange(batch_size, dtype=float) + (i + 1 - batch_size)
check_data(ids, expected_time)

# Check that any remainders are handled
msg_remaining = 20 % batch_size
ids = reader.finalize()
if ids is None:
assert msg_remaining == 0
else:
assert len(ids.time) == msg_remaining != 0
expected_time = np.arange(msg_remaining, dtype=float) + 20 - msg_remaining
check_data(ids, expected_time)


@pytest.mark.parametrize("batch_size", [1, 2, 5, 7, 10, 13, 20])
def test_batched_reader_cp(batch_size):
ids = imas.IDSFactory(DD_VERSION).new("core_profiles")

ids.ids_properties.homogeneous_time = IDS_TIME_MODE_HOMOGENEOUS
ids.time = [0.0]
ids.profiles_1d.resize(1)
ids.profiles_1d[0].grid.rho_tor_norm = np.linspace(0, 1, 6)
ids.profiles_1d[0].ion.resize(1)
ids.profiles_1d[0].ion[0].z_ion = 1

metadata = StreamingIMASMetadata(
data_dictionary_version=DD_VERSION,
ids_name="magnetics",
static_data=ids,
dynamic_data=[
DynamicData(path="time", shape=(1,), data_type="f64"),
DynamicData(path="profiles_1d[0]/zeff", shape=(6,), data_type="f64"),
DynamicData(
path="profiles_1d[0]/ion[0]/density", shape=(6,), data_type="f64"
),
DynamicData(path="global_quantities/ip", shape=(1,), data_type="f64"),
],
)
reader = BatchedIDSConsumer(metadata, batch_size)

def check_data(ids, expected_time):
assert np.array_equal(ids.time, expected_time)
assert np.array_equal(ids.global_quantities.ip, expected_time + 13)
# Check dynamic AoS
assert len(ids.profiles_1d) == len(expected_time)
for j, p1d in enumerate(ids.profiles_1d):
assert np.array_equal(p1d.zeff, np.arange(6) + expected_time[j] + 1)
assert np.array_equal(
p1d.ion[0].density, np.arange(6) + expected_time[j] + 7
)
assert p1d.ion[0].z_ion == 1

# Pretend sending 20 messages
for i in range(20):
test_data = np.arange(14, dtype="<f8") + i
ids = reader.process_message(test_data.tobytes())
# Only expect a result after batch_size items are processed
if ids is None:
assert (i + 1) % batch_size != 0
continue
expected_time = np.arange(batch_size, dtype=float) + (i + 1 - batch_size)
check_data(ids, expected_time)

# Check that any remainders are handled
msg_remaining = 20 % batch_size
ids = reader.finalize()
if ids is None:
assert msg_remaining == 0
else:
assert len(ids.time) == msg_remaining != 0
expected_time = np.arange(msg_remaining, dtype=float) + 20 - msg_remaining
check_data(ids, expected_time)
Loading