From b5c2a812aeba3863b09709a2baa25ef30e9e865e Mon Sep 17 00:00:00 2001 From: Frans Irgolitsch Date: Fri, 1 May 2026 12:24:52 -0400 Subject: [PATCH] fix_galvo_shift_zarr: per-tile detect + threaded column strips + --skip_tiles for manual overrides --- scripts/linum_fix_galvo_shift_zarr.py | 426 ++++++++++++++++---------- 1 file changed, 268 insertions(+), 158 deletions(-) diff --git a/scripts/linum_fix_galvo_shift_zarr.py b/scripts/linum_fix_galvo_shift_zarr.py index fc9f50db..de683a62 100644 --- a/scripts/linum_fix_galvo_shift_zarr.py +++ b/scripts/linum_fix_galvo_shift_zarr.py @@ -19,8 +19,11 @@ projection, and calling the same dark-band detector used for raw tiles. The fix per chunk uses a circular roll (``np.roll``) identical to the raw-tile -fix, moving the dark band to the end of the tile's A-line range. Those edge -pixels are then linearly interpolated from the adjoining valid columns. +fix (``linumpy.geometry.galvo.fix_galvo_shift``), moving the dark galvo-return +band to the end of the tile's A-line range. No interpolation is performed -- +the galvo-return columns are valid data once rolled to the right edge, and the +downstream pipeline crops them away. Pass ``--use_gpu`` to run the per-chunk +roll on a CuPy device through ``linumpy.gpu.corrections.fix_galvo_shift``. ``--mode undo`` reverses a previously applied fix by rolling each chunk in the opposite direction. Use this when the pipeline incorrectly applied a galvo fix @@ -42,7 +45,7 @@ linum_fix_galvo_shift_zarr.py mosaic_grid_3d_z47.ome.zarr fixed_z47.ome.zarr \\ --band_start 440 --band_width 40 -Undo an incorrectly applied fix (shift value from pipeline log or slice_config):: +Undo an incorrectly applied fix (shift value from the pipeline log):: linum_fix_galvo_shift_zarr.py mosaic_grid_3d_z50.ome.zarr fixed_z50.ome.zarr \\ --mode undo --shift 60 @@ -57,16 +60,16 @@ import linumpy.config.threads # noqa: F401 import argparse +from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Any import numpy as np from tqdm.auto import tqdm from linumpy.cli.args import add_overwrite_arg, assert_output_exists -from linumpy.geometry.galvo import detect_galvo_band_in_tile, detect_galvo_shift +from linumpy.geometry.galvo import detect_galvo_band_in_tile, detect_galvo_shift, fix_galvo_shift from linumpy.io import slice_config as slice_config_io -from linumpy.io.zarr import OmeZarrWriter +from linumpy.io.zarr import OmeZarrWriter, read_omezarr def _build_arg_parser() -> argparse.ArgumentParser: @@ -87,11 +90,14 @@ def _build_arg_parser() -> argparse.ArgumentParser: detect_group.add_argument( "--n_extra", type=int, - default=None, - help="Number of galvo-return pixels (n_extra from acquisition " - "metadata). When provided, uses the same gradient-pair " - "detector as the pipeline for reliable detection. " - "Find this in the tile info.txt files or Nextflow config.", + default=40, + help="Number of galvo-return pixels (the ``n_extra`` field from acquisition metadata, " + "typically 40). Enables the gradient-pair detector used by the pipeline; " + "strongly recommended for reliable detection. " + "Note: the assembled mosaic tile width is already cropped to ``n_alines`` (the trailing " + "n_extra guard columns are stripped during pre-processing), but when the galvo fix was " + "missed the dark band still sits inside the kept tile range and is still ~n_extra pixels " + "wide -- so this value remains the correct one to pass.", ) detect_group.add_argument( "--band_start", @@ -114,18 +120,28 @@ def _build_arg_parser() -> argparse.ArgumentParser: default=None, help="Explicit roll shift for --mode undo. Equals the shift that was applied during pipeline creation.", ) - detect_group.add_argument( - "--detection_level", - type=int, - default=1, - help="Pyramid level used for auto-detection (0=full res). Default: 1 (2x downsampled for speed).", - ) detect_group.add_argument( "--min_confidence", type=float, - default=0.2, + default=0.5, help="Minimum detection confidence to proceed with fix in auto mode [%(default)s].", ) + detect_group.add_argument( + "--min_tile_signal", + type=float, + default=5.0, + help="Mean intensity threshold below which a tile is treated as background and " + "left UNCHANGED (no roll applied) [%(default)s]. Detection on dark tiles is " + "unreliable and produces visible displacement artefacts in the output.", + ) + detect_group.add_argument( + "--skip_tiles", + default="", + help="Semicolon-separated list of 'kx,ky' tile coordinates to leave UNCHANGED " + "(no roll). Use to manually patch a small set of tiles where the " + "auto-detected shift wraps tissue across the tile boundary. " + "Example: '13,4;13,8;3,3'.", + ) config_group = p.add_argument_group("Slice config update") config_group.add_argument( @@ -141,12 +157,6 @@ def _build_arg_parser() -> argparse.ArgumentParser: metavar="OUT_PNG", help="Save a before/after comparison PNG after fixing. Uses the same 3-panel XY/XZ/YZ layout as the pipeline preview.", ) - preview_group.add_argument( - "--preview_level", - type=int, - default=2, - help="Pyramid level used for the preview (0=full res). Default: 2 (4x downsampled, faster). ", - ) preview_group.add_argument("--cmap", default="magma", help="Colormap for the preview [%(default)s].") scan_group = p.add_argument_group( @@ -162,10 +172,23 @@ def _build_arg_parser() -> argparse.ArgumentParser: type=int, metavar=("START", "STOP", "STEP"), default=None, - help="Range of band_start values to try, in level-0 pixels. E.g. --scan_range 50 250 10", + help="Range of band_start values to try, in pixels. E.g. --scan_range 50 250 10", ) p.add_argument("-v", "--verbose", action="store_true", help="Print per-chunk detection results.") + p.add_argument( + "--use_gpu", + action="store_true", + help="Run the per-strip roll on a CuPy device via linumpy.gpu.corrections.fix_galvo_shift. " + "Detection always runs on CPU; only useful when zarr I/O is not the bottleneck.", + ) + p.add_argument( + "--workers", + type=int, + default=4, + help="Number of threads pipelining read/roll/write across tile columns [%(default)s]. " + "Each worker holds one tile-column strip in memory (~chunk_x * ny * nz bytes).", + ) add_overwrite_arg(p) return p @@ -179,7 +202,6 @@ def _generate_comparison_preview( before_path: Path, after_path: Path, out_png: Path, - level: int = 2, cmap: str = "magma", band_start: int | None = None, band_width: int | None = None, @@ -198,25 +220,22 @@ def _generate_comparison_preview( OME-Zarr directories to compare. out_png : Path Output PNG file path. - level : int - Pyramid level to read (higher = faster, lower res). Clamped - to the number of available levels. cmap : str Matplotlib colourmap. band_start : int or None - Start column of the galvo band in level-0 pixels (optional overlay). + Start column of the galvo band in pixels (optional overlay). band_width : int or None - Width of the galvo band in level-0 pixels (optional overlay). + Width of the galvo band in pixels (optional overlay). chunk_x : int or None - Tile chunk width in level-0 pixels (optional overlay). + Tile chunk width in pixels (optional overlay). """ import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt - def _read_panels(zarr_path: Path, level: int) -> Any: - arr, _, _actual, _ = _open_level(zarr_path, level) + def _read_panels(zarr_path: Path) -> tuple: + arr, _ = read_omezarr(zarr_path) vol = np.asarray(arr, dtype=np.float32) # Pick the Z slice with the highest mean signal so tissue is always visible. z_means = vol.mean(axis=(1, 2)) @@ -232,10 +251,10 @@ def _read_panels(zarr_path: Path, level: int) -> Any: yz = np.array(vol[:, :, y])[::-1] return xy, xz, yz - print(f"Reading before zarr for preview (level {level}) ...") - before_panels = _read_panels(before_path, level) - print(f"Reading after zarr for preview (level {level}) ...") - after_panels = _read_panels(after_path, level) + print("Reading before zarr for preview ...") + before_panels = _read_panels(before_path) + print("Reading after zarr for preview ...") + after_panels = _read_panels(after_path) # Shared colour limits from the after volume (cleaner signal). all_after = np.concatenate([p.ravel() for p in after_panels]) @@ -313,38 +332,11 @@ def _read_panels(zarr_path: Path, level: int) -> Any: # --------------------------------------------------------------------------- -def _open_level(zarr_root: Path, level: int) -> Any: - """Open a specific pyramid level from an OME-Zarr, returning (zarr_array, res).""" - import zarr - from ome_zarr.io import parse_url - from ome_zarr.reader import Multiscales, Reader - - location = parse_url(str(zarr_root)) - if location is None: - raise FileNotFoundError(f"Cannot open as OME-Zarr: {zarr_root}") - reader = Reader(location) - nodes = list(reader()) - image_node = nodes[0] - multiscale = next(s for s in image_node.specs if isinstance(s, Multiscales)) - - # Clamp to available levels - actual_level = min(level, len(multiscale.datasets) - 1) - arr = zarr.open_array(zarr_root / multiscale.datasets[actual_level], mode="r") - - coord_transforms = image_node.metadata["coordinateTransformations"][0] - res = [1.0] * len(arr.shape) - for tr in coord_transforms: - if tr["type"] == "scale": - res = tr["scale"] - break - - return arr, res, actual_level, multiscale - - -def _auto_detect(zarr_root: Path, detection_level: int, n_extra: int | None = None, verbose: bool = False) -> tuple: - """Sample representative chunks and return (band_start, band_width, confidence). +def _auto_detect(zarr_root: Path, n_extra: int | None = None, verbose: bool = False) -> tuple: + """Sample representative chunks and return ``(band_start, band_width, confidence)``. - band_start and band_width are expressed in level-0 (full-resolution) pixels. + All values are in pixels. Mosaic-grid OME-Zarrs are written without a + pyramid, so detection always runs at full resolution. When *n_extra* is provided the same gradient-pair detector used by the pipeline (``detect_galvo_shift``) is applied to each chunk AIP -- this is @@ -355,24 +347,18 @@ def _auto_detect(zarr_root: Path, detection_level: int, n_extra: int | None = No ---------- zarr_root : Path Path to the OME-Zarr root directory. - detection_level : int - Pyramid level to use for detection (higher = faster, lower res). n_extra : int or None Number of galvo-return pixels from acquisition metadata (the ``n_extra`` field in info.txt / Nextflow config). Strongly recommended. verbose : bool If True, print per-chunk detection details. """ - det_arr, _, actual_level, _ = _open_level(zarr_root, detection_level) - scale_factor = 2**actual_level # ratio between detection level and level 0 + arr, _ = read_omezarr(zarr_root) - chunk_x = det_arr.chunks[1] - chunk_y = det_arr.chunks[2] - n_cx = det_arr.shape[1] // chunk_x - n_cy = det_arr.shape[2] // chunk_y - - # n_extra at the downsampled level - n_extra_ds = round(n_extra / scale_factor) if n_extra else None + chunk_x = arr.chunks[1] + chunk_y = arr.chunks[2] + n_cx = arr.shape[1] // chunk_x + n_cy = arr.shape[2] // chunk_y # Sample a spread of chunks from the central region (more likely tissue). cx_lo = max(0, n_cx // 4) @@ -389,7 +375,7 @@ def _auto_detect(zarr_root: Path, detection_level: int, n_extra: int | None = No ys = cy_mid * chunk_y ye = ys + chunk_y - chunk = np.asarray(det_arr[:, xs:xe, ys:ye], dtype=np.float32) + chunk = np.asarray(arr[:, xs:xe, ys:ye], dtype=np.float32) if float(chunk.mean()) < 5.0: if verbose: print(f" Chunk ({cx}, {cy_mid}): skipped (low signal mean={chunk.mean():.1f})") @@ -397,27 +383,25 @@ def _auto_detect(zarr_root: Path, detection_level: int, n_extra: int | None = No tile_aip = chunk.mean(axis=0) # (chunk_x, chunk_y) - if n_extra_ds: + if n_extra: # Use the proven gradient-pair detector from the pipeline. - # detect_galvo_shift returns (shift, confidence) where - # so band_start = chunk_x - shift - n_extra - shift_ds, conf = detect_galvo_shift(tile_aip, n_pixel_return=n_extra_ds) - bs_ds = chunk_x - shift_ds - n_extra_ds - bw_ds = n_extra_ds + # detect_galvo_shift returns (shift, confidence); the band sits at + # band_start = chunk_x - shift - n_extra after the implied roll. + shift, conf = detect_galvo_shift(tile_aip, n_pixel_return=n_extra) + bs = chunk_x - shift - n_extra + bw = n_extra else: # Fallback: threshold-based detector (less reliable) - bs_ds, bw_ds, conf = detect_galvo_band_in_tile(tile_aip) + bs, bw, conf = detect_galvo_band_in_tile(tile_aip) if verbose: - bs_l0 = round(bs_ds * scale_factor) - bw_l0 = round(bw_ds * scale_factor) print( f" Chunk ({cx:3d}, {cy_mid}): " - f"band_start={bs_l0:4d}px band_width={bw_l0:3d}px " - f"confidence={conf:.3f}" + (" [gradient-pair]" if n_extra_ds else " [threshold fallback]") + f"band_start={bs:4d}px band_width={bw:3d}px " + f"confidence={conf:.3f}" + (" [gradient-pair]" if n_extra else " [threshold fallback]") ) - detections.append((bs_ds, bw_ds, conf)) + detections.append((bs, bw, conf)) if not detections: return 0, 0, 0.0 @@ -448,11 +432,7 @@ def _auto_detect(zarr_root: Path, detection_level: int, n_extra: int | None = No f"±{tol:.0f}px → confidence penalty factor {consistency**0.5:.3f}" ) - # Scale back to level-0 pixels. - band_start_l0 = round(med_start * scale_factor) - band_width_l0 = round(med_width * scale_factor) - - return band_start_l0, band_width_l0, best_conf + return round(med_start), round(med_width), best_conf # --------------------------------------------------------------------------- @@ -467,7 +447,6 @@ def _scan_band_start( scan_stop: int, scan_step: int, out_png: Path, - level: int = 1, cmap: str = "magma", ) -> None: """Sweep *band_start* over a range and save a contact-sheet PNG. @@ -481,13 +460,11 @@ def _scan_band_start( zarr_root : Path Path to the OME-Zarr root directory to scan. band_width : int - Width of the dark band in level-0 pixels (typically ``n_extra``). + Width of the dark band in pixels (typically ``n_extra``). scan_start, scan_stop, scan_step : int - Range in level-0 pixels (Python-style: *scan_stop* is exclusive). + Range in pixels (Python-style: *scan_stop* is exclusive). out_png : Path Output contact-sheet PNG. - level : int - Pyramid level to use for speed (images are downsampled). cmap : str Matplotlib colourmap name. """ @@ -496,18 +473,14 @@ def _scan_band_start( matplotlib.use("Agg") import matplotlib.pyplot as plt - arr, _, actual_level, _ = _open_level(zarr_root, level) - scale_factor = 2**actual_level + arr, _ = read_omezarr(zarr_root) chunk_x = arr.chunks[1] chunk_y = arr.chunks[2] n_cx = arr.shape[1] // chunk_x n_cy = arr.shape[2] // chunk_y - # Scale level-0 parameters to the detection level. - bw_ds = max(1, round(band_width / scale_factor)) - start_ds = max(0, round(scan_start / scale_factor)) - stop_ds = round(scan_stop / scale_factor) - step_ds = max(1, round(scan_step / scale_factor)) + bw = max(1, band_width) + step = max(1, scan_step) # Sample a spread of central tiles. cx_lo = max(0, n_cx // 4) @@ -533,8 +506,8 @@ def _scan_band_start( vmin = float(np.percentile(avg_tile, 0.5)) vmax = float(np.percentile(avg_tile, 99.5)) - candidates_ds = list(range(start_ds, stop_ds, step_ds)) - n_cand = len(candidates_ds) + candidates = list(range(scan_start, scan_stop, step)) + n_cand = len(candidates) n_cols = min(8, n_cand + 1) n_rows = (n_cand + 1 + n_cols - 1) // n_cols @@ -547,21 +520,19 @@ def _scan_band_start( axes_flat[0].set_title("ORIGINAL", color="white", fontsize=8) axes_flat[0].set_axis_off() - for i, bs_ds in enumerate(candidates_ds): - bs_l0 = round(bs_ds * scale_factor) - roll = chunk_x - bs_ds - bw_ds + for i, bs in enumerate(candidates): + roll = chunk_x - bs - bw fixed = np.roll(avg_tile, roll, axis=0) ax = axes_flat[i + 1] ax.imshow(fixed.T, cmap=cmap, vmin=vmin, vmax=vmax, aspect="auto", origin="lower") - roll_l0 = round(roll * scale_factor) - ax.set_title(f"bs={bs_l0} r={roll_l0}", color="white", fontsize=7) + ax.set_title(f"bs={bs} r={roll}", color="white", fontsize=7) ax.set_axis_off() for j in range(n_cand + 1, len(axes_flat)): axes_flat[j].set_visible(False) fig.suptitle( - f"band_start scan | band_width={band_width}px | pyramid level {actual_level} ({scale_factor}x downsampled)", + f"band_start scan | band_width={band_width}px", color="white", fontsize=10, ) @@ -570,8 +541,8 @@ def _scan_band_start( plt.close(fig) print(f"Scan contact sheet saved → {out_png}") - print(f" {n_cand} candidates in level-0 range [{scan_start}:{scan_stop}:{scan_step}]px") - print(" Title format: bs= r= (level-0 px)") + print(f" {n_cand} candidates in range [{scan_start}:{scan_stop}:{scan_step}]px") + print(" Title format: bs= r= (px)") # --------------------------------------------------------------------------- @@ -579,10 +550,37 @@ def _scan_band_start( # --------------------------------------------------------------------------- +def _parse_skip_tiles(spec: str) -> frozenset[tuple[int, int]]: + """Parse a 'kx,ky;kx,ky' string into a set of tile coords.""" + if not spec: + return frozenset() + out: set[tuple[int, int]] = set() + for chunk in spec.split(";"): + chunk = chunk.strip() + if not chunk: + continue + kx_str, ky_str = chunk.split(",") + out.add((int(kx_str), int(ky_str))) + return frozenset(out) + + def _apply_fix( - zarr_root: Path, output_path: Path, band_start: int, band_width: int, mode: str, undo_shift: int, _verbose: bool = False + zarr_root: Path, + output_path: Path, + band_start: int, + band_width: int, + mode: str, + undo_shift: int, + overwrite: bool = True, + use_gpu: bool = False, + workers: int = 4, + n_extra: int | None = None, + min_confidence: float = 0.5, + min_tile_signal: float = 5.0, + skip_tiles: frozenset[tuple[int, int]] = frozenset(), + _verbose: bool = False, ) -> None: - """Write a corrected OME-Zarr, processing each level-0 chunk individually. + """Write a corrected OME-Zarr, processing each chunk individually. **fix mode**: The galvo desynchronisation means A-lines are out of order within each tile chunk. A single circular roll by ``chunk_x - band_start - band_width`` @@ -606,9 +604,30 @@ def _apply_fix( ``'fix'`` or ``'undo'``. undo_shift : int The roll shift that was applied by the pipeline (undo mode). + overwrite : bool + Overwrite *output_path* if it already exists. + use_gpu : bool + Run the per-strip roll on a CuPy device via ``linumpy.gpu.corrections``. + workers : int + Number of threads pipelining read → roll → write across tile columns. + Each worker holds one tile-column strip in memory. + n_extra : int or None + If set, per-tile detection uses the gradient-pair detector + (``detect_galvo_shift``) with this guard width. Same value as the + global detection step. + min_confidence : float + Confidence threshold for accepting a per-tile shift. Tiles below this + threshold fall back to the global ``band_start`` / ``band_width``. + min_tile_signal : float + Tiles whose mean intensity is below this value are treated as + background and left unchanged (no roll). Detection on dark tiles is + unreliable and produces visible displacement artefacts. + skip_tiles : frozenset of (kx, ky) + Tile coordinates manually flagged to be left unchanged (no roll), + in addition to background tiles. Use to patch a small set of tiles + where auto-detection wraps tissue across the tile boundary. """ - arr, res, _, multiscale = _open_level(zarr_root, level=0) - n_levels_in = len(multiscale.datasets) + arr, res = read_omezarr(zarr_root) shape = arr.shape # (nz, nx_mosaic, ny_mosaic) chunk_x = arr.chunks[1] # OCT tile width in X (A-line axis) chunk_y = arr.chunks[2] # OCT tile height in Y (B-scan axis) @@ -617,47 +636,136 @@ def _apply_fix( n_cx = shape[1] // chunk_x n_cy = shape[2] // chunk_y - roll_amount = 0 if mode == "fix": + if not 0 <= band_start < chunk_x or band_width <= 0 or band_start + band_width > chunk_x: + raise ValueError( + f"Band [{band_start}:{band_start + band_width}] does not fit inside a tile of width {chunk_x}px. " + "Check --band_start / --band_width or detection inputs." + ) band_end = band_start + band_width - roll_amount = chunk_x - band_start - band_width + default_shift = chunk_x - band_start - band_width print( - f"Rolling each tile chunk by +{roll_amount} px " - f"(band [{band_start}:{band_end}] → right edge of tile) " - f"in {n_cx}x{n_cy} tile chunks." + f"Per-tile galvo fix: fallback shift +{default_shift} px " + f"(global band [{band_start}:{band_end}]) for tiles below " + f"min_confidence={min_confidence:.2f} in {n_cx}x{n_cy} grid." ) else: - print(f"Rolling each tile chunk by {-undo_shift:+d} px to reverse applied galvo fix") + default_shift = -int(undo_shift) + print(f"Rolling each tile chunk by {default_shift:+d} px to reverse applied galvo fix") + + # CPU roll helper (used for per-tile sub-blocks). The GPU roll helper is + # only meaningful in undo mode where the whole strip shares one shift; for + # fix mode, per-tile detection means many small rolls and the CPU path is + # the right choice. + def _roll_cpu(block: np.ndarray, shift: int) -> np.ndarray: + return fix_galvo_shift(block, shift=shift, axis=1) + + if use_gpu and mode != "fix": + from linumpy.gpu.corrections import fix_galvo_shift as _fix_galvo_shift_gpu + + def _roll_strip(strip: np.ndarray) -> np.ndarray: + return _fix_galvo_shift_gpu(strip, default_shift, axis=1, use_gpu=True) + else: + + def _roll_strip(strip: np.ndarray) -> np.ndarray: + return fix_galvo_shift(strip, shift=default_shift, axis=1) writer = OmeZarrWriter( output_path, shape=shape, chunk_shape=(shape[0], chunk_x, chunk_y), dtype=dtype, - overwrite=True, + overwrite=overwrite, ) - for kx in tqdm(range(n_cx), desc="Tile columns (axis 1)"): + nz = shape[0] + + # Per-tile-column accounting (thread-safe: each kx writes its own slot). + n_per_tile = np.zeros(n_cx, dtype=np.int32) + n_fallback = np.zeros(n_cx, dtype=np.int32) + n_skipped = np.zeros(n_cx, dtype=np.int32) + shifts_used: list[list[int]] = [[] for _ in range(n_cx)] + + def _detect_tile_shift(tile_aip: np.ndarray) -> tuple[int, float, bool]: + """Return (shift, confidence, used_per_tile) for a single tile AIP.""" + if n_extra: + sh, cf = detect_galvo_shift(tile_aip, n_pixel_return=n_extra) + sh = int(sh) + else: + bs, bw, cf = detect_galvo_band_in_tile(tile_aip) + sh = chunk_x - int(bs) - int(bw) if bw else default_shift + if float(cf) >= min_confidence: + return sh, float(cf), True + return default_shift, float(cf), False + + def _process_column(kx: int) -> None: xs = kx * chunk_x xe = xs + chunk_x + strip = arr[:, xs:xe, :] + + if mode != "fix": + writer[0:nz, xs:xe, :] = _roll_strip(strip) + return + # Per-tile detect+roll: AIP once, then n_cy small rolls. + aip_strip = strip.mean(axis=0) # (chunk_x, ny_total), float + out = np.empty_like(strip) for ky in range(n_cy): ys = ky * chunk_y ye = ys + chunk_y + tile_aip = aip_strip[:, ys:ye] + # Manual override: skip tiles the user has flagged as + # producing wrap artefacts. + if (kx, ky) in skip_tiles: + out[:, :, ys:ye] = strip[:, :, ys:ye] + n_skipped[kx] += 1 + continue + # Background tiles: leave content untouched. Detection on noise + # produces spurious shifts that visibly displace tile content. + if float(tile_aip.mean()) < min_tile_signal: + out[:, :, ys:ye] = strip[:, :, ys:ye] + n_skipped[kx] += 1 + continue + sh, _cf, used = _detect_tile_shift(tile_aip) + out[:, :, ys:ye] = _roll_cpu(strip[:, :, ys:ye], sh) + shifts_used[kx].append(sh) + if used: + n_per_tile[kx] += 1 + else: + n_fallback[kx] += 1 + writer[0:nz, xs:xe, :] = out - chunk = np.asarray(arr[:, xs:xe, ys:ye], dtype=np.float32) - - fixed = np.roll(chunk, roll_amount, axis=1) if mode == "fix" else np.roll(chunk, -undo_shift, axis=1) + n_workers = max(1, int(workers)) + if n_workers == 1: + for kx in tqdm(range(n_cx), desc="Tile columns"): + _process_column(kx) + else: + with ThreadPoolExecutor(max_workers=n_workers) as pool: + list( + tqdm( + pool.map(_process_column, range(n_cx)), + total=n_cx, + desc=f"Tile columns ({n_workers} workers)", + ) + ) - writer[0 : shape[0], xs:xe, ys:ye] = fixed.astype(dtype) + if mode == "fix": + total = int(n_per_tile.sum() + n_fallback.sum() + n_skipped.sum()) + all_shifts = [s for col in shifts_used for s in col] + if all_shifts: + uniq, counts = np.unique(np.asarray(all_shifts), return_counts=True) + order = np.argsort(-counts) + top = ", ".join(f"{int(uniq[i]):+d}px x{int(counts[i])}" for i in order[:5]) + print( + f"Per-tile detection: {int(n_per_tile.sum())}/{total} per-tile, " + f"{int(n_fallback.sum())} fallback (+{default_shift}px), " + f"{int(n_skipped.sum())} skipped (background, mean<{min_tile_signal:.1f}). " + f"Shift histogram (top {min(5, len(uniq))}): {top}" + ) - if n_levels_in > 1: - print(f"Regenerating OME-Zarr pyramid ({n_levels_in} levels) ...") - else: - print("Input has no pyramid -- writing single-level OME-Zarr.") - # n_levels in finalize() counts *additional* downsampled levels beyond level 0, - # so pass (n_levels_in - 1) to reproduce the same number of levels as the input. - writer.finalize(res, n_levels=n_levels_in - 1) + # Mosaic grids are written without a pyramid (single level), so we + # finalize with no extra levels too. + writer.finalize(res, n_levels=0) # --------------------------------------------------------------------------- @@ -715,8 +823,7 @@ def main() -> None: parser.error("--scan requires --band_width.") print( f"Band-start scan: range [{args.scan_range[0]}:{args.scan_range[1]}:{args.scan_range[2]}]px " - f"band_width={args.band_width}px " - f"(pyramid level {args.detection_level}) ..." + f"band_width={args.band_width}px ..." ) _scan_band_start( input_path, @@ -725,7 +832,6 @@ def main() -> None: scan_stop=args.scan_range[1], scan_step=args.scan_range[2], out_png=Path(args.scan), - level=args.detection_level, cmap=args.cmap, ) return @@ -745,16 +851,15 @@ def main() -> None: else: detector = "gradient-pair" if args.n_extra else "threshold fallback" print( - f"Auto-detecting galvo band using {detector} detector " - f"(pyramid level {args.detection_level})" + (f", n_extra={args.n_extra}px" if args.n_extra else "") + " ..." - ) - band_start, band_width, confidence = _auto_detect( - input_path, args.detection_level, n_extra=args.n_extra, verbose=args.verbose + f"Auto-detecting galvo band using {detector} detector" + + (f", n_extra={args.n_extra}px" if args.n_extra else "") + + " ..." ) + band_start, band_width, confidence = _auto_detect(input_path, n_extra=args.n_extra, verbose=args.verbose) band_start += args.band_offset - print("\nDetection result (scaled to level-0 pixels):") + print("\nDetection result (pixels):") print(f" band_start = {band_start} px" + (f" (offset: {args.band_offset:+d}px)" if args.band_offset else "")) print(f" band_width = {band_width} px") print(f" confidence = {confidence:.3f}") @@ -782,15 +887,15 @@ def main() -> None: print(f"[undo] will reverse roll shift={undo_shift}px per tile chunk") # ------------------------------------------------------------------ - # Step 2 - open level-0 array to report tile metadata + # Step 2 - open array to report tile metadata # ------------------------------------------------------------------ - arr, _res, _, _ = _open_level(input_path, level=0) + arr, _res = read_omezarr(input_path) chunk_x = arr.chunks[1] chunk_y = arr.chunks[2] n_cx = arr.shape[1] // chunk_x n_cy = arr.shape[2] // chunk_y - print("\nMosaic info (level 0):") + print("\nMosaic info:") print(f" shape = {arr.shape} (Z, Y, X)") print(f" tile chunks = ({chunk_x}, {chunk_y}) px in (X, Y)") print(f" tile grid = {n_cx} x {n_cy} tiles") @@ -812,6 +917,13 @@ def main() -> None: band_width=band_width, mode=args.mode, undo_shift=undo_shift, + overwrite=args.overwrite, + use_gpu=args.use_gpu, + workers=args.workers, + n_extra=args.n_extra, + min_confidence=args.min_confidence, + min_tile_signal=args.min_tile_signal, + skip_tiles=_parse_skip_tiles(args.skip_tiles), _verbose=args.verbose, ) print(f"Corrected zarr written: {output_path}") @@ -821,16 +933,14 @@ def main() -> None: # ------------------------------------------------------------------ if args.preview: preview_path = Path(args.preview) - arr0, _, _, _ = _open_level(input_path, level=0) _generate_comparison_preview( input_path, output_path, preview_path, - level=args.preview_level, cmap=args.cmap, band_start=band_start if args.mode == "fix" else None, band_width=band_width if args.mode == "fix" else None, - chunk_x=arr0.chunks[1], + chunk_x=chunk_x, ) # ------------------------------------------------------------------