Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 56 additions & 37 deletions scallops/stitch/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,50 @@ def _fuse(
)


def _register_across_channels(
img: np.ndarray,
channel_reference: int | None = None,
channel_cross_correlation_upsample: int = 1,
channel_window: int = 1,
channel_filter_percentiles: tuple[float, float] | None = None,
):
ref_values = img[channel_reference]
img_dtype = ref_values.dtype
if channel_window > 1:
ref_values = _apply_window(ref_values, channel_window)
if channel_filter_percentiles is not None:
ref_values = _filter_percentiles(
ref_values,
q1=channel_filter_percentiles[0],
q2=channel_filter_percentiles[1],
)

for c in range(img.shape[0]):
if c != channel_reference:
moving = img[c]
if channel_window > 1:
moving = _apply_window(moving, channel_window)
if channel_filter_percentiles is not None:
moving = _filter_percentiles(
moving,
q1=channel_filter_percentiles[0],
q2=channel_filter_percentiles[1],
)

offset, _ = calc_best_shift(
moving,
ref_values,
upsample_factor=channel_cross_correlation_upsample,
overlap_min=0,
)
# offset, _, _ = phase_cross_correlation(moving, ref_values)

if not np.all(offset == 0.0):
# skimage SimilarityTransform has (x,y,[z]) convention
st = SimilarityTransform(translation=offset[::-1])
img[c] = warp(img[c], st, preserve_range=True).astype(img_dtype)


def _fuse_image(
image_paths: list[str] | zarr.Group,
image_attrs: dict[str, str | list[str]],
Expand Down Expand Up @@ -431,46 +475,11 @@ def _fuse_image(
if "z" in img.dims:
img = img.max(dim="z") if not isinstance(z_index, int) else img.isel(z=z_index)
img = img.values
if channel_reference is not None:
ref_values = img[channel_reference]
img_dtype = ref_values.dtype
if channel_window > 1:
ref_values = _apply_window(ref_values, channel_window)
if channel_filter_percentiles is not None:
ref_values = _filter_percentiles(
ref_values,
q1=channel_filter_percentiles[0],
q2=channel_filter_percentiles[1],
)

for c in range(img.shape[0]):
if c != channel_reference:
moving = img[c]
if channel_window > 1:
moving = _apply_window(moving, channel_window)
if channel_filter_percentiles is not None:
moving = _filter_percentiles(
moving,
q1=channel_filter_percentiles[0],
q2=channel_filter_percentiles[1],
)

offset, _ = calc_best_shift(
moving,
ref_values,
upsample_factor=channel_cross_correlation_upsample,
overlap_min=0,
)
# offset, _, _ = phase_cross_correlation(moving, ref_values)

if not np.all(offset == 0.0):
# skimage SimilarityTransform has (x,y,[z]) convention
st = SimilarityTransform(translation=offset[::-1])
img[c] = warp(img[c], st, preserve_range=True).astype(img_dtype)
if output_channels is not None:
if output_channels is not None and channel_reference is None:
img = img[output_channels]

# order: radial, illumination, crop
# order: radial, illumination, channel registration, crop
if radial_correction_k is not None:
img = radial_correct(img, radial_correction_k)
img = dtype_convert(img, target_dtype)
Expand All @@ -481,6 +490,16 @@ def _fuse_image(
if ffp is not None:
img /= ffp
img.clip(0, 1, out=img)
if channel_reference is not None:
_register_across_channels(
img=img,
channel_reference=channel_reference,
channel_cross_correlation_upsample=channel_cross_correlation_upsample,
channel_window=channel_window,
channel_filter_percentiles=channel_filter_percentiles,
)
if output_channels is not None:
img = img[output_channels]

img = _crop_image(img, crop_width)

Expand Down