From 6fad78140154a36bb701dd9b4164de63644690f0 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Tue, 23 Jun 2026 06:53:29 -0400 Subject: [PATCH] Moved registration across channels --- scallops/stitch/fuse.py | 93 +++++++++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 37 deletions(-) diff --git a/scallops/stitch/fuse.py b/scallops/stitch/fuse.py index 2630303..f106d47 100644 --- a/scallops/stitch/fuse.py +++ b/scallops/stitch/fuse.py @@ -393,6 +393,50 @@ def _fuse( ) +def _register_across_channels( + img: np.ndarray, + channel_reference: int | None = None, + channel_cross_correlation_upsample: int = 1, + channel_window: int = 1, + channel_filter_percentiles: tuple[float, float] | None = None, +): + ref_values = img[channel_reference] + img_dtype = ref_values.dtype + if channel_window > 1: + ref_values = _apply_window(ref_values, channel_window) + if channel_filter_percentiles is not None: + ref_values = _filter_percentiles( + ref_values, + q1=channel_filter_percentiles[0], + q2=channel_filter_percentiles[1], + ) + + for c in range(img.shape[0]): + if c != channel_reference: + moving = img[c] + if channel_window > 1: + moving = _apply_window(moving, channel_window) + if channel_filter_percentiles is not None: + moving = _filter_percentiles( + moving, + q1=channel_filter_percentiles[0], + q2=channel_filter_percentiles[1], + ) + + offset, _ = calc_best_shift( + moving, + ref_values, + upsample_factor=channel_cross_correlation_upsample, + overlap_min=0, + ) + # offset, _, _ = phase_cross_correlation(moving, ref_values) + + if not np.all(offset == 0.0): + # skimage SimilarityTransform has (x,y,[z]) convention + st = SimilarityTransform(translation=offset[::-1]) + img[c] = warp(img[c], st, preserve_range=True).astype(img_dtype) + + def _fuse_image( image_paths: list[str] | zarr.Group, image_attrs: dict[str, str | list[str]], @@ -431,46 +475,11 @@ def _fuse_image( if "z" in img.dims: img = img.max(dim="z") if not isinstance(z_index, int) else img.isel(z=z_index) img = img.values - if channel_reference is not None: - ref_values = img[channel_reference] - img_dtype = ref_values.dtype - if channel_window > 1: - ref_values = _apply_window(ref_values, channel_window) - if channel_filter_percentiles is not None: - ref_values = _filter_percentiles( - ref_values, - q1=channel_filter_percentiles[0], - q2=channel_filter_percentiles[1], - ) - for c in range(img.shape[0]): - if c != channel_reference: - moving = img[c] - if channel_window > 1: - moving = _apply_window(moving, channel_window) - if channel_filter_percentiles is not None: - moving = _filter_percentiles( - moving, - q1=channel_filter_percentiles[0], - q2=channel_filter_percentiles[1], - ) - - offset, _ = calc_best_shift( - moving, - ref_values, - upsample_factor=channel_cross_correlation_upsample, - overlap_min=0, - ) - # offset, _, _ = phase_cross_correlation(moving, ref_values) - - if not np.all(offset == 0.0): - # skimage SimilarityTransform has (x,y,[z]) convention - st = SimilarityTransform(translation=offset[::-1]) - img[c] = warp(img[c], st, preserve_range=True).astype(img_dtype) - if output_channels is not None: + if output_channels is not None and channel_reference is None: img = img[output_channels] - # order: radial, illumination, crop + # order: radial, illumination, channel registration, crop if radial_correction_k is not None: img = radial_correct(img, radial_correction_k) img = dtype_convert(img, target_dtype) @@ -481,6 +490,16 @@ def _fuse_image( if ffp is not None: img /= ffp img.clip(0, 1, out=img) + if channel_reference is not None: + _register_across_channels( + img=img, + channel_reference=channel_reference, + channel_cross_correlation_upsample=channel_cross_correlation_upsample, + channel_window=channel_window, + channel_filter_percentiles=channel_filter_percentiles, + ) + if output_channels is not None: + img = img[output_channels] img = _crop_image(img, crop_width)