Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion httomo/data/dataset_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import time
import h5py
from typing import List, Literal, Optional, Tuple, Union
from httomo.data.hdf._utils.reslice import reslice
from httomo.data.hdf._utils.reslice import reslice, reslice_memory_estimator
from httomo.data.padding import extrapolate_after, extrapolate_before
from httomo.runner.auxiliary_data import AuxiliaryData
from httomo.runner.dataset import DataSetBlock
Expand Down Expand Up @@ -286,6 +286,10 @@ def __init__(
start = time.perf_counter()
self._data = self._reslice(source.slicing_dim, slicing_dim, source_data)
end = time.perf_counter()
log_once(
f"reslice_memory_estimator: {reslice_memory_estimator(source_data.shape, source_data.dtype, source.slicing_dim, slicing_dim, self._comm)}",
level=logging.ERROR,
)
if slicing_dim == 1:
log_once(
f"Slicing axis change (reslice) from projection to sinogram took {(end - start):.9f}s.",
Expand Down
74 changes: 74 additions & 0 deletions httomo/data/hdf/_utils/reslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,77 @@ def reslice(

start_idx = 0 if comm.rank == 0 else split_indices[comm.rank - 1]
return new_data, next_slice_dim, start_idx


def reslice_memory_estimator(
data_shape: Tuple[int, int, int],
dtype: numpy.dtype,
current_slice_dim: int,
next_slice_dim: int,
comm: Comm,
) -> dict:
rank = comm.rank
nprocs = comm.size
itemsize = numpy.dtype(dtype).itemsize
input_size = numpy.prod(data_shape) * itemsize

split_sizes = []
length = data_shape[next_slice_dim]
split_indices = [round((length / nprocs) * r) for r in range(1, nprocs)]

prev_idx = 0
for i in range(nprocs):
next_idx = split_indices[i] if i < len(split_indices) else length
split_shape = list(data_shape)
split_shape[next_slice_dim] = next_idx - prev_idx
split_sizes.append(numpy.prod(split_shape) * itemsize)
prev_idx = next_idx

total_split_size = sum(split_sizes)

all_split_sizes = comm.allgather(split_sizes)
recv_sizes = [all_split_sizes[p][rank] for p in range(nprocs)]

output_shape = list(data_shape)
output_shape[current_slice_dim] = sum(
recv_sizes[p]
// (
itemsize
* numpy.prod([data_shape[d] for d in range(3) if d != next_slice_dim])
)
for p in range(nprocs)
)
output_size = numpy.prod(output_shape) * itemsize

max_send_buffer = max(split_sizes)
max_recv_buffer = max(recv_sizes)

from httomo.data.mpiutil import _mpi_max_elements

max_elements = _mpi_max_elements - 1
max_transfer_elements = max(
max(split_sizes) // itemsize, max(recv_sizes) // itemsize
)

needs_chunking = max_transfer_elements > max_elements

if needs_chunking:
chunk_overhead_send = max_send_buffer
chunk_overhead_recv = max_recv_buffer
else:
chunk_overhead_send = 0
chunk_overhead_recv = 0

peak_before_ring = input_size + total_split_size + output_size

peak_during_ring = (
peak_before_ring
+ max_send_buffer # Temporary send buffer
+ max_recv_buffer # Temporary recv buffer
+ chunk_overhead_send # Flattened send array (if chunking)
+ chunk_overhead_recv # Flattened recv array (if chunking)
)

peak_after_ring = output_size

return max(peak_before_ring, peak_during_ring, peak_after_ring) * 1.01
Loading