From 31f7811f9d9f3ec7d603f0535ed39dc83d214d80 Mon Sep 17 00:00:00 2001 From: Frans Irgolitsch Date: Wed, 29 Apr 2026 15:38:13 -0400 Subject: [PATCH] feat: improve Allen atlas integration and RAS alignment (#101) --- linumpy/reference/allen.py | 431 ++++- linumpy/tests/test_io_allen.py | 284 +++ scripts/linum_align_to_ras.py | 1079 +++++++++++ scripts/linum_analyze_shifts.py | 298 --- scripts/linum_assess_slice_quality.py | 400 ---- scripts/linum_generate_pipeline_report.py | 2028 --------------------- scripts/tests/test_align_to_ras.py | 238 +++ 7 files changed, 2031 insertions(+), 2727 deletions(-) create mode 100644 linumpy/tests/test_io_allen.py create mode 100755 scripts/linum_align_to_ras.py delete mode 100644 scripts/linum_analyze_shifts.py delete mode 100644 scripts/linum_assess_slice_quality.py delete mode 100644 scripts/linum_generate_pipeline_report.py create mode 100644 scripts/tests/test_align_to_ras.py diff --git a/linumpy/reference/allen.py b/linumpy/reference/allen.py index fc8ea7f3..9b403992 100644 --- a/linumpy/reference/allen.py +++ b/linumpy/reference/allen.py @@ -1,7 +1,10 @@ """Methods to download data from the Allen Institute.""" +from collections.abc import Callable, Sequence from pathlib import Path +from typing import Any +import numpy as np import requests import SimpleITK as sitk from tqdm import tqdm @@ -9,6 +12,40 @@ AVAILABLE_RESOLUTIONS = [10, 25, 50, 100] +def numpy_to_sitk_image(volume: np.ndarray, spacing: tuple | Sequence, cast_dtype: type | None = None) -> sitk.Image: + """Convert numpy array (Z, Y, X) to SimpleITK image format. + + Parameters + ---------- + volume : np.ndarray + 3D volume with shape (Z, Y, X) matching the project-wide convention + (axis 0 = Z/depth, axis 1 = Y/row, axis 2 = X/column). + spacing : tuple + Voxel spacing in mm as (res_z, res_y, res_x). + cast_dtype : numpy dtype or None + If provided, cast the volume to this dtype before creating the SITK image + (useful for registration where float32 is expected). If None, preserve + the input numpy dtype. + + Returns + ------- + sitk.Image + SimpleITK image with proper spacing and orientation + """ + # sitk.GetImageFromArray interprets a numpy array with shape (Z, Y, X) as a + # SITK image with size (X, Y, Z), so no transpose is needed. The SITK call + # copies the buffer into its own storage, so we only allocate an extra + # numpy array when an explicit dtype cast is requested. + vol_for_sitk = volume.astype(cast_dtype, copy=False) if cast_dtype is not None else volume + vol_sitk = sitk.GetImageFromArray(vol_for_sitk) + # Spacing: SimpleITK uses (X, Y, Z) = (width, height, depth). + # Our spacing is (res_z, res_y, res_x), so SITK spacing is (res_x, res_y, res_z). + vol_sitk.SetSpacing([spacing[2], spacing[1], spacing[0]]) + vol_sitk.SetOrigin([0, 0, 0]) + vol_sitk.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1]) + return vol_sitk + + def download_template(resolution: int, cache: bool = True, cache_dir: str = ".data/") -> sitk.Image: """Download a 3D average mouse brain. @@ -41,7 +78,7 @@ def download_template(resolution: int, cache: bool = True, cache_dir: str = ".da if not (nrrd_file.is_file()): # Download the template response = requests.get(url, stream=True) - with nrrd_file.open("wb") as f: + with Path(nrrd_file).open("wb") as f: for data in tqdm(response.iter_content()): f.write(data) @@ -53,3 +90,395 @@ def download_template(resolution: int, cache: bool = True, cache_dir: str = ".da nrrd_file.unlink() # Removes the nrrd file return vol + + +def download_template_ras_aligned(resolution: int, cache: bool = True, cache_dir: str = ".data/") -> sitk.Image: + """Download a 3D average mouse brain and align it to RAS+ orientation. + + The Allen CCF v3 template is stored in PIR orientation + (SITK axes ``(X, Y, Z) = (AP, DV, ML)`` with ``+X = Posterior``, + ``+Y = Inferior``, ``+Z = Right``). Converting to RAS+ + (``+X = Right``, ``+Y = Anterior``, ``+Z = Superior``) requires + ``PermuteAxes((2, 0, 1))`` followed by flipping **both** the Y and Z + axes (I → S and P → A). + + Parameters + ---------- + resolution + Allen template resolution in micron. Must be 10, 25, 50 or 100. + cache + Keep the downloaded volume in cache + cache_dir + Cache directory + + Returns + ------- + Allen average mouse brain in RAS+ orientation. + """ + vol = download_template(resolution, cache, cache_dir) + + # Preparing the affine to align the template in the RAS+ + r_mm = resolution / 1e3 # Convert the resolution from micron to mm + vol.SetSpacing([r_mm] * 3) # Set the spacing in mm + # Ensure origin/direction are standardized so physical coordinates are stable + vol.SetOrigin([0.0, 0.0, 0.0]) + vol.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1]) + + # Convert PIR → RAS: + # PermuteAxes((2, 0, 1)) maps (P, I, R) → (R, P, I) + # Flip Y (P → A) and Z (I → S) to reach (R, A, S). + vol = sitk.PermuteAxes(vol, (2, 0, 1)) + vol = sitk.Flip(vol, (False, True, True)) + # After permuting/flipping, also ensure origin/direction are identity/zero + vol.SetOrigin([0.0, 0.0, 0.0]) + vol.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1]) + + return vol + + +def register_3d_rigid_to_allen( + moving_image: np.ndarray, + moving_spacing: tuple, + allen_resolution: int = 100, + metric: str = "MI", + max_iterations: int = 1000, + verbose: bool = False, + progress_callback: Callable[[Any], None] | None = None, + initial_rotation_deg: tuple = (0.0, 0.0, 0.0), +) -> tuple: + """Perform 3D rigid registration of a brain volume to the Allen atlas. + + Parameters + ---------- + moving_image : np.ndarray + 3D brain volume to register (shape: Z, Y, X) + moving_spacing : tuple + Voxel spacing in mm (res_z, res_y, res_x) + allen_resolution : int + Allen template resolution in micron (default: 100) + metric : str + Similarity metric: 'MI' (mutual information), 'MSE', 'CC' (correlation), + or 'AntsCC' (ANTS correlation) + max_iterations : int + Maximum number of iterations + verbose : bool + Print registration progress + progress_callback : callable, optional + Callback function called on each iteration with the registration method. + Function signature: callback(registration_method) + initial_rotation_deg : tuple, optional + Initial rotation in degrees (rx, ry, rz) applied before optimization. + + Returns + ------- + transform : sitk.Euler3DTransform + Rigid transform to align moving_image to Allen atlas + stop_condition : str + Optimizer stopping condition + error : float + Final registration metric value + """ + # Download and prepare Allen atlas in RAS orientation + allen_atlas = download_template_ras_aligned(allen_resolution, cache=True) + + # If the moving image is coarser than the Allen atlas along any axis, + # downsample the atlas to match the moving resolution. The registration + # cost is dominated by the fixed (Allen) image size, so downsampling the + # atlas up-front is much cheaper than upsampling moving to a finer grid + # that carries no additional information. + moving_min_spacing_mm = min(moving_spacing) + allen_spacing_mm = allen_atlas.GetSpacing() + allen_min_spacing_mm = min(allen_spacing_mm) + if moving_min_spacing_mm > allen_min_spacing_mm * 1.2: + target_spacing_mm = float(moving_min_spacing_mm) + allen_size = allen_atlas.GetSize() + new_size = [max(1, round(sz * sp / target_spacing_mm)) for sz, sp in zip(allen_size, allen_spacing_mm, strict=False)] + ref = sitk.Image(new_size, allen_atlas.GetPixelIDValue()) + ref.SetOrigin(allen_atlas.GetOrigin()) + ref.SetDirection(allen_atlas.GetDirection()) + ref.SetSpacing((target_spacing_mm,) * 3) + resampler = sitk.ResampleImageFilter() + resampler.SetReferenceImage(ref) + resampler.SetInterpolator(sitk.sitkLinear) + resampler.SetDefaultPixelValue(0) + allen_atlas = resampler.Execute(allen_atlas) + if verbose: + print( + f"Downsampled Allen atlas to match moving spacing: " + f"{allen_spacing_mm} mm → {allen_atlas.GetSpacing()} mm, " + f"size {allen_size} → {allen_atlas.GetSize()}" + ) + + # Crop moving image to tissue bounding box to reduce volume size. + # Large motor drift during acquisition inflates the canvas with empty space, + # causing the Allen-domain resampling to clip away brain tissue. Cropping + # first keeps the volume compact so most of the brain survives resampling, + # giving the optimizer a much better cost-function landscape. + margin_voxels = 10 + crop_origin_mm = (0.0, 0.0, 0.0) # physical offset in (Z, Y, X) order + nonzero_coords = np.nonzero(moving_image) + if len(nonzero_coords[0]) > 0: + bbox_slices = tuple( + slice( + max(0, int(dim.min()) - margin_voxels), + min(moving_image.shape[ax], int(dim.max()) + margin_voxels + 1), + ) + for ax, dim in enumerate(nonzero_coords) + ) + crop_origin_mm = ( + bbox_slices[0].start * moving_spacing[0], + bbox_slices[1].start * moving_spacing[1], + bbox_slices[2].start * moving_spacing[2], + ) + cropped = moving_image[bbox_slices] + if verbose: + print(f"Cropped tissue bounding box: {moving_image.shape} -> {cropped.shape}") + moving_image = cropped + + # Convert moving image to SimpleITK format. + # Origin stays at (0,0,0) so the compact brain sits at the start of physical + # space and overlaps with the Allen atlas domain during resampling. The crop + # offset is added to the final transform's translation after registration so + # the transform remains valid for the original (uncropped) full volume. + moving_sitk = numpy_to_sitk_image(moving_image, moving_spacing) + + # Compute a preliminary brain centre BEFORE any resampling. + # This is used as the fallback only when needs_resample=False (images already + # share the same physical space). When resampling IS needed, this value is + # overwritten below with the centroid of the clipped brain within the Allen + # domain, because the full-brain geometric centre can be tens of mm outside + # the Allen atlas extent and would produce a translation that maps every + # Allen voxel outside the resampled moving image buffer. + original_moving_size = moving_sitk.GetSize() + original_moving_center_idx = [s / 2.0 for s in original_moving_size] + original_moving_center = np.array(moving_sitk.TransformContinuousIndexToPhysicalPoint(original_moving_center_idx)) + + # Resample moving image to match Allen atlas spacing and size for better registration. + # NOTE: we deliberately keep the original moving center computed above so that the + # centre-aligned fallback initialisation is always correct even after resampling. + allen_spacing = allen_atlas.GetSpacing() + allen_size = allen_atlas.GetSize() + moving_spacing_sitk = moving_sitk.GetSpacing() + moving_size_sitk = moving_sitk.GetSize() + + # Check if resampling is needed (if spacing differs significantly or sizes are very different) + spacing_ratio = np.array(allen_spacing) / np.array(moving_spacing_sitk) + size_ratio = np.array(allen_size, dtype=float) / np.array(moving_size_sitk, dtype=float) + + # Resample if spacing differs by more than 10% or if volumes are very different sizes + needs_resample = np.any(np.abs(spacing_ratio - 1.0) > 0.1) or np.any(size_ratio < 0.5) or np.any(size_ratio > 2.0) + + if needs_resample: + if verbose: + print( + f"Resampling moving image from {moving_spacing_sitk} mm, size {moving_size_sitk} " + f"to {allen_spacing} mm, size {allen_size}" + ) + resampler = sitk.ResampleImageFilter() + resampler.SetReferenceImage(allen_atlas) + resampler.SetInterpolator(sitk.sitkLinear) + resampler.SetDefaultPixelValue(0) + moving_sitk = resampler.Execute(moving_sitk) + + # Recompute the effective brain centre from the RESAMPLED image. + # The pre-resampling centre can lie far outside the Allen domain (e.g. a + # large 25 µm brain whose geometric centre is at ~37 mm, while the Allen + # atlas only spans ~11 mm). Using that centre directly gives a translation + # of +31 mm, which maps every Allen voxel outside the moving image buffer. + # Instead, use the centroid of the non-zero (brain-tissue) voxels that + # survived the clipping into the Allen domain. + moving_arr = sitk.GetArrayFromImage(moving_sitk) # shape (Z, Y, X) in numpy + nonzero_idx = np.argwhere(moving_arr > 0) # rows are (z, y, x) + if len(nonzero_idx) > 0: + centroid_zyx = nonzero_idx.mean(axis=0) + # SITK index order is (x, y, z), reverse of numpy (z, y, x) + centroid_xyz = [float(centroid_zyx[2]), float(centroid_zyx[1]), float(centroid_zyx[0])] + original_moving_center = np.array(moving_sitk.TransformContinuousIndexToPhysicalPoint(centroid_xyz)) + if verbose: + print(f"Resampled brain centroid (physical): {original_moving_center} mm") + # If all voxels are zero (brain entirely outside Allen domain), keep + # the pre-resampling centre and accept a potentially poor initialization. + + # Normalize images for better registration + fixed_image = sitk.Normalize(allen_atlas) + moving_image_sitk = sitk.Normalize(moving_sitk) + + if verbose: + print(f"Fixed (Allen) image: size={fixed_image.GetSize()}, spacing={fixed_image.GetSpacing()}") + print(f"Moving (brain) image: size={moving_image_sitk.GetSize()}, spacing={moving_image_sitk.GetSpacing()}") + + # Initialize registration + registration_method = sitk.ImageRegistrationMethod() + + # Set metric + # Note: For correlation-based metrics, negative values are possible + # The optimizer will maximize MI/CC and minimize MSE + if metric.upper() == "MI": + registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) + elif metric.upper() == "MSE": + registration_method.SetMetricAsMeanSquares() + elif metric.upper() == "CC": + registration_method.SetMetricAsCorrelation() + elif metric.upper() == "ANTSCC": + registration_method.SetMetricAsANTSNeighborhoodCorrelation(radius=20) + else: + raise ValueError(f"Unknown metric: {metric}. Choose from: MI, MSE, CC, AntsCC") + + # Set metric sampling - use regular sampling for reproducibility and speed + registration_method.SetMetricSamplingStrategy(registration_method.REGULAR) + registration_method.SetMetricSamplingPercentage(0.25) # 25% of pixels is usually sufficient + + # Set optimizer with conservative parameters + # Use smaller learning rate and steps to prevent overshooting + learning_rate = 0.5 # Smaller learning rate for stability + min_step = 0.0001 + registration_method.SetOptimizerAsRegularStepGradientDescent( + learningRate=learning_rate, + minStep=min_step, + numberOfIterations=max_iterations, + relaxationFactor=0.5, + gradientMagnitudeTolerance=1e-8, + ) + + # Use physical shift for scaling - more appropriate for physical coordinate registration + # This computes scales based on how a 1mm shift affects the metric + registration_method.SetOptimizerScalesFromPhysicalShift() + + # Multi-resolution approach - start coarse, refine progressively + # More levels for robustness + registration_method.SetShrinkFactorsPerLevel([8, 4, 2, 1]) + registration_method.SetSmoothingSigmasPerLevel([4, 2, 1, 0]) + registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() + + # Initialize rigid transform with guaranteed overlap. + # Use the ORIGINAL moving image centre (before any resampling) so that + # the centre-aligned fallback always produces a meaningful initial translation + # regardless of the resolution/size relationship between the two images. + initial_transform = sitk.Euler3DTransform() + + # Calculate image centres in physical space + fixed_size = fixed_image.GetSize() + fixed_center_idx = [s / 2.0 for s in fixed_size] + fixed_center = np.array(fixed_image.TransformContinuousIndexToPhysicalPoint(fixed_center_idx)) + + # Translation to align brain centre with Allen centre (ensures initial overlap). + # ITK transform maps fixed→moving: T(p) = R(p - c) + c + t + # For identity rotation and c=fixed_center: T(fixed_center) = fixed_center + t + # We need T(fixed_center) = original_moving_center, so t = moving_center - fixed_center. + translation = tuple(original_moving_center - fixed_center) + + # Set center of rotation to fixed image center + initial_transform.SetCenter(fixed_center) + + # Convert initial rotation from degrees to radians + rx_rad = np.deg2rad(initial_rotation_deg[0]) + ry_rad = np.deg2rad(initial_rotation_deg[1]) + rz_rad = np.deg2rad(initial_rotation_deg[2]) + + # Set translation to align centers and apply initial rotation + initial_transform.SetTranslation(translation) + initial_transform.SetRotation(rx_rad, ry_rad, rz_rad) + + if verbose: + print(f"Initial center alignment: fixed={fixed_center}, moving (original)={original_moving_center}") + print(f"Translation to align centers: {translation}") + if any(r != 0 for r in initial_rotation_deg): + print(f"Initial rotation (deg): {initial_rotation_deg}") + + # Only try MOMENTS initialization if no initial rotation was specified + # (user-specified rotation takes precedence) and the image was NOT resampled + # into the Allen domain. After resampling, the brain occupies only a small + # corner of the 640³ Allen image; sitk.Normalize then gives the large + # zero-padded background a uniform negative value that dominates the + # centre-of-mass computation, producing translation ≈ 0 which places every + # sample point outside the brain buffer. + if all(r == 0 for r in initial_rotation_deg) and not needs_resample: + try: + # Use MOMENTS initialization which is more robust + init_transform = sitk.Euler3DTransform() + init_transform = sitk.CenteredTransformInitializer( + fixed_image, moving_image_sitk, init_transform, sitk.CenteredTransformInitializerFilter.MOMENTS + ) + # Verify the initialized transform has reasonable translation + init_params = init_transform.GetParameters() + init_translation = np.array(init_params[3:6]) + + # Check if the initialized transform is reasonable (translation not too large) + # If translation is reasonable, use it; otherwise use our center-aligned one + translation_magnitude = np.linalg.norm(init_translation) + fixed_size_mm = np.array(fixed_image.GetSpacing()) * np.array(fixed_image.GetSize()) + max_reasonable_translation = np.linalg.norm(fixed_size_mm) * 0.5 # Half the image size + + if translation_magnitude < max_reasonable_translation: + initial_transform = init_transform + if verbose: + print(f"Using MOMENTS initialization (translation magnitude: {translation_magnitude:.2f} mm)") + else: + if verbose: + print( + f"MOMENTS initialization translation too large ({translation_magnitude:.2f} mm), using center-aligned" + ) + except Exception as e: + if verbose: + print(f"MOMENTS initialization failed: {e}, using center-aligned translation") + + if verbose: + final_params = initial_transform.GetParameters() + final_center = initial_transform.GetCenter() + print(f"Final initial transform: rotation={final_params[:3]}, translation={final_params[3:]}") + print(f"Transform center: {final_center}") + + registration_method.SetInitialTransform(initial_transform) + registration_method.SetInterpolator(sitk.sitkLinear) + + # Set up iteration callback + if verbose or progress_callback is not None: + + def command_iteration(method: Any) -> None: + if verbose: + if method.GetOptimizerIteration() == 0: + print(f"Estimated scales: {method.GetOptimizerScales()}") + print( + f"Iteration {method.GetOptimizerIteration():3d} = " + f"{method.GetMetricValue():7.5f} : " + f"{method.GetOptimizerPosition()}" + ) + if progress_callback is not None: + progress_callback(method) + + registration_method.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(registration_method)) + + # Execute registration + final_transform = registration_method.Execute(fixed_image, moving_image_sitk) + + stop_condition = registration_method.GetOptimizerStopConditionDescription() + error = registration_method.GetMetricValue() + + if verbose: + print(f"Registration complete: {stop_condition}") + print(f"Final metric value: {error:.6f}") + final_params = final_transform.GetParameters() + print(f"Final transform: rotation={final_params[:3]}, translation={final_params[3:]}") + print(f"Fixed image size: {fixed_image.GetSize()}, spacing: {fixed_image.GetSpacing()}") + print(f"Moving image size: {moving_image_sitk.GetSize()}, spacing: {moving_image_sitk.GetSpacing()}") + + # Restore crop offset in the translation so the transform is valid for the + # full original (uncropped) brain volume. Derivation: + # T(p) = R(p-c)+c+t maps Allen coords to cropped-brain coords (origin=0). + # Same tissue in full brain is at (cropped_coord + crop_origin_mm). + # So t_full = t_crop + crop_origin_sitk (center c cancels out). + if any(v != 0.0 for v in crop_origin_mm): + params = list(final_transform.GetParameters()) + # SITK Euler3D params: (rx, ry, rz, tx, ty, tz) in SITK XYZ order + # numpy axis order (Z, Y, X) -> SITK (X, Y, Z): + params[3] += crop_origin_mm[2] # SITK X = numpy axis 2 + params[4] += crop_origin_mm[1] # SITK Y = numpy axis 1 + params[5] += crop_origin_mm[0] # SITK Z = numpy axis 0 + final_transform.SetParameters(params) + if verbose: + print( + f"Adjusted translation for crop: +" + f"[{crop_origin_mm[2]:.3f}, {crop_origin_mm[1]:.3f}, {crop_origin_mm[0]:.3f}] mm (SITK XYZ)" + ) + + return final_transform, stop_condition, error diff --git a/linumpy/tests/test_io_allen.py b/linumpy/tests/test_io_allen.py new file mode 100644 index 00000000..6c33a782 --- /dev/null +++ b/linumpy/tests/test_io_allen.py @@ -0,0 +1,284 @@ +"""Tests for linumpy/io/allen.py — orientation handling and registration. + +The Allen template download is monkey-patched to return a synthetic PIR-oriented +volume with a deliberately asymmetric tissue distribution. That keeps these +tests offline and lets us verify that ``download_template_ras_aligned`` really +produces a RAS+ volume (``+X = Right``, ``+Y = Anterior``, ``+Z = Superior``). +""" + +from __future__ import annotations + +import numpy as np +import pytest +import SimpleITK as sitk + +from linumpy.reference import allen + +# --------------------------------------------------------------------------- +# Synthetic PIR-oriented Allen template +# --------------------------------------------------------------------------- + + +def _make_synthetic_pir_template(resolution_um: int = 100) -> sitk.Image: + """Build a small synthetic volume that mimics the Allen CCF nrrd layout. + + Allen CCF v3 stores the template in PIR: + nrrd axis 0 = AP (+=Posterior) + nrrd axis 1 = DV (+=Inferior) + nrrd axis 2 = ML (+=Right) + + ``sitk.ReadImage`` maps nrrd axis k to SITK axis k, so the returned + SITK image has ``(X, Y, Z) = (AP, DV, ML)``. Each axis is given a + unique, monotonically increasing gradient so we can identify the + resulting orientation unambiguously after the RAS reorientation. + """ + # Pick axis sizes that are all distinct so permutations are detectable. + ap_size, dv_size, ml_size = 12, 8, 10 + + # numpy shape (Z, Y, X) for sitk.GetImageFromArray: + # numpy Z ↔ SITK Z = ML + # numpy Y ↔ SITK Y = DV + # numpy X ↔ SITK X = AP + ap = np.arange(ap_size, dtype=np.float32)[None, None, :] * 1.0 # unit step + dv = np.arange(dv_size, dtype=np.float32)[None, :, None] * 100.0 + ml = np.arange(ml_size, dtype=np.float32)[:, None, None] * 10000.0 + + arr = ap + dv + ml # each axis contributes a distinct decimal place + + vol = sitk.GetImageFromArray(arr) + r_mm = resolution_um / 1e3 + vol.SetSpacing((r_mm, r_mm, r_mm)) + vol.SetOrigin((0.0, 0.0, 0.0)) + vol.SetDirection((1, 0, 0, 0, 1, 0, 0, 0, 1)) + return vol + + +# --------------------------------------------------------------------------- +# download_template_ras_aligned — orientation +# --------------------------------------------------------------------------- + + +class TestDownloadTemplateRasAligned: + """Verify the RAS reorientation of the Allen template.""" + + @pytest.fixture + def ras_template(self, monkeypatch): + def fake_download_template(resolution, cache=True, cache_dir=".data/"): + return _make_synthetic_pir_template(resolution) + + monkeypatch.setattr(allen, "download_template", fake_download_template) + return allen.download_template_ras_aligned(100) + + def test_spacing_is_isotropic_and_in_mm(self, ras_template): + spacing = ras_template.GetSpacing() + assert spacing == pytest.approx((0.1, 0.1, 0.1)) + + def test_origin_is_zero(self, ras_template): + assert ras_template.GetOrigin() == pytest.approx((0.0, 0.0, 0.0)) + + def test_direction_is_identity(self, ras_template): + assert ras_template.GetDirection() == pytest.approx((1, 0, 0, 0, 1, 0, 0, 0, 1)) + + def test_size_reflects_permutation(self, ras_template): + """After ``PermuteAxes((2, 0, 1))`` the SITK size becomes (ML, AP, DV).""" + # Input sizes: AP=12, DV=8, ML=10 → output (ML, AP, DV) = (10, 12, 8) + assert ras_template.GetSize() == (10, 12, 8) + + def test_positive_x_is_right(self, ras_template): + """+X must point toward Right (originally +ML in nrrd).""" + arr = sitk.GetArrayFromImage(ras_template) + # numpy axis 2 = SITK X; ML gradient was the `10000` coefficient. + col = arr[0, 0, :] + diffs = np.diff(col) + # Gradient along X in RAS-aligned volume should increase monotonically. + assert np.all(diffs > 0), f"+X is not monotonic along ML (Right): {col}" + + def test_positive_y_is_anterior(self, ras_template): + """+Y must point toward Anterior (originally -AP in nrrd). + + Raw AP gradient increases with +Posterior, so after reorientation the + AP gradient should DECREASE along +Y (since +Y = Anterior). + """ + arr = sitk.GetArrayFromImage(ras_template) + # numpy axis 1 = SITK Y; AP gradient was the `1.0` coefficient. + # Extract AP component by taking the modulo-100 decimal of a single X,Z column. + col = arr[0, :, 0] % 100.0 # keep only AP contribution (0 .. 11) + diffs = np.diff(col) + assert np.all(diffs < 0), f"+Y is not anterior (AP should decrease): {col}" + + def test_positive_z_is_superior(self, ras_template): + """+Z must point toward Superior (originally -DV in nrrd). + + Raw DV gradient increases with +Inferior, so after reorientation the + DV gradient should DECREASE along +Z (since +Z = Superior). + """ + arr = sitk.GetArrayFromImage(ras_template) + # numpy axis 0 = SITK Z; DV gradient was the `100` coefficient. + # Extract DV component using (value % 10000) // 100. + col = (arr[:, 0, 0] % 10000.0) // 100.0 # 0 .. 7 + diffs = np.diff(col) + assert np.all(diffs < 0), f"+Z is not superior (DV should decrease): {col}" + + +# --------------------------------------------------------------------------- +# numpy_to_sitk_image +# --------------------------------------------------------------------------- + + +class TestNumpyToSitkImage: + def test_roundtrip_preserves_values(self): + arr = np.arange(2 * 3 * 4, dtype=np.float32).reshape(2, 3, 4) + img = allen.numpy_to_sitk_image(arr, spacing=(0.1, 0.2, 0.3)) + back = sitk.GetArrayFromImage(img) + np.testing.assert_array_equal(back, arr) + + def test_spacing_is_permuted_to_xyz(self): + arr = np.zeros((2, 3, 4), dtype=np.float32) + img = allen.numpy_to_sitk_image(arr, spacing=(0.1, 0.2, 0.3)) + # spacing=(res_z, res_y, res_x) → SITK GetSpacing=(res_x, res_y, res_z) + assert img.GetSpacing() == pytest.approx((0.3, 0.2, 0.1)) + + def test_size_is_reversed_from_numpy_shape(self): + arr = np.zeros((2, 3, 4), dtype=np.float32) + img = allen.numpy_to_sitk_image(arr, spacing=(1.0, 1.0, 1.0)) + assert img.GetSize() == (4, 3, 2) + + def test_origin_and_direction_are_identity(self): + arr = np.zeros((2, 3, 4), dtype=np.float32) + img = allen.numpy_to_sitk_image(arr, spacing=(1.0, 1.0, 1.0)) + assert img.GetOrigin() == (0.0, 0.0, 0.0) + assert img.GetDirection() == (1, 0, 0, 0, 1, 0, 0, 0, 1) + + def test_cast_dtype_produces_float32(self): + arr = np.ones((2, 3, 4), dtype=np.uint16) + img = allen.numpy_to_sitk_image(arr, spacing=(1.0, 1.0, 1.0), cast_dtype=np.float32) + assert img.GetPixelID() == sitk.sitkFloat32 + + def test_no_cast_preserves_dtype(self): + arr = np.ones((2, 3, 4), dtype=np.uint16) + img = allen.numpy_to_sitk_image(arr, spacing=(1.0, 1.0, 1.0)) + assert img.GetPixelID() == sitk.sitkUInt16 + + def test_input_array_not_modified(self): + arr = np.arange(24, dtype=np.float32).reshape(2, 3, 4) + original = arr.copy() + allen.numpy_to_sitk_image(arr, spacing=(1.0, 1.0, 1.0), cast_dtype=np.float32) + np.testing.assert_array_equal(arr, original) + + +# --------------------------------------------------------------------------- +# register_3d_rigid_to_allen — end-to-end self-registration +# --------------------------------------------------------------------------- + + +def _make_synthetic_brain(shape=(24, 24, 24), spacing=(0.2, 0.2, 0.2)): + """Small asymmetric synthetic brain with a unique intensity pattern per axis.""" + z, y, x = np.indices(shape, dtype=np.float32) + # Ellipsoid mask offset from centre, asymmetric along each axis. + cz, cy, cx = shape[0] * 0.55, shape[1] * 0.5, shape[2] * 0.45 + rz, ry, rx = shape[0] * 0.35, shape[1] * 0.3, shape[2] * 0.4 + mask = ((z - cz) / rz) ** 2 + ((y - cy) / ry) ** 2 + ((x - cx) / rx) ** 2 < 1 + brain = np.zeros(shape, dtype=np.float32) + # Distinct gradient along each axis so registration has more than a single + # rotationally symmetric blob to work with. + brain[mask] = 1.0 + 0.3 * (z[mask] / shape[0]) + 0.5 * (y[mask] / shape[1]) + 0.7 * (x[mask] / shape[2]) + return brain + + +class TestRegisterRigidToAllen: + """End-to-end registration tests using a synthetic Allen template.""" + + @pytest.fixture(autouse=True) + def patch_allen(self, monkeypatch): + def fake_download_template(resolution, cache=True, cache_dir=".data/"): + return _make_synthetic_pir_template(resolution) + + monkeypatch.setattr(allen, "download_template", fake_download_template) + + def test_self_registration_recovers_identity(self): + """Registering the RAS Allen template against itself yields ~identity.""" + target = allen.download_template_ras_aligned(100) + moving = sitk.GetArrayFromImage(target) # numpy (Z, Y, X) + # SITK spacing is (X, Y, Z); moving_spacing is (res_z, res_y, res_x) + sx, sy, sz = target.GetSpacing() + transform, stop, _err = allen.register_3d_rigid_to_allen( + moving_image=moving, + moving_spacing=(sz, sy, sx), + allen_resolution=100, + metric="MSE", + max_iterations=50, + verbose=False, + ) + params = transform.GetParameters() + rotation = np.array(params[:3]) + translation = np.array(params[3:6]) + # The MSE minimum is at identity; allow generous tolerances because the + # synthetic volume is tiny. + assert np.max(np.abs(rotation)) < 0.1, f"Rotation too large: {rotation}" + assert np.max(np.abs(translation)) < 1.0, f"Translation too large: {translation}" + assert stop # non-empty stop-condition string + + def test_downsamples_allen_when_moving_is_coarser(self, capsys): + """If moving resolution > allen resolution, allen must be downsampled.""" + # Moving at 200 µm, allen synthetic at 100 µm → expect downsampling. + shape = (10, 10, 10) + moving = _make_synthetic_brain(shape, spacing=(0.2, 0.2, 0.2)) + _, _, _ = allen.register_3d_rigid_to_allen( + moving_image=moving, + moving_spacing=(0.2, 0.2, 0.2), + allen_resolution=100, + metric="MSE", + max_iterations=3, + verbose=True, + ) + captured = capsys.readouterr().out + assert "Downsampled Allen atlas" in captured + + def test_does_not_downsample_when_already_coarse(self, capsys): + """If moving resolution ≤ allen resolution, allen must NOT be downsampled.""" + shape = (10, 10, 10) + moving = _make_synthetic_brain(shape, spacing=(0.05, 0.05, 0.05)) + _, _, _ = allen.register_3d_rigid_to_allen( + moving_image=moving, + moving_spacing=(0.05, 0.05, 0.05), + allen_resolution=100, + metric="MSE", + max_iterations=3, + verbose=True, + ) + captured = capsys.readouterr().out + assert "Downsampled Allen atlas" not in captured + + def test_crop_offset_reported_in_verbose_output(self, capsys): + """The ``crop_origin_mm`` restoration must add an offset proportional to + the leading zero-padding of the moving volume. We use a plain cube so + the non-zero bounding box equals the cube's shape exactly, making the + expected crop origin easy to compute. + """ + # A fully filled cube — nonzero bbox equals the full cube shape. + cube_size = 12 + cube = np.ones((cube_size, cube_size, cube_size), dtype=np.float32) + leading_pad = (20, 15, 25) # (pad_z, pad_y, pad_x); each > 10 (margin) + canvas = np.pad(cube, [(p, 5) for p in leading_pad], mode="constant", constant_values=0) + + _, _, _ = allen.register_3d_rigid_to_allen( + moving_image=canvas, + moving_spacing=(0.1, 0.1, 0.1), + allen_resolution=100, + metric="MSE", + max_iterations=0, + verbose=True, + ) + captured = capsys.readouterr().out + # Expected crop start per numpy axis (voxels): pad_axis - margin = pad - 10. + margin = 10 + spacing = 0.1 + expected_numpy = tuple((p - margin) * spacing for p in leading_pad) + # SITK XYZ = numpy axes (X=2, Y=1, Z=0) + expected_sitk_xyz = (expected_numpy[2], expected_numpy[1], expected_numpy[0]) + expected_log = ( + "Adjusted translation for crop: +[" + f"{expected_sitk_xyz[0]:.3f}, {expected_sitk_xyz[1]:.3f}, {expected_sitk_xyz[2]:.3f}" + "] mm (SITK XYZ)" + ) + assert expected_log in captured, f"Expected log not found. Got:\n{captured}" diff --git a/scripts/linum_align_to_ras.py b/scripts/linum_align_to_ras.py new file mode 100755 index 00000000..ac8812cb --- /dev/null +++ b/scripts/linum_align_to_ras.py @@ -0,0 +1,1079 @@ +#!/usr/bin/env python3 + +""" +Align a 3D brain volume to RAS orientation using rigid registration to the Allen atlas. + +This script computes a rigid transform from the input brain volume to a RAS-aligned +version by registering it to the Allen Brain Atlas. The transform can be applied +directly to the zarr file (resampling) or stored in OME-Zarr metadata. +""" + +# Configure thread limits before numpy/scipy imports +import linumpy.config.threads # noqa: F401 + +import argparse +import json +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import SimpleITK as sitk +from tqdm.auto import tqdm + +from linumpy.imaging.orientation import ( + apply_orientation_transform, + parse_orientation_code, + reorder_resolution, +) +from linumpy.io.zarr import AnalysisOmeZarrWriter, read_omezarr +from linumpy.reference import allen + +matplotlib.use("Agg") # Non-interactive backend + +# Constants +DEFAULT_ALLEN_RESOLUTION = 100 +DEFAULT_MAX_ITERATIONS = 1000 +DEFAULT_METRIC = "MI" + + +def _debug_log(message: str, **fields: Any) -> None: + """Append an NDJSON line describing a slicing/labelling decision. + + Active only when ``LINUMPY_DEBUG_LOG`` is set, so production runs pay + nothing. Used to capture runtime evidence of which volume conventions + each preview function actually receives. + """ + import os + + path = os.environ.get("LINUMPY_DEBUG_LOG") + if not path: + return + try: + import time + + entry = { + "id": f"log_{int(time.time() * 1000)}_panels", + "timestamp": int(time.time() * 1000), + "sessionId": "6fa1b3", + "runId": "panels-fix", + "hypothesisId": "H1", + "location": "linum_align_to_ras.py", + "message": message, + "data": fields, + } + with Path(path).open("a") as f: + f.write(json.dumps(entry) + "\n") + except Exception: + pass + + +def _build_arg_parser() -> argparse.ArgumentParser: + """Build the command-line argument parser.""" + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter) + p.add_argument("input_zarr", help="Input OME-Zarr file from 3D reconstruction pipeline") + p.add_argument("output_zarr", help="Output OME-Zarr file (RAS-aligned)") + p.add_argument( + "--allen-resolution", + type=int, + default=DEFAULT_ALLEN_RESOLUTION, + choices=allen.AVAILABLE_RESOLUTIONS, + help="Allen atlas resolution in micron [%(default)s]", + ) + p.add_argument( + "--metric", + type=str, + default=DEFAULT_METRIC, + choices=["MI", "MSE", "CC", "AntsCC"], + help="Registration metric [%(default)s]", + ) + p.add_argument( + "--max-iterations", + type=int, + default=DEFAULT_MAX_ITERATIONS, + help="Maximum registration iterations [%(default)s]", + ) + p.add_argument( + "--store-transform-only", action="store_true", help="Store transform in metadata only (don't resample volume)" + ) + p.add_argument("--level", type=int, default=0, help="Pyramid level for registration (0 = full resolution) [%(default)s]") + p.add_argument( + "--chunks", type=int, nargs=3, default=None, help="Chunk size for output zarr. Uses input chunks when None." + ) + p.add_argument( + "--n-levels", type=int, default=None, help="Number of pyramid levels for output. Uses Allen atlas levels when None." + ) + p.add_argument( + "--pyramid_resolutions", + type=float, + nargs="+", + default=None, + help="Target pyramid resolution levels in µm (e.g. 10 25 50 100).\n" + "If omitted, inherits levels from input zarr metadata or uses Allen resolutions.", + ) + p.add_argument( + "--make_isotropic", action="store_true", default=True, help="Resample to isotropic voxels at each pyramid level." + ) + p.add_argument("--no_isotropic", dest="make_isotropic", action="store_false") + p.add_argument("--verbose", action="store_true", help="Print registration progress") + p.add_argument("--preview", type=str, default=None, help="Generate preview image showing alignment comparison") + p.add_argument( + "--input-orientation", + type=str, + default=None, + help="Input volume orientation code (3 letters: R/L, A/P, S/I)\nExamples: 'RAS' (Allen), 'LPI', 'PIR'", + ) + p.add_argument( + "--initial-rotation", + type=float, + nargs=3, + default=[0.0, 0.0, 0.0], + metavar=("RX", "RY", "RZ"), + help="Initial rotation angles in degrees (Rx, Ry, Rz).\nUse to provide initial orientation hint for registration.", + ) + p.add_argument("--preview-only", action="store_true", help="Only generate preview of input volume (no registration)") + p.add_argument( + "--orientation-preview", + type=str, + default=None, + metavar="PATH", + help="Save a 3-panel preview of the volume after --input-orientation and\n" + "--initial-rotation are applied. Use to verify these parameters\n" + "before committing to a full registration run.", + ) + p.add_argument( + "--orientation-preview-only", + action="store_true", + help="Generate --orientation-preview and exit without running registration.", + ) + return p + + +# ============================================================================= +# Orientation utilities — imported from linumpy.imaging.orientation +# ============================================================================= + + +def create_registration_progress_callback( + max_iterations: int, + n_resolution_levels: int = 3, + pbar: tqdm | None = None, + registration_start_step: int = 0, + registration_steps: int = 0, +) -> Callable: + """ + Create a progress callback for registration. + + Parameters + ---------- + max_iterations : int + Maximum iterations per level + n_resolution_levels : int + Number of resolution levels in the registration pyramid + pbar : tqdm, optional + Progress bar to update + registration_start_step : int + Step number where registration starts in progress bar + registration_steps : int + Number of steps allocated for registration + + Returns + ------- + callable + Progress callback function compatible with SimpleITK registration + """ + total_iterations = [0] + level_counter = [0] + last_iteration = [-1] + # Worst-case budget (used only as the denominator for the progress bar). + estimated_total = float(max_iterations * n_resolution_levels) + + def callback(method: Any) -> None: + """Update progress during registration iterations.""" + iteration = method.GetOptimizerIteration() + metric = method.GetMetricValue() + + # Detect resolution-level transitions (iteration counter resets to 0 + # when SimpleITK starts the next pyramid level). + if iteration < last_iteration[0]: + level_counter[0] += 1 + last_iteration[0] = iteration + + total_iterations[0] += 1 + + if pbar is not None: + # Blend "within-level" progress with completed levels so the bar + # advances smoothly across resolutions and does not stall when a + # level converges early or hits max_iterations. + within_level = min(1.0, (iteration + 1) / max_iterations) + level_progress = (level_counter[0] + within_level) / n_resolution_levels + progress_ratio = min(1.0, max(level_progress, total_iterations[0] / estimated_total)) + target_step = registration_start_step + int(registration_steps * progress_ratio) + if target_step > pbar.n: + pbar.n = target_step + pbar.set_postfix_str(f"metric={metric:.6f} level={level_counter[0] + 1}/{n_resolution_levels}") + pbar.refresh() + + return callback + + +# ============================================================================= +# Transform utilities +# ============================================================================= + + +def sitk_transform_to_affine_matrix(transform: sitk.Transform) -> np.ndarray: + """ + Convert SimpleITK transform to 4x4 affine matrix. + + Parameters + ---------- + transform : sitk.Transform + SimpleITK Euler3DTransform or AffineTransform + + Returns + ------- + np.ndarray + 4x4 affine matrix in (Z, Y, X) coordinate ordering, matching the + OME-NGFF axis declaration used by the pipeline. + """ + if isinstance(transform, sitk.Euler3DTransform): + center = np.array(transform.GetCenter()) + params = transform.GetParameters() + rx, ry, rz = params[:3] + translation = np.array(params[3:6]) + + # Build rotation matrix from Euler angles + cx, cy, cz = np.cos([rx, ry, rz]) + sx, sy, sz = np.sin([rx, ry, rz]) + + r = np.array( + [ + [cz * cy, cz * sy * sx - sz * cx, cz * sy * cx + sz * sx], + [sz * cy, sz * sy * sx + cz * cx, sz * sy * cx - cz * sx], + [-sy, cy * sx, cy * cx], + ] + ) + + matrix = np.eye(4) + matrix[:3, :3] = r + matrix[:3, 3] = translation + center - r @ center + + elif isinstance(transform, sitk.AffineTransform): + r = np.array(transform.GetMatrix()).reshape(3, 3) + translation = np.array(transform.GetTranslation()) + center = np.array(transform.GetCenter()) + + matrix = np.eye(4) + matrix[:3, :3] = r + matrix[:3, 3] = translation + center - r @ center + else: + raise ValueError(f"Unsupported transform type: {type(transform)}") + + # Permute from SimpleITK (X, Y, Z) to our (Z, Y, X) ordering (OME-NGFF axis order). + permute = np.array([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]]) + return permute @ matrix @ permute.T + + +def store_transform_in_metadata(zarr_path: Path, transform: sitk.Transform) -> None: + """Store transform in OME-Zarr metadata as affine coordinate transformation.""" + affine_matrix = sitk_transform_to_affine_matrix(transform) + zattrs_path = Path(zarr_path) / ".zattrs" + + if not zattrs_path.exists(): + raise FileNotFoundError(f".zattrs not found: {zarr_path}") + + with Path(zattrs_path).open(encoding="utf-8") as f: + metadata = json.load(f) + + affine_transform = {"type": "affine", "affine": affine_matrix.flatten().tolist()} + + multiscales = metadata.get("multiscales", []) + if not multiscales: + raise ValueError("No multiscales entry found in metadata") + + for dataset in multiscales[0].get("datasets", []): + existing = dataset.get("coordinateTransformations", []) + dataset["coordinateTransformations"] = [affine_transform, *existing] + + with Path(zattrs_path).open("w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2) + + print(f"Stored affine transform in metadata: {zattrs_path}") + + +# ============================================================================= +# Resolution utilities +# ============================================================================= + + +def get_pyramid_resolutions_from_zarr(zarr_path: Path) -> list[float] | None: + """ + Extract pyramid resolution levels from OME-Zarr metadata. + + Parameters + ---------- + zarr_path : Path + Path to OME-Zarr file + + Returns + ------- + list of float or None + Target resolutions in microns, or None if not found + """ + for metadata_file in ["zarr.json", ".zattrs"]: + metadata_path = zarr_path / metadata_file + if not metadata_path.exists(): + continue + + try: + with Path(metadata_path).open(encoding="utf-8") as f: + metadata = json.load(f) + except (OSError, json.JSONDecodeError): + continue + + multiscales = metadata.get("multiscales", []) + if not multiscales: + continue + + resolutions = [] + for dataset in multiscales[0].get("datasets", []): + transforms = dataset.get("coordinateTransformations", []) + for tr in transforms: + if tr.get("type") == "scale" and "scale" in tr: + # Get finest spatial dimension, convert mm to µm + scale = tr["scale"][-3:] + res_um = min(float(s) for s in scale) * 1000 + resolutions.append(res_um) + break + + if resolutions: + return resolutions + + return None + + +# ============================================================================= +# Core processing functions +# ============================================================================= + + +def compute_centered_reference_and_transform( + moving_sitk: sitk.Image, transform: sitk.Transform, output_spacing: tuple | None = None +) -> tuple[sitk.Image, sitk.Transform]: + """ + Compute a reference image and modified transform that centers the output volume. + + This creates an output that is centered in the volume (brain in the middle), + preserving the original resolution. + + Parameters + ---------- + moving_sitk : sitk.Image + The input moving image + transform : sitk.Transform + Transform to apply (moving -> fixed/RAS space) + output_spacing : tuple, optional + Output voxel spacing. If None, uses moving image spacing. + + Returns + ------- + ref : sitk.Image + Reference image for resampling, with origin at 0 + composite_transform : sitk.Transform + Modified transform that maps moving image to centered output + """ + if output_spacing is None: + output_spacing = moving_sitk.GetSpacing() + + # Get corners of the moving image in physical coordinates + size = moving_sitk.GetSize() + corners = [ + (0, 0, 0), + (size[0] - 1, 0, 0), + (0, size[1] - 1, 0), + (0, 0, size[2] - 1), + (size[0] - 1, size[1] - 1, 0), + (size[0] - 1, 0, size[2] - 1), + (0, size[1] - 1, size[2] - 1), + (size[0] - 1, size[1] - 1, size[2] - 1), + ] + + # Map brain corners to FIXED/RAS space. + # The registration transform maps fixed→moving (ResampleImageFilter convention), + # so we use its inverse (moving→fixed) to find where the brain corners land + # in the fixed (RAS/Allen) coordinate system. + inv_transform = transform.GetInverse() + transformed_pts = [] + for idx in corners: + phys = moving_sitk.TransformContinuousIndexToPhysicalPoint(idx) + transformed_pts.append(inv_transform.TransformPoint(phys)) + + pts = np.array(transformed_pts) + pts_min = pts.min(axis=0) + pts_max = pts.max(axis=0) + + # Compute output size to cover the full transformed brain extent + spacing = np.array(output_spacing) + extent = pts_max - pts_min + new_size = np.ceil(extent / spacing).astype(int) + + # Reference image: origin at (0,0,0), spanning [0, new_size*spacing]. + # Output voxel p maps to fixed-space coordinate (p + pts_min). + ref = sitk.Image([int(s) for s in new_size], moving_sitk.GetPixelIDValue()) + ref.SetSpacing(tuple(spacing)) + ref.SetOrigin((0.0, 0.0, 0.0)) + ref.SetDirection((1, 0, 0, 0, 1, 0, 0, 0, 1)) # Identity direction (RAS) + + # Shift transform: output space → fixed space (translate by pts_min). + # This maps output origin (0,0,0) to the brain's fixed-space bounding box minimum. + shift_transform = sitk.TranslationTransform(3) + shift_transform.SetOffset(tuple(pts_min)) + + # Composite transform for resampling: + # output point → (shift) → fixed space → (T) → moving space + # SimpleITK CompositeTransform applies transforms in REVERSE order of + # addition (the most recently added transform is applied first, matching + # ITK's stack convention). To obtain ``transform(shift(p))`` we must add + # ``transform`` first and ``shift`` last. + composite = sitk.CompositeTransform(3) + composite.AddTransform(transform) # added first → applied last (fixed → moving) + composite.AddTransform(shift_transform) # added last → applied first (output → fixed) + + return ref, composite + + +def apply_transform_to_zarr( + input_path: Path, + output_path: Path, + transform: sitk.Transform, + chunks: tuple | None = None, + n_levels: int | None = None, + pyramid_resolutions: list | None = None, + make_isotropic: bool = True, + orientation_permutation: tuple | None = None, + orientation_flips: tuple | None = None, + pbar: tqdm | None = None, +) -> None: + """ + Apply transform to zarr file by resampling into RAS-aligned space. + + The output is centered on the transformed brain volume, preserving the + original resolution. This corrects any rotation/off-axis alignment without + placing the brain in the Allen atlas coordinate system. + + Parameters + ---------- + input_path: Path + Path to input OME-Zarr + output_path: Path + Path to output OME-Zarr + transform : sitk.Transform + Transform to apply + chunks : tuple, optional + Chunk size for output + n_levels : int, optional + Number of pyramid levels (if None, use source pyramid or Allen resolutions) + orientation_permutation : tuple, optional + Axis permutation for orientation correction + orientation_flips : tuple, optional + Axis flips for orientation correction + pbar : tqdm, optional + Progress bar + pyramid_resolutions : list, optional + Explicit list of resolutions for the output pyramid + make_isotropic : bool + If True, resample output to isotropic resolution + """ + + def update_pbar() -> None: + if pbar: + pbar.update(1) + + # Load volume at full resolution (level 0) and capture its actual spacing. + # base_resolution comes from the downsampled registration level, so we must + # read the level-0 spacing from the file to get the correct physical extent. + vol_zarr, level0_resolution = read_omezarr(input_path, level=0) + if chunks is None: + chunks = getattr(vol_zarr, "chunks", None) + if chunks is None: + chunks = (128,) * len(vol_zarr.shape) + + vol = np.asarray(vol_zarr[:]) + original_dtype = vol.dtype + update_pbar() + + # Apply orientation correction + resolution = level0_resolution + if orientation_permutation is not None: + vol = apply_orientation_transform(vol, orientation_permutation, orientation_flips) + resolution = reorder_resolution(resolution, orientation_permutation) + + # Compute a tissue-representative background value on the numpy array + # BEFORE allocating the (potentially large) SimpleITK float32 copy. Using + # this as the default pixel value avoids black borders that would skew + # downstream normalization and visualization. + nonzero_mask = vol > 0 + bg_value = float(np.percentile(vol[nonzero_mask], 1)) if nonzero_mask.any() else 0.0 + del nonzero_mask + + # Convert to SimpleITK + vol_sitk = allen.numpy_to_sitk_image(vol, resolution, cast_dtype=np.float32) + del vol # free original volume before resampling + update_pbar() + + # Compute reference image and modified transform that centers the output + reference, centered_transform = compute_centered_reference_and_transform(vol_sitk, transform) + + resampler = sitk.ResampleImageFilter() + resampler.SetReferenceImage(reference) + resampler.SetInterpolator(sitk.sitkLinear) + resampler.SetDefaultPixelValue(bg_value) + resampler.SetTransform(centered_transform) + + transformed_sitk = resampler.Execute(vol_sitk) + del vol_sitk # free input before allocating output array + transformed = sitk.GetArrayFromImage(transformed_sitk) + del transformed_sitk # free SimpleITK image after extracting numpy array + update_pbar() + + # GetArrayFromImage already yields numpy (Z, Y, X) matching our convention. + update_pbar() + + # Convert back to original dtype + if np.issubdtype(original_dtype, np.integer): + info = np.iinfo(original_dtype) + transformed = np.clip(np.rint(transformed), info.min, info.max).astype(original_dtype) + else: + transformed = transformed.astype(original_dtype) + + # Write output + writer = AnalysisOmeZarrWriter( + output_path, + shape=transformed.shape, + chunk_shape=chunks, + dtype=transformed.dtype, + overwrite=True, + ) + writer[:] = transformed + + if n_levels is not None: + writer.finalize(list(resolution), n_levels=n_levels) + else: + if pyramid_resolutions is not None: + target_resolutions = pyramid_resolutions + else: + # Fallback: inherit levels from input zarr metadata, or use Allen resolutions + target_resolutions = get_pyramid_resolutions_from_zarr(Path(input_path)) + if target_resolutions is None: + target_resolutions = list(allen.AVAILABLE_RESOLUTIONS) + writer.finalize(list(resolution), target_resolutions_um=target_resolutions, make_isotropic=make_isotropic) + + update_pbar() + + +# ============================================================================= +# Preview generation +# ============================================================================= + + +def create_input_preview(input_path: Path, output_path: Path, level: int = 0) -> None: + """Create preview of input volume to help determine orientation.""" + vol_zarr, resolution = read_omezarr(input_path, level=level) + vol = np.asarray(vol_zarr[:]) + + z_mid = vol.shape[0] // 2 + x_mid = vol.shape[1] // 2 + y_mid = vol.shape[2] // 2 + + vmin, vmax = np.percentile(vol, [1, 99]) + + fig, axes = plt.subplots(2, 2, figsize=(14, 14)) + fig.suptitle(f"Input Volume Preview\nShape: {vol.shape} (Z, Y, X), Resolution: {resolution} mm", fontsize=14, y=0.98) + + # Axial slice (dim0 midpoint) + axes[0, 0].imshow(vol[z_mid, :, :].T, cmap="gray", origin="lower", vmin=vmin, vmax=vmax) + axes[0, 0].set_title("Slice at dim0 midpoint\nShows: dim1 × dim2") + axes[0, 0].set_xlabel("dim1 →") + axes[0, 0].set_ylabel("dim2 →") + + # Sagittal slice (dim1 midpoint) + axes[0, 1].imshow(vol[::-1, x_mid, :], cmap="gray", origin="lower", vmin=vmin, vmax=vmax) + axes[0, 1].set_title("Slice at dim1 midpoint\nShows: dim2 × dim0") + axes[0, 1].set_xlabel("dim2 →") + axes[0, 1].set_ylabel("dim0 →") + + # Coronal slice (dim2 midpoint) + axes[1, 0].imshow(vol[::-1, :, y_mid], cmap="gray", origin="lower", vmin=vmin, vmax=vmax) + axes[1, 0].set_title("Slice at dim2 midpoint\nShows: dim1 × dim0") + axes[1, 0].set_xlabel("dim1 →") + axes[1, 0].set_ylabel("dim0 →") + + # Help text + axes[1, 1].axis("off") + help_text = """ +ORIENTATION GUIDE (Allen Atlas = RAS+) + +Allen RAS+ convention: + • R (Right): +X direction + • A (Anterior): +Y direction (nose) + • S (Superior): +Z direction (top) + +For each dimension, identify the anatomical direction: + R/L for right/left + A/P for anterior/posterior + S/I for superior/inferior + +Example: + dim0→Superior, dim1→Anterior, dim2→Right + → orientation code = 'SAR' +""" + axes[1, 1].text( + 0.02, + 0.98, + help_text, + transform=axes[1, 1].transAxes, + fontsize=10, + verticalalignment="top", + fontfamily="monospace", + bbox={"boxstyle": "round", "facecolor": "wheat", "alpha": 0.5}, + ) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + print(f"Input preview saved to: {output_path}") + + +def create_alignment_preview( + input_path: Path, + output_path: Path | None, + transform: sitk.Transform, + resolution: tuple, + preview_path: str, + allen_resolution: int = DEFAULT_ALLEN_RESOLUTION, + level: int = 0, + orientation_permutation: tuple | None = None, + orientation_flips: tuple | None = None, + pbar: tqdm | None = None, +) -> None: + """Create preview comparing original, aligned, and Allen template. + + Shows center slices from each volume in their own coordinate frames. + The Allen template is shown for reference but may not spatially align + with the brain volume since we're not placing it in Allen coordinate space. + """ + + def update_pbar() -> None: + if pbar: + pbar.update(1) + + # Load original + vol_original, orig_res = read_omezarr(input_path, level=level) + vol_original = np.asarray(vol_original[:]) + + if orientation_permutation is not None: + vol_original = apply_orientation_transform(vol_original, orientation_permutation, orientation_flips) + orig_res = reorder_resolution(tuple(orig_res), orientation_permutation) + + # apply_orientation_transform yields linumpy convention (S, R, A): dim0=S, + # dim1=R, dim2=A. The aligned and Allen-template volumes below are in + # standard RAS — numpy (S, A, R): dim0=S, dim1=A, dim2=R. Permute the + # original to (S, A, R) here so all three columns share one convention and + # a single set of "Axial / Coronal / Sagittal" labels applies uniformly. + vol_original = np.transpose(vol_original, (0, 2, 1)) + orig_res = (orig_res[0], orig_res[2], orig_res[1]) + update_pbar() + + # Load aligned volume from output file, or compute it + if output_path and Path(output_path).exists(): + vol_aligned, _aligned_res = read_omezarr(output_path, level=level) + vol_aligned = np.asarray(vol_aligned[:]) + else: + # Compute aligned volume using the transform + vol_sitk = allen.numpy_to_sitk_image(vol_original, resolution) + # Create reference and centered transform + reference, centered_transform = compute_centered_reference_and_transform(vol_sitk, transform) + + vol_arr = sitk.GetArrayViewFromImage(vol_sitk) + nonzero = vol_arr[vol_arr > 0] + bg_value = float(np.percentile(nonzero, 1)) if len(nonzero) > 0 else 0.0 + resampler = sitk.ResampleImageFilter() + resampler.SetReferenceImage(reference) + resampler.SetInterpolator(sitk.sitkLinear) + resampler.SetDefaultPixelValue(bg_value) + resampler.SetTransform(centered_transform) + transformed_sitk = resampler.Execute(vol_sitk) + vol_aligned = sitk.GetArrayFromImage(transformed_sitk) + update_pbar() + + # Load Allen template at native resolution for reference + # We'll just show it as a reference, not spatially aligned + allen_sitk = allen.download_template_ras_aligned(allen_resolution, cache=True) + allen_template = sitk.GetArrayFromImage(allen_sitk) + # GetArrayFromImage already yields numpy (Z, Y, X) matching our convention. + update_pbar() + + # Helper functions + def get_center_slices(vol: Any) -> Any: + """Get center slices in each plane.""" + z, y, x = vol.shape[0] // 2, vol.shape[1] // 2, vol.shape[2] // 2 + return vol[z, :, :], vol[:, y, :], vol[:, :, x] + + def get_display_range(vol: Any) -> Any: + """Get display range from non-zero values.""" + nonzero = vol[vol > 0] + if len(nonzero) > 0: + return np.percentile(nonzero, [1, 99]) + return 0, 1 + + def find_content_center_slices(vol: Any) -> Any: + """Find the slice with maximum content independently for each axis. + + Using a shared 3D centroid for all three views fails when the brain is + asymmetric (e.g. cut at 45°): the centroid lands near the cut boundary, + so one or more of the orthogonal slice views passes through the cut plane + and shows a black stripe. Instead, pick each index independently as the + slice with the highest total signal along that axis. + """ + if vol.max() == 0: + return get_center_slices(vol) + z = int(np.argmax(vol.sum(axis=(1, 2)))) + x = int(np.argmax(vol.sum(axis=(0, 2)))) + y = int(np.argmax(vol.sum(axis=(0, 1)))) + return vol[z, :, :], vol[:, x, :], vol[:, :, y] + + # Get slices - use content-centered slices for aligned volume + orig_slices = get_center_slices(vol_original) + aligned_slices = find_content_center_slices(vol_aligned) + allen_slices = get_center_slices(allen_template) + + orig_vmin, orig_vmax = get_display_range(vol_original) + align_vmin, align_vmax = get_display_range(vol_aligned) + allen_vmin, allen_vmax = get_display_range(allen_template) + + # Create figure + fig, axes = plt.subplots(3, 3, figsize=(18, 18)) + fig.suptitle("Alignment Preview: Original vs Aligned vs Allen Template (Reference)", fontsize=16) + + # All three volumes are in standard RAS, numpy (S, A, R): + # dim0=S (Superior), dim1=A (Anterior), dim2=R (Right). + # Slicing → anatomical plane: + # vol[z, :, :] fixes S → AXIAL (rows=A, cols=R) + # vol[:, y, :] fixes A → CORONAL (rows=S, cols=R) + # vol[:, :, x] fixes R → SAGITTAL (rows=S, cols=A) + plane_names = ["Axial (AR)", "Coronal (SR)", "Sagittal (SA)"] + + _debug_log( + "create_alignment_preview: shapes & labels", + original_shape=list(vol_original.shape), + aligned_shape=list(vol_aligned.shape), + allen_shape=list(allen_template.shape), + plane_names=plane_names, + ) + + for row, plane_name in enumerate(plane_names): + # Original - use .T for row 0 (XY plane) to match display convention + data = orig_slices[row].T if row == 0 else orig_slices[row][::-1, :] + axes[row, 0].imshow(data, cmap="gray", origin="lower", vmin=orig_vmin, vmax=orig_vmax) + axes[row, 0].set_title(f"Original - {plane_name}") + axes[row, 0].axis("off") + + # Aligned + data = aligned_slices[row].T if row == 0 else aligned_slices[row][::-1, :] + axes[row, 1].imshow(data, cmap="gray", origin="lower", vmin=align_vmin, vmax=align_vmax) + axes[row, 1].set_title(f"Aligned - {plane_name}") + axes[row, 1].axis("off") + + data = allen_slices[row].T if row == 0 else allen_slices[row][::-1, :] + axes[row, 2].imshow(data, cmap="gray", origin="lower", vmin=allen_vmin, vmax=allen_vmax) + axes[row, 2].set_title(f"Allen {allen_resolution}µm - {plane_name}") + axes[row, 2].axis("off") + + # Add info text + info_text = ( + f"Original shape: {vol_original.shape}\nAligned shape: {vol_aligned.shape}\nAllen shape: {allen_template.shape}" + ) + bbox_props = {"boxstyle": "round", "facecolor": "wheat", "alpha": 0.5} + fig.text(0.02, 0.02, info_text, fontsize=10, family="monospace", bbox=bbox_props) + + plt.tight_layout() + Path(preview_path).parent.mkdir(parents=True, exist_ok=True) + fig.savefig(preview_path, dpi=150, bbox_inches="tight") + plt.close(fig) + update_pbar() + + print(f"Alignment preview saved to: {preview_path}") + + +# ============================================================================= +# Main entry point +# ============================================================================= + + +def create_orientation_preview( + input_path: Path, + preview_path: str, + level: int = 0, + orientation_permutation: tuple | None = None, + orientation_flips: tuple | None = None, + initial_rotation_deg: tuple = (0.0, 0.0, 0.0), +) -> None: + """ + Save a 3-panel orthogonal preview of the volume after orientation correction and initial rotation are applied. + + Axes are labelled in RAS space (Z=S, X=R, Y=A) so the result can be + inspected directly against the Allen atlas orientation. + + Parameters + ---------- + input_path: Path + Path to input OME-Zarr. + preview_path : str + Output PNG path. + level : int + Pyramid level to load (lower = higher resolution but slower). + orientation_permutation : tuple, optional + Axis permutation from ``parse_orientation_code``. + orientation_flips : tuple, optional + Axis flips from ``parse_orientation_code``. + initial_rotation_deg : tuple of float + (Rx, Ry, Rz) initial rotation angles in degrees applied after orientation. + """ + vol_zarr, resolution = read_omezarr(input_path, level=level) + vol = np.asarray(vol_zarr[:]).astype(np.float32) + + # Apply orientation permutation + flips + if orientation_permutation is not None: + vol = apply_orientation_transform(vol, orientation_permutation, orientation_flips) + resolution = list(reorder_resolution(tuple(resolution), orientation_permutation)) + + # Apply initial rotation via SimpleITK (same path as the registration uses) + if any(r != 0.0 for r in initial_rotation_deg): + vol_sitk = allen.numpy_to_sitk_image(vol, resolution, cast_dtype=np.float32) + center = vol_sitk.TransformContinuousIndexToPhysicalPoint([s / 2.0 for s in vol_sitk.GetSize()]) + rx, ry, rz = [np.deg2rad(a) for a in initial_rotation_deg] + t = sitk.Euler3DTransform() + t.SetCenter(center) + t.SetRotation(rx, ry, rz) + resampler = sitk.ResampleImageFilter() + resampler.SetReferenceImage(vol_sitk) + resampler.SetTransform(t.GetInverse()) + resampler.SetInterpolator(sitk.sitkLinear) + vol = sitk.GetArrayFromImage(resampler.Execute(vol_sitk)) + + # Display range from non-zero voxels + nonzero = vol[vol > 0] + vmin, vmax = np.percentile(nonzero if len(nonzero) else vol.ravel(), [1, 99]) + + # Build title + applied = [] + if orientation_permutation is not None: + applied.append("orientation") + if any(r != 0.0 for r in initial_rotation_deg): + applied.append(f"rotation {list(initial_rotation_deg)}°") + subtitle = f"({', '.join(applied)} applied)" if applied else "(no corrections applied)" + + z_mid = vol.shape[0] // 2 + y_mid = vol.shape[1] // 2 + x_mid = vol.shape[2] // 2 + + fig, axes = plt.subplots(1, 3, figsize=(18, 6)) + fig.suptitle( + f"Orientation Preview — {subtitle}\n" + f"Shape: {vol.shape} | After corrections: dim0=S (Superior), dim1=R (Right), dim2=A (Anterior)", + fontsize=11, + ) + + # After apply_orientation_transform the volume is in linumpy convention + # (S, R, A): dim0=S (Superior), dim1=R (Right), dim2=A (Anterior). + # Slicing → anatomical plane: + # vol[z, :, :] fixes S → AXIAL (rows=R, cols=A) + # vol[:, y, :] fixes R → SAGITTAL (rows=S, cols=A) + # vol[:, :, x] fixes A → CORONAL (rows=S, cols=R) + # `.T` on the axial view + row reversal on the others orients the figure + # so Superior is up and Right/Anterior point in the natural directions. + axes[0].imshow(vol[z_mid, :, :].T, cmap="gray", origin="lower", vmin=vmin, vmax=vmax) + axes[0].set_title(f"Axial (dim0=S={z_mid})") + axes[0].set_xlabel("dim1=R (← L R →)") + axes[0].set_ylabel("dim2=A (← P A →)") + + axes[1].imshow(vol[::-1, y_mid, :], cmap="gray", origin="lower", vmin=vmin, vmax=vmax) + axes[1].set_title(f"Sagittal (dim1=R={y_mid})") + axes[1].set_xlabel("dim2=A (← P A →)") + axes[1].set_ylabel("dim0=S (← I S →)") + + axes[2].imshow(vol[::-1, :, x_mid], cmap="gray", origin="lower", vmin=vmin, vmax=vmax) + axes[2].set_title(f"Coronal (dim2=A={x_mid})") + axes[2].set_xlabel("dim1=R (← L R →)") + axes[2].set_ylabel("dim0=S (← I S →)") + + _debug_log( + "create_orientation_preview: slicing decisions", + vol_shape=list(vol.shape), + panels=[ + {"axes": 0, "slice": f"vol[{z_mid}, :, :].T", "fixed_axis": "dim0=S", "plane": "Axial"}, + {"axes": 1, "slice": f"vol[::-1, {y_mid}, :]", "fixed_axis": "dim1=R", "plane": "Sagittal"}, + {"axes": 2, "slice": f"vol[::-1, :, {x_mid}]", "fixed_axis": "dim2=A", "plane": "Coronal"}, + ], + ) + + plt.tight_layout() + Path(preview_path).parent.mkdir(parents=True, exist_ok=True) + fig.savefig(preview_path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Orientation preview saved to: {preview_path}") + + +# ============================================================================= +# Main entry point +# ============================================================================= + + +def main() -> None: + """Run the script. parse arguments and run alignment workflow.""" + parser = _build_arg_parser() + args = parser.parse_args() + + input_path = Path(args.input_zarr) + output_path = Path(args.output_zarr) + + if not input_path.exists(): + raise FileNotFoundError(f"Input zarr not found: {input_path}") + + # Preview-only mode + if args.preview_only: + preview_path = Path(args.preview) if args.preview else Path("input_preview.png") + create_input_preview(input_path, preview_path, level=args.level) + return + + # Parse orientation + orientation_permutation = None + orientation_flips = None + if args.input_orientation: + try: + orientation_permutation, orientation_flips = parse_orientation_code(args.input_orientation) + print(f"Input orientation '{args.input_orientation}':") + print(f" Axis permutation: {orientation_permutation}") + print(f" Axis flips: {orientation_flips}") + except ValueError as e: + parser.error(str(e)) + + # Orientation + initial-rotation preview (can exit before registration) + if args.orientation_preview or args.orientation_preview_only: + preview_out = args.orientation_preview or "orientation_preview.png" + create_orientation_preview( + input_path, + preview_out, + level=args.level, + orientation_permutation=orientation_permutation, + orientation_flips=orientation_flips, + initial_rotation_deg=tuple(args.initial_rotation), + ) + if args.orientation_preview_only: + return + + # Load input volume + vol_zarr, zarr_resolution = read_omezarr(Path(input_path), level=args.level) + resolution = tuple(zarr_resolution) + + # Progress bar - allocate steps for each phase + registration_steps = 3 # Steps allocated for registration progress + base_steps = 2 if args.store_transform_only else 5 # Load + save steps + total_steps = base_steps + registration_steps + if args.preview: + total_steps += 4 + pbar = tqdm(total=total_steps, desc="Aligning to RAS") + + vol = np.asarray(vol_zarr[:]) + pbar.update(1) + + if args.verbose: + print(f"Volume shape: {vol.shape}, Resolution: {resolution} mm") + + # Apply orientation correction for registration + if orientation_permutation is not None: + vol = apply_orientation_transform(vol, orientation_permutation, orientation_flips) + resolution = reorder_resolution(resolution, orientation_permutation) + + # Create progress callback for registration + registration_start_step = pbar.n + progress_callback = create_registration_progress_callback( + max_iterations=args.max_iterations, + n_resolution_levels=3, + pbar=pbar, + registration_start_step=registration_start_step, + registration_steps=registration_steps, + ) + + # Register to Allen atlas + pbar.set_postfix_str("registering...") + transform, stop_condition, error = allen.register_3d_rigid_to_allen( + moving_image=vol, + moving_spacing=resolution, + allen_resolution=args.allen_resolution, + metric=args.metric, + max_iterations=args.max_iterations, + verbose=args.verbose, + progress_callback=progress_callback, + initial_rotation_deg=tuple(args.initial_rotation), + ) + # Ensure progress bar reaches end of registration steps + pbar.n = registration_start_step + registration_steps + pbar.refresh() + + print(f"Registration complete: {stop_condition}") + print(f"Final metric value: {error:.6f}") + del vol # free registration-level volume before loading full-resolution data + + # Apply or store transform + if args.store_transform_only: + store_transform_in_metadata(input_path, transform) + pbar.update(1) + else: + apply_transform_to_zarr( + input_path, + output_path, + transform, + chunks=tuple(args.chunks) if args.chunks else None, + n_levels=args.n_levels, + pyramid_resolutions=args.pyramid_resolutions, + make_isotropic=args.make_isotropic, + orientation_permutation=orientation_permutation, + orientation_flips=orientation_flips, + pbar=pbar, + ) + print(f"Aligned volume saved to: {output_path}") + + # Save transform file + # Strip the compound .ome.zarr extension (Path.stem only removes the last suffix) + stem = output_path.with_suffix("").with_suffix("").name + transform_path = output_path.parent / f"{stem}_transform.tfm" + sitk.WriteTransform(transform, str(transform_path)) + print(f"Transform saved to: {transform_path}") + pbar.update(1) + + # Generate preview + if args.preview: + pbar.set_postfix_str("generating preview...") + create_alignment_preview( + input_path, + output_path if not args.store_transform_only else None, + transform, + resolution, + args.preview, + allen_resolution=args.allen_resolution, + level=args.level, + orientation_permutation=orientation_permutation, + orientation_flips=orientation_flips, + pbar=pbar, + ) + + pbar.set_postfix_str("complete") + pbar.close() + + +if __name__ == "__main__": + main() diff --git a/scripts/linum_analyze_shifts.py b/scripts/linum_analyze_shifts.py deleted file mode 100644 index 03b97e7d..00000000 --- a/scripts/linum_analyze_shifts.py +++ /dev/null @@ -1,298 +0,0 @@ -#!/usr/bin/env python3 -""" -Analyze XY shifts from a shifts file and generate a drift analysis report. - -Produces: -- Summary statistics of pairwise shifts -- Outlier detection using IQR method -- Cumulative drift analysis -- Visualization of drift patterns - -Useful for debugging alignment issues and understanding sample drift during acquisition. -""" - -import linumpy.config.threads # noqa: F401 - -import argparse -import logging -from pathlib import Path -from typing import Any - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -from linumpy.cli.args import add_overwrite_arg, assert_output_exists - -logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") -logger = logging.getLogger(__name__) - - -def _build_arg_parser() -> argparse.ArgumentParser: - p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter) - p.add_argument("in_shifts", help="Input shifts CSV file (shifts_xy.csv)") - p.add_argument("out_directory", help="Output directory for analysis results") - - p.add_argument( - "--resolution", type=float, default=10.0, help="Resolution in µm/pixel for converting mm to pixels [%(default)s]" - ) - p.add_argument("--iqr_multiplier", type=float, default=1.5, help="IQR multiplier for outlier detection [%(default)s]") - p.add_argument("--slice_config", default=None, help="Optional slice config file to mark excluded slices") - - add_overwrite_arg(p) - return p - - -def load_shifts(shifts_path: Path) -> Any: - """Load shifts CSV file. - - Rows are sorted by ``moving_id`` so that every ``cumsum`` downstream - reflects slice order rather than CSV row order. - """ - df = pd.read_csv(shifts_path) - required_cols = ["fixed_id", "moving_id", "x_shift_mm", "y_shift_mm"] - for col in required_cols: - if col not in df.columns: - raise ValueError(f"Missing required column: {col}") - return df.sort_values("moving_id").reset_index(drop=True) - - -def detect_outliers(df: Any, iqr_multiplier: float = 1.5) -> Any: - """Detect outliers using IQR method on shift magnitude.""" - shift_mag = np.sqrt(df["x_shift_mm"] ** 2 + df["y_shift_mm"] ** 2) - q1 = shift_mag.quantile(0.25) - q3 = shift_mag.quantile(0.75) - iqr = q3 - q1 - upper_bound = q3 + iqr_multiplier * iqr - outlier_mask = shift_mag > upper_bound - return outlier_mask, upper_bound, q1, q3, iqr - - -def filter_with_local_median(df: Any, outlier_mask: Any) -> Any: - """Replace outliers with local median of neighbors.""" - df_filtered = df.copy() - for idx in df[outlier_mask].index: - pos = df.index.get_loc(idx) - neighbors_x, neighbors_y = [], [] - for offset in [-2, -1, 1, 2]: - neighbor_pos = pos + offset - if 0 <= neighbor_pos < len(df): - neighbor_idx = df.index[neighbor_pos] - if not outlier_mask[neighbor_idx]: - neighbors_x.append(df.loc[neighbor_idx, "x_shift_mm"]) - neighbors_y.append(df.loc[neighbor_idx, "y_shift_mm"]) - if neighbors_x: - df_filtered.loc[idx, "x_shift_mm"] = np.median(neighbors_x) - df_filtered.loc[idx, "y_shift_mm"] = np.median(neighbors_y) - return df_filtered - - -def generate_report(df: Any, df_filtered: Any, outlier_mask: Any, stats: Any, resolution: Any, output_dir: Path) -> str: - """Generate text report.""" - px_per_mm = 1000 / resolution - - report_lines = [ - "=" * 60, - "SHIFTS ANALYSIS REPORT", - "=" * 60, - "", - "OVERVIEW", - "-" * 40, - f"Total shift pairs: {len(df)}", - f"Resolution: {resolution} µm/pixel", - "", - "PAIRWISE SHIFT STATISTICS", - "-" * 40, - f"X shift (mm): Mean={df['x_shift_mm'].mean():.4f}, Std={df['x_shift_mm'].std():.4f}", - f"Y shift (mm): Mean={df['y_shift_mm'].mean():.4f}, Std={df['y_shift_mm'].std():.4f}", - "", - "OUTLIER DETECTION (IQR Method)", - "-" * 40, - f"Q1={stats['q1']:.4f}, Q3={stats['q3']:.4f}, IQR={stats['iqr']:.4f}", - f"Upper bound: {stats['upper_bound']:.4f} mm", - f"Outliers detected: {outlier_mask.sum()}", - ] - - if outlier_mask.sum() > 0: - report_lines.append("") - report_lines.append("Outlier shifts:") - shift_mag = np.sqrt(df["x_shift_mm"] ** 2 + df["y_shift_mm"] ** 2) - for idx in df[outlier_mask].index: - row = df.loc[idx] - mag = shift_mag[idx] - report_lines.append( - f" {int(row['fixed_id'])}->{int(row['moving_id'])}: " - f"({row['x_shift_mm']:.3f}, {row['y_shift_mm']:.3f}) mm, mag={mag:.3f} mm" - ) - - # Cumulative drift - cumsum_x_orig = df["x_shift_mm"].cumsum() - cumsum_y_orig = df["y_shift_mm"].cumsum() - cumsum_x_filt = df_filtered["x_shift_mm"].cumsum() - cumsum_y_filt = df_filtered["y_shift_mm"].cumsum() - - report_lines.extend( - [ - "", - "CUMULATIVE DRIFT", - "-" * 40, - f"Before filtering: X={cumsum_x_orig.iloc[-1]:.3f} mm, Y={cumsum_y_orig.iloc[-1]:.3f} mm", - f"After filtering: X={cumsum_x_filt.iloc[-1]:.3f} mm, Y={cumsum_y_filt.iloc[-1]:.3f} mm", - "", - f"In pixels (at {resolution} µm/pixel):", - f" Before: X={cumsum_x_orig.iloc[-1] * px_per_mm:.0f} px, Y={cumsum_y_orig.iloc[-1] * px_per_mm:.0f} px", - f" After: X={cumsum_x_filt.iloc[-1] * px_per_mm:.0f} px, Y={cumsum_y_filt.iloc[-1] * px_per_mm:.0f} px", - ] - ) - - # Centered drift - mid_idx = len(cumsum_x_filt) // 2 - centered_x = cumsum_x_filt - cumsum_x_filt.iloc[mid_idx] - centered_y = cumsum_y_filt - cumsum_y_filt.iloc[mid_idx] - - report_lines.extend( - [ - "", - f"CENTERED DRIFT (around slice {mid_idx})", - "-" * 40, - f"X range: {centered_x.min() * px_per_mm:.0f} to {centered_x.max() * px_per_mm:.0f} px", - f"Y range: {centered_y.min() * px_per_mm:.0f} to {centered_y.max() * px_per_mm:.0f} px", - "", - "=" * 60, - ] - ) - - report_text = "\n".join(report_lines) - - # Save report - report_path = Path(output_dir) / "shifts_analysis.txt" - with report_path.open("w") as f: - f.write(report_text) - - return report_text - - -def generate_plots(df: Any, df_filtered: Any, _outlier_mask: Any, stats: Any, resolution: Any, output_dir: Path) -> Path: - """Generate visualization plots.""" - px_per_mm = 1000 / resolution - upper_bound = stats["upper_bound"] - - # Calculate cumulative drift - cumsum_x_orig = df["x_shift_mm"].cumsum() - cumsum_y_orig = df["y_shift_mm"].cumsum() - cumsum_x_filt = df_filtered["x_shift_mm"].cumsum() - cumsum_y_filt = df_filtered["y_shift_mm"].cumsum() - - mid_idx = len(cumsum_x_filt) // 2 - centered_x = cumsum_x_filt - cumsum_x_filt.iloc[mid_idx] - centered_y = cumsum_y_filt - cumsum_y_filt.iloc[mid_idx] - - # Create figure - fig, axes = plt.subplots(2, 2, figsize=(14, 10)) - - # Plot 1: Pairwise shifts - ax = axes[0, 0] - ax.plot(df["moving_id"], df["x_shift_mm"], "b.-", label="X shift (original)", alpha=0.7) - ax.plot(df["moving_id"], df["y_shift_mm"], "r.-", label="Y shift (original)", alpha=0.7) - ax.plot(df["moving_id"], df_filtered["x_shift_mm"], "b-", label="X shift (filtered)", linewidth=2) - ax.plot(df["moving_id"], df_filtered["y_shift_mm"], "r-", label="Y shift (filtered)", linewidth=2) - ax.axhline(y=0, color="k", linestyle="--", alpha=0.3) - ax.axhline(y=upper_bound, color="g", linestyle=":", label=f"IQR threshold ({upper_bound:.2f}mm)") - ax.axhline(y=-upper_bound, color="g", linestyle=":") - ax.set_xlabel("Slice ID") - ax.set_ylabel("Shift (mm)") - ax.set_title("Pairwise Shifts") - ax.legend(fontsize=8) - ax.grid(True, alpha=0.3) - - # Plot 2: Cumulative drift - ax = axes[0, 1] - ax.plot(df["moving_id"], cumsum_x_orig, "b--", label="X original", alpha=0.5) - ax.plot(df["moving_id"], cumsum_y_orig, "r--", label="Y original", alpha=0.5) - ax.plot(df["moving_id"], cumsum_x_filt, "b-", label="X filtered", linewidth=2) - ax.plot(df["moving_id"], cumsum_y_filt, "r-", label="Y filtered", linewidth=2) - ax.axhline(y=0, color="k", linestyle="--", alpha=0.3) - ax.set_xlabel("Slice ID") - ax.set_ylabel("Cumulative Drift (mm)") - ax.set_title("Cumulative Drift") - ax.legend() - ax.grid(True, alpha=0.3) - - # Plot 3: Centered cumulative drift in pixels - ax = axes[1, 0] - ax.plot(df["moving_id"], centered_x * px_per_mm, "b-", label="X (centered)", linewidth=2) - ax.plot(df["moving_id"], centered_y * px_per_mm, "r-", label="Y (centered)", linewidth=2) - ax.axhline(y=0, color="k", linestyle="--", alpha=0.3) - ax.set_xlabel("Slice ID") - ax.set_ylabel(f"Centered Drift (pixels at {resolution}µm)") - ax.set_title("Centered Cumulative Drift") - ax.legend() - ax.grid(True, alpha=0.3) - - # Plot 4: Drift trajectory - ax = axes[1, 1] - ax.plot(cumsum_x_filt * px_per_mm, cumsum_y_filt * px_per_mm, "g-", linewidth=2) - ax.plot(cumsum_x_filt.iloc[0] * px_per_mm, cumsum_y_filt.iloc[0] * px_per_mm, "go", markersize=10, label="Start") - ax.plot(cumsum_x_filt.iloc[-1] * px_per_mm, cumsum_y_filt.iloc[-1] * px_per_mm, "ro", markersize=10, label="End") - ax.plot( - cumsum_x_filt.iloc[mid_idx] * px_per_mm, cumsum_y_filt.iloc[mid_idx] * px_per_mm, "ko", markersize=10, label="Middle" - ) - ax.set_xlabel("X position (pixels)") - ax.set_ylabel("Y position (pixels)") - ax.set_title("Drift Trajectory (filtered)") - ax.legend() - ax.grid(True, alpha=0.3) - ax.axis("equal") - - plt.tight_layout() - - # Save plot - plot_path = Path(output_dir) / "drift_analysis.png" - fig.savefig(plot_path, dpi=150, bbox_inches="tight") - plt.close(fig) - - logger.info("Saved plot: %s", plot_path) - return plot_path - - -def main() -> None: - """Run function.""" - parser = _build_arg_parser() - args = parser.parse_args() - - # Create output directory - assert_output_exists(args.out_directory, parser, args) - Path(args.out_directory).mkdir(parents=True) - - # Load shifts - logger.info("Loading shifts from %s", args.in_shifts) - df = load_shifts(args.in_shifts) - logger.info("Loaded %s shift pairs", len(df)) - - # Detect outliers - outlier_mask, upper_bound, q1, q3, iqr = detect_outliers(df, args.iqr_multiplier) - logger.info("Detected %s outliers (IQR bound: %.3f mm)", outlier_mask.sum(), upper_bound) - - # Filter outliers - df_filtered = filter_with_local_median(df, outlier_mask) - - # Statistics - stats = {"q1": q1, "q3": q3, "iqr": iqr, "upper_bound": upper_bound} - - # Generate report - report = generate_report(df, df_filtered, outlier_mask, stats, args.resolution, args.out_directory) - print(report) - - # Generate plots - generate_plots(df, df_filtered, outlier_mask, stats, args.resolution, args.out_directory) - - # Save filtered shifts (useful for debugging) - filtered_path = Path(args.out_directory) / "shifts_filtered.csv" - df_filtered.to_csv(filtered_path, index=False) - logger.info("Saved filtered shifts: %s", filtered_path) - - logger.info("Analysis complete. Results saved to %s", args.out_directory) - - -if __name__ == "__main__": - main() diff --git a/scripts/linum_assess_slice_quality.py b/scripts/linum_assess_slice_quality.py deleted file mode 100644 index c1867341..00000000 --- a/scripts/linum_assess_slice_quality.py +++ /dev/null @@ -1,400 +0,0 @@ -#!/usr/bin/env python3 -""" -Assess slice quality for 3D mosaic grids and optionally update slice configuration. - -This script analyzes mosaic grid slices to detect quality issues that might affect -reconstruction. It uses multiple metrics to identify problematic slices: - -1. **SSIM (Structural Similarity)**: Compares each slice to its neighbors -2. **Edge Preservation**: Detects if edge structures are preserved compared to neighbors -3. **Variance Consistency**: Checks for unusual signal variance (data loss/corruption) -4. **First Slice Detection**: Automatically identifies calibration slices (thicker/different) - -GPU acceleration is used when available (--use_gpu, default on) for SSIM and -edge-detection computations. Falls back to CPU automatically if no GPU is detected. - -The output can be: -- A new slice_config.csv with quality scores and recommendations -- An update to an existing slice_config.csv with quality assessments -- A quality report for review - -Example usage: - # Assess quality and create/update slice config - linum_assess_slice_quality.py /path/to/mosaics slice_config.csv - - # Assess and exclude low-quality slices automatically - linum_assess_slice_quality.py /path/to/mosaics slice_config.csv --min_quality 0.3 - - # Exclude first N calibration slices - linum_assess_slice_quality.py /path/to/mosaics slice_config.csv --exclude_first 1 - - # Update existing config with quality info - linum_assess_slice_quality.py /path/to/mosaics slice_config.csv --update_existing - - # Force CPU fallback - linum_assess_slice_quality.py /path/to/mosaics slice_config.csv --no-use_gpu -""" - -from __future__ import annotations - -# Configure thread limits before numpy/scipy imports -import linumpy.config.threads # noqa: F401 - -import argparse -import re -from pathlib import Path -from typing import TYPE_CHECKING, Any - -from tqdm.auto import tqdm - -if TYPE_CHECKING: - import numpy as np - -from linumpy.cli.args import add_overwrite_arg, assert_output_exists -from linumpy.gpu import GPU_AVAILABLE -from linumpy.gpu.image_quality import ( - assess_slice_quality_gpu, - clear_gpu_memory, -) -from linumpy.io import slice_config as slice_config_io -from linumpy.io.zarr import read_omezarr -from linumpy.metrics.image_quality import ( - assess_slice_quality, - detect_calibration_slice, -) - - -def _build_arg_parser() -> argparse.ArgumentParser: - p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter) - p.add_argument("input", help="Input directory containing mosaic grids (*.ome.zarr)") - p.add_argument("output_file", help="Output slice configuration CSV file") - - gpu_group = p.add_argument_group("GPU Options") - gpu_group.add_argument( - "--use_gpu", - default=True, - action=argparse.BooleanOptionalAction, - help="Use GPU acceleration if available. [%(default)s]", - ) - gpu_group.add_argument("--gpu_id", type=int, default=0, help="GPU device ID to use. [%(default)s]") - - quality_group = p.add_argument_group("Quality Assessment") - quality_group.add_argument( - "--min_quality", - type=float, - default=0.0, - help="Minimum quality score to include slice (0-1). Default: 0.0 (include all, just report)", - ) - quality_group.add_argument( - "--sample_depth", - type=int, - default=5, - help="Number of z-planes to sample per slice for faster assessment. Default: 5 (0=all)", - ) - quality_group.add_argument( - "--pyramid_level", - type=int, - default=0, - help="Pyramid level to use for assessment (0=full res). Higher levels are faster but less accurate. Default: 0", - ) - quality_group.add_argument( - "--roi_size", - type=int, - default=0, - help="Side length of center crop in XY (pixels) used for " - "all quality metrics. 0 = full plane (slow for large " - "single-resolution mosaics). Recommended: 1024.", - ) - quality_group.add_argument( - "--processes", - type=int, - default=1, - help="Number of parallel workers for slice assessment (CPU mode only).\n" - "Each worker reads its own zarr planes concurrently.\n" - "Default: 1 (sequential). Set to params.processes.", - ) - - calib_group = p.add_argument_group("Calibration Slice Handling") - calib_group.add_argument( - "--exclude_first", - type=int, - default=1, - help="Exclude first N slices as calibration slices. Default: 1 (first slice is usually calibration)", - ) - calib_group.add_argument( - "--detect_calibration", - action="store_true", - help="Automatically detect calibration slices by their different thickness/structure", - ) - calib_group.add_argument( - "--calibration_thickness_ratio", - type=float, - default=1.5, - help="Slices with thickness ratio > this are flagged as calibration. Default: 1.5", - ) - - update_group = p.add_argument_group("Update Existing Config") - update_group.add_argument( - "--update_existing", action="store_true", help="Update an existing slice_config.csv with quality info" - ) - update_group.add_argument("--existing_config", type=str, default=None, help="Path to existing slice config to update") - - output_group = p.add_argument_group("Output Options") - output_group.add_argument("--report_only", action="store_true", help="Only print report, don't write config file") - output_group.add_argument("-v", "--verbose", action="store_true", help="Print detailed quality metrics per slice") - - add_overwrite_arg(p) - return p - - -def get_mosaic_files(directory: Path) -> dict[int, Path]: - """Find all mosaic grid files and extract slice IDs.""" - pattern = r".*z(\d+).*\.ome\.zarr$" - mosaics = {} - - for f in directory.iterdir(): - if f.is_dir() and f.suffix == ".zarr": - match = re.match(pattern, f.name) - if match: - slice_id = int(match.group(1)) - mosaics[slice_id] = f - - return dict(sorted(mosaics.items())) - - -def read_existing_config(config_path: Path) -> dict[int, dict[str, Any]]: - """Read an existing slice configuration file keyed by integer ``slice_id``.""" - rows = slice_config_io.read(config_path) - return {int(sid): dict(row) for sid, row in rows.items()} - - -def write_slice_config_with_quality( - output_file: Path, - slice_ids: list[int], - quality_results: dict[int, dict[str, Any]], - exclude_ids: list[int], - existing_config: dict[int, dict[str, Any]] | None = None, -) -> None: - """Write ``slice_config.csv`` with the decision columns set from the quality. - - assessment. Raw per-metric scores (ssim_mean / edge_score / variance_score / - depth) intentionally stay out of the CSV — they live in the pipeline report - and per-stage diagnostics JSON, not in the per-slice decision trace. - """ - out_rows: list[dict[str, object]] = [] - for slice_id in slice_ids: - quality = quality_results.get(slice_id, {}) - use = "true" - reason = "" - if slice_id in exclude_ids: - use = "false" - if quality.get("is_calibration", False): - reason = "calibration_slice" - elif quality.get("overall", 1.0) < quality.get("min_threshold", 0): - reason = "low_quality" - elif quality.get("exclude_first", False): - reason = "first_slice_excluded" - else: - reason = "manually_excluded" - - existing = existing_config.get(slice_id, {}) if existing_config else {} - if existing.get("use", "true").lower() in ["false", "0", "no"]: - use = "false" - if not reason: - reason = existing.get("exclude_reason") or existing.get("notes") or "previously_excluded" - - row: dict[str, object] = { - "slice_id": f"{slice_id:02d}", - "use": use, - "quality_score": f"{float(quality.get('overall', 0.0)):.3f}", - "exclude_reason": reason, - } - if existing.get("galvo_confidence", ""): - row["galvo_confidence"] = existing["galvo_confidence"] - if existing.get("galvo_fix", ""): - row["galvo_fix"] = existing["galvo_fix"] - for carry in ("notes",): - val = existing.get(carry) - if val: - row[carry] = val - out_rows.append(row) - - slice_config_io.write(output_file, out_rows) - - -def main() -> None: - """Run function operation.""" - p = _build_arg_parser() - args = p.parse_args() - - input_path = Path(args.input) - output_file = Path(args.output_file) - - if not args.report_only: - assert_output_exists(output_file, p, args) - - if not input_path.is_dir(): - p.error(f"Input directory not found: {input_path}") - - use_gpu = args.use_gpu and GPU_AVAILABLE - if args.use_gpu and not GPU_AVAILABLE: - print("Warning: GPU requested but not available. Using CPU.") - elif use_gpu: - try: - import cupy as cp - - cp.cuda.Device(args.gpu_id).use() - print(f"Using GPU device {args.gpu_id}") - except Exception as e: - print(f"Warning: Could not select GPU {args.gpu_id}: {e}. Using default.") - - print(f"Scanning for mosaic grids in: {input_path}") - mosaic_files = get_mosaic_files(input_path) - - if not mosaic_files: - p.error(f"No mosaic grid files found in {input_path}") - - slice_ids = sorted(mosaic_files.keys()) - print(f"Found {len(slice_ids)} slices: {[f'{s:02d}' for s in slice_ids]}") - - existing_config = None - if args.update_existing: - config_to_load = args.existing_config if args.existing_config else output_file - if Path(config_to_load).exists(): - existing_config = read_existing_config(Path(config_to_load)) - print(f"Loaded existing config with {len(existing_config)} entries") - - exclude_ids = set() - - if args.exclude_first > 0: - first_slices = slice_ids[: args.exclude_first] - exclude_ids.update(first_slices) - print(f"Excluding first {args.exclude_first} slice(s) as calibration: {first_slices}") - - print(f"\nLoading slices (pyramid_level={args.pyramid_level})...") - volumes: dict[int, np.ndarray | None] = {} - for slice_id in tqdm(slice_ids, desc="Loading slices"): - try: - vol, _ = read_omezarr(mosaic_files[slice_id], level=args.pyramid_level) - volumes[slice_id] = vol - except Exception as e: - print(f" Warning: Could not load slice {slice_id:02d}: {e}") - volumes[slice_id] = None - - calibration_slices = [] - if args.detect_calibration: - print(f"Detecting calibration slices (thickness ratio > {args.calibration_thickness_ratio})...") - valid_volumes = {sid: vol for sid, vol in volumes.items() if vol is not None} - calibration_slices = detect_calibration_slice(valid_volumes, args.calibration_thickness_ratio) - if calibration_slices: - exclude_ids.update(calibration_slices) - print(f"Detected calibration slices: {calibration_slices}") - - print(f"\nAssessing slice quality (GPU={'enabled' if use_gpu else 'disabled'}, sample_depth={args.sample_depth})...") - quality_results: dict[int, dict[str, Any]] = {} - - if use_gpu: - for i, slice_id in enumerate(tqdm(slice_ids, desc="Assessing quality")): - vol = volumes.get(slice_id) - if vol is None: - quality_results[slice_id] = { - "overall": 0.0, - "ssim_mean": 0.0, - "edge_score": 0.0, - "variance_score": 0.0, - "depth": 0, - "has_data": False, - "error": "load_failed", - } - continue - vol_before = volumes.get(slice_ids[i - 1]) if i > 0 else None - vol_after = volumes.get(slice_ids[i + 1]) if i < len(slice_ids) - 1 else None - overall, metrics = assess_slice_quality_gpu(vol, vol_before, vol_after, args.sample_depth) - metrics["is_calibration"] = slice_id in calibration_slices - metrics["exclude_first"] = slice_id in slice_ids[: args.exclude_first] - metrics["min_threshold"] = args.min_quality - quality_results[slice_id] = metrics - if args.min_quality > 0 and overall < args.min_quality: - exclude_ids.add(slice_id) - clear_gpu_memory() - else: - from concurrent.futures import ThreadPoolExecutor, as_completed - - def _assess_one(idx_and_id: tuple) -> Any: - i, slice_id = idx_and_id - vol = volumes.get(slice_id) - if vol is None: - return slice_id, { - "overall": 0.0, - "ssim_mean": 0.0, - "edge_score": 0.0, - "variance_score": 0.0, - "depth": 0, - "has_data": False, - "error": "load_failed", - } - vol_before = volumes.get(slice_ids[i - 1]) if i > 0 else None - vol_after = volumes.get(slice_ids[i + 1]) if i < len(slice_ids) - 1 else None - _overall, metrics = assess_slice_quality(vol, vol_before, vol_after, args.sample_depth, xy_roi=args.roi_size) - metrics["is_calibration"] = slice_id in calibration_slices - metrics["exclude_first"] = slice_id in slice_ids[: args.exclude_first] - metrics["min_threshold"] = args.min_quality - return slice_id, metrics - - tasks = list(enumerate(slice_ids)) - with ThreadPoolExecutor(max_workers=args.processes) as executor: - futures = {executor.submit(_assess_one, t): t for t in tasks} - for future in tqdm(as_completed(futures), total=len(futures), desc="Assessing quality"): - slice_id, metrics = future.result() - quality_results[slice_id] = metrics - if args.min_quality > 0 and metrics.get("overall", 0.0) < args.min_quality: - exclude_ids.add(slice_id) - - print("\n" + "=" * 70) - print(f"SLICE QUALITY REPORT{' (GPU-accelerated)' if use_gpu else ' (CPU)'}") - print("=" * 70) - print(f"{'Slice':<8} {'Quality':<10} {'SSIM':<10} {'Edge':<10} {'Var':<10} {'Depth':<8} {'Status':<15}") - print("-" * 70) - - for slice_id in slice_ids: - q = quality_results.get(slice_id, {}) - status = [] - if slice_id in exclude_ids: - if q.get("is_calibration"): - status.append("CALIBRATION") - elif q.get("exclude_first"): - status.append("FIRST_SLICE") - elif q.get("overall", 1.0) < args.min_quality: - status.append("LOW_QUALITY") - else: - status.append("EXCLUDED") - else: - status.append("OK") - - status_str = ",".join(status) - print( - f"{slice_id:02d} {q.get('overall', 0):.3f} " - f"{q.get('ssim_mean', 0):.3f} {q.get('edge_score', 0):.3f} " - f"{q.get('variance_score', 0):.3f} {q.get('depth', 0):<8} {status_str}" - ) - - print("-" * 70) - print(f"Total slices: {len(slice_ids)}") - print(f"Excluded: {len(exclude_ids)}") - print(f"Included: {len(slice_ids) - len(exclude_ids)}") - - if args.min_quality > 0: - low_quality = [s for s in slice_ids if quality_results.get(s, {}).get("overall", 1.0) < args.min_quality] - if low_quality: - print(f"Low quality slices (< {args.min_quality}): {low_quality}") - - if not args.report_only: - write_slice_config_with_quality(output_file, slice_ids, quality_results, list(exclude_ids), existing_config) - print(f"\nSlice configuration written to: {output_file}") - - if exclude_ids: - print(f"\nExcluded slice IDs: {sorted(exclude_ids)}") - - -if __name__ == "__main__": - main() diff --git a/scripts/linum_generate_pipeline_report.py b/scripts/linum_generate_pipeline_report.py deleted file mode 100644 index b00c23df..00000000 --- a/scripts/linum_generate_pipeline_report.py +++ /dev/null @@ -1,2028 +0,0 @@ -#!/usr/bin/env python3 -""" -Generate a quality report from pipeline metrics. - -This script aggregates metrics from various pipeline steps and generates -a comprehensive report in HTML or text format to help identify potential -issues in the 3D reconstruction pipeline. -""" - -# Configure thread limits before numpy/scipy imports -import linumpy.config.threads # noqa: F401 - -import argparse -import base64 -import io as _io -import json -import re -import zipfile -from collections import defaultdict -from datetime import datetime -from pathlib import Path - -try: - from PIL import Image as _PILImage - - _PIL_AVAILABLE = True -except ImportError: - _PIL_AVAILABLE = False - -from typing import Any - -import numpy as np - -from linumpy.metrics import aggregate_metrics, compute_summary_statistics - -# Logical pipeline step ordering -STEP_ORDER = [ - "stitch_3d", - "xy_transform_estimation", - "normalize_intensities", - "psf_compensation", - "crop_interface", - "pairwise_registration", - "stack_slices", -] - -# Human-readable display names (step_name → display label) -STEP_DISPLAY_NAMES = { - "stitch_3d": "Stitch 3D", - "xy_transform_estimation": "XY Transform Estimation", - "normalize_intensities": "Normalize Intensities", - "psf_compensation": "PSF Compensation", - "crop_interface": "Crop Interface", - "pairwise_registration": "Pairwise Registration", - "stack_slices": "Stack Slices", -} - -# Human-readable descriptions for pipeline steps -STEP_DESCRIPTIONS = { - "stitch_3d": "Stitches individual mosaic tiles into a single 2D slice.", - "xy_transform_estimation": "Estimates the affine transformation for tile overlap correction.", - "normalize_intensities": "Normalizes per-slice intensities using agarose background.", - "psf_compensation": "Compensates for beam profile / PSF attenuation along the optical axis.", - "crop_interface": "Detects and crops the tissue-agarose interface.", - "pairwise_registration": "Registers consecutive serial sections to align the 3D volume.", - "stack_slices": "Stacks registered slices into the final 3D volume.", -} - -# Maps pipeline step_name → image category shown in that step section -STEP_PREVIEW_CATEGORY = { - "stitch_3d": "stitch_preview", - "pairwise_registration": "common_space_preview", -} - - -def _build_arg_parser() -> argparse.ArgumentParser: - p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter) - p.add_argument("input_dir", help="Input directory containing pipeline output with metrics files.") - p.add_argument("output_report", help="Output report file path (.html, .zip, or .txt)") - p.add_argument( - "--format", - choices=["html", "text", "zip", "auto"], - default="auto", - help="Output format. 'auto' infers from extension. [%(default)s]", - ) - p.add_argument("--title", default="Pipeline Quality Report", help="Report title. [%(default)s]") - p.add_argument("--verbose", action="store_true", help="Include all metric details in the report.") - p.add_argument( - "--overview_png", type=Path, default=None, help="Path to the main volume PNG screenshot (embedded in summary)." - ) - p.add_argument( - "--annotated_png", type=Path, default=None, help="Path to the annotated volume PNG screenshot (embedded in summary)." - ) - p.add_argument("--max_overview_width", type=int, default=900, help="Max pixel width for overview images. [%(default)s]") - p.add_argument("--max_thumb_width", type=int, default=380, help="Max pixel width for gallery thumbnails. [%(default)s]") - p.add_argument("--no_images", action="store_true", help="Disable image discovery for zip bundles.") - return p - - -def get_status_color(status: str) -> str: - """Get HTML color for status.""" - colors = { - "ok": "#28a745", # green - "warning": "#ffc107", # yellow/amber - "error": "#dc3545", # red - "info": "#17a2b8", # blue - "unknown": "#6c757d", # gray - } - return colors.get(status, colors["unknown"]) - - -def get_status_emoji(status: str) -> str: - """Get emoji for status in text format.""" - emojis = {"ok": "✓", "warning": "⚠", "error": "✗", "info": "ℹ", "unknown": "?"} - return emojis.get(status, "?") - - -def format_value(value: float, precision: int = 4) -> str: - """Format a value for display.""" - if isinstance(value, float): - if abs(value) < 0.0001 or abs(value) > 10000: - return f"{value:.{precision}e}" - return f"{value:.{precision}f}" - elif isinstance(value, list) and len(value) > 5: - return f"[{len(value)} items]" - return str(value) - - -def sort_steps(aggregated: dict) -> dict: - """Sort pipeline steps in logical execution order.""" - - def step_key(step_name: str) -> Any: - try: - return (0, STEP_ORDER.index(step_name)) - except ValueError: - return (1, step_name) - - return dict(sorted(aggregated.items(), key=lambda x: step_key(x[0]))) - - -def extract_slice_id(source_file: str) -> str: - """Extract a meaningful slice identifier from a source file path.""" - path = Path(source_file) - # Search path components for a slice pattern like z01, z002, slice_3 - for part in reversed(path.parts): - m = re.search(r"(z\d+|slice_z?\d+)", part, re.IGNORECASE) - if m: - return m.group(1) - return path.stem - - -def parse_issue(issue_str: str) -> dict: - """Parse an issue string of the form 'source: metric: value op threshold (level)'.""" - parts = issue_str.split(": ", 2) - if len(parts) < 3: - return {"source": parts[0] if parts else "", "metric": "", "raw": issue_str, "value": None, "threshold": None} - source, metric, rest = parts[0], parts[1], parts[2] - m = re.match(r"([+-]?[\d.e+-]+)\s*([><]=?)\s*([+-]?[\d.e+-]+)", rest) - if m: - return { - "source": source, - "metric": metric, - "raw": issue_str, - "value": float(m.group(1)), - "op": m.group(2), - "threshold": float(m.group(3)), - } - return {"source": source, "metric": metric, "raw": issue_str, "value": None, "threshold": None} - - -def group_issues(issues: list[str]) -> list[dict]: - """ - Group issues by metric name. - - Returns a list of dicts with keys: metric, count, values, threshold, details. - """ - groups = defaultdict(list) - for issue in issues: - parsed = parse_issue(issue) - key = parsed["metric"] if parsed["metric"] else "__other__" - groups[key].append(parsed) - - result = [] - for metric, items in groups.items(): - values = [i["value"] for i in items if i.get("value") is not None] - threshold = items[0].get("threshold") if items else None - op = items[0].get("op", ">") if items else ">" - result.append( - { - "metric": metric if metric != "__other__" else "", - "count": len(items), - "values": values, - "threshold": threshold, - "op": op, - "details": [i["raw"] for i in items], - } - ) - return result - - -def separate_metrics_by_type(metrics_list: list[dict]) -> tuple[dict, dict]: - """ - Separate metrics into quality metrics and info/parameter fields. - - Returns - ------- - tuple - quality_metrics: {name: {'entries': [{value, status}], 'unit': str}} - info_fields: {name: {'values': [v], 'description': str, 'is_constant': bool, 'display_value': any}} - """ - quality_metrics: dict = {} - info_fields: dict = {} - - for m in metrics_list: - for name, data in m.get("metrics", {}).items(): - if not isinstance(data, dict): - continue - status = data.get("status", "ok") - value = data.get("value") - unit = data.get("unit") or "" - desc = data.get("description") or "" - - if status == "info": - if name not in info_fields: - info_fields[name] = {"values": [], "description": desc, "unit": unit} - info_fields[name]["values"].append(value) - else: - if name not in quality_metrics: - quality_metrics[name] = {"entries": [], "unit": unit, "description": desc} - quality_metrics[name]["entries"].append({"value": value, "status": status}) - - # Determine if each info field is constant across all files - for info in info_fields.values(): - vals = info["values"] - try: - numeric = [v for v in vals if isinstance(v, (int, float))] - if numeric and len(numeric) == len(vals): - is_const = float(np.std(numeric)) < 1e-10 - else: - is_const = len({str(v) for v in vals}) <= 1 - except Exception: - is_const = len({str(v) for v in vals}) <= 1 - info["is_constant"] = is_const - info["display_value"] = vals[0] if vals else None - - return quality_metrics, info_fields - - -def generate_sparkline_svg(values: list, statuses: list[str] | None = None, width: int = 160, height: int = 36) -> str: - """Generate an inline SVG bar-chart sparkline for a list of values.""" - numeric = [(i, v) for i, v in enumerate(values) if isinstance(v, (int, float))] - if len(numeric) < 2: - return "" - - all_vals = [v for _, v in numeric] - min_val, max_val = min(all_vals), max(all_vals) - val_range = max_val - min_val or 1.0 - - if statuses is None: - statuses = ["ok"] * len(values) - - n = len(values) - bar_w = width / n - rects = [] - for i, v in numeric: - h = max(2.0, (v - min_val) / val_range * (height - 4)) - y = height - h - color = get_status_color(statuses[i]) if i < len(statuses) else get_status_color("ok") - rects.append( - f'' - ) - - title = f"Min: {min_val:.3g} Max: {max_val:.3g} n={len(numeric)}" - return ( - f'' + "".join(rects) + "" - ) - - -def generate_trend_line_svg( - values: list, - _labels: list[str] | None = None, - width: int = 420, - height: int = 90, - show_trend: bool = True, - color: str = "#4a90d9", -) -> str: - """Generate an inline SVG line chart for cross-slice trend visualisation.""" - numeric = [(i, float(v)) for i, v in enumerate(values) if isinstance(v, (int, float))] - if len(numeric) < 2: - return "" - - xs = [p[0] for p in numeric] - ys = [p[1] for p in numeric] - min_y, max_y = min(ys), max(ys) - y_range = max_y - min_y or 1.0 - pad_x, pad_y = 30, 10 - - def to_svg_x(i: Any) -> Any: - return pad_x + (i / (len(values) - 1)) * (width - 2 * pad_x) - - def to_svg_y(v: Any) -> Any: - return height - pad_y - ((v - min_y) / y_range) * (height - 2 * pad_y) - - # Build polyline points - pts = " ".join(f"{to_svg_x(i):.1f},{to_svg_y(v):.1f}" for i, v in numeric) - - elements = [ - f'', - ] - - # Dots at each data point - for i, v in numeric: - elements.append(f'') - - # Trend line (least squares) - if show_trend and len(xs) >= 3: - x_arr = np.array(xs, dtype=float) - y_arr = np.array(ys, dtype=float) - slope = (np.mean(x_arr * y_arr) - np.mean(x_arr) * np.mean(y_arr)) / (np.mean(x_arr**2) - np.mean(x_arr) ** 2 + 1e-12) - intercept = np.mean(y_arr) - slope * np.mean(x_arr) - x0, x1 = xs[0], xs[-1] - y0, y1 = slope * x0 + intercept, slope * x1 + intercept - elements.append( - f'' - ) - - # Y-axis labels - elements.append( - f'{max_y:.3g}' - ) - elements.append( - f'{min_y:.3g}' - ) - - title_text = f"n={len(numeric)}, range [{min_y:.3g}, {max_y:.3g}]" - return ( - f'' + "".join(elements) + "" - ) - - -def compute_cross_slice_trends(aggregated: dict[str, list[dict]]) -> dict: - """ - Compute cross-slice aggregate trends from aggregated metrics. - - Returns a dict with trend groups, each containing: - 'label', 'description', 'series': [{name, values, unit}] - """ - trends = {} - - def _extract(metrics_list: Any, key: str) -> list: - """Extract sorted numerical values for a given metric key.""" - pairs = [] - for m in metrics_list: - src = m.get("source_file", "") - val = m.get("metrics", {}).get(key, {}).get("value") - if isinstance(val, (int, float)): - pairs.append((src, val)) - pairs.sort(key=lambda p: p[0]) # sort by source file path - return [v for _, v in pairs] - - # XY tile transform: scale and shear across slices - if "xy_transform_estimation" in aggregated: - ml = aggregated["xy_transform_estimation"] - t00 = _extract(ml, "transform_00") - t11 = _extract(ml, "transform_11") - rms = _extract(ml, "rms_residual") - acc_sys = _extract(ml, "accumulated_systematic_error_px") - acc_rnd = _extract(ml, "accumulated_random_error_px") - series = [] - if t00: - series.append({"name": "Step Y (px)", "values": t00, "unit": "px"}) - if t11: - series.append({"name": "Step X (px)", "values": t11, "unit": "px"}) - if rms: - series.append({"name": "RMS residual (px)", "values": rms, "unit": "px"}) - if acc_sys: - series.append({"name": "Accum. systematic error (px)", "values": acc_sys, "unit": "px"}) - if acc_rnd: - series.append({"name": "Accum. random error (px)", "values": acc_rnd, "unit": "px"}) - if series: - trends["xy_transform"] = { - "label": "XY Tile Transform Consistency", - "description": ( - "Tile step sizes and fitting residuals across slices. Large variation indicates unstable tile positioning." - ), - "series": series, - } - - # Pairwise registration: cumulative drift - if "pairwise_registration" in aggregated: - ml = aggregated["pairwise_registration"] - tx = _extract(ml, "translation_x") - ty = _extract(ml, "translation_y") - rot = _extract(ml, "rotation") - series = [] - if tx: - cum_tx = list(np.cumsum(tx)) - series.append({"name": "Cumulative tx (px)", "values": cum_tx, "unit": "px"}) - if ty: - cum_ty = list(np.cumsum(ty)) - series.append({"name": "Cumulative ty (px)", "values": cum_ty, "unit": "px"}) - if rot: - cum_rot = list(np.cumsum(rot)) - series.append({"name": "Cumulative rotation (deg)", "values": cum_rot, "unit": "deg"}) - if series: - trends["registration_drift"] = { - "label": "Cumulative Registration Drift", - "description": ( - "Accumulated translation and rotation across all slices. " - "A large net drift indicates systematic 3D volume distortion." - ), - "series": series, - } - - # Interface depth trend - if "crop_interface" in aggregated: - ml = aggregated["crop_interface"] - depth = _extract(ml, "detected_interface_depth_um") - if depth: - trends["interface_depth"] = { - "label": "Interface Depth Trend", - "description": ( - "Detected tissue-agarose interface depth across slices. " - "A systematic slope may indicate progressive tissue deformation." - ), - "series": [{"name": "Interface depth (µm)", "values": depth, "unit": "µm"}], - } - - # Background normalization drift - if "normalize_intensities" in aggregated: - ml = aggregated["normalize_intensities"] - bg = _extract(ml, "mean_background") - if bg: - trends["background_drift"] = { - "label": "Background Level Trend", - "description": ( - "Mean agarose background level across slices. " - "A strong trend indicates illumination drift during acquisition." - ), - "series": [{"name": "Mean background", "values": bg, "unit": ""}], - } - - return trends - - -# ============================================================================= -# Diagnostic data discovery -# ============================================================================= - - -def discover_interpolation_data(input_dir: Path) -> dict | None: - """ - Discover slice-interpolation outputs. - - Reads per-slice diagnostic JSONs written by ``linum_interpolate_missing_slice.py`` - (``slice_z*_interpolated_diagnostics.json``) and the preview PNGs. - ``slice_config_final.csv`` (produced by ``finalise_interpolation``) is - read via :mod:`linumpy.io.slice_config` to enrich the rows with the - per-slice trace fields (``interpolated``, ``interpolation_method_used``, - ``interpolation_fallback_reason``, ``use``, ``auto_excluded``, ...). - - Returns - ------- - dict or None - ``None`` when no interpolation happened. Otherwise a dict with keys - ``rows`` (list of per-slice dicts), ``images`` (list of preview - PNG paths), ``slice_config_final`` (path or None) and - ``summary`` (aggregated stats). - """ - from linumpy.io import slice_config as slice_config_io - - interp_dir = input_dir / "interpolate_missing_slice" - if not interp_dir.is_dir(): - return None - - diag_files = sorted(interp_dir.glob("slice_z*_interpolated_diagnostics.json")) - if not diag_files: - return None - - rows: list[dict] = [] - for path in diag_files: - try: - with path.open() as fh: - data = json.load(fh) - except Exception: - continue - rows.append( - { - "slice_id": str(data.get("slice_id") or "").strip(), - "method": str(data.get("method") or "unknown"), - "method_used": ( - "" - if data.get("interpolation_failed") is True - else str(data.get("method_used") or data.get("method") or "unknown") - ), - "fallback_reason": str(data.get("fallback_reason") or ""), - "interpolation_failed": bool(data.get("interpolation_failed", False)), - "pre_reg_ncc": data.get("pre_reg_ncc"), - "post_reg_ncc": data.get("post_reg_ncc"), - "ncc_improvement": data.get("ncc_improvement"), - "affine_determinant": data.get("affine_determinant"), - "output_path": str(data.get("output_path") or ""), - "diagnostics_path": str(path), - } - ) - - if not rows: - return None - - # Enrich from slice_config_final.csv when available (single source of truth). - slice_config_final = input_dir / "slice_config_final.csv" - if slice_config_final.exists(): - try: - sc_rows = slice_config_io.read(slice_config_final) - for r in rows: - sid = slice_config_io.normalize_slice_id(r["slice_id"]) - sc_row = sc_rows.get(sid) - if sc_row is not None: - r["slice_config_use"] = sc_row.get("use", "") - r["slice_config_interpolated"] = sc_row.get("interpolated", "") - r["slice_config_interpolation_failed"] = sc_row.get("interpolation_failed", "") - r["slice_config_auto_excluded"] = sc_row.get("auto_excluded", "") - r["slice_config_notes"] = sc_row.get("notes", "") - except Exception: - slice_config_final = None - - images: list[Path] = sorted(interp_dir.glob("slice_z*_interpolated_preview.png")) - - method_counts: dict[str, int] = {} - method_used_counts: dict[str, int] = {} - fallback_counts: dict[str, int] = {} - pre_nccs: list[float] = [] - post_nccs: list[float] = [] - improvements: list[float] = [] - - def _to_float(value: object) -> float | None: - if not isinstance(value, (int, float, str, bytes, bytearray)): - return None - try: - return float(value) - except (TypeError, ValueError): - return None - - for r in rows: - method = (r.get("method") or "unknown").strip() or "unknown" - method_used = (r.get("method_used") or method).strip() or method - fallback = (r.get("fallback_reason") or "").strip() - method_counts[method] = method_counts.get(method, 0) + 1 - method_used_counts[method_used] = method_used_counts.get(method_used, 0) + 1 - if fallback: - fallback_counts[fallback] = fallback_counts.get(fallback, 0) + 1 - pre = _to_float(r.get("pre_reg_ncc")) - post = _to_float(r.get("post_reg_ncc")) - imp = _to_float(r.get("ncc_improvement")) - if pre is not None: - pre_nccs.append(pre) - if post is not None: - post_nccs.append(post) - if imp is not None: - improvements.append(imp) - - n_failed = sum(1 for r in rows if r.get("interpolation_failed")) - - summary = { - "count": len(rows), - "n_succeeded": len(rows) - n_failed, - "n_failed": n_failed, - "method_counts": method_counts, - "method_used_counts": method_used_counts, - "fallback_counts": fallback_counts, - "n_with_fallback": sum(fallback_counts.values()), - "pre_reg_ncc_mean": float(np.mean(pre_nccs)) if pre_nccs else None, - "post_reg_ncc_mean": float(np.mean(post_nccs)) if post_nccs else None, - "ncc_improvement_mean": float(np.mean(improvements)) if improvements else None, - } - - return { - "rows": rows, - "images": images, - "slice_config_final": slice_config_final if (slice_config_final and slice_config_final.exists()) else None, - "summary": summary, - } - - -def discover_diagnostic_data(input_dir: Path) -> dict[str, dict]: - """ - Discover diagnostic outputs in the pipeline output directory. - - Looks for known diagnostic subdirectories and reads their JSON data. - - Returns - ------- - dict - Maps diagnostic_name → {'label', 'description', 'json_data': [...], 'images': [Path]} - """ - import json as _json - - diagnostics: dict[str, dict] = {} - - diag_dir = input_dir / "diagnostics" - if not diag_dir.exists(): - return diagnostics - - # Define known diagnostics: (subdir, label, description) - known = [ - ("dilation_analysis", "Tile Dilation Analysis", "Per-slice scale factors and mosaic positioning accuracy."), - ("aggregated_dilation", "Aggregated Dilation Analysis", "Cross-slice tile dilation summary."), - ("rotation_analysis", "Rotation Drift Analysis", "Rotation angle drift across slices."), - ("acquisition_rotation", "Acquisition Rotation Analysis", "In-plane rotation estimated from acquisition metadata."), - ( - "motor_only_stitch", - "Motor-Only Stitching (comparison)", - "Stitched mosaic using motor positions only (no registration correction).", - ), - ( - "motor_only_stack", - "Motor-Only Stack (comparison)", - "Volume stacked without pairwise registration (motor positions only).", - ), - ( - "stitch_comparison", - "Stitching Comparison", - "Side-by-side comparison of registration-based vs motor-based stitching.", - ), - ] - - for subdir_name, label, description in known: - subdir = diag_dir / subdir_name - if not subdir.exists(): - continue - - json_data = [] - images = [] - - # Collect all JSON files (recursively for per-slice diagnostics) - for json_file in sorted(subdir.rglob("*.json")): - try: - with Path(json_file).open() as f: - data = _json.load(f) - data["_source"] = str(json_file) - json_data.append(data) - except Exception: - pass - - # Collect PNG images - images.extend(sorted(subdir.rglob("*.png"))) - - if json_data or images: - diagnostics[subdir_name] = { - "label": label, - "description": description, - "json_data": json_data, - "images": images, - } - - return diagnostics - - -def discover_images( - input_dir: Path, overview_png: Path | None = None, annotated_png: Path | None = None -) -> dict[str, list[Path]]: - """ - Discover preview images in the pipeline output directory. - - Returns a dict mapping category → sorted list of image paths: - 'overview' – main volume screenshots (up to 2) - 'stitch_preview' – per-slice stitched previews - 'common_space_preview' – common-space alignment previews - 'diag_*' – images found in diagnostics/ subdirs - """ - images: dict[str, list[Path]] = { - "overview": [], - "stitch_preview": [], - "common_space_preview": [], - } - - # Overview images from CLI (staged in Nextflow work dir) - for p in [overview_png, annotated_png]: - if p and Path(p).exists(): - images["overview"].append(Path(p)) - - # Stitched slice previews - stitch_dir = input_dir / "previews" / "stitched_slices" - if stitch_dir.exists(): - images["stitch_preview"] = sorted(stitch_dir.glob("*.png")) - - # Common-space alignment previews - cs_dir = input_dir / "common_space_previews" - if cs_dir.exists(): - images["common_space_preview"] = sorted(cs_dir.glob("*.png")) - - # Auto-detect overview from stack output directories if not provided via CLI - if not images["overview"]: - for stack_dir_name in ("stack_motor", "stack", "normalize_z_intensity"): - d = input_dir / stack_dir_name - if d.exists(): - pngs = sorted(d.glob("*.png")) - if pngs: - images["overview"] = pngs[:2] # at most overview + annotated - break - - # Diagnostic images: add one category per diagnostics subdir - diag_dir = input_dir / "diagnostics" - if diag_dir.exists(): - for subdir in sorted(diag_dir.iterdir()): - if subdir.is_dir(): - pngs = sorted(subdir.rglob("*.png")) - if pngs: - cat_key = f"diag_{subdir.name}" - images[cat_key] = pngs - - return images - - -def image_to_data_uri(path: Path, max_width: int | None = None) -> str: - """Encode a PNG image as a base64 data URI, optionally resizing.""" - if max_width and _PIL_AVAILABLE: - with _PILImage.open(path) as img: - if img.width > max_width: - ratio = max_width / img.width - new_size = (max_width, int(img.height * ratio)) - img = img.resize(new_size, _PILImage.Resampling.LANCZOS) - buf = _io.BytesIO() - img.save(buf, format="PNG", optimize=True) - data_bytes = buf.getvalue() - else: - data_bytes = path.read_bytes() - b64 = base64.b64encode(data_bytes).decode("ascii") - return f"data:image/png;base64,{b64}" - - -def render_image_gallery_html( - images: list[Path], mode: str = "embed", category: str = "images", _label: str = "Preview Images", max_width: int = 380 -) -> str: - """ - Render a collapsible image gallery section. - - Parameters - ---------- - images : list of Path - Image file paths to include in the gallery. - mode : str - Embedding mode: 'embed' (base64 in HTML) or 'link' (relative path for zip mode). - category : str - Image category name, used as subfolder in zip mode. - max_width : int - Maximum image width in pixels for embedded previews. - """ - if not images: - return "" - - items = [] - for p in images: - src = image_to_data_uri(p, max_width=max_width) if mode == "embed" else f"previews/{category}/{p.name}" - name = p.stem - items.append( - f'" - ) - - return f""" - -""" - - -def generate_zip_bundle(html: str, images: dict[str, list[Path]], output_path: Path) -> None: - """Bundle the HTML report and all image files into a zip archive.""" - with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zf: - zf.writestr("index.html", html) - for category, paths in images.items(): - for p in paths: - zf.write(p, f"previews/{category}/{p.name}") - - -def compute_overall_status(aggregated: dict[str, list[dict]]) -> tuple: - """ - Compute overall status counts from aggregated metrics. - - Returns - ------- - tuple - (all_statuses, error_count, warning_count, ok_count) - """ - all_statuses = [m.get("overall_status", "unknown") for step_metrics in aggregated.values() for m in step_metrics] - - error_count = all_statuses.count("error") - warning_count = all_statuses.count("warning") - ok_count = all_statuses.count("ok") - - return all_statuses, error_count, warning_count, ok_count - - -def get_step_status(metrics_list: list[dict]) -> str: - """Get the overall status for a step based on its metrics.""" - step_statuses = [m.get("overall_status", "unknown") for m in metrics_list] - if "error" in step_statuses: - return "error" - elif "warning" in step_statuses: - return "warning" - return "ok" - - -def collect_issues(metrics_list: list[dict]) -> tuple: - """ - Collect all warnings and errors from a metrics list. - - Returns - ------- - tuple - (all_warnings, all_errors) - """ - all_warnings = [] - all_errors = [] - for m in metrics_list: - source = Path(m.get("source_file", "unknown")).stem - all_warnings.extend(f"{source}: {w}" for w in m.get("warnings", [])) - all_errors.extend(f"{source}: {e}" for e in m.get("errors", [])) - return all_warnings, all_errors - - -def _render_grouped_issues_html(grouped: list[dict], color_class: str, label: str) -> str: - """Render a collapsible grouped-issues section in HTML.""" - total = sum(g["count"] for g in grouped) - html = f""" -
- - {label} - {total} - -
-""" - for g in grouped: - if g["count"] == 1: - html += f'
{g["details"][0]}
\n' - else: - vals = g["values"] - val_str = f"range {min(vals):.3g} – {max(vals):.3g}" if vals else f"{g['count']} occurrences" - thresh_str = f", threshold: {g['threshold']:.3g}" if g["threshold"] is not None else "" - summary_line = f"{g['metric']}: {g['count']} slices affected ({val_str}{thresh_str})" - html += '
\n' - html += f' {summary_line}\n' - html += '
\n' - for detail in g["details"]: - html += f'
{detail}
\n' - html += "
\n" - html += "
\n" - html += "
\n
\n" - return html - - -def _render_interpolation_section_html( - interpolation: dict, - image_mode: str = "link", - max_thumb_width: int = 380, -) -> str: - """Render the slice-interpolation section of the HTML report.""" - summary = interpolation.get("summary", {}) - rows = interpolation.get("rows", []) - images = interpolation.get("images", []) - slice_config_final = interpolation.get("slice_config_final") - - count = summary.get("count", 0) - n_failed = summary.get("n_failed", 0) - n_succeeded = summary.get("n_succeeded", count - n_failed) - method_counts = summary.get("method_counts", {}) - method_used_counts = summary.get("method_used_counts", {}) - fallback_counts = summary.get("fallback_counts", {}) - pre_mean = summary.get("pre_reg_ncc_mean") - post_mean = summary.get("post_reg_ncc_mean") - imp_mean = summary.get("ncc_improvement_mean") - - status = "ok" - if n_failed > 0 and count > 0: - status = "warning" if n_failed < count else "error" - - html = '\n
\n' - html += "

Slice Interpolation

\n" - html += ( - '

' - "Missing slices reconstructed from their neighbours via zmorph. " - "Successful interpolations stamp interpolated=true and are flagged " - "reliable=0 in downstream pairwise registration. When quality gates " - "fail the slice is hard-skipped (interpolation_failed=true) " - "and the slot stays a genuine gap in the stacked volume \u2014 no blended volume is " - "written. See docs/SLICE_INTERPOLATION_FEATURE.md.

\n" - ) - - html += '
\n' - html += f'
{count}
' - html += '
Gaps Detected
\n' - ok_color = get_status_color("ok") - html += ( - f'
' - f'{n_succeeded}
Successfully Interpolated
\n' - ) - html += ( - f'
' - f'{n_failed}
Hard-Skipped (Gap)
\n' - ) - if pre_mean is not None: - html += f'
{pre_mean:.3f}
' - html += '
Mean Pre-Reg NCC
\n' - if post_mean is not None: - html += f'
{post_mean:.3f}
' - html += '
Mean Post-Reg NCC
\n' - if imp_mean is not None: - html += f'
{imp_mean:+.3f}
' - html += '
Mean NCC Improvement
\n' - html += "
\n" - - # Method breakdown - html += '
\n' - html += ' \n' - html += ' \n' - html += " \n" - html += " \n" - if fallback_counts: - html += " \n" - if slice_config_final is not None: - html += " " - html += f"\n" - html += "
Method requested" - html += ", ".join(f"{k}: {v}" for k, v in sorted(method_counts.items())) or "(none)" - html += "
Method actually used" - html += ", ".join(f"{k}: {v}" for k, v in sorted(method_used_counts.items())) or "(none)" - html += "
Hard-skip reasons" - html += ", ".join(f"{k}: {v}" for k, v in sorted(fallback_counts.items())) - html += "
Per-slice trace file{slice_config_final.name}
\n" - html += "
\n" - - # Per-slice table (cap to 50 rows; more than that is rare) - if rows: - html += '
\n' - html += ' ' - html += f"Per-slice interpolation diagnostics ({len(rows)} slice(s))\n" - html += ' \n' - html += ( - " " - "" - "" - "" - "\n" - ) - for r in rows[:50]: - sid = r.get("slice_id", "") or "?" - failed = bool(r.get("interpolation_failed")) - status_label = "SKIPPED" if failed else "OK" - method_used = r.get("method_used", "") or ("—" if failed else "") - fb = r.get("fallback_reason", "") or "" - pre = r.get("pre_reg_ncc", "") - post = r.get("post_reg_ncc", "") - imp = r.get("ncc_improvement", "") - det = r.get("affine_determinant", "") - - pre_fmt = f"{float(pre):.3f}" if pre not in ("", None) else "-" - post_fmt = f"{float(post):.3f}" if post not in ("", None) else "-" - imp_fmt = f"{float(imp):+.3f}" if imp not in ("", None) else "-" - det_fmt = f"{float(det):.3f}" if det not in ("", None) else "-" - - if failed: - row_style = ' style="background:#ffe5e5;"' - elif fb: - row_style = ' style="background:#fff8e1;"' - else: - row_style = "" - html += ( - f" " - f"" - f"" - f"" - "\n" - ) - if len(rows) > 50: - html += ( - f' \n' - ) - html += "
SliceStatusMethod UsedReasonPre NCCPost NCCΔNCC|det|
{sid}{status_label}{method_used}{fb}{pre_fmt}{post_fmt}{imp_fmt}{det_fmt}
(showing first 50 of {len(rows)} rows)
\n" - html += "
\n" - - # Preview image gallery (shown in zip/link mode only; embed mode skips images) - if images: - gallery = render_image_gallery_html( - images, - mode=image_mode, - category="diag_interpolate_missing_slice", - _label="Interpolation Previews", - max_width=max_thumb_width, - ) - html += gallery - - html += "
\n" - return html - - -def _render_interpolation_section_text(interpolation: dict) -> str: - """Render the slice-interpolation section of the text report.""" - summary = interpolation.get("summary", {}) - rows = interpolation.get("rows", []) - count = summary.get("count", 0) - n_failed = summary.get("n_failed", 0) - n_succeeded = summary.get("n_succeeded", count - n_failed) - pre_mean = summary.get("pre_reg_ncc_mean") - post_mean = summary.get("post_reg_ncc_mean") - imp_mean = summary.get("ncc_improvement_mean") - - lines = [] - lines.append("") - lines.append(f"{get_status_emoji('info')} SLICE INTERPOLATION") - lines.append("-" * 70) - lines.append(f" Gaps detected : {count}") - lines.append(f" Successfully interp'd : {n_succeeded}") - lines.append(f" Hard-skipped (gap) : {n_failed}") - if pre_mean is not None: - lines.append(f" Mean pre-reg NCC : {pre_mean:.3f}") - if post_mean is not None: - lines.append(f" Mean post-reg NCC : {post_mean:.3f}") - if imp_mean is not None: - lines.append(f" Mean NCC improvement : {imp_mean:+.3f}") - - method_used_counts = summary.get("method_used_counts", {}) - if method_used_counts: - mu_parts = ", ".join(f"{k}: {v}" for k, v in sorted(method_used_counts.items())) - lines.append(f" Methods used : {mu_parts}") - fallback_counts = summary.get("fallback_counts", {}) - if fallback_counts: - fb_parts = ", ".join(f"{k}: {v}" for k, v in sorted(fallback_counts.items())) - lines.append(f" Hard-skip reasons : {fb_parts}") - - if rows: - lines.append("") - lines.append(f" {'Slice':<6} {'Status':<8} {'Used':<14} {'Reason':<28} {'PreNCC':>7} {'PostNCC':>7}") - lines.append(" " + "-" * 80) - for r in rows[:50]: - sid = (r.get("slice_id", "") or "?")[:6] - failed = bool(r.get("interpolation_failed")) - status = "SKIP" if failed else "OK" - method_used = (r.get("method_used", "") or ("—" if failed else ""))[:14] - fb = (r.get("fallback_reason", "") or "")[:28] - pre = r.get("pre_reg_ncc", "") - post = r.get("post_reg_ncc", "") - try: - pre_fmt = f"{float(pre):.3f}" if pre not in ("", None) else "-" - except (TypeError, ValueError): - pre_fmt = "-" - try: - post_fmt = f"{float(post):.3f}" if post not in ("", None) else "-" - except (TypeError, ValueError): - post_fmt = "-" - lines.append(f" {sid:<6} {status:<8} {method_used:<14} {fb:<28} {pre_fmt:>7} {post_fmt:>7}") - if len(rows) > 50: - lines.append(f" ... ({len(rows) - 50} more row(s) not shown)") - - return "\n".join(lines) - - -def generate_html_report( - aggregated: dict[str, list[dict]], - title: str, - verbose: bool = False, - images: dict[str, list[Path]] | None = None, - image_mode: str = "embed", - max_overview_width: int = 900, - max_thumb_width: int = 380, - trends: dict | None = None, - diagnostics: dict | None = None, - interpolation: dict | None = None, -) -> str: - """Generate an HTML report from aggregated metrics.""" - aggregated = sort_steps(aggregated) - images = images or {} - - _, error_count, warning_count, ok_count = compute_overall_status(aggregated) - - if error_count > 0: - overall_status = "error" - overall_message = f"{error_count} error(s), {warning_count} warning(s)" - elif warning_count > 0: - overall_status = "warning" - overall_message = f"{warning_count} warning(s)" - else: - overall_status = "ok" - overall_message = "All checks passed" - - html = f""" - - - - - {title} - - - -
-

{title}

-
Generated: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
-
- -
-

Summary

-
- {overall_message} -
-
-
-
{len(aggregated)}
-
Pipeline Steps
-
-
-
{sum(len(v) for v in aggregated.values())}
-
Total Metrics Files
-
-
-
{ok_count}
-
OK
-
-
-
{warning_count}
-
Warnings
-
-
-
{error_count}
-
Errors
-
-
-
-""" - - # Overview images in the summary section - overview_imgs = images.get("overview", []) - if overview_imgs: - html += '
\n' - html += ' \n' - html += '
\n' - for p in overview_imgs: - if image_mode == "embed": - src = image_to_data_uri(p, max_width=max_overview_width) - else: - src = f"previews/overview/{p.name}" - html += ( - f"
" - f'' - f'{p.stem}' - f"
{p.stem}
\n" - ) - html += "
\n
\n" - - # Cross-slice trends section - if trends: - colors = ["#4a90d9", "#e67e22", "#27ae60", "#8e44ad", "#c0392b"] - html += '\n \n" - - # Generate section for each step - for step_name, metrics_list in aggregated.items(): - summary = compute_summary_statistics(metrics_list) - step_status = get_step_status(metrics_list) - description = STEP_DESCRIPTIONS.get(step_name, "") - - # Separate quality metrics from info/parameter fields - quality_metrics, info_fields = separate_metrics_by_type(metrics_list) - - html += f""" -
-
- {STEP_DISPLAY_NAMES.get(step_name, step_name.replace("_", " ").title())} - - {summary["count"]} items — {step_status.upper()} - -
-""" - if description: - html += f'
{description}
\n' - - # --- Quality metrics stats table with sparklines --- - if quality_metrics: - html += ' \n' - html += """ - - - - - - - - - -""" - for metric_name, mdata in quality_metrics.items(): - entries = mdata["entries"] - numeric_vals = [e["value"] for e in entries if isinstance(e.get("value"), (int, float))] - if not numeric_vals: - continue - arr = np.array(numeric_vals) - mean_v = float(np.mean(arr)) - median_v = float(np.median(arr)) - std_v = float(np.std(arr)) - min_v = float(np.min(arr)) - max_v = float(np.max(arr)) - statuses = [e.get("status", "ok") for e in entries] - unit = mdata.get("unit", "") - unit_str = f" {unit}" if unit else "" - - # Worst status in this metric - if "error" in statuses: - metric_status = "error" - elif "warning" in statuses: - metric_status = "warning" - else: - metric_status = "ok" - - sparkline = generate_sparkline_svg([e.get("value") for e in entries], statuses) - - html += f""" - - - - - - - - -""" - html += "
MetricMeanMedianStdMinMaxDistribution
- - {metric_name}{unit_str} - {format_value(mean_v)}{format_value(median_v)}{format_value(std_v)}{format_value(min_v)}{format_value(max_v)}{sparkline}
\n" - - # --- Errors and warnings (grouped, collapsible) --- - all_warnings, all_errors = collect_issues(metrics_list) - - if all_errors: - grouped_errors = group_issues(all_errors) - html += _render_grouped_issues_html(grouped_errors, "error", "Errors") - - if all_warnings: - grouped_warnings = group_issues(all_warnings) - html += _render_grouped_issues_html(grouped_warnings, "warning", "Warnings") - - # --- Info / parameter fields (collapsed) --- - if info_fields: - constant_params = {k: v for k, v in info_fields.items() if v["is_constant"]} - variable_infos = {k: v for k, v in info_fields.items() if not v["is_constant"]} - - if constant_params: - html += """ -
- Pipeline Parameters - -""" - for name, info in constant_params.items(): - val = info["display_value"] - unit = info.get("unit", "") - unit_str = f" {unit}" if unit else "" - html += f""" - - - -""" - html += "
{name}{format_value(val)}{unit_str}
\n
\n" - - if variable_infos: - html += """ -
- Variable Info Fields (per-slice) - - - - - - - - -""" - for name, info in variable_infos.items(): - numeric = [v for v in info["values"] if isinstance(v, (int, float))] - if not numeric: - continue - arr = np.array(numeric) - unit = info.get("unit", "") - unit_str = f" {unit}" if unit else "" - html += f""" - - - - - - -""" - html += "
FieldMeanStdMinMax
{name}{unit_str}{format_value(float(np.mean(arr)))}{format_value(float(np.std(arr)))}{format_value(float(np.min(arr)))}{format_value(float(np.max(arr)))}
\n
\n" - - # --- Verbose: individual per-slice results (collapsible as a unit) --- - if verbose: - n_items = len(metrics_list) - html += f""" -
- Individual Results ({n_items} slices) -""" - for m in metrics_list: - source = extract_slice_id(m.get("source_file", "unknown")) - m_status = m.get("overall_status", "unknown") - html += f""" -
- - - {source} - - -""" - for name, data in m.get("metrics", {}).items(): - if isinstance(data, dict): - value = data.get("value", "N/A") - unit = data.get("unit", "") or "" - status = data.get("status", "info") - html += f""" - - - - -""" - html += """
- - {name} - {format_value(value)}{(" " + unit) if unit else ""}
-
-""" - html += "
\n" - - # --- Per-step preview image gallery --- - preview_category = STEP_PREVIEW_CATEGORY.get(step_name) - if preview_category: - step_imgs = images.get(preview_category, []) - if step_imgs: - html += render_image_gallery_html( - step_imgs, mode=image_mode, category=preview_category, max_width=max_thumb_width - ) - - html += "
\n" - - # Slice interpolation section (only if interpolation happened) - if interpolation: - html += _render_interpolation_section_html(interpolation, image_mode=image_mode, max_thumb_width=max_thumb_width) - - # Diagnostics section (only if diagnostic data was found) - if diagnostics: - html += '\n
\n' - html += "

Diagnostic Outputs

\n" - html += ( - '

' - "Additional diagnostic analyses enabled in the pipeline configuration.

\n" - ) - for diag_key, diag in diagnostics.items(): - label = diag["label"] - description = diag["description"] - json_data = diag.get("json_data", []) - diag_images = diag.get("images", []) - - html += '
\n' - html += f'
{label}
\n' - html += f'
{description}
\n' - - # Render key JSON fields - if json_data: - # Collect interesting numeric/scalar fields from first entry - first = json_data[0] - numeric_fields = {} - for k, v in first.items(): - if k.startswith("_") or k == "slice_id": - continue - if isinstance(v, (int, float, str, bool)): - numeric_fields[k] = v - elif isinstance(v, dict): - # like scale_factors / residuals / distortions sub-dicts - for sk, sv in v.items(): - if isinstance(sv, (int, float, str, bool)): - numeric_fields[f"{k}.{sk}"] = sv - - if numeric_fields: - html += ' \n' - for k, v in list(numeric_fields.items())[:20]: - html += ( - f" " - f"\n" - ) - html += "
{k}{format_value(v) if isinstance(v, (int, float)) else v}
\n" - - # Render diagnostic image gallery - if diag_images: - # In zip mode images are referenced via relative paths; in embed mode as data URIs - cat_key = f"diag_{diag_key}" - gallery = render_image_gallery_html( - diag_images, mode=image_mode, category=cat_key, _label=f"{label} Images", max_width=max_thumb_width - ) - html += gallery - - html += "
\n" - html += "
\n" - - html += """ - - -""" - return html - - -def generate_text_report( - aggregated: dict[str, list[dict]], - title: str, - verbose: bool = False, - interpolation: dict | None = None, -) -> str: - """Generate a plain text report from aggregated metrics.""" - aggregated = sort_steps(aggregated) - - lines = [] - lines.append("=" * 70) - lines.append(title.center(70)) - lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}".center(70)) - lines.append("=" * 70) - lines.append("") - - _, error_count, warning_count, ok_count = compute_overall_status(aggregated) - - lines.append("SUMMARY") - lines.append("-" * 70) - lines.append(f" Pipeline Steps: {len(aggregated)}") - lines.append(f" Total Metrics Files: {sum(len(v) for v in aggregated.values())}") - lines.append( - f" Status: {get_status_emoji('ok')} OK: {ok_count} " - f"{get_status_emoji('warning')} Warnings: {warning_count} " - f"{get_status_emoji('error')} Errors: {error_count}" - ) - lines.append("") - - for step_name, metrics_list in aggregated.items(): - summary = compute_summary_statistics(metrics_list) - step_status = get_step_status(metrics_list) - - lines.append("") - lines.append(f"{get_status_emoji(step_status)} {step_name.replace('_', ' ').upper()}") - lines.append("-" * 70) - lines.append(f" Items: {summary['count']} | Status: {step_status.upper()}") - - # Quality metrics stats - quality_metrics, _ = separate_metrics_by_type(metrics_list) - if quality_metrics: - lines.append("") - lines.append(" Quality Metrics:") - lines.append(f" {'Metric':<25} {'Mean':>12} {'Median':>12} {'Std':>12} {'Min':>12} {'Max':>12}") - lines.append(" " + "-" * 77) - for metric_name, mdata in quality_metrics.items(): - entries = mdata["entries"] - numeric_vals = [e["value"] for e in entries if isinstance(e.get("value"), (int, float))] - if not numeric_vals: - continue - arr = np.array(numeric_vals) - name = metric_name[:25] - lines.append( - f" {name:<25} {format_value(float(np.mean(arr))):>12} " - f"{format_value(float(np.median(arr))):>12} " - f"{format_value(float(np.std(arr))):>12} " - f"{format_value(float(np.min(arr))):>12} " - f"{format_value(float(np.max(arr))):>12}" - ) - - all_warnings, all_errors = collect_issues(metrics_list) - - if all_errors: - lines.append("") - lines.append(f" {get_status_emoji('error')} ERRORS:") - for g in group_issues(all_errors): - if g["count"] == 1: - lines.append(f" - {g['details'][0]}") - else: - vals = g["values"] - val_str = f"range {min(vals):.3g}–{max(vals):.3g}" if vals else f"{g['count']} occurrences" - lines.append(f" - {g['metric']}: {g['count']} slices ({val_str})") - - if all_warnings: - lines.append("") - lines.append(f" {get_status_emoji('warning')} WARNINGS:") - for g in group_issues(all_warnings): - if g["count"] == 1: - lines.append(f" - {g['details'][0]}") - else: - vals = g["values"] - val_str = f"range {min(vals):.3g}–{max(vals):.3g}" if vals else f"{g['count']} occurrences" - lines.append(f" - {g['metric']}: {g['count']} slices ({val_str})") - - if verbose: - lines.append("") - lines.append(" Individual Results:") - for m in metrics_list: - source = extract_slice_id(m.get("source_file", "unknown")) - m_status = m.get("overall_status", "unknown") - lines.append(f" {get_status_emoji(m_status)} {source}") - for name, data in m.get("metrics", {}).items(): - if isinstance(data, dict): - value = data.get("value", "N/A") - unit = data.get("unit", "") or "" - lines.append(f" {name}: {format_value(value)}{(' ' + unit) if unit else ''}") - - if interpolation: - lines.append(_render_interpolation_section_text(interpolation)) - - lines.append("") - lines.append("=" * 70) - lines.append("End of Report".center(70)) - lines.append("=" * 70) - - return "\n".join(lines) - - -def main() -> None: - """Run function.""" - parser = _build_arg_parser() - args = parser.parse_args() - - input_dir = Path(args.input_dir) - output_file = Path(args.output_report) - - if not input_dir.exists(): - parser.error(f"Input directory does not exist: {input_dir}") - - # Determine format - if args.format == "auto": - suffix = output_file.suffix.lower() - if suffix == ".html": - output_format = "html" - elif suffix == ".zip": - output_format = "zip" - else: - output_format = "text" - else: - output_format = args.format - - # Aggregate metrics from all subdirectories - print(f"Scanning for metrics files in: {input_dir}") - aggregated = aggregate_metrics(input_dir) - - if not aggregated: - print("No metrics files found. Checking for process subdirectories...") - for subdir in input_dir.iterdir(): - if subdir.is_dir(): - sub_aggregated = aggregate_metrics(subdir) - for step, metrics in sub_aggregated.items(): - if step not in aggregated: - aggregated[step] = [] - aggregated[step].extend(metrics) - - if not aggregated: - print("Warning: No metrics files found in the input directory.") - print("Make sure the pipeline has been run with metrics collection enabled.") - aggregated = {} - - print(f"Found {sum(len(v) for v in aggregated.values())} metrics files across {len(aggregated)} pipeline steps") - - # Discover preview images — only for zip bundles; HTML is always image-free - images: dict[str, list[Path]] = {} - if output_format == "zip" and not args.no_images: - images = discover_images(input_dir, overview_png=args.overview_png, annotated_png=args.annotated_png) - total_imgs = sum(len(v) for v in images.values()) - if total_imgs: - print(f"Found {total_imgs} preview image(s) to bundle in zip") - - # Zip bundles use relative image links; standalone HTML has no images - image_mode = "link" - - # Compute cross-slice aggregate trends - trends = compute_cross_slice_trends(aggregated) - if trends: - n_trend_groups = len(trends) - print(f"Computed {n_trend_groups} cross-slice trend group(s)") - - # Discover slice-interpolation outputs - interpolation = discover_interpolation_data(input_dir) - if interpolation: - s = interpolation["summary"] - print(f"Found interpolation output(s): {s['count']} slice(s), {s['n_with_fallback']} with fallback") - if output_format == "zip" and not args.no_images and interpolation.get("images"): - images["diag_interpolate_missing_slice"] = list(interpolation["images"]) - - # Discover diagnostic outputs - diagnostics = discover_diagnostic_data(input_dir) - if diagnostics: - print(f"Found {len(diagnostics)} diagnostic output(s): {', '.join(diagnostics.keys())}") - # In zip mode, include diagnostic images in the bundle - if output_format == "zip" and not args.no_images: - for diag_key, diag in diagnostics.items(): - cat_key = f"diag_{diag_key}" - diag_imgs = diag.get("images", []) - if diag_imgs: - images[cat_key] = diag_imgs - - # Generate report - output_file.parent.mkdir(parents=True, exist_ok=True) - if output_format in ("html", "zip"): - report = generate_html_report( - aggregated, - args.title, - args.verbose, - images=images, - image_mode=image_mode, - max_overview_width=args.max_overview_width, - max_thumb_width=args.max_thumb_width, - trends=trends if trends else None, - diagnostics=diagnostics if diagnostics else None, - interpolation=interpolation, - ) - if output_format == "zip": - if output_file.suffix.lower() != ".zip": - output_file = output_file.with_suffix(".zip") - generate_zip_bundle(report, images, output_file) - else: - with Path(output_file).open("w") as f: - f.write(report) - else: - report = generate_text_report(aggregated, args.title, args.verbose, interpolation=interpolation) - with Path(output_file).open("w") as f: - f.write(report) - - print(f"Report saved to: {output_file}") - - _, error_count, warning_count, _ = compute_overall_status(aggregated) - - if error_count > 0: - print(f"\n{get_status_emoji('error')} {error_count} error(s) found - please review the report") - elif warning_count > 0: - print(f"\n{get_status_emoji('warning')} {warning_count} warning(s) found - please review the report") - else: - print(f"\n{get_status_emoji('ok')} All checks passed") - - -if __name__ == "__main__": - main() diff --git a/scripts/tests/test_align_to_ras.py b/scripts/tests/test_align_to_ras.py new file mode 100644 index 00000000..05626cbd --- /dev/null +++ b/scripts/tests/test_align_to_ras.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +"""Tests for ``scripts/linum_align_to_ras.py``. + +The script is loaded via :mod:`importlib` so we can test its pure-Python helper +functions (no ``zarr`` I/O) without relying on the console entry point. +""" + +from __future__ import annotations + +import importlib.util +from pathlib import Path + +import numpy as np +import pytest +import SimpleITK as sitk + +SCRIPT_PATH = Path(__file__).resolve().parents[1] / "linum_align_to_ras.py" + + +@pytest.fixture(scope="module") +def align_module(): + """Load ``linum_align_to_ras.py`` as a module.""" + spec = importlib.util.spec_from_file_location("linum_align_to_ras", SCRIPT_PATH) + assert spec is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +# --------------------------------------------------------------------------- +# CLI help +# --------------------------------------------------------------------------- + + +def test_help(script_runner): + ret = script_runner.run(["linum_align_to_ras.py", "--help"]) + assert ret.success + + +# --------------------------------------------------------------------------- +# sitk_transform_to_affine_matrix +# --------------------------------------------------------------------------- + + +class TestSitkTransformToAffine: + def test_identity_transform_yields_identity_matrix(self, align_module): + t = sitk.Euler3DTransform() + mat = align_module.sitk_transform_to_affine_matrix(t) + assert mat.shape == (4, 4) + np.testing.assert_allclose(mat, np.eye(4), atol=1e-12) + + def test_pure_translation_is_permuted_to_zyx(self, align_module): + t = sitk.Euler3DTransform() + # SITK translation in (X, Y, Z) = (1, 2, 3) + t.SetTranslation((1.0, 2.0, 3.0)) + mat = align_module.sitk_transform_to_affine_matrix(t) + # After conversion to NGFF (Z, Y, X) order, translation must be (3, 2, 1). + np.testing.assert_allclose(mat[:3, 3], [3.0, 2.0, 1.0], atol=1e-12) + np.testing.assert_allclose(mat[:3, :3], np.eye(3), atol=1e-12) + + def test_rotation_is_permuted_to_zyx(self, align_module): + """A pure rotation around SITK X (=numpy axis 2) should appear as a + rotation around the last axis of the NGFF matrix (axis Z→row 2).""" + t = sitk.Euler3DTransform() + t.SetRotation(np.pi / 4, 0.0, 0.0) # rotate around SITK X + mat = align_module.sitk_transform_to_affine_matrix(t) + # Rotation around numpy axis 2 (X in NGFF) leaves column/row 2 unchanged. + assert mat.shape == (4, 4) + np.testing.assert_allclose(mat[2, 2], 1.0, atol=1e-9) + np.testing.assert_allclose(mat[2, :3], [0.0, 0.0, 1.0], atol=1e-9) + np.testing.assert_allclose(mat[:3, 2], [0.0, 0.0, 1.0], atol=1e-9) + + +# --------------------------------------------------------------------------- +# compute_centered_reference_and_transform +# --------------------------------------------------------------------------- + + +class TestComputeCenteredReferenceAndTransform: + @staticmethod + def _make_moving(shape=(20, 20, 20), spacing=(0.1, 0.1, 0.1)): + """A small ellipsoid brain so the resampled output has known volume.""" + z, y, x = np.indices(shape, dtype=np.float32) + cz, cy, cx = shape[0] / 2, shape[1] / 2, shape[2] / 2 + rz, ry, rx = shape[0] * 0.3, shape[1] * 0.3, shape[2] * 0.3 + mask = ((z - cz) / rz) ** 2 + ((y - cy) / ry) ** 2 + ((x - cx) / rx) ** 2 < 1 + arr = mask.astype(np.float32) + img = sitk.GetImageFromArray(arr) + img.SetSpacing((spacing[2], spacing[1], spacing[0])) # SITK XYZ + img.SetOrigin((0.0, 0.0, 0.0)) + img.SetDirection((1, 0, 0, 0, 1, 0, 0, 0, 1)) + return img, int(mask.sum()) + + def test_reference_origin_is_zero(self, align_module): + moving, _ = self._make_moving() + t = sitk.Euler3DTransform() + ref, _ = align_module.compute_centered_reference_and_transform(moving, t) + assert ref.GetOrigin() == pytest.approx((0.0, 0.0, 0.0)) + + def test_reference_spacing_matches_moving_by_default(self, align_module): + moving, _ = self._make_moving(spacing=(0.125, 0.1, 0.2)) + t = sitk.Euler3DTransform() + ref, _ = align_module.compute_centered_reference_and_transform(moving, t) + assert ref.GetSpacing() == pytest.approx(moving.GetSpacing()) + + def test_reference_spacing_override(self, align_module): + moving, _ = self._make_moving() + t = sitk.Euler3DTransform() + ref, _ = align_module.compute_centered_reference_and_transform(moving, t, output_spacing=(0.05, 0.05, 0.05)) + assert ref.GetSpacing() == pytest.approx((0.05, 0.05, 0.05)) + + def test_identity_transform_roundtrip(self, align_module): + """For T = identity, resampling through the composite should recover + the original brain volume (no information loss).""" + moving, brain_voxels = self._make_moving() + t = sitk.Euler3DTransform() # identity + ref, composite = align_module.compute_centered_reference_and_transform(moving, t) + + resampler = sitk.ResampleImageFilter() + resampler.SetReferenceImage(ref) + resampler.SetTransform(composite) + resampler.SetInterpolator(sitk.sitkLinear) + resampler.SetDefaultPixelValue(0.0) + out = resampler.Execute(moving) + arr = sitk.GetArrayFromImage(out) + + nonzero = (arr > 0.5).sum() + assert abs(int(nonzero) - brain_voxels) / brain_voxels < 0.05 + + def test_rotation_preserves_brain_volume(self, align_module): + """A rigid rotation + translation preserves the brain voxel count.""" + moving, brain_voxels = self._make_moving() + + t = sitk.Euler3DTransform() + t.SetRotation(np.deg2rad(15), np.deg2rad(-10), np.deg2rad(30)) + t.SetTranslation((0.3, -0.2, 0.1)) + center_mm = moving.TransformContinuousIndexToPhysicalPoint([s / 2.0 for s in moving.GetSize()]) + t.SetCenter(center_mm) + + ref, composite = align_module.compute_centered_reference_and_transform(moving, t) + + resampler = sitk.ResampleImageFilter() + resampler.SetReferenceImage(ref) + resampler.SetTransform(composite) + resampler.SetInterpolator(sitk.sitkLinear) + resampler.SetDefaultPixelValue(0.0) + out = resampler.Execute(moving) + arr = sitk.GetArrayFromImage(out) + + nonzero = (arr > 0.5).sum() + # Rigid transform preserves volume; allow 5% tolerance for interpolation. + assert abs(int(nonzero) - brain_voxels) / brain_voxels < 0.05 + + def test_composite_transform_semantics(self, align_module): + """The composite must compute T(shift(p)), not shift(T(p)). + + Sample points at the origin of output space (0, 0, 0) and along each + axis; verify the composite maps them to the same physical point as the + *manually* composed ``T ∘ shift``. + """ + moving, _ = self._make_moving() + t = sitk.Euler3DTransform() + t.SetRotation(np.deg2rad(20), 0.0, np.deg2rad(-5)) + t.SetTranslation((0.25, 0.1, -0.15)) + center_mm = moving.TransformContinuousIndexToPhysicalPoint([s / 2.0 for s in moving.GetSize()]) + t.SetCenter(center_mm) + + ref, composite = align_module.compute_centered_reference_and_transform(moving, t) + + # Rebuild the shift transform the helper used: offset == pts_min + # recovered from the composite (last-added transform is the shift). + sample_points = [ + (0.0, 0.0, 0.0), + tuple(ref.GetSpacing()), + tuple(np.array(ref.GetSize(), dtype=float) * ref.GetSpacing() / 2), + ] + for p in sample_points: + actual = np.array(composite.TransformPoint(p)) + # Compose T(shift(p)) manually. Retrieve shift from the 2-member + # composite: ITK applies transforms in reverse order, so nth = last + # added = shift. + shift = composite.GetNthTransform(1) + expected = np.array(t.TransformPoint(shift.TransformPoint(p))) + np.testing.assert_allclose(actual, expected, atol=1e-9) + + # Sanity check that the *wrong* ordering ``shift(T(p))`` does NOT + # match (unless the transform is degenerate). + wrong = np.array(shift.TransformPoint(t.TransformPoint(p))) + assert not np.allclose(actual, wrong, atol=1e-4), "Composite accidentally matches the buggy order shift(T(p))" + + +# --------------------------------------------------------------------------- +# store_transform_in_metadata (skipped unless zarr fixture available) +# --------------------------------------------------------------------------- + + +class TestStoreTransformInMetadata: + """Smoke test: ensure the metadata writer builds a valid affine block.""" + + def test_affine_block_written_to_zattrs(self, align_module, tmp_path): + import json + + # Create a minimal OME-Zarr v0.4 directory with a .zattrs file. + store_path = tmp_path / "test.ome.zarr" + store_path.mkdir() + initial_attrs = { + "multiscales": [ + { + "version": "0.4", + "axes": [ + {"name": "z", "type": "space", "unit": "millimeter"}, + {"name": "y", "type": "space", "unit": "millimeter"}, + {"name": "x", "type": "space", "unit": "millimeter"}, + ], + "datasets": [{"path": "0", "coordinateTransformations": []}], + } + ] + } + (store_path / ".zattrs").write_text(json.dumps(initial_attrs)) + + t = sitk.Euler3DTransform() + t.SetTranslation((0.5, 1.5, 2.5)) + + align_module.store_transform_in_metadata(str(store_path), t) + + with (store_path / ".zattrs").open() as f: + metadata = json.load(f) + + ms = metadata["multiscales"][0] + ds = ms["datasets"][0] + ctfs = ds["coordinateTransformations"] + affines = [c for c in ctfs if c.get("type") == "affine"] + assert len(affines) == 1 + mat = np.array(affines[0]["affine"]).reshape(4, 4) + assert mat.shape == (4, 4) + np.testing.assert_allclose(mat[3], [0, 0, 0, 1], atol=1e-12) + # Translation must be permuted to NGFF (Z, Y, X) ordering. + np.testing.assert_allclose(mat[:3, 3], [2.5, 1.5, 0.5], atol=1e-12)