Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 36 additions & 82 deletions httomo/runner/dataset_store_backing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Callable, List, ParamSpec, Tuple
from typing import List, Tuple

import numpy as np
from numpy.typing import DTypeLike
Expand Down Expand Up @@ -56,50 +56,6 @@ class DataSetStoreBacking(Enum):
File = 2


P = ParamSpec("P")


def _reduce_decorator_factory(
comm: MPI.Comm,
) -> Callable[[Callable[P, DataSetStoreBacking]], Callable[P, DataSetStoreBacking]]:
"""
Generate decorator for store-backing calculator function that will use the given MPI
communicator for the reduce operation.
"""

def reduce_decorator(
func: Callable[P, DataSetStoreBacking],
) -> Callable[P, DataSetStoreBacking]:
"""
Decorator for store-backing calculator function.
"""

def wrapper(*args: P.args, **kwargs: P.kwargs) -> DataSetStoreBacking:
"""
Perform store-backing calculation across all MPI processes and reduce.
"""
# reduce store backing enum variant across all processes - if any has
# `File` variant, all should use a file
send_buffer = np.zeros(1, dtype=bool)
recv_buffer = np.zeros(1, dtype=bool)
store_backing = func(*args, **kwargs)

if store_backing is DataSetStoreBacking.File:
send_buffer[0] = True

# do a logical or of all the enum variants across the processes
comm.Allreduce([send_buffer, MPI.BOOL], [recv_buffer, MPI.BOOL], MPI.LOR)

if bool(recv_buffer[0]) is True:
return DataSetStoreBacking.File

return DataSetStoreBacking.RAM

return wrapper

return reduce_decorator


def _non_last_section_in_pipeline(
memory_limit_bytes: int,
write_chunk_bytes: int,
Expand All @@ -117,19 +73,6 @@ def _non_last_section_in_pipeline(
return DataSetStoreBacking.RAM


def _last_section_in_pipeline(
memory_limit_bytes: int,
write_chunk_bytes: int,
) -> DataSetStoreBacking:
"""
Calculate backing of dataset store for last section in pipeline
"""
if memory_limit_bytes > 0 and write_chunk_bytes >= memory_limit_bytes:
return DataSetStoreBacking.File

return DataSetStoreBacking.RAM


def determine_store_backing(
comm: MPI.Comm,
sections: List[Section],
Expand All @@ -138,41 +81,52 @@ def determine_store_backing(
global_shape: Tuple[int, int, int],
section_idx: int,
) -> DataSetStoreBacking:
reduce_decorator = _reduce_decorator_factory(comm)
# Get chunk shape created by reader of section `n` (the current section) that will account
# for padding. This chunk shape is based on the chunk shape written by the writer of
# section `n - 1` (the previous section)
padded_input_chunk_shape = calculate_section_chunk_shape(
comm=comm,
global_shape=global_shape,
slicing_dim=_get_slicing_dim(sections[section_idx].pattern) - 1,
padding=determine_section_padding(sections[section_idx]),
)
padded_input_chunk_bytes = int(
np.prod(padded_input_chunk_shape) * np.dtype(dtype).itemsize
)

# Get chunk shape input to section
current_chunk_shape = calculate_section_chunk_shape(
# Get unpadded chunk shape input to current section (for calculation of bytes in output
# chunk for the current section)
input_chunk_shape = calculate_section_chunk_shape(
comm=comm,
global_shape=global_shape,
slicing_dim=_get_slicing_dim(sections[section_idx].pattern) - 1,
padding=(0, 0),
)

# Get the number of bytes in the input chunk to the section w/ potential modifications to
# the non-slicing dims
current_chunk_bytes = calculate_section_chunk_bytes(
chunk_shape=current_chunk_shape,
# the non-slicing dims, to then determine the number of bytes in the output chunk written
# by the current section
output_chunk_bytes = calculate_section_chunk_bytes(
chunk_shape=input_chunk_shape,
dtype=dtype,
section=sections[section_idx],
)

if section_idx == len(sections) - 1:
return reduce_decorator(_last_section_in_pipeline)(
memory_limit_bytes=memory_limit_bytes,
write_chunk_bytes=current_chunk_bytes,
)

# Get chunk shape created by reader of section `n+1`, that will add padding to the
# chunk shape written by the writer of section `n`
next_chunk_shape = calculate_section_chunk_shape(
comm=comm,
global_shape=global_shape,
slicing_dim=_get_slicing_dim(sections[section_idx + 1].pattern) - 1,
padding=determine_section_padding(sections[section_idx + 1]),
)
next_chunk_bytes = int(np.prod(next_chunk_shape) * np.dtype(dtype).itemsize)
return reduce_decorator(_non_last_section_in_pipeline)(
send_buffer = np.zeros(1, dtype=bool)
recv_buffer = np.zeros(1, dtype=bool)
store_backing = _non_last_section_in_pipeline(
memory_limit_bytes=memory_limit_bytes,
write_chunk_bytes=current_chunk_bytes,
read_chunk_bytes=next_chunk_bytes,
read_chunk_bytes=padded_input_chunk_bytes,
write_chunk_bytes=output_chunk_bytes,
)

if store_backing is DataSetStoreBacking.File:
send_buffer[0] = True

# do a logical OR of all the enum variants across the processes
comm.Allreduce([send_buffer, MPI.BOOL], [recv_buffer, MPI.BOOL], MPI.LOR)

if bool(recv_buffer[0]) is True:
return DataSetStoreBacking.File

return DataSetStoreBacking.RAM
17 changes: 8 additions & 9 deletions httomo/runner/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,19 +235,18 @@ def _setup_source_sink(self, section: Section, idx: int):
self.source.finalize()
self.source = new_source

store_backing = determine_store_backing(
comm=self.comm,
sections=self._sections,
memory_limit_bytes=self._memory_limit_bytes,
dtype=self.source.dtype,
global_shape=self.source.global_shape,
section_idx=idx,
)

if section.is_last:
# we don't need to store the results - this sink just discards it
self.sink = DummySink(slicing_dim_section)
else:
store_backing = determine_store_backing(
comm=self.comm,
sections=self._sections,
memory_limit_bytes=self._memory_limit_bytes,
dtype=self.source.dtype,
global_shape=self.source.global_shape,
section_idx=idx,
)
self.sink = DataSetStoreWriter(
slicing_dim_section,
self.comm,
Expand Down
Loading
Loading