diff --git a/linumpy/reference/allen.py b/linumpy/reference/allen.py
index fc8ea7f3..9b403992 100644
--- a/linumpy/reference/allen.py
+++ b/linumpy/reference/allen.py
@@ -1,7 +1,10 @@
"""Methods to download data from the Allen Institute."""
+from collections.abc import Callable, Sequence
from pathlib import Path
+from typing import Any
+import numpy as np
import requests
import SimpleITK as sitk
from tqdm import tqdm
@@ -9,6 +12,40 @@
AVAILABLE_RESOLUTIONS = [10, 25, 50, 100]
+def numpy_to_sitk_image(volume: np.ndarray, spacing: tuple | Sequence, cast_dtype: type | None = None) -> sitk.Image:
+ """Convert numpy array (Z, Y, X) to SimpleITK image format.
+
+ Parameters
+ ----------
+ volume : np.ndarray
+ 3D volume with shape (Z, Y, X) matching the project-wide convention
+ (axis 0 = Z/depth, axis 1 = Y/row, axis 2 = X/column).
+ spacing : tuple
+ Voxel spacing in mm as (res_z, res_y, res_x).
+ cast_dtype : numpy dtype or None
+ If provided, cast the volume to this dtype before creating the SITK image
+ (useful for registration where float32 is expected). If None, preserve
+ the input numpy dtype.
+
+ Returns
+ -------
+ sitk.Image
+ SimpleITK image with proper spacing and orientation
+ """
+ # sitk.GetImageFromArray interprets a numpy array with shape (Z, Y, X) as a
+ # SITK image with size (X, Y, Z), so no transpose is needed. The SITK call
+ # copies the buffer into its own storage, so we only allocate an extra
+ # numpy array when an explicit dtype cast is requested.
+ vol_for_sitk = volume.astype(cast_dtype, copy=False) if cast_dtype is not None else volume
+ vol_sitk = sitk.GetImageFromArray(vol_for_sitk)
+ # Spacing: SimpleITK uses (X, Y, Z) = (width, height, depth).
+ # Our spacing is (res_z, res_y, res_x), so SITK spacing is (res_x, res_y, res_z).
+ vol_sitk.SetSpacing([spacing[2], spacing[1], spacing[0]])
+ vol_sitk.SetOrigin([0, 0, 0])
+ vol_sitk.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1])
+ return vol_sitk
+
+
def download_template(resolution: int, cache: bool = True, cache_dir: str = ".data/") -> sitk.Image:
"""Download a 3D average mouse brain.
@@ -41,7 +78,7 @@ def download_template(resolution: int, cache: bool = True, cache_dir: str = ".da
if not (nrrd_file.is_file()):
# Download the template
response = requests.get(url, stream=True)
- with nrrd_file.open("wb") as f:
+ with Path(nrrd_file).open("wb") as f:
for data in tqdm(response.iter_content()):
f.write(data)
@@ -53,3 +90,395 @@ def download_template(resolution: int, cache: bool = True, cache_dir: str = ".da
nrrd_file.unlink() # Removes the nrrd file
return vol
+
+
+def download_template_ras_aligned(resolution: int, cache: bool = True, cache_dir: str = ".data/") -> sitk.Image:
+ """Download a 3D average mouse brain and align it to RAS+ orientation.
+
+ The Allen CCF v3 template is stored in PIR orientation
+ (SITK axes ``(X, Y, Z) = (AP, DV, ML)`` with ``+X = Posterior``,
+ ``+Y = Inferior``, ``+Z = Right``). Converting to RAS+
+ (``+X = Right``, ``+Y = Anterior``, ``+Z = Superior``) requires
+ ``PermuteAxes((2, 0, 1))`` followed by flipping **both** the Y and Z
+ axes (I → S and P → A).
+
+ Parameters
+ ----------
+ resolution
+ Allen template resolution in micron. Must be 10, 25, 50 or 100.
+ cache
+ Keep the downloaded volume in cache
+ cache_dir
+ Cache directory
+
+ Returns
+ -------
+ Allen average mouse brain in RAS+ orientation.
+ """
+ vol = download_template(resolution, cache, cache_dir)
+
+ # Preparing the affine to align the template in the RAS+
+ r_mm = resolution / 1e3 # Convert the resolution from micron to mm
+ vol.SetSpacing([r_mm] * 3) # Set the spacing in mm
+ # Ensure origin/direction are standardized so physical coordinates are stable
+ vol.SetOrigin([0.0, 0.0, 0.0])
+ vol.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1])
+
+ # Convert PIR → RAS:
+ # PermuteAxes((2, 0, 1)) maps (P, I, R) → (R, P, I)
+ # Flip Y (P → A) and Z (I → S) to reach (R, A, S).
+ vol = sitk.PermuteAxes(vol, (2, 0, 1))
+ vol = sitk.Flip(vol, (False, True, True))
+ # After permuting/flipping, also ensure origin/direction are identity/zero
+ vol.SetOrigin([0.0, 0.0, 0.0])
+ vol.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1])
+
+ return vol
+
+
+def register_3d_rigid_to_allen(
+ moving_image: np.ndarray,
+ moving_spacing: tuple,
+ allen_resolution: int = 100,
+ metric: str = "MI",
+ max_iterations: int = 1000,
+ verbose: bool = False,
+ progress_callback: Callable[[Any], None] | None = None,
+ initial_rotation_deg: tuple = (0.0, 0.0, 0.0),
+) -> tuple:
+ """Perform 3D rigid registration of a brain volume to the Allen atlas.
+
+ Parameters
+ ----------
+ moving_image : np.ndarray
+ 3D brain volume to register (shape: Z, Y, X)
+ moving_spacing : tuple
+ Voxel spacing in mm (res_z, res_y, res_x)
+ allen_resolution : int
+ Allen template resolution in micron (default: 100)
+ metric : str
+ Similarity metric: 'MI' (mutual information), 'MSE', 'CC' (correlation),
+ or 'AntsCC' (ANTS correlation)
+ max_iterations : int
+ Maximum number of iterations
+ verbose : bool
+ Print registration progress
+ progress_callback : callable, optional
+ Callback function called on each iteration with the registration method.
+ Function signature: callback(registration_method)
+ initial_rotation_deg : tuple, optional
+ Initial rotation in degrees (rx, ry, rz) applied before optimization.
+
+ Returns
+ -------
+ transform : sitk.Euler3DTransform
+ Rigid transform to align moving_image to Allen atlas
+ stop_condition : str
+ Optimizer stopping condition
+ error : float
+ Final registration metric value
+ """
+ # Download and prepare Allen atlas in RAS orientation
+ allen_atlas = download_template_ras_aligned(allen_resolution, cache=True)
+
+ # If the moving image is coarser than the Allen atlas along any axis,
+ # downsample the atlas to match the moving resolution. The registration
+ # cost is dominated by the fixed (Allen) image size, so downsampling the
+ # atlas up-front is much cheaper than upsampling moving to a finer grid
+ # that carries no additional information.
+ moving_min_spacing_mm = min(moving_spacing)
+ allen_spacing_mm = allen_atlas.GetSpacing()
+ allen_min_spacing_mm = min(allen_spacing_mm)
+ if moving_min_spacing_mm > allen_min_spacing_mm * 1.2:
+ target_spacing_mm = float(moving_min_spacing_mm)
+ allen_size = allen_atlas.GetSize()
+ new_size = [max(1, round(sz * sp / target_spacing_mm)) for sz, sp in zip(allen_size, allen_spacing_mm, strict=False)]
+ ref = sitk.Image(new_size, allen_atlas.GetPixelIDValue())
+ ref.SetOrigin(allen_atlas.GetOrigin())
+ ref.SetDirection(allen_atlas.GetDirection())
+ ref.SetSpacing((target_spacing_mm,) * 3)
+ resampler = sitk.ResampleImageFilter()
+ resampler.SetReferenceImage(ref)
+ resampler.SetInterpolator(sitk.sitkLinear)
+ resampler.SetDefaultPixelValue(0)
+ allen_atlas = resampler.Execute(allen_atlas)
+ if verbose:
+ print(
+ f"Downsampled Allen atlas to match moving spacing: "
+ f"{allen_spacing_mm} mm → {allen_atlas.GetSpacing()} mm, "
+ f"size {allen_size} → {allen_atlas.GetSize()}"
+ )
+
+ # Crop moving image to tissue bounding box to reduce volume size.
+ # Large motor drift during acquisition inflates the canvas with empty space,
+ # causing the Allen-domain resampling to clip away brain tissue. Cropping
+ # first keeps the volume compact so most of the brain survives resampling,
+ # giving the optimizer a much better cost-function landscape.
+ margin_voxels = 10
+ crop_origin_mm = (0.0, 0.0, 0.0) # physical offset in (Z, Y, X) order
+ nonzero_coords = np.nonzero(moving_image)
+ if len(nonzero_coords[0]) > 0:
+ bbox_slices = tuple(
+ slice(
+ max(0, int(dim.min()) - margin_voxels),
+ min(moving_image.shape[ax], int(dim.max()) + margin_voxels + 1),
+ )
+ for ax, dim in enumerate(nonzero_coords)
+ )
+ crop_origin_mm = (
+ bbox_slices[0].start * moving_spacing[0],
+ bbox_slices[1].start * moving_spacing[1],
+ bbox_slices[2].start * moving_spacing[2],
+ )
+ cropped = moving_image[bbox_slices]
+ if verbose:
+ print(f"Cropped tissue bounding box: {moving_image.shape} -> {cropped.shape}")
+ moving_image = cropped
+
+ # Convert moving image to SimpleITK format.
+ # Origin stays at (0,0,0) so the compact brain sits at the start of physical
+ # space and overlaps with the Allen atlas domain during resampling. The crop
+ # offset is added to the final transform's translation after registration so
+ # the transform remains valid for the original (uncropped) full volume.
+ moving_sitk = numpy_to_sitk_image(moving_image, moving_spacing)
+
+ # Compute a preliminary brain centre BEFORE any resampling.
+ # This is used as the fallback only when needs_resample=False (images already
+ # share the same physical space). When resampling IS needed, this value is
+ # overwritten below with the centroid of the clipped brain within the Allen
+ # domain, because the full-brain geometric centre can be tens of mm outside
+ # the Allen atlas extent and would produce a translation that maps every
+ # Allen voxel outside the resampled moving image buffer.
+ original_moving_size = moving_sitk.GetSize()
+ original_moving_center_idx = [s / 2.0 for s in original_moving_size]
+ original_moving_center = np.array(moving_sitk.TransformContinuousIndexToPhysicalPoint(original_moving_center_idx))
+
+ # Resample moving image to match Allen atlas spacing and size for better registration.
+ # NOTE: we deliberately keep the original moving center computed above so that the
+ # centre-aligned fallback initialisation is always correct even after resampling.
+ allen_spacing = allen_atlas.GetSpacing()
+ allen_size = allen_atlas.GetSize()
+ moving_spacing_sitk = moving_sitk.GetSpacing()
+ moving_size_sitk = moving_sitk.GetSize()
+
+ # Check if resampling is needed (if spacing differs significantly or sizes are very different)
+ spacing_ratio = np.array(allen_spacing) / np.array(moving_spacing_sitk)
+ size_ratio = np.array(allen_size, dtype=float) / np.array(moving_size_sitk, dtype=float)
+
+ # Resample if spacing differs by more than 10% or if volumes are very different sizes
+ needs_resample = np.any(np.abs(spacing_ratio - 1.0) > 0.1) or np.any(size_ratio < 0.5) or np.any(size_ratio > 2.0)
+
+ if needs_resample:
+ if verbose:
+ print(
+ f"Resampling moving image from {moving_spacing_sitk} mm, size {moving_size_sitk} "
+ f"to {allen_spacing} mm, size {allen_size}"
+ )
+ resampler = sitk.ResampleImageFilter()
+ resampler.SetReferenceImage(allen_atlas)
+ resampler.SetInterpolator(sitk.sitkLinear)
+ resampler.SetDefaultPixelValue(0)
+ moving_sitk = resampler.Execute(moving_sitk)
+
+ # Recompute the effective brain centre from the RESAMPLED image.
+ # The pre-resampling centre can lie far outside the Allen domain (e.g. a
+ # large 25 µm brain whose geometric centre is at ~37 mm, while the Allen
+ # atlas only spans ~11 mm). Using that centre directly gives a translation
+ # of +31 mm, which maps every Allen voxel outside the moving image buffer.
+ # Instead, use the centroid of the non-zero (brain-tissue) voxels that
+ # survived the clipping into the Allen domain.
+ moving_arr = sitk.GetArrayFromImage(moving_sitk) # shape (Z, Y, X) in numpy
+ nonzero_idx = np.argwhere(moving_arr > 0) # rows are (z, y, x)
+ if len(nonzero_idx) > 0:
+ centroid_zyx = nonzero_idx.mean(axis=0)
+ # SITK index order is (x, y, z), reverse of numpy (z, y, x)
+ centroid_xyz = [float(centroid_zyx[2]), float(centroid_zyx[1]), float(centroid_zyx[0])]
+ original_moving_center = np.array(moving_sitk.TransformContinuousIndexToPhysicalPoint(centroid_xyz))
+ if verbose:
+ print(f"Resampled brain centroid (physical): {original_moving_center} mm")
+ # If all voxels are zero (brain entirely outside Allen domain), keep
+ # the pre-resampling centre and accept a potentially poor initialization.
+
+ # Normalize images for better registration
+ fixed_image = sitk.Normalize(allen_atlas)
+ moving_image_sitk = sitk.Normalize(moving_sitk)
+
+ if verbose:
+ print(f"Fixed (Allen) image: size={fixed_image.GetSize()}, spacing={fixed_image.GetSpacing()}")
+ print(f"Moving (brain) image: size={moving_image_sitk.GetSize()}, spacing={moving_image_sitk.GetSpacing()}")
+
+ # Initialize registration
+ registration_method = sitk.ImageRegistrationMethod()
+
+ # Set metric
+ # Note: For correlation-based metrics, negative values are possible
+ # The optimizer will maximize MI/CC and minimize MSE
+ if metric.upper() == "MI":
+ registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
+ elif metric.upper() == "MSE":
+ registration_method.SetMetricAsMeanSquares()
+ elif metric.upper() == "CC":
+ registration_method.SetMetricAsCorrelation()
+ elif metric.upper() == "ANTSCC":
+ registration_method.SetMetricAsANTSNeighborhoodCorrelation(radius=20)
+ else:
+ raise ValueError(f"Unknown metric: {metric}. Choose from: MI, MSE, CC, AntsCC")
+
+ # Set metric sampling - use regular sampling for reproducibility and speed
+ registration_method.SetMetricSamplingStrategy(registration_method.REGULAR)
+ registration_method.SetMetricSamplingPercentage(0.25) # 25% of pixels is usually sufficient
+
+ # Set optimizer with conservative parameters
+ # Use smaller learning rate and steps to prevent overshooting
+ learning_rate = 0.5 # Smaller learning rate for stability
+ min_step = 0.0001
+ registration_method.SetOptimizerAsRegularStepGradientDescent(
+ learningRate=learning_rate,
+ minStep=min_step,
+ numberOfIterations=max_iterations,
+ relaxationFactor=0.5,
+ gradientMagnitudeTolerance=1e-8,
+ )
+
+ # Use physical shift for scaling - more appropriate for physical coordinate registration
+ # This computes scales based on how a 1mm shift affects the metric
+ registration_method.SetOptimizerScalesFromPhysicalShift()
+
+ # Multi-resolution approach - start coarse, refine progressively
+ # More levels for robustness
+ registration_method.SetShrinkFactorsPerLevel([8, 4, 2, 1])
+ registration_method.SetSmoothingSigmasPerLevel([4, 2, 1, 0])
+ registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
+
+ # Initialize rigid transform with guaranteed overlap.
+ # Use the ORIGINAL moving image centre (before any resampling) so that
+ # the centre-aligned fallback always produces a meaningful initial translation
+ # regardless of the resolution/size relationship between the two images.
+ initial_transform = sitk.Euler3DTransform()
+
+ # Calculate image centres in physical space
+ fixed_size = fixed_image.GetSize()
+ fixed_center_idx = [s / 2.0 for s in fixed_size]
+ fixed_center = np.array(fixed_image.TransformContinuousIndexToPhysicalPoint(fixed_center_idx))
+
+ # Translation to align brain centre with Allen centre (ensures initial overlap).
+ # ITK transform maps fixed→moving: T(p) = R(p - c) + c + t
+ # For identity rotation and c=fixed_center: T(fixed_center) = fixed_center + t
+ # We need T(fixed_center) = original_moving_center, so t = moving_center - fixed_center.
+ translation = tuple(original_moving_center - fixed_center)
+
+ # Set center of rotation to fixed image center
+ initial_transform.SetCenter(fixed_center)
+
+ # Convert initial rotation from degrees to radians
+ rx_rad = np.deg2rad(initial_rotation_deg[0])
+ ry_rad = np.deg2rad(initial_rotation_deg[1])
+ rz_rad = np.deg2rad(initial_rotation_deg[2])
+
+ # Set translation to align centers and apply initial rotation
+ initial_transform.SetTranslation(translation)
+ initial_transform.SetRotation(rx_rad, ry_rad, rz_rad)
+
+ if verbose:
+ print(f"Initial center alignment: fixed={fixed_center}, moving (original)={original_moving_center}")
+ print(f"Translation to align centers: {translation}")
+ if any(r != 0 for r in initial_rotation_deg):
+ print(f"Initial rotation (deg): {initial_rotation_deg}")
+
+ # Only try MOMENTS initialization if no initial rotation was specified
+ # (user-specified rotation takes precedence) and the image was NOT resampled
+ # into the Allen domain. After resampling, the brain occupies only a small
+ # corner of the 640³ Allen image; sitk.Normalize then gives the large
+ # zero-padded background a uniform negative value that dominates the
+ # centre-of-mass computation, producing translation ≈ 0 which places every
+ # sample point outside the brain buffer.
+ if all(r == 0 for r in initial_rotation_deg) and not needs_resample:
+ try:
+ # Use MOMENTS initialization which is more robust
+ init_transform = sitk.Euler3DTransform()
+ init_transform = sitk.CenteredTransformInitializer(
+ fixed_image, moving_image_sitk, init_transform, sitk.CenteredTransformInitializerFilter.MOMENTS
+ )
+ # Verify the initialized transform has reasonable translation
+ init_params = init_transform.GetParameters()
+ init_translation = np.array(init_params[3:6])
+
+ # Check if the initialized transform is reasonable (translation not too large)
+ # If translation is reasonable, use it; otherwise use our center-aligned one
+ translation_magnitude = np.linalg.norm(init_translation)
+ fixed_size_mm = np.array(fixed_image.GetSpacing()) * np.array(fixed_image.GetSize())
+ max_reasonable_translation = np.linalg.norm(fixed_size_mm) * 0.5 # Half the image size
+
+ if translation_magnitude < max_reasonable_translation:
+ initial_transform = init_transform
+ if verbose:
+ print(f"Using MOMENTS initialization (translation magnitude: {translation_magnitude:.2f} mm)")
+ else:
+ if verbose:
+ print(
+ f"MOMENTS initialization translation too large ({translation_magnitude:.2f} mm), using center-aligned"
+ )
+ except Exception as e:
+ if verbose:
+ print(f"MOMENTS initialization failed: {e}, using center-aligned translation")
+
+ if verbose:
+ final_params = initial_transform.GetParameters()
+ final_center = initial_transform.GetCenter()
+ print(f"Final initial transform: rotation={final_params[:3]}, translation={final_params[3:]}")
+ print(f"Transform center: {final_center}")
+
+ registration_method.SetInitialTransform(initial_transform)
+ registration_method.SetInterpolator(sitk.sitkLinear)
+
+ # Set up iteration callback
+ if verbose or progress_callback is not None:
+
+ def command_iteration(method: Any) -> None:
+ if verbose:
+ if method.GetOptimizerIteration() == 0:
+ print(f"Estimated scales: {method.GetOptimizerScales()}")
+ print(
+ f"Iteration {method.GetOptimizerIteration():3d} = "
+ f"{method.GetMetricValue():7.5f} : "
+ f"{method.GetOptimizerPosition()}"
+ )
+ if progress_callback is not None:
+ progress_callback(method)
+
+ registration_method.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(registration_method))
+
+ # Execute registration
+ final_transform = registration_method.Execute(fixed_image, moving_image_sitk)
+
+ stop_condition = registration_method.GetOptimizerStopConditionDescription()
+ error = registration_method.GetMetricValue()
+
+ if verbose:
+ print(f"Registration complete: {stop_condition}")
+ print(f"Final metric value: {error:.6f}")
+ final_params = final_transform.GetParameters()
+ print(f"Final transform: rotation={final_params[:3]}, translation={final_params[3:]}")
+ print(f"Fixed image size: {fixed_image.GetSize()}, spacing: {fixed_image.GetSpacing()}")
+ print(f"Moving image size: {moving_image_sitk.GetSize()}, spacing: {moving_image_sitk.GetSpacing()}")
+
+ # Restore crop offset in the translation so the transform is valid for the
+ # full original (uncropped) brain volume. Derivation:
+ # T(p) = R(p-c)+c+t maps Allen coords to cropped-brain coords (origin=0).
+ # Same tissue in full brain is at (cropped_coord + crop_origin_mm).
+ # So t_full = t_crop + crop_origin_sitk (center c cancels out).
+ if any(v != 0.0 for v in crop_origin_mm):
+ params = list(final_transform.GetParameters())
+ # SITK Euler3D params: (rx, ry, rz, tx, ty, tz) in SITK XYZ order
+ # numpy axis order (Z, Y, X) -> SITK (X, Y, Z):
+ params[3] += crop_origin_mm[2] # SITK X = numpy axis 2
+ params[4] += crop_origin_mm[1] # SITK Y = numpy axis 1
+ params[5] += crop_origin_mm[0] # SITK Z = numpy axis 0
+ final_transform.SetParameters(params)
+ if verbose:
+ print(
+ f"Adjusted translation for crop: +"
+ f"[{crop_origin_mm[2]:.3f}, {crop_origin_mm[1]:.3f}, {crop_origin_mm[0]:.3f}] mm (SITK XYZ)"
+ )
+
+ return final_transform, stop_condition, error
diff --git a/linumpy/tests/test_io_allen.py b/linumpy/tests/test_io_allen.py
new file mode 100644
index 00000000..6c33a782
--- /dev/null
+++ b/linumpy/tests/test_io_allen.py
@@ -0,0 +1,284 @@
+"""Tests for linumpy/io/allen.py — orientation handling and registration.
+
+The Allen template download is monkey-patched to return a synthetic PIR-oriented
+volume with a deliberately asymmetric tissue distribution. That keeps these
+tests offline and lets us verify that ``download_template_ras_aligned`` really
+produces a RAS+ volume (``+X = Right``, ``+Y = Anterior``, ``+Z = Superior``).
+"""
+
+from __future__ import annotations
+
+import numpy as np
+import pytest
+import SimpleITK as sitk
+
+from linumpy.reference import allen
+
+# ---------------------------------------------------------------------------
+# Synthetic PIR-oriented Allen template
+# ---------------------------------------------------------------------------
+
+
+def _make_synthetic_pir_template(resolution_um: int = 100) -> sitk.Image:
+ """Build a small synthetic volume that mimics the Allen CCF nrrd layout.
+
+ Allen CCF v3 stores the template in PIR:
+ nrrd axis 0 = AP (+=Posterior)
+ nrrd axis 1 = DV (+=Inferior)
+ nrrd axis 2 = ML (+=Right)
+
+ ``sitk.ReadImage`` maps nrrd axis k to SITK axis k, so the returned
+ SITK image has ``(X, Y, Z) = (AP, DV, ML)``. Each axis is given a
+ unique, monotonically increasing gradient so we can identify the
+ resulting orientation unambiguously after the RAS reorientation.
+ """
+ # Pick axis sizes that are all distinct so permutations are detectable.
+ ap_size, dv_size, ml_size = 12, 8, 10
+
+ # numpy shape (Z, Y, X) for sitk.GetImageFromArray:
+ # numpy Z ↔ SITK Z = ML
+ # numpy Y ↔ SITK Y = DV
+ # numpy X ↔ SITK X = AP
+ ap = np.arange(ap_size, dtype=np.float32)[None, None, :] * 1.0 # unit step
+ dv = np.arange(dv_size, dtype=np.float32)[None, :, None] * 100.0
+ ml = np.arange(ml_size, dtype=np.float32)[:, None, None] * 10000.0
+
+ arr = ap + dv + ml # each axis contributes a distinct decimal place
+
+ vol = sitk.GetImageFromArray(arr)
+ r_mm = resolution_um / 1e3
+ vol.SetSpacing((r_mm, r_mm, r_mm))
+ vol.SetOrigin((0.0, 0.0, 0.0))
+ vol.SetDirection((1, 0, 0, 0, 1, 0, 0, 0, 1))
+ return vol
+
+
+# ---------------------------------------------------------------------------
+# download_template_ras_aligned — orientation
+# ---------------------------------------------------------------------------
+
+
+class TestDownloadTemplateRasAligned:
+ """Verify the RAS reorientation of the Allen template."""
+
+ @pytest.fixture
+ def ras_template(self, monkeypatch):
+ def fake_download_template(resolution, cache=True, cache_dir=".data/"):
+ return _make_synthetic_pir_template(resolution)
+
+ monkeypatch.setattr(allen, "download_template", fake_download_template)
+ return allen.download_template_ras_aligned(100)
+
+ def test_spacing_is_isotropic_and_in_mm(self, ras_template):
+ spacing = ras_template.GetSpacing()
+ assert spacing == pytest.approx((0.1, 0.1, 0.1))
+
+ def test_origin_is_zero(self, ras_template):
+ assert ras_template.GetOrigin() == pytest.approx((0.0, 0.0, 0.0))
+
+ def test_direction_is_identity(self, ras_template):
+ assert ras_template.GetDirection() == pytest.approx((1, 0, 0, 0, 1, 0, 0, 0, 1))
+
+ def test_size_reflects_permutation(self, ras_template):
+ """After ``PermuteAxes((2, 0, 1))`` the SITK size becomes (ML, AP, DV)."""
+ # Input sizes: AP=12, DV=8, ML=10 → output (ML, AP, DV) = (10, 12, 8)
+ assert ras_template.GetSize() == (10, 12, 8)
+
+ def test_positive_x_is_right(self, ras_template):
+ """+X must point toward Right (originally +ML in nrrd)."""
+ arr = sitk.GetArrayFromImage(ras_template)
+ # numpy axis 2 = SITK X; ML gradient was the `10000` coefficient.
+ col = arr[0, 0, :]
+ diffs = np.diff(col)
+ # Gradient along X in RAS-aligned volume should increase monotonically.
+ assert np.all(diffs > 0), f"+X is not monotonic along ML (Right): {col}"
+
+ def test_positive_y_is_anterior(self, ras_template):
+ """+Y must point toward Anterior (originally -AP in nrrd).
+
+ Raw AP gradient increases with +Posterior, so after reorientation the
+ AP gradient should DECREASE along +Y (since +Y = Anterior).
+ """
+ arr = sitk.GetArrayFromImage(ras_template)
+ # numpy axis 1 = SITK Y; AP gradient was the `1.0` coefficient.
+ # Extract AP component by taking the modulo-100 decimal of a single X,Z column.
+ col = arr[0, :, 0] % 100.0 # keep only AP contribution (0 .. 11)
+ diffs = np.diff(col)
+ assert np.all(diffs < 0), f"+Y is not anterior (AP should decrease): {col}"
+
+ def test_positive_z_is_superior(self, ras_template):
+ """+Z must point toward Superior (originally -DV in nrrd).
+
+ Raw DV gradient increases with +Inferior, so after reorientation the
+ DV gradient should DECREASE along +Z (since +Z = Superior).
+ """
+ arr = sitk.GetArrayFromImage(ras_template)
+ # numpy axis 0 = SITK Z; DV gradient was the `100` coefficient.
+ # Extract DV component using (value % 10000) // 100.
+ col = (arr[:, 0, 0] % 10000.0) // 100.0 # 0 .. 7
+ diffs = np.diff(col)
+ assert np.all(diffs < 0), f"+Z is not superior (DV should decrease): {col}"
+
+
+# ---------------------------------------------------------------------------
+# numpy_to_sitk_image
+# ---------------------------------------------------------------------------
+
+
+class TestNumpyToSitkImage:
+ def test_roundtrip_preserves_values(self):
+ arr = np.arange(2 * 3 * 4, dtype=np.float32).reshape(2, 3, 4)
+ img = allen.numpy_to_sitk_image(arr, spacing=(0.1, 0.2, 0.3))
+ back = sitk.GetArrayFromImage(img)
+ np.testing.assert_array_equal(back, arr)
+
+ def test_spacing_is_permuted_to_xyz(self):
+ arr = np.zeros((2, 3, 4), dtype=np.float32)
+ img = allen.numpy_to_sitk_image(arr, spacing=(0.1, 0.2, 0.3))
+ # spacing=(res_z, res_y, res_x) → SITK GetSpacing=(res_x, res_y, res_z)
+ assert img.GetSpacing() == pytest.approx((0.3, 0.2, 0.1))
+
+ def test_size_is_reversed_from_numpy_shape(self):
+ arr = np.zeros((2, 3, 4), dtype=np.float32)
+ img = allen.numpy_to_sitk_image(arr, spacing=(1.0, 1.0, 1.0))
+ assert img.GetSize() == (4, 3, 2)
+
+ def test_origin_and_direction_are_identity(self):
+ arr = np.zeros((2, 3, 4), dtype=np.float32)
+ img = allen.numpy_to_sitk_image(arr, spacing=(1.0, 1.0, 1.0))
+ assert img.GetOrigin() == (0.0, 0.0, 0.0)
+ assert img.GetDirection() == (1, 0, 0, 0, 1, 0, 0, 0, 1)
+
+ def test_cast_dtype_produces_float32(self):
+ arr = np.ones((2, 3, 4), dtype=np.uint16)
+ img = allen.numpy_to_sitk_image(arr, spacing=(1.0, 1.0, 1.0), cast_dtype=np.float32)
+ assert img.GetPixelID() == sitk.sitkFloat32
+
+ def test_no_cast_preserves_dtype(self):
+ arr = np.ones((2, 3, 4), dtype=np.uint16)
+ img = allen.numpy_to_sitk_image(arr, spacing=(1.0, 1.0, 1.0))
+ assert img.GetPixelID() == sitk.sitkUInt16
+
+ def test_input_array_not_modified(self):
+ arr = np.arange(24, dtype=np.float32).reshape(2, 3, 4)
+ original = arr.copy()
+ allen.numpy_to_sitk_image(arr, spacing=(1.0, 1.0, 1.0), cast_dtype=np.float32)
+ np.testing.assert_array_equal(arr, original)
+
+
+# ---------------------------------------------------------------------------
+# register_3d_rigid_to_allen — end-to-end self-registration
+# ---------------------------------------------------------------------------
+
+
+def _make_synthetic_brain(shape=(24, 24, 24), spacing=(0.2, 0.2, 0.2)):
+ """Small asymmetric synthetic brain with a unique intensity pattern per axis."""
+ z, y, x = np.indices(shape, dtype=np.float32)
+ # Ellipsoid mask offset from centre, asymmetric along each axis.
+ cz, cy, cx = shape[0] * 0.55, shape[1] * 0.5, shape[2] * 0.45
+ rz, ry, rx = shape[0] * 0.35, shape[1] * 0.3, shape[2] * 0.4
+ mask = ((z - cz) / rz) ** 2 + ((y - cy) / ry) ** 2 + ((x - cx) / rx) ** 2 < 1
+ brain = np.zeros(shape, dtype=np.float32)
+ # Distinct gradient along each axis so registration has more than a single
+ # rotationally symmetric blob to work with.
+ brain[mask] = 1.0 + 0.3 * (z[mask] / shape[0]) + 0.5 * (y[mask] / shape[1]) + 0.7 * (x[mask] / shape[2])
+ return brain
+
+
+class TestRegisterRigidToAllen:
+ """End-to-end registration tests using a synthetic Allen template."""
+
+ @pytest.fixture(autouse=True)
+ def patch_allen(self, monkeypatch):
+ def fake_download_template(resolution, cache=True, cache_dir=".data/"):
+ return _make_synthetic_pir_template(resolution)
+
+ monkeypatch.setattr(allen, "download_template", fake_download_template)
+
+ def test_self_registration_recovers_identity(self):
+ """Registering the RAS Allen template against itself yields ~identity."""
+ target = allen.download_template_ras_aligned(100)
+ moving = sitk.GetArrayFromImage(target) # numpy (Z, Y, X)
+ # SITK spacing is (X, Y, Z); moving_spacing is (res_z, res_y, res_x)
+ sx, sy, sz = target.GetSpacing()
+ transform, stop, _err = allen.register_3d_rigid_to_allen(
+ moving_image=moving,
+ moving_spacing=(sz, sy, sx),
+ allen_resolution=100,
+ metric="MSE",
+ max_iterations=50,
+ verbose=False,
+ )
+ params = transform.GetParameters()
+ rotation = np.array(params[:3])
+ translation = np.array(params[3:6])
+ # The MSE minimum is at identity; allow generous tolerances because the
+ # synthetic volume is tiny.
+ assert np.max(np.abs(rotation)) < 0.1, f"Rotation too large: {rotation}"
+ assert np.max(np.abs(translation)) < 1.0, f"Translation too large: {translation}"
+ assert stop # non-empty stop-condition string
+
+ def test_downsamples_allen_when_moving_is_coarser(self, capsys):
+ """If moving resolution > allen resolution, allen must be downsampled."""
+ # Moving at 200 µm, allen synthetic at 100 µm → expect downsampling.
+ shape = (10, 10, 10)
+ moving = _make_synthetic_brain(shape, spacing=(0.2, 0.2, 0.2))
+ _, _, _ = allen.register_3d_rigid_to_allen(
+ moving_image=moving,
+ moving_spacing=(0.2, 0.2, 0.2),
+ allen_resolution=100,
+ metric="MSE",
+ max_iterations=3,
+ verbose=True,
+ )
+ captured = capsys.readouterr().out
+ assert "Downsampled Allen atlas" in captured
+
+ def test_does_not_downsample_when_already_coarse(self, capsys):
+ """If moving resolution ≤ allen resolution, allen must NOT be downsampled."""
+ shape = (10, 10, 10)
+ moving = _make_synthetic_brain(shape, spacing=(0.05, 0.05, 0.05))
+ _, _, _ = allen.register_3d_rigid_to_allen(
+ moving_image=moving,
+ moving_spacing=(0.05, 0.05, 0.05),
+ allen_resolution=100,
+ metric="MSE",
+ max_iterations=3,
+ verbose=True,
+ )
+ captured = capsys.readouterr().out
+ assert "Downsampled Allen atlas" not in captured
+
+ def test_crop_offset_reported_in_verbose_output(self, capsys):
+ """The ``crop_origin_mm`` restoration must add an offset proportional to
+ the leading zero-padding of the moving volume. We use a plain cube so
+ the non-zero bounding box equals the cube's shape exactly, making the
+ expected crop origin easy to compute.
+ """
+ # A fully filled cube — nonzero bbox equals the full cube shape.
+ cube_size = 12
+ cube = np.ones((cube_size, cube_size, cube_size), dtype=np.float32)
+ leading_pad = (20, 15, 25) # (pad_z, pad_y, pad_x); each > 10 (margin)
+ canvas = np.pad(cube, [(p, 5) for p in leading_pad], mode="constant", constant_values=0)
+
+ _, _, _ = allen.register_3d_rigid_to_allen(
+ moving_image=canvas,
+ moving_spacing=(0.1, 0.1, 0.1),
+ allen_resolution=100,
+ metric="MSE",
+ max_iterations=0,
+ verbose=True,
+ )
+ captured = capsys.readouterr().out
+ # Expected crop start per numpy axis (voxels): pad_axis - margin = pad - 10.
+ margin = 10
+ spacing = 0.1
+ expected_numpy = tuple((p - margin) * spacing for p in leading_pad)
+ # SITK XYZ = numpy axes (X=2, Y=1, Z=0)
+ expected_sitk_xyz = (expected_numpy[2], expected_numpy[1], expected_numpy[0])
+ expected_log = (
+ "Adjusted translation for crop: +["
+ f"{expected_sitk_xyz[0]:.3f}, {expected_sitk_xyz[1]:.3f}, {expected_sitk_xyz[2]:.3f}"
+ "] mm (SITK XYZ)"
+ )
+ assert expected_log in captured, f"Expected log not found. Got:\n{captured}"
diff --git a/scripts/linum_align_to_ras.py b/scripts/linum_align_to_ras.py
new file mode 100755
index 00000000..ac8812cb
--- /dev/null
+++ b/scripts/linum_align_to_ras.py
@@ -0,0 +1,1079 @@
+#!/usr/bin/env python3
+
+"""
+Align a 3D brain volume to RAS orientation using rigid registration to the Allen atlas.
+
+This script computes a rigid transform from the input brain volume to a RAS-aligned
+version by registering it to the Allen Brain Atlas. The transform can be applied
+directly to the zarr file (resampling) or stored in OME-Zarr metadata.
+"""
+
+# Configure thread limits before numpy/scipy imports
+import linumpy.config.threads # noqa: F401
+
+import argparse
+import json
+from collections.abc import Callable
+from pathlib import Path
+from typing import Any
+
+import matplotlib
+import matplotlib.pyplot as plt
+import numpy as np
+import SimpleITK as sitk
+from tqdm.auto import tqdm
+
+from linumpy.imaging.orientation import (
+ apply_orientation_transform,
+ parse_orientation_code,
+ reorder_resolution,
+)
+from linumpy.io.zarr import AnalysisOmeZarrWriter, read_omezarr
+from linumpy.reference import allen
+
+matplotlib.use("Agg") # Non-interactive backend
+
+# Constants
+DEFAULT_ALLEN_RESOLUTION = 100
+DEFAULT_MAX_ITERATIONS = 1000
+DEFAULT_METRIC = "MI"
+
+
+def _debug_log(message: str, **fields: Any) -> None:
+ """Append an NDJSON line describing a slicing/labelling decision.
+
+ Active only when ``LINUMPY_DEBUG_LOG`` is set, so production runs pay
+ nothing. Used to capture runtime evidence of which volume conventions
+ each preview function actually receives.
+ """
+ import os
+
+ path = os.environ.get("LINUMPY_DEBUG_LOG")
+ if not path:
+ return
+ try:
+ import time
+
+ entry = {
+ "id": f"log_{int(time.time() * 1000)}_panels",
+ "timestamp": int(time.time() * 1000),
+ "sessionId": "6fa1b3",
+ "runId": "panels-fix",
+ "hypothesisId": "H1",
+ "location": "linum_align_to_ras.py",
+ "message": message,
+ "data": fields,
+ }
+ with Path(path).open("a") as f:
+ f.write(json.dumps(entry) + "\n")
+ except Exception:
+ pass
+
+
+def _build_arg_parser() -> argparse.ArgumentParser:
+ """Build the command-line argument parser."""
+ p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
+ p.add_argument("input_zarr", help="Input OME-Zarr file from 3D reconstruction pipeline")
+ p.add_argument("output_zarr", help="Output OME-Zarr file (RAS-aligned)")
+ p.add_argument(
+ "--allen-resolution",
+ type=int,
+ default=DEFAULT_ALLEN_RESOLUTION,
+ choices=allen.AVAILABLE_RESOLUTIONS,
+ help="Allen atlas resolution in micron [%(default)s]",
+ )
+ p.add_argument(
+ "--metric",
+ type=str,
+ default=DEFAULT_METRIC,
+ choices=["MI", "MSE", "CC", "AntsCC"],
+ help="Registration metric [%(default)s]",
+ )
+ p.add_argument(
+ "--max-iterations",
+ type=int,
+ default=DEFAULT_MAX_ITERATIONS,
+ help="Maximum registration iterations [%(default)s]",
+ )
+ p.add_argument(
+ "--store-transform-only", action="store_true", help="Store transform in metadata only (don't resample volume)"
+ )
+ p.add_argument("--level", type=int, default=0, help="Pyramid level for registration (0 = full resolution) [%(default)s]")
+ p.add_argument(
+ "--chunks", type=int, nargs=3, default=None, help="Chunk size for output zarr. Uses input chunks when None."
+ )
+ p.add_argument(
+ "--n-levels", type=int, default=None, help="Number of pyramid levels for output. Uses Allen atlas levels when None."
+ )
+ p.add_argument(
+ "--pyramid_resolutions",
+ type=float,
+ nargs="+",
+ default=None,
+ help="Target pyramid resolution levels in µm (e.g. 10 25 50 100).\n"
+ "If omitted, inherits levels from input zarr metadata or uses Allen resolutions.",
+ )
+ p.add_argument(
+ "--make_isotropic", action="store_true", default=True, help="Resample to isotropic voxels at each pyramid level."
+ )
+ p.add_argument("--no_isotropic", dest="make_isotropic", action="store_false")
+ p.add_argument("--verbose", action="store_true", help="Print registration progress")
+ p.add_argument("--preview", type=str, default=None, help="Generate preview image showing alignment comparison")
+ p.add_argument(
+ "--input-orientation",
+ type=str,
+ default=None,
+ help="Input volume orientation code (3 letters: R/L, A/P, S/I)\nExamples: 'RAS' (Allen), 'LPI', 'PIR'",
+ )
+ p.add_argument(
+ "--initial-rotation",
+ type=float,
+ nargs=3,
+ default=[0.0, 0.0, 0.0],
+ metavar=("RX", "RY", "RZ"),
+ help="Initial rotation angles in degrees (Rx, Ry, Rz).\nUse to provide initial orientation hint for registration.",
+ )
+ p.add_argument("--preview-only", action="store_true", help="Only generate preview of input volume (no registration)")
+ p.add_argument(
+ "--orientation-preview",
+ type=str,
+ default=None,
+ metavar="PATH",
+ help="Save a 3-panel preview of the volume after --input-orientation and\n"
+ "--initial-rotation are applied. Use to verify these parameters\n"
+ "before committing to a full registration run.",
+ )
+ p.add_argument(
+ "--orientation-preview-only",
+ action="store_true",
+ help="Generate --orientation-preview and exit without running registration.",
+ )
+ return p
+
+
+# =============================================================================
+# Orientation utilities — imported from linumpy.imaging.orientation
+# =============================================================================
+
+
+def create_registration_progress_callback(
+ max_iterations: int,
+ n_resolution_levels: int = 3,
+ pbar: tqdm | None = None,
+ registration_start_step: int = 0,
+ registration_steps: int = 0,
+) -> Callable:
+ """
+ Create a progress callback for registration.
+
+ Parameters
+ ----------
+ max_iterations : int
+ Maximum iterations per level
+ n_resolution_levels : int
+ Number of resolution levels in the registration pyramid
+ pbar : tqdm, optional
+ Progress bar to update
+ registration_start_step : int
+ Step number where registration starts in progress bar
+ registration_steps : int
+ Number of steps allocated for registration
+
+ Returns
+ -------
+ callable
+ Progress callback function compatible with SimpleITK registration
+ """
+ total_iterations = [0]
+ level_counter = [0]
+ last_iteration = [-1]
+ # Worst-case budget (used only as the denominator for the progress bar).
+ estimated_total = float(max_iterations * n_resolution_levels)
+
+ def callback(method: Any) -> None:
+ """Update progress during registration iterations."""
+ iteration = method.GetOptimizerIteration()
+ metric = method.GetMetricValue()
+
+ # Detect resolution-level transitions (iteration counter resets to 0
+ # when SimpleITK starts the next pyramid level).
+ if iteration < last_iteration[0]:
+ level_counter[0] += 1
+ last_iteration[0] = iteration
+
+ total_iterations[0] += 1
+
+ if pbar is not None:
+ # Blend "within-level" progress with completed levels so the bar
+ # advances smoothly across resolutions and does not stall when a
+ # level converges early or hits max_iterations.
+ within_level = min(1.0, (iteration + 1) / max_iterations)
+ level_progress = (level_counter[0] + within_level) / n_resolution_levels
+ progress_ratio = min(1.0, max(level_progress, total_iterations[0] / estimated_total))
+ target_step = registration_start_step + int(registration_steps * progress_ratio)
+ if target_step > pbar.n:
+ pbar.n = target_step
+ pbar.set_postfix_str(f"metric={metric:.6f} level={level_counter[0] + 1}/{n_resolution_levels}")
+ pbar.refresh()
+
+ return callback
+
+
+# =============================================================================
+# Transform utilities
+# =============================================================================
+
+
+def sitk_transform_to_affine_matrix(transform: sitk.Transform) -> np.ndarray:
+ """
+ Convert SimpleITK transform to 4x4 affine matrix.
+
+ Parameters
+ ----------
+ transform : sitk.Transform
+ SimpleITK Euler3DTransform or AffineTransform
+
+ Returns
+ -------
+ np.ndarray
+ 4x4 affine matrix in (Z, Y, X) coordinate ordering, matching the
+ OME-NGFF axis declaration used by the pipeline.
+ """
+ if isinstance(transform, sitk.Euler3DTransform):
+ center = np.array(transform.GetCenter())
+ params = transform.GetParameters()
+ rx, ry, rz = params[:3]
+ translation = np.array(params[3:6])
+
+ # Build rotation matrix from Euler angles
+ cx, cy, cz = np.cos([rx, ry, rz])
+ sx, sy, sz = np.sin([rx, ry, rz])
+
+ r = np.array(
+ [
+ [cz * cy, cz * sy * sx - sz * cx, cz * sy * cx + sz * sx],
+ [sz * cy, sz * sy * sx + cz * cx, sz * sy * cx - cz * sx],
+ [-sy, cy * sx, cy * cx],
+ ]
+ )
+
+ matrix = np.eye(4)
+ matrix[:3, :3] = r
+ matrix[:3, 3] = translation + center - r @ center
+
+ elif isinstance(transform, sitk.AffineTransform):
+ r = np.array(transform.GetMatrix()).reshape(3, 3)
+ translation = np.array(transform.GetTranslation())
+ center = np.array(transform.GetCenter())
+
+ matrix = np.eye(4)
+ matrix[:3, :3] = r
+ matrix[:3, 3] = translation + center - r @ center
+ else:
+ raise ValueError(f"Unsupported transform type: {type(transform)}")
+
+ # Permute from SimpleITK (X, Y, Z) to our (Z, Y, X) ordering (OME-NGFF axis order).
+ permute = np.array([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]])
+ return permute @ matrix @ permute.T
+
+
+def store_transform_in_metadata(zarr_path: Path, transform: sitk.Transform) -> None:
+ """Store transform in OME-Zarr metadata as affine coordinate transformation."""
+ affine_matrix = sitk_transform_to_affine_matrix(transform)
+ zattrs_path = Path(zarr_path) / ".zattrs"
+
+ if not zattrs_path.exists():
+ raise FileNotFoundError(f".zattrs not found: {zarr_path}")
+
+ with Path(zattrs_path).open(encoding="utf-8") as f:
+ metadata = json.load(f)
+
+ affine_transform = {"type": "affine", "affine": affine_matrix.flatten().tolist()}
+
+ multiscales = metadata.get("multiscales", [])
+ if not multiscales:
+ raise ValueError("No multiscales entry found in metadata")
+
+ for dataset in multiscales[0].get("datasets", []):
+ existing = dataset.get("coordinateTransformations", [])
+ dataset["coordinateTransformations"] = [affine_transform, *existing]
+
+ with Path(zattrs_path).open("w", encoding="utf-8") as f:
+ json.dump(metadata, f, indent=2)
+
+ print(f"Stored affine transform in metadata: {zattrs_path}")
+
+
+# =============================================================================
+# Resolution utilities
+# =============================================================================
+
+
+def get_pyramid_resolutions_from_zarr(zarr_path: Path) -> list[float] | None:
+ """
+ Extract pyramid resolution levels from OME-Zarr metadata.
+
+ Parameters
+ ----------
+ zarr_path : Path
+ Path to OME-Zarr file
+
+ Returns
+ -------
+ list of float or None
+ Target resolutions in microns, or None if not found
+ """
+ for metadata_file in ["zarr.json", ".zattrs"]:
+ metadata_path = zarr_path / metadata_file
+ if not metadata_path.exists():
+ continue
+
+ try:
+ with Path(metadata_path).open(encoding="utf-8") as f:
+ metadata = json.load(f)
+ except (OSError, json.JSONDecodeError):
+ continue
+
+ multiscales = metadata.get("multiscales", [])
+ if not multiscales:
+ continue
+
+ resolutions = []
+ for dataset in multiscales[0].get("datasets", []):
+ transforms = dataset.get("coordinateTransformations", [])
+ for tr in transforms:
+ if tr.get("type") == "scale" and "scale" in tr:
+ # Get finest spatial dimension, convert mm to µm
+ scale = tr["scale"][-3:]
+ res_um = min(float(s) for s in scale) * 1000
+ resolutions.append(res_um)
+ break
+
+ if resolutions:
+ return resolutions
+
+ return None
+
+
+# =============================================================================
+# Core processing functions
+# =============================================================================
+
+
+def compute_centered_reference_and_transform(
+ moving_sitk: sitk.Image, transform: sitk.Transform, output_spacing: tuple | None = None
+) -> tuple[sitk.Image, sitk.Transform]:
+ """
+ Compute a reference image and modified transform that centers the output volume.
+
+ This creates an output that is centered in the volume (brain in the middle),
+ preserving the original resolution.
+
+ Parameters
+ ----------
+ moving_sitk : sitk.Image
+ The input moving image
+ transform : sitk.Transform
+ Transform to apply (moving -> fixed/RAS space)
+ output_spacing : tuple, optional
+ Output voxel spacing. If None, uses moving image spacing.
+
+ Returns
+ -------
+ ref : sitk.Image
+ Reference image for resampling, with origin at 0
+ composite_transform : sitk.Transform
+ Modified transform that maps moving image to centered output
+ """
+ if output_spacing is None:
+ output_spacing = moving_sitk.GetSpacing()
+
+ # Get corners of the moving image in physical coordinates
+ size = moving_sitk.GetSize()
+ corners = [
+ (0, 0, 0),
+ (size[0] - 1, 0, 0),
+ (0, size[1] - 1, 0),
+ (0, 0, size[2] - 1),
+ (size[0] - 1, size[1] - 1, 0),
+ (size[0] - 1, 0, size[2] - 1),
+ (0, size[1] - 1, size[2] - 1),
+ (size[0] - 1, size[1] - 1, size[2] - 1),
+ ]
+
+ # Map brain corners to FIXED/RAS space.
+ # The registration transform maps fixed→moving (ResampleImageFilter convention),
+ # so we use its inverse (moving→fixed) to find where the brain corners land
+ # in the fixed (RAS/Allen) coordinate system.
+ inv_transform = transform.GetInverse()
+ transformed_pts = []
+ for idx in corners:
+ phys = moving_sitk.TransformContinuousIndexToPhysicalPoint(idx)
+ transformed_pts.append(inv_transform.TransformPoint(phys))
+
+ pts = np.array(transformed_pts)
+ pts_min = pts.min(axis=0)
+ pts_max = pts.max(axis=0)
+
+ # Compute output size to cover the full transformed brain extent
+ spacing = np.array(output_spacing)
+ extent = pts_max - pts_min
+ new_size = np.ceil(extent / spacing).astype(int)
+
+ # Reference image: origin at (0,0,0), spanning [0, new_size*spacing].
+ # Output voxel p maps to fixed-space coordinate (p + pts_min).
+ ref = sitk.Image([int(s) for s in new_size], moving_sitk.GetPixelIDValue())
+ ref.SetSpacing(tuple(spacing))
+ ref.SetOrigin((0.0, 0.0, 0.0))
+ ref.SetDirection((1, 0, 0, 0, 1, 0, 0, 0, 1)) # Identity direction (RAS)
+
+ # Shift transform: output space → fixed space (translate by pts_min).
+ # This maps output origin (0,0,0) to the brain's fixed-space bounding box minimum.
+ shift_transform = sitk.TranslationTransform(3)
+ shift_transform.SetOffset(tuple(pts_min))
+
+ # Composite transform for resampling:
+ # output point → (shift) → fixed space → (T) → moving space
+ # SimpleITK CompositeTransform applies transforms in REVERSE order of
+ # addition (the most recently added transform is applied first, matching
+ # ITK's stack convention). To obtain ``transform(shift(p))`` we must add
+ # ``transform`` first and ``shift`` last.
+ composite = sitk.CompositeTransform(3)
+ composite.AddTransform(transform) # added first → applied last (fixed → moving)
+ composite.AddTransform(shift_transform) # added last → applied first (output → fixed)
+
+ return ref, composite
+
+
+def apply_transform_to_zarr(
+ input_path: Path,
+ output_path: Path,
+ transform: sitk.Transform,
+ chunks: tuple | None = None,
+ n_levels: int | None = None,
+ pyramid_resolutions: list | None = None,
+ make_isotropic: bool = True,
+ orientation_permutation: tuple | None = None,
+ orientation_flips: tuple | None = None,
+ pbar: tqdm | None = None,
+) -> None:
+ """
+ Apply transform to zarr file by resampling into RAS-aligned space.
+
+ The output is centered on the transformed brain volume, preserving the
+ original resolution. This corrects any rotation/off-axis alignment without
+ placing the brain in the Allen atlas coordinate system.
+
+ Parameters
+ ----------
+ input_path: Path
+ Path to input OME-Zarr
+ output_path: Path
+ Path to output OME-Zarr
+ transform : sitk.Transform
+ Transform to apply
+ chunks : tuple, optional
+ Chunk size for output
+ n_levels : int, optional
+ Number of pyramid levels (if None, use source pyramid or Allen resolutions)
+ orientation_permutation : tuple, optional
+ Axis permutation for orientation correction
+ orientation_flips : tuple, optional
+ Axis flips for orientation correction
+ pbar : tqdm, optional
+ Progress bar
+ pyramid_resolutions : list, optional
+ Explicit list of resolutions for the output pyramid
+ make_isotropic : bool
+ If True, resample output to isotropic resolution
+ """
+
+ def update_pbar() -> None:
+ if pbar:
+ pbar.update(1)
+
+ # Load volume at full resolution (level 0) and capture its actual spacing.
+ # base_resolution comes from the downsampled registration level, so we must
+ # read the level-0 spacing from the file to get the correct physical extent.
+ vol_zarr, level0_resolution = read_omezarr(input_path, level=0)
+ if chunks is None:
+ chunks = getattr(vol_zarr, "chunks", None)
+ if chunks is None:
+ chunks = (128,) * len(vol_zarr.shape)
+
+ vol = np.asarray(vol_zarr[:])
+ original_dtype = vol.dtype
+ update_pbar()
+
+ # Apply orientation correction
+ resolution = level0_resolution
+ if orientation_permutation is not None:
+ vol = apply_orientation_transform(vol, orientation_permutation, orientation_flips)
+ resolution = reorder_resolution(resolution, orientation_permutation)
+
+ # Compute a tissue-representative background value on the numpy array
+ # BEFORE allocating the (potentially large) SimpleITK float32 copy. Using
+ # this as the default pixel value avoids black borders that would skew
+ # downstream normalization and visualization.
+ nonzero_mask = vol > 0
+ bg_value = float(np.percentile(vol[nonzero_mask], 1)) if nonzero_mask.any() else 0.0
+ del nonzero_mask
+
+ # Convert to SimpleITK
+ vol_sitk = allen.numpy_to_sitk_image(vol, resolution, cast_dtype=np.float32)
+ del vol # free original volume before resampling
+ update_pbar()
+
+ # Compute reference image and modified transform that centers the output
+ reference, centered_transform = compute_centered_reference_and_transform(vol_sitk, transform)
+
+ resampler = sitk.ResampleImageFilter()
+ resampler.SetReferenceImage(reference)
+ resampler.SetInterpolator(sitk.sitkLinear)
+ resampler.SetDefaultPixelValue(bg_value)
+ resampler.SetTransform(centered_transform)
+
+ transformed_sitk = resampler.Execute(vol_sitk)
+ del vol_sitk # free input before allocating output array
+ transformed = sitk.GetArrayFromImage(transformed_sitk)
+ del transformed_sitk # free SimpleITK image after extracting numpy array
+ update_pbar()
+
+ # GetArrayFromImage already yields numpy (Z, Y, X) matching our convention.
+ update_pbar()
+
+ # Convert back to original dtype
+ if np.issubdtype(original_dtype, np.integer):
+ info = np.iinfo(original_dtype)
+ transformed = np.clip(np.rint(transformed), info.min, info.max).astype(original_dtype)
+ else:
+ transformed = transformed.astype(original_dtype)
+
+ # Write output
+ writer = AnalysisOmeZarrWriter(
+ output_path,
+ shape=transformed.shape,
+ chunk_shape=chunks,
+ dtype=transformed.dtype,
+ overwrite=True,
+ )
+ writer[:] = transformed
+
+ if n_levels is not None:
+ writer.finalize(list(resolution), n_levels=n_levels)
+ else:
+ if pyramid_resolutions is not None:
+ target_resolutions = pyramid_resolutions
+ else:
+ # Fallback: inherit levels from input zarr metadata, or use Allen resolutions
+ target_resolutions = get_pyramid_resolutions_from_zarr(Path(input_path))
+ if target_resolutions is None:
+ target_resolutions = list(allen.AVAILABLE_RESOLUTIONS)
+ writer.finalize(list(resolution), target_resolutions_um=target_resolutions, make_isotropic=make_isotropic)
+
+ update_pbar()
+
+
+# =============================================================================
+# Preview generation
+# =============================================================================
+
+
+def create_input_preview(input_path: Path, output_path: Path, level: int = 0) -> None:
+ """Create preview of input volume to help determine orientation."""
+ vol_zarr, resolution = read_omezarr(input_path, level=level)
+ vol = np.asarray(vol_zarr[:])
+
+ z_mid = vol.shape[0] // 2
+ x_mid = vol.shape[1] // 2
+ y_mid = vol.shape[2] // 2
+
+ vmin, vmax = np.percentile(vol, [1, 99])
+
+ fig, axes = plt.subplots(2, 2, figsize=(14, 14))
+ fig.suptitle(f"Input Volume Preview\nShape: {vol.shape} (Z, Y, X), Resolution: {resolution} mm", fontsize=14, y=0.98)
+
+ # Axial slice (dim0 midpoint)
+ axes[0, 0].imshow(vol[z_mid, :, :].T, cmap="gray", origin="lower", vmin=vmin, vmax=vmax)
+ axes[0, 0].set_title("Slice at dim0 midpoint\nShows: dim1 × dim2")
+ axes[0, 0].set_xlabel("dim1 →")
+ axes[0, 0].set_ylabel("dim2 →")
+
+ # Sagittal slice (dim1 midpoint)
+ axes[0, 1].imshow(vol[::-1, x_mid, :], cmap="gray", origin="lower", vmin=vmin, vmax=vmax)
+ axes[0, 1].set_title("Slice at dim1 midpoint\nShows: dim2 × dim0")
+ axes[0, 1].set_xlabel("dim2 →")
+ axes[0, 1].set_ylabel("dim0 →")
+
+ # Coronal slice (dim2 midpoint)
+ axes[1, 0].imshow(vol[::-1, :, y_mid], cmap="gray", origin="lower", vmin=vmin, vmax=vmax)
+ axes[1, 0].set_title("Slice at dim2 midpoint\nShows: dim1 × dim0")
+ axes[1, 0].set_xlabel("dim1 →")
+ axes[1, 0].set_ylabel("dim0 →")
+
+ # Help text
+ axes[1, 1].axis("off")
+ help_text = """
+ORIENTATION GUIDE (Allen Atlas = RAS+)
+
+Allen RAS+ convention:
+ • R (Right): +X direction
+ • A (Anterior): +Y direction (nose)
+ • S (Superior): +Z direction (top)
+
+For each dimension, identify the anatomical direction:
+ R/L for right/left
+ A/P for anterior/posterior
+ S/I for superior/inferior
+
+Example:
+ dim0→Superior, dim1→Anterior, dim2→Right
+ → orientation code = 'SAR'
+"""
+ axes[1, 1].text(
+ 0.02,
+ 0.98,
+ help_text,
+ transform=axes[1, 1].transAxes,
+ fontsize=10,
+ verticalalignment="top",
+ fontfamily="monospace",
+ bbox={"boxstyle": "round", "facecolor": "wheat", "alpha": 0.5},
+ )
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
+ plt.close()
+ print(f"Input preview saved to: {output_path}")
+
+
+def create_alignment_preview(
+ input_path: Path,
+ output_path: Path | None,
+ transform: sitk.Transform,
+ resolution: tuple,
+ preview_path: str,
+ allen_resolution: int = DEFAULT_ALLEN_RESOLUTION,
+ level: int = 0,
+ orientation_permutation: tuple | None = None,
+ orientation_flips: tuple | None = None,
+ pbar: tqdm | None = None,
+) -> None:
+ """Create preview comparing original, aligned, and Allen template.
+
+ Shows center slices from each volume in their own coordinate frames.
+ The Allen template is shown for reference but may not spatially align
+ with the brain volume since we're not placing it in Allen coordinate space.
+ """
+
+ def update_pbar() -> None:
+ if pbar:
+ pbar.update(1)
+
+ # Load original
+ vol_original, orig_res = read_omezarr(input_path, level=level)
+ vol_original = np.asarray(vol_original[:])
+
+ if orientation_permutation is not None:
+ vol_original = apply_orientation_transform(vol_original, orientation_permutation, orientation_flips)
+ orig_res = reorder_resolution(tuple(orig_res), orientation_permutation)
+
+ # apply_orientation_transform yields linumpy convention (S, R, A): dim0=S,
+ # dim1=R, dim2=A. The aligned and Allen-template volumes below are in
+ # standard RAS — numpy (S, A, R): dim0=S, dim1=A, dim2=R. Permute the
+ # original to (S, A, R) here so all three columns share one convention and
+ # a single set of "Axial / Coronal / Sagittal" labels applies uniformly.
+ vol_original = np.transpose(vol_original, (0, 2, 1))
+ orig_res = (orig_res[0], orig_res[2], orig_res[1])
+ update_pbar()
+
+ # Load aligned volume from output file, or compute it
+ if output_path and Path(output_path).exists():
+ vol_aligned, _aligned_res = read_omezarr(output_path, level=level)
+ vol_aligned = np.asarray(vol_aligned[:])
+ else:
+ # Compute aligned volume using the transform
+ vol_sitk = allen.numpy_to_sitk_image(vol_original, resolution)
+ # Create reference and centered transform
+ reference, centered_transform = compute_centered_reference_and_transform(vol_sitk, transform)
+
+ vol_arr = sitk.GetArrayViewFromImage(vol_sitk)
+ nonzero = vol_arr[vol_arr > 0]
+ bg_value = float(np.percentile(nonzero, 1)) if len(nonzero) > 0 else 0.0
+ resampler = sitk.ResampleImageFilter()
+ resampler.SetReferenceImage(reference)
+ resampler.SetInterpolator(sitk.sitkLinear)
+ resampler.SetDefaultPixelValue(bg_value)
+ resampler.SetTransform(centered_transform)
+ transformed_sitk = resampler.Execute(vol_sitk)
+ vol_aligned = sitk.GetArrayFromImage(transformed_sitk)
+ update_pbar()
+
+ # Load Allen template at native resolution for reference
+ # We'll just show it as a reference, not spatially aligned
+ allen_sitk = allen.download_template_ras_aligned(allen_resolution, cache=True)
+ allen_template = sitk.GetArrayFromImage(allen_sitk)
+ # GetArrayFromImage already yields numpy (Z, Y, X) matching our convention.
+ update_pbar()
+
+ # Helper functions
+ def get_center_slices(vol: Any) -> Any:
+ """Get center slices in each plane."""
+ z, y, x = vol.shape[0] // 2, vol.shape[1] // 2, vol.shape[2] // 2
+ return vol[z, :, :], vol[:, y, :], vol[:, :, x]
+
+ def get_display_range(vol: Any) -> Any:
+ """Get display range from non-zero values."""
+ nonzero = vol[vol > 0]
+ if len(nonzero) > 0:
+ return np.percentile(nonzero, [1, 99])
+ return 0, 1
+
+ def find_content_center_slices(vol: Any) -> Any:
+ """Find the slice with maximum content independently for each axis.
+
+ Using a shared 3D centroid for all three views fails when the brain is
+ asymmetric (e.g. cut at 45°): the centroid lands near the cut boundary,
+ so one or more of the orthogonal slice views passes through the cut plane
+ and shows a black stripe. Instead, pick each index independently as the
+ slice with the highest total signal along that axis.
+ """
+ if vol.max() == 0:
+ return get_center_slices(vol)
+ z = int(np.argmax(vol.sum(axis=(1, 2))))
+ x = int(np.argmax(vol.sum(axis=(0, 2))))
+ y = int(np.argmax(vol.sum(axis=(0, 1))))
+ return vol[z, :, :], vol[:, x, :], vol[:, :, y]
+
+ # Get slices - use content-centered slices for aligned volume
+ orig_slices = get_center_slices(vol_original)
+ aligned_slices = find_content_center_slices(vol_aligned)
+ allen_slices = get_center_slices(allen_template)
+
+ orig_vmin, orig_vmax = get_display_range(vol_original)
+ align_vmin, align_vmax = get_display_range(vol_aligned)
+ allen_vmin, allen_vmax = get_display_range(allen_template)
+
+ # Create figure
+ fig, axes = plt.subplots(3, 3, figsize=(18, 18))
+ fig.suptitle("Alignment Preview: Original vs Aligned vs Allen Template (Reference)", fontsize=16)
+
+ # All three volumes are in standard RAS, numpy (S, A, R):
+ # dim0=S (Superior), dim1=A (Anterior), dim2=R (Right).
+ # Slicing → anatomical plane:
+ # vol[z, :, :] fixes S → AXIAL (rows=A, cols=R)
+ # vol[:, y, :] fixes A → CORONAL (rows=S, cols=R)
+ # vol[:, :, x] fixes R → SAGITTAL (rows=S, cols=A)
+ plane_names = ["Axial (AR)", "Coronal (SR)", "Sagittal (SA)"]
+
+ _debug_log(
+ "create_alignment_preview: shapes & labels",
+ original_shape=list(vol_original.shape),
+ aligned_shape=list(vol_aligned.shape),
+ allen_shape=list(allen_template.shape),
+ plane_names=plane_names,
+ )
+
+ for row, plane_name in enumerate(plane_names):
+ # Original - use .T for row 0 (XY plane) to match display convention
+ data = orig_slices[row].T if row == 0 else orig_slices[row][::-1, :]
+ axes[row, 0].imshow(data, cmap="gray", origin="lower", vmin=orig_vmin, vmax=orig_vmax)
+ axes[row, 0].set_title(f"Original - {plane_name}")
+ axes[row, 0].axis("off")
+
+ # Aligned
+ data = aligned_slices[row].T if row == 0 else aligned_slices[row][::-1, :]
+ axes[row, 1].imshow(data, cmap="gray", origin="lower", vmin=align_vmin, vmax=align_vmax)
+ axes[row, 1].set_title(f"Aligned - {plane_name}")
+ axes[row, 1].axis("off")
+
+ data = allen_slices[row].T if row == 0 else allen_slices[row][::-1, :]
+ axes[row, 2].imshow(data, cmap="gray", origin="lower", vmin=allen_vmin, vmax=allen_vmax)
+ axes[row, 2].set_title(f"Allen {allen_resolution}µm - {plane_name}")
+ axes[row, 2].axis("off")
+
+ # Add info text
+ info_text = (
+ f"Original shape: {vol_original.shape}\nAligned shape: {vol_aligned.shape}\nAllen shape: {allen_template.shape}"
+ )
+ bbox_props = {"boxstyle": "round", "facecolor": "wheat", "alpha": 0.5}
+ fig.text(0.02, 0.02, info_text, fontsize=10, family="monospace", bbox=bbox_props)
+
+ plt.tight_layout()
+ Path(preview_path).parent.mkdir(parents=True, exist_ok=True)
+ fig.savefig(preview_path, dpi=150, bbox_inches="tight")
+ plt.close(fig)
+ update_pbar()
+
+ print(f"Alignment preview saved to: {preview_path}")
+
+
+# =============================================================================
+# Main entry point
+# =============================================================================
+
+
+def create_orientation_preview(
+ input_path: Path,
+ preview_path: str,
+ level: int = 0,
+ orientation_permutation: tuple | None = None,
+ orientation_flips: tuple | None = None,
+ initial_rotation_deg: tuple = (0.0, 0.0, 0.0),
+) -> None:
+ """
+ Save a 3-panel orthogonal preview of the volume after orientation correction and initial rotation are applied.
+
+ Axes are labelled in RAS space (Z=S, X=R, Y=A) so the result can be
+ inspected directly against the Allen atlas orientation.
+
+ Parameters
+ ----------
+ input_path: Path
+ Path to input OME-Zarr.
+ preview_path : str
+ Output PNG path.
+ level : int
+ Pyramid level to load (lower = higher resolution but slower).
+ orientation_permutation : tuple, optional
+ Axis permutation from ``parse_orientation_code``.
+ orientation_flips : tuple, optional
+ Axis flips from ``parse_orientation_code``.
+ initial_rotation_deg : tuple of float
+ (Rx, Ry, Rz) initial rotation angles in degrees applied after orientation.
+ """
+ vol_zarr, resolution = read_omezarr(input_path, level=level)
+ vol = np.asarray(vol_zarr[:]).astype(np.float32)
+
+ # Apply orientation permutation + flips
+ if orientation_permutation is not None:
+ vol = apply_orientation_transform(vol, orientation_permutation, orientation_flips)
+ resolution = list(reorder_resolution(tuple(resolution), orientation_permutation))
+
+ # Apply initial rotation via SimpleITK (same path as the registration uses)
+ if any(r != 0.0 for r in initial_rotation_deg):
+ vol_sitk = allen.numpy_to_sitk_image(vol, resolution, cast_dtype=np.float32)
+ center = vol_sitk.TransformContinuousIndexToPhysicalPoint([s / 2.0 for s in vol_sitk.GetSize()])
+ rx, ry, rz = [np.deg2rad(a) for a in initial_rotation_deg]
+ t = sitk.Euler3DTransform()
+ t.SetCenter(center)
+ t.SetRotation(rx, ry, rz)
+ resampler = sitk.ResampleImageFilter()
+ resampler.SetReferenceImage(vol_sitk)
+ resampler.SetTransform(t.GetInverse())
+ resampler.SetInterpolator(sitk.sitkLinear)
+ vol = sitk.GetArrayFromImage(resampler.Execute(vol_sitk))
+
+ # Display range from non-zero voxels
+ nonzero = vol[vol > 0]
+ vmin, vmax = np.percentile(nonzero if len(nonzero) else vol.ravel(), [1, 99])
+
+ # Build title
+ applied = []
+ if orientation_permutation is not None:
+ applied.append("orientation")
+ if any(r != 0.0 for r in initial_rotation_deg):
+ applied.append(f"rotation {list(initial_rotation_deg)}°")
+ subtitle = f"({', '.join(applied)} applied)" if applied else "(no corrections applied)"
+
+ z_mid = vol.shape[0] // 2
+ y_mid = vol.shape[1] // 2
+ x_mid = vol.shape[2] // 2
+
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
+ fig.suptitle(
+ f"Orientation Preview — {subtitle}\n"
+ f"Shape: {vol.shape} | After corrections: dim0=S (Superior), dim1=R (Right), dim2=A (Anterior)",
+ fontsize=11,
+ )
+
+ # After apply_orientation_transform the volume is in linumpy convention
+ # (S, R, A): dim0=S (Superior), dim1=R (Right), dim2=A (Anterior).
+ # Slicing → anatomical plane:
+ # vol[z, :, :] fixes S → AXIAL (rows=R, cols=A)
+ # vol[:, y, :] fixes R → SAGITTAL (rows=S, cols=A)
+ # vol[:, :, x] fixes A → CORONAL (rows=S, cols=R)
+ # `.T` on the axial view + row reversal on the others orients the figure
+ # so Superior is up and Right/Anterior point in the natural directions.
+ axes[0].imshow(vol[z_mid, :, :].T, cmap="gray", origin="lower", vmin=vmin, vmax=vmax)
+ axes[0].set_title(f"Axial (dim0=S={z_mid})")
+ axes[0].set_xlabel("dim1=R (← L R →)")
+ axes[0].set_ylabel("dim2=A (← P A →)")
+
+ axes[1].imshow(vol[::-1, y_mid, :], cmap="gray", origin="lower", vmin=vmin, vmax=vmax)
+ axes[1].set_title(f"Sagittal (dim1=R={y_mid})")
+ axes[1].set_xlabel("dim2=A (← P A →)")
+ axes[1].set_ylabel("dim0=S (← I S →)")
+
+ axes[2].imshow(vol[::-1, :, x_mid], cmap="gray", origin="lower", vmin=vmin, vmax=vmax)
+ axes[2].set_title(f"Coronal (dim2=A={x_mid})")
+ axes[2].set_xlabel("dim1=R (← L R →)")
+ axes[2].set_ylabel("dim0=S (← I S →)")
+
+ _debug_log(
+ "create_orientation_preview: slicing decisions",
+ vol_shape=list(vol.shape),
+ panels=[
+ {"axes": 0, "slice": f"vol[{z_mid}, :, :].T", "fixed_axis": "dim0=S", "plane": "Axial"},
+ {"axes": 1, "slice": f"vol[::-1, {y_mid}, :]", "fixed_axis": "dim1=R", "plane": "Sagittal"},
+ {"axes": 2, "slice": f"vol[::-1, :, {x_mid}]", "fixed_axis": "dim2=A", "plane": "Coronal"},
+ ],
+ )
+
+ plt.tight_layout()
+ Path(preview_path).parent.mkdir(parents=True, exist_ok=True)
+ fig.savefig(preview_path, dpi=150, bbox_inches="tight")
+ plt.close(fig)
+ print(f"Orientation preview saved to: {preview_path}")
+
+
+# =============================================================================
+# Main entry point
+# =============================================================================
+
+
+def main() -> None:
+ """Run the script. parse arguments and run alignment workflow."""
+ parser = _build_arg_parser()
+ args = parser.parse_args()
+
+ input_path = Path(args.input_zarr)
+ output_path = Path(args.output_zarr)
+
+ if not input_path.exists():
+ raise FileNotFoundError(f"Input zarr not found: {input_path}")
+
+ # Preview-only mode
+ if args.preview_only:
+ preview_path = Path(args.preview) if args.preview else Path("input_preview.png")
+ create_input_preview(input_path, preview_path, level=args.level)
+ return
+
+ # Parse orientation
+ orientation_permutation = None
+ orientation_flips = None
+ if args.input_orientation:
+ try:
+ orientation_permutation, orientation_flips = parse_orientation_code(args.input_orientation)
+ print(f"Input orientation '{args.input_orientation}':")
+ print(f" Axis permutation: {orientation_permutation}")
+ print(f" Axis flips: {orientation_flips}")
+ except ValueError as e:
+ parser.error(str(e))
+
+ # Orientation + initial-rotation preview (can exit before registration)
+ if args.orientation_preview or args.orientation_preview_only:
+ preview_out = args.orientation_preview or "orientation_preview.png"
+ create_orientation_preview(
+ input_path,
+ preview_out,
+ level=args.level,
+ orientation_permutation=orientation_permutation,
+ orientation_flips=orientation_flips,
+ initial_rotation_deg=tuple(args.initial_rotation),
+ )
+ if args.orientation_preview_only:
+ return
+
+ # Load input volume
+ vol_zarr, zarr_resolution = read_omezarr(Path(input_path), level=args.level)
+ resolution = tuple(zarr_resolution)
+
+ # Progress bar - allocate steps for each phase
+ registration_steps = 3 # Steps allocated for registration progress
+ base_steps = 2 if args.store_transform_only else 5 # Load + save steps
+ total_steps = base_steps + registration_steps
+ if args.preview:
+ total_steps += 4
+ pbar = tqdm(total=total_steps, desc="Aligning to RAS")
+
+ vol = np.asarray(vol_zarr[:])
+ pbar.update(1)
+
+ if args.verbose:
+ print(f"Volume shape: {vol.shape}, Resolution: {resolution} mm")
+
+ # Apply orientation correction for registration
+ if orientation_permutation is not None:
+ vol = apply_orientation_transform(vol, orientation_permutation, orientation_flips)
+ resolution = reorder_resolution(resolution, orientation_permutation)
+
+ # Create progress callback for registration
+ registration_start_step = pbar.n
+ progress_callback = create_registration_progress_callback(
+ max_iterations=args.max_iterations,
+ n_resolution_levels=3,
+ pbar=pbar,
+ registration_start_step=registration_start_step,
+ registration_steps=registration_steps,
+ )
+
+ # Register to Allen atlas
+ pbar.set_postfix_str("registering...")
+ transform, stop_condition, error = allen.register_3d_rigid_to_allen(
+ moving_image=vol,
+ moving_spacing=resolution,
+ allen_resolution=args.allen_resolution,
+ metric=args.metric,
+ max_iterations=args.max_iterations,
+ verbose=args.verbose,
+ progress_callback=progress_callback,
+ initial_rotation_deg=tuple(args.initial_rotation),
+ )
+ # Ensure progress bar reaches end of registration steps
+ pbar.n = registration_start_step + registration_steps
+ pbar.refresh()
+
+ print(f"Registration complete: {stop_condition}")
+ print(f"Final metric value: {error:.6f}")
+ del vol # free registration-level volume before loading full-resolution data
+
+ # Apply or store transform
+ if args.store_transform_only:
+ store_transform_in_metadata(input_path, transform)
+ pbar.update(1)
+ else:
+ apply_transform_to_zarr(
+ input_path,
+ output_path,
+ transform,
+ chunks=tuple(args.chunks) if args.chunks else None,
+ n_levels=args.n_levels,
+ pyramid_resolutions=args.pyramid_resolutions,
+ make_isotropic=args.make_isotropic,
+ orientation_permutation=orientation_permutation,
+ orientation_flips=orientation_flips,
+ pbar=pbar,
+ )
+ print(f"Aligned volume saved to: {output_path}")
+
+ # Save transform file
+ # Strip the compound .ome.zarr extension (Path.stem only removes the last suffix)
+ stem = output_path.with_suffix("").with_suffix("").name
+ transform_path = output_path.parent / f"{stem}_transform.tfm"
+ sitk.WriteTransform(transform, str(transform_path))
+ print(f"Transform saved to: {transform_path}")
+ pbar.update(1)
+
+ # Generate preview
+ if args.preview:
+ pbar.set_postfix_str("generating preview...")
+ create_alignment_preview(
+ input_path,
+ output_path if not args.store_transform_only else None,
+ transform,
+ resolution,
+ args.preview,
+ allen_resolution=args.allen_resolution,
+ level=args.level,
+ orientation_permutation=orientation_permutation,
+ orientation_flips=orientation_flips,
+ pbar=pbar,
+ )
+
+ pbar.set_postfix_str("complete")
+ pbar.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/linum_analyze_shifts.py b/scripts/linum_analyze_shifts.py
deleted file mode 100644
index 03b97e7d..00000000
--- a/scripts/linum_analyze_shifts.py
+++ /dev/null
@@ -1,298 +0,0 @@
-#!/usr/bin/env python3
-"""
-Analyze XY shifts from a shifts file and generate a drift analysis report.
-
-Produces:
-- Summary statistics of pairwise shifts
-- Outlier detection using IQR method
-- Cumulative drift analysis
-- Visualization of drift patterns
-
-Useful for debugging alignment issues and understanding sample drift during acquisition.
-"""
-
-import linumpy.config.threads # noqa: F401
-
-import argparse
-import logging
-from pathlib import Path
-from typing import Any
-
-import matplotlib.pyplot as plt
-import numpy as np
-import pandas as pd
-
-from linumpy.cli.args import add_overwrite_arg, assert_output_exists
-
-logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
-logger = logging.getLogger(__name__)
-
-
-def _build_arg_parser() -> argparse.ArgumentParser:
- p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
- p.add_argument("in_shifts", help="Input shifts CSV file (shifts_xy.csv)")
- p.add_argument("out_directory", help="Output directory for analysis results")
-
- p.add_argument(
- "--resolution", type=float, default=10.0, help="Resolution in µm/pixel for converting mm to pixels [%(default)s]"
- )
- p.add_argument("--iqr_multiplier", type=float, default=1.5, help="IQR multiplier for outlier detection [%(default)s]")
- p.add_argument("--slice_config", default=None, help="Optional slice config file to mark excluded slices")
-
- add_overwrite_arg(p)
- return p
-
-
-def load_shifts(shifts_path: Path) -> Any:
- """Load shifts CSV file.
-
- Rows are sorted by ``moving_id`` so that every ``cumsum`` downstream
- reflects slice order rather than CSV row order.
- """
- df = pd.read_csv(shifts_path)
- required_cols = ["fixed_id", "moving_id", "x_shift_mm", "y_shift_mm"]
- for col in required_cols:
- if col not in df.columns:
- raise ValueError(f"Missing required column: {col}")
- return df.sort_values("moving_id").reset_index(drop=True)
-
-
-def detect_outliers(df: Any, iqr_multiplier: float = 1.5) -> Any:
- """Detect outliers using IQR method on shift magnitude."""
- shift_mag = np.sqrt(df["x_shift_mm"] ** 2 + df["y_shift_mm"] ** 2)
- q1 = shift_mag.quantile(0.25)
- q3 = shift_mag.quantile(0.75)
- iqr = q3 - q1
- upper_bound = q3 + iqr_multiplier * iqr
- outlier_mask = shift_mag > upper_bound
- return outlier_mask, upper_bound, q1, q3, iqr
-
-
-def filter_with_local_median(df: Any, outlier_mask: Any) -> Any:
- """Replace outliers with local median of neighbors."""
- df_filtered = df.copy()
- for idx in df[outlier_mask].index:
- pos = df.index.get_loc(idx)
- neighbors_x, neighbors_y = [], []
- for offset in [-2, -1, 1, 2]:
- neighbor_pos = pos + offset
- if 0 <= neighbor_pos < len(df):
- neighbor_idx = df.index[neighbor_pos]
- if not outlier_mask[neighbor_idx]:
- neighbors_x.append(df.loc[neighbor_idx, "x_shift_mm"])
- neighbors_y.append(df.loc[neighbor_idx, "y_shift_mm"])
- if neighbors_x:
- df_filtered.loc[idx, "x_shift_mm"] = np.median(neighbors_x)
- df_filtered.loc[idx, "y_shift_mm"] = np.median(neighbors_y)
- return df_filtered
-
-
-def generate_report(df: Any, df_filtered: Any, outlier_mask: Any, stats: Any, resolution: Any, output_dir: Path) -> str:
- """Generate text report."""
- px_per_mm = 1000 / resolution
-
- report_lines = [
- "=" * 60,
- "SHIFTS ANALYSIS REPORT",
- "=" * 60,
- "",
- "OVERVIEW",
- "-" * 40,
- f"Total shift pairs: {len(df)}",
- f"Resolution: {resolution} µm/pixel",
- "",
- "PAIRWISE SHIFT STATISTICS",
- "-" * 40,
- f"X shift (mm): Mean={df['x_shift_mm'].mean():.4f}, Std={df['x_shift_mm'].std():.4f}",
- f"Y shift (mm): Mean={df['y_shift_mm'].mean():.4f}, Std={df['y_shift_mm'].std():.4f}",
- "",
- "OUTLIER DETECTION (IQR Method)",
- "-" * 40,
- f"Q1={stats['q1']:.4f}, Q3={stats['q3']:.4f}, IQR={stats['iqr']:.4f}",
- f"Upper bound: {stats['upper_bound']:.4f} mm",
- f"Outliers detected: {outlier_mask.sum()}",
- ]
-
- if outlier_mask.sum() > 0:
- report_lines.append("")
- report_lines.append("Outlier shifts:")
- shift_mag = np.sqrt(df["x_shift_mm"] ** 2 + df["y_shift_mm"] ** 2)
- for idx in df[outlier_mask].index:
- row = df.loc[idx]
- mag = shift_mag[idx]
- report_lines.append(
- f" {int(row['fixed_id'])}->{int(row['moving_id'])}: "
- f"({row['x_shift_mm']:.3f}, {row['y_shift_mm']:.3f}) mm, mag={mag:.3f} mm"
- )
-
- # Cumulative drift
- cumsum_x_orig = df["x_shift_mm"].cumsum()
- cumsum_y_orig = df["y_shift_mm"].cumsum()
- cumsum_x_filt = df_filtered["x_shift_mm"].cumsum()
- cumsum_y_filt = df_filtered["y_shift_mm"].cumsum()
-
- report_lines.extend(
- [
- "",
- "CUMULATIVE DRIFT",
- "-" * 40,
- f"Before filtering: X={cumsum_x_orig.iloc[-1]:.3f} mm, Y={cumsum_y_orig.iloc[-1]:.3f} mm",
- f"After filtering: X={cumsum_x_filt.iloc[-1]:.3f} mm, Y={cumsum_y_filt.iloc[-1]:.3f} mm",
- "",
- f"In pixels (at {resolution} µm/pixel):",
- f" Before: X={cumsum_x_orig.iloc[-1] * px_per_mm:.0f} px, Y={cumsum_y_orig.iloc[-1] * px_per_mm:.0f} px",
- f" After: X={cumsum_x_filt.iloc[-1] * px_per_mm:.0f} px, Y={cumsum_y_filt.iloc[-1] * px_per_mm:.0f} px",
- ]
- )
-
- # Centered drift
- mid_idx = len(cumsum_x_filt) // 2
- centered_x = cumsum_x_filt - cumsum_x_filt.iloc[mid_idx]
- centered_y = cumsum_y_filt - cumsum_y_filt.iloc[mid_idx]
-
- report_lines.extend(
- [
- "",
- f"CENTERED DRIFT (around slice {mid_idx})",
- "-" * 40,
- f"X range: {centered_x.min() * px_per_mm:.0f} to {centered_x.max() * px_per_mm:.0f} px",
- f"Y range: {centered_y.min() * px_per_mm:.0f} to {centered_y.max() * px_per_mm:.0f} px",
- "",
- "=" * 60,
- ]
- )
-
- report_text = "\n".join(report_lines)
-
- # Save report
- report_path = Path(output_dir) / "shifts_analysis.txt"
- with report_path.open("w") as f:
- f.write(report_text)
-
- return report_text
-
-
-def generate_plots(df: Any, df_filtered: Any, _outlier_mask: Any, stats: Any, resolution: Any, output_dir: Path) -> Path:
- """Generate visualization plots."""
- px_per_mm = 1000 / resolution
- upper_bound = stats["upper_bound"]
-
- # Calculate cumulative drift
- cumsum_x_orig = df["x_shift_mm"].cumsum()
- cumsum_y_orig = df["y_shift_mm"].cumsum()
- cumsum_x_filt = df_filtered["x_shift_mm"].cumsum()
- cumsum_y_filt = df_filtered["y_shift_mm"].cumsum()
-
- mid_idx = len(cumsum_x_filt) // 2
- centered_x = cumsum_x_filt - cumsum_x_filt.iloc[mid_idx]
- centered_y = cumsum_y_filt - cumsum_y_filt.iloc[mid_idx]
-
- # Create figure
- fig, axes = plt.subplots(2, 2, figsize=(14, 10))
-
- # Plot 1: Pairwise shifts
- ax = axes[0, 0]
- ax.plot(df["moving_id"], df["x_shift_mm"], "b.-", label="X shift (original)", alpha=0.7)
- ax.plot(df["moving_id"], df["y_shift_mm"], "r.-", label="Y shift (original)", alpha=0.7)
- ax.plot(df["moving_id"], df_filtered["x_shift_mm"], "b-", label="X shift (filtered)", linewidth=2)
- ax.plot(df["moving_id"], df_filtered["y_shift_mm"], "r-", label="Y shift (filtered)", linewidth=2)
- ax.axhline(y=0, color="k", linestyle="--", alpha=0.3)
- ax.axhline(y=upper_bound, color="g", linestyle=":", label=f"IQR threshold ({upper_bound:.2f}mm)")
- ax.axhline(y=-upper_bound, color="g", linestyle=":")
- ax.set_xlabel("Slice ID")
- ax.set_ylabel("Shift (mm)")
- ax.set_title("Pairwise Shifts")
- ax.legend(fontsize=8)
- ax.grid(True, alpha=0.3)
-
- # Plot 2: Cumulative drift
- ax = axes[0, 1]
- ax.plot(df["moving_id"], cumsum_x_orig, "b--", label="X original", alpha=0.5)
- ax.plot(df["moving_id"], cumsum_y_orig, "r--", label="Y original", alpha=0.5)
- ax.plot(df["moving_id"], cumsum_x_filt, "b-", label="X filtered", linewidth=2)
- ax.plot(df["moving_id"], cumsum_y_filt, "r-", label="Y filtered", linewidth=2)
- ax.axhline(y=0, color="k", linestyle="--", alpha=0.3)
- ax.set_xlabel("Slice ID")
- ax.set_ylabel("Cumulative Drift (mm)")
- ax.set_title("Cumulative Drift")
- ax.legend()
- ax.grid(True, alpha=0.3)
-
- # Plot 3: Centered cumulative drift in pixels
- ax = axes[1, 0]
- ax.plot(df["moving_id"], centered_x * px_per_mm, "b-", label="X (centered)", linewidth=2)
- ax.plot(df["moving_id"], centered_y * px_per_mm, "r-", label="Y (centered)", linewidth=2)
- ax.axhline(y=0, color="k", linestyle="--", alpha=0.3)
- ax.set_xlabel("Slice ID")
- ax.set_ylabel(f"Centered Drift (pixels at {resolution}µm)")
- ax.set_title("Centered Cumulative Drift")
- ax.legend()
- ax.grid(True, alpha=0.3)
-
- # Plot 4: Drift trajectory
- ax = axes[1, 1]
- ax.plot(cumsum_x_filt * px_per_mm, cumsum_y_filt * px_per_mm, "g-", linewidth=2)
- ax.plot(cumsum_x_filt.iloc[0] * px_per_mm, cumsum_y_filt.iloc[0] * px_per_mm, "go", markersize=10, label="Start")
- ax.plot(cumsum_x_filt.iloc[-1] * px_per_mm, cumsum_y_filt.iloc[-1] * px_per_mm, "ro", markersize=10, label="End")
- ax.plot(
- cumsum_x_filt.iloc[mid_idx] * px_per_mm, cumsum_y_filt.iloc[mid_idx] * px_per_mm, "ko", markersize=10, label="Middle"
- )
- ax.set_xlabel("X position (pixels)")
- ax.set_ylabel("Y position (pixels)")
- ax.set_title("Drift Trajectory (filtered)")
- ax.legend()
- ax.grid(True, alpha=0.3)
- ax.axis("equal")
-
- plt.tight_layout()
-
- # Save plot
- plot_path = Path(output_dir) / "drift_analysis.png"
- fig.savefig(plot_path, dpi=150, bbox_inches="tight")
- plt.close(fig)
-
- logger.info("Saved plot: %s", plot_path)
- return plot_path
-
-
-def main() -> None:
- """Run function."""
- parser = _build_arg_parser()
- args = parser.parse_args()
-
- # Create output directory
- assert_output_exists(args.out_directory, parser, args)
- Path(args.out_directory).mkdir(parents=True)
-
- # Load shifts
- logger.info("Loading shifts from %s", args.in_shifts)
- df = load_shifts(args.in_shifts)
- logger.info("Loaded %s shift pairs", len(df))
-
- # Detect outliers
- outlier_mask, upper_bound, q1, q3, iqr = detect_outliers(df, args.iqr_multiplier)
- logger.info("Detected %s outliers (IQR bound: %.3f mm)", outlier_mask.sum(), upper_bound)
-
- # Filter outliers
- df_filtered = filter_with_local_median(df, outlier_mask)
-
- # Statistics
- stats = {"q1": q1, "q3": q3, "iqr": iqr, "upper_bound": upper_bound}
-
- # Generate report
- report = generate_report(df, df_filtered, outlier_mask, stats, args.resolution, args.out_directory)
- print(report)
-
- # Generate plots
- generate_plots(df, df_filtered, outlier_mask, stats, args.resolution, args.out_directory)
-
- # Save filtered shifts (useful for debugging)
- filtered_path = Path(args.out_directory) / "shifts_filtered.csv"
- df_filtered.to_csv(filtered_path, index=False)
- logger.info("Saved filtered shifts: %s", filtered_path)
-
- logger.info("Analysis complete. Results saved to %s", args.out_directory)
-
-
-if __name__ == "__main__":
- main()
diff --git a/scripts/linum_assess_slice_quality.py b/scripts/linum_assess_slice_quality.py
deleted file mode 100644
index c1867341..00000000
--- a/scripts/linum_assess_slice_quality.py
+++ /dev/null
@@ -1,400 +0,0 @@
-#!/usr/bin/env python3
-"""
-Assess slice quality for 3D mosaic grids and optionally update slice configuration.
-
-This script analyzes mosaic grid slices to detect quality issues that might affect
-reconstruction. It uses multiple metrics to identify problematic slices:
-
-1. **SSIM (Structural Similarity)**: Compares each slice to its neighbors
-2. **Edge Preservation**: Detects if edge structures are preserved compared to neighbors
-3. **Variance Consistency**: Checks for unusual signal variance (data loss/corruption)
-4. **First Slice Detection**: Automatically identifies calibration slices (thicker/different)
-
-GPU acceleration is used when available (--use_gpu, default on) for SSIM and
-edge-detection computations. Falls back to CPU automatically if no GPU is detected.
-
-The output can be:
-- A new slice_config.csv with quality scores and recommendations
-- An update to an existing slice_config.csv with quality assessments
-- A quality report for review
-
-Example usage:
- # Assess quality and create/update slice config
- linum_assess_slice_quality.py /path/to/mosaics slice_config.csv
-
- # Assess and exclude low-quality slices automatically
- linum_assess_slice_quality.py /path/to/mosaics slice_config.csv --min_quality 0.3
-
- # Exclude first N calibration slices
- linum_assess_slice_quality.py /path/to/mosaics slice_config.csv --exclude_first 1
-
- # Update existing config with quality info
- linum_assess_slice_quality.py /path/to/mosaics slice_config.csv --update_existing
-
- # Force CPU fallback
- linum_assess_slice_quality.py /path/to/mosaics slice_config.csv --no-use_gpu
-"""
-
-from __future__ import annotations
-
-# Configure thread limits before numpy/scipy imports
-import linumpy.config.threads # noqa: F401
-
-import argparse
-import re
-from pathlib import Path
-from typing import TYPE_CHECKING, Any
-
-from tqdm.auto import tqdm
-
-if TYPE_CHECKING:
- import numpy as np
-
-from linumpy.cli.args import add_overwrite_arg, assert_output_exists
-from linumpy.gpu import GPU_AVAILABLE
-from linumpy.gpu.image_quality import (
- assess_slice_quality_gpu,
- clear_gpu_memory,
-)
-from linumpy.io import slice_config as slice_config_io
-from linumpy.io.zarr import read_omezarr
-from linumpy.metrics.image_quality import (
- assess_slice_quality,
- detect_calibration_slice,
-)
-
-
-def _build_arg_parser() -> argparse.ArgumentParser:
- p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
- p.add_argument("input", help="Input directory containing mosaic grids (*.ome.zarr)")
- p.add_argument("output_file", help="Output slice configuration CSV file")
-
- gpu_group = p.add_argument_group("GPU Options")
- gpu_group.add_argument(
- "--use_gpu",
- default=True,
- action=argparse.BooleanOptionalAction,
- help="Use GPU acceleration if available. [%(default)s]",
- )
- gpu_group.add_argument("--gpu_id", type=int, default=0, help="GPU device ID to use. [%(default)s]")
-
- quality_group = p.add_argument_group("Quality Assessment")
- quality_group.add_argument(
- "--min_quality",
- type=float,
- default=0.0,
- help="Minimum quality score to include slice (0-1). Default: 0.0 (include all, just report)",
- )
- quality_group.add_argument(
- "--sample_depth",
- type=int,
- default=5,
- help="Number of z-planes to sample per slice for faster assessment. Default: 5 (0=all)",
- )
- quality_group.add_argument(
- "--pyramid_level",
- type=int,
- default=0,
- help="Pyramid level to use for assessment (0=full res). Higher levels are faster but less accurate. Default: 0",
- )
- quality_group.add_argument(
- "--roi_size",
- type=int,
- default=0,
- help="Side length of center crop in XY (pixels) used for "
- "all quality metrics. 0 = full plane (slow for large "
- "single-resolution mosaics). Recommended: 1024.",
- )
- quality_group.add_argument(
- "--processes",
- type=int,
- default=1,
- help="Number of parallel workers for slice assessment (CPU mode only).\n"
- "Each worker reads its own zarr planes concurrently.\n"
- "Default: 1 (sequential). Set to params.processes.",
- )
-
- calib_group = p.add_argument_group("Calibration Slice Handling")
- calib_group.add_argument(
- "--exclude_first",
- type=int,
- default=1,
- help="Exclude first N slices as calibration slices. Default: 1 (first slice is usually calibration)",
- )
- calib_group.add_argument(
- "--detect_calibration",
- action="store_true",
- help="Automatically detect calibration slices by their different thickness/structure",
- )
- calib_group.add_argument(
- "--calibration_thickness_ratio",
- type=float,
- default=1.5,
- help="Slices with thickness ratio > this are flagged as calibration. Default: 1.5",
- )
-
- update_group = p.add_argument_group("Update Existing Config")
- update_group.add_argument(
- "--update_existing", action="store_true", help="Update an existing slice_config.csv with quality info"
- )
- update_group.add_argument("--existing_config", type=str, default=None, help="Path to existing slice config to update")
-
- output_group = p.add_argument_group("Output Options")
- output_group.add_argument("--report_only", action="store_true", help="Only print report, don't write config file")
- output_group.add_argument("-v", "--verbose", action="store_true", help="Print detailed quality metrics per slice")
-
- add_overwrite_arg(p)
- return p
-
-
-def get_mosaic_files(directory: Path) -> dict[int, Path]:
- """Find all mosaic grid files and extract slice IDs."""
- pattern = r".*z(\d+).*\.ome\.zarr$"
- mosaics = {}
-
- for f in directory.iterdir():
- if f.is_dir() and f.suffix == ".zarr":
- match = re.match(pattern, f.name)
- if match:
- slice_id = int(match.group(1))
- mosaics[slice_id] = f
-
- return dict(sorted(mosaics.items()))
-
-
-def read_existing_config(config_path: Path) -> dict[int, dict[str, Any]]:
- """Read an existing slice configuration file keyed by integer ``slice_id``."""
- rows = slice_config_io.read(config_path)
- return {int(sid): dict(row) for sid, row in rows.items()}
-
-
-def write_slice_config_with_quality(
- output_file: Path,
- slice_ids: list[int],
- quality_results: dict[int, dict[str, Any]],
- exclude_ids: list[int],
- existing_config: dict[int, dict[str, Any]] | None = None,
-) -> None:
- """Write ``slice_config.csv`` with the decision columns set from the quality.
-
- assessment. Raw per-metric scores (ssim_mean / edge_score / variance_score /
- depth) intentionally stay out of the CSV — they live in the pipeline report
- and per-stage diagnostics JSON, not in the per-slice decision trace.
- """
- out_rows: list[dict[str, object]] = []
- for slice_id in slice_ids:
- quality = quality_results.get(slice_id, {})
- use = "true"
- reason = ""
- if slice_id in exclude_ids:
- use = "false"
- if quality.get("is_calibration", False):
- reason = "calibration_slice"
- elif quality.get("overall", 1.0) < quality.get("min_threshold", 0):
- reason = "low_quality"
- elif quality.get("exclude_first", False):
- reason = "first_slice_excluded"
- else:
- reason = "manually_excluded"
-
- existing = existing_config.get(slice_id, {}) if existing_config else {}
- if existing.get("use", "true").lower() in ["false", "0", "no"]:
- use = "false"
- if not reason:
- reason = existing.get("exclude_reason") or existing.get("notes") or "previously_excluded"
-
- row: dict[str, object] = {
- "slice_id": f"{slice_id:02d}",
- "use": use,
- "quality_score": f"{float(quality.get('overall', 0.0)):.3f}",
- "exclude_reason": reason,
- }
- if existing.get("galvo_confidence", ""):
- row["galvo_confidence"] = existing["galvo_confidence"]
- if existing.get("galvo_fix", ""):
- row["galvo_fix"] = existing["galvo_fix"]
- for carry in ("notes",):
- val = existing.get(carry)
- if val:
- row[carry] = val
- out_rows.append(row)
-
- slice_config_io.write(output_file, out_rows)
-
-
-def main() -> None:
- """Run function operation."""
- p = _build_arg_parser()
- args = p.parse_args()
-
- input_path = Path(args.input)
- output_file = Path(args.output_file)
-
- if not args.report_only:
- assert_output_exists(output_file, p, args)
-
- if not input_path.is_dir():
- p.error(f"Input directory not found: {input_path}")
-
- use_gpu = args.use_gpu and GPU_AVAILABLE
- if args.use_gpu and not GPU_AVAILABLE:
- print("Warning: GPU requested but not available. Using CPU.")
- elif use_gpu:
- try:
- import cupy as cp
-
- cp.cuda.Device(args.gpu_id).use()
- print(f"Using GPU device {args.gpu_id}")
- except Exception as e:
- print(f"Warning: Could not select GPU {args.gpu_id}: {e}. Using default.")
-
- print(f"Scanning for mosaic grids in: {input_path}")
- mosaic_files = get_mosaic_files(input_path)
-
- if not mosaic_files:
- p.error(f"No mosaic grid files found in {input_path}")
-
- slice_ids = sorted(mosaic_files.keys())
- print(f"Found {len(slice_ids)} slices: {[f'{s:02d}' for s in slice_ids]}")
-
- existing_config = None
- if args.update_existing:
- config_to_load = args.existing_config if args.existing_config else output_file
- if Path(config_to_load).exists():
- existing_config = read_existing_config(Path(config_to_load))
- print(f"Loaded existing config with {len(existing_config)} entries")
-
- exclude_ids = set()
-
- if args.exclude_first > 0:
- first_slices = slice_ids[: args.exclude_first]
- exclude_ids.update(first_slices)
- print(f"Excluding first {args.exclude_first} slice(s) as calibration: {first_slices}")
-
- print(f"\nLoading slices (pyramid_level={args.pyramid_level})...")
- volumes: dict[int, np.ndarray | None] = {}
- for slice_id in tqdm(slice_ids, desc="Loading slices"):
- try:
- vol, _ = read_omezarr(mosaic_files[slice_id], level=args.pyramid_level)
- volumes[slice_id] = vol
- except Exception as e:
- print(f" Warning: Could not load slice {slice_id:02d}: {e}")
- volumes[slice_id] = None
-
- calibration_slices = []
- if args.detect_calibration:
- print(f"Detecting calibration slices (thickness ratio > {args.calibration_thickness_ratio})...")
- valid_volumes = {sid: vol for sid, vol in volumes.items() if vol is not None}
- calibration_slices = detect_calibration_slice(valid_volumes, args.calibration_thickness_ratio)
- if calibration_slices:
- exclude_ids.update(calibration_slices)
- print(f"Detected calibration slices: {calibration_slices}")
-
- print(f"\nAssessing slice quality (GPU={'enabled' if use_gpu else 'disabled'}, sample_depth={args.sample_depth})...")
- quality_results: dict[int, dict[str, Any]] = {}
-
- if use_gpu:
- for i, slice_id in enumerate(tqdm(slice_ids, desc="Assessing quality")):
- vol = volumes.get(slice_id)
- if vol is None:
- quality_results[slice_id] = {
- "overall": 0.0,
- "ssim_mean": 0.0,
- "edge_score": 0.0,
- "variance_score": 0.0,
- "depth": 0,
- "has_data": False,
- "error": "load_failed",
- }
- continue
- vol_before = volumes.get(slice_ids[i - 1]) if i > 0 else None
- vol_after = volumes.get(slice_ids[i + 1]) if i < len(slice_ids) - 1 else None
- overall, metrics = assess_slice_quality_gpu(vol, vol_before, vol_after, args.sample_depth)
- metrics["is_calibration"] = slice_id in calibration_slices
- metrics["exclude_first"] = slice_id in slice_ids[: args.exclude_first]
- metrics["min_threshold"] = args.min_quality
- quality_results[slice_id] = metrics
- if args.min_quality > 0 and overall < args.min_quality:
- exclude_ids.add(slice_id)
- clear_gpu_memory()
- else:
- from concurrent.futures import ThreadPoolExecutor, as_completed
-
- def _assess_one(idx_and_id: tuple) -> Any:
- i, slice_id = idx_and_id
- vol = volumes.get(slice_id)
- if vol is None:
- return slice_id, {
- "overall": 0.0,
- "ssim_mean": 0.0,
- "edge_score": 0.0,
- "variance_score": 0.0,
- "depth": 0,
- "has_data": False,
- "error": "load_failed",
- }
- vol_before = volumes.get(slice_ids[i - 1]) if i > 0 else None
- vol_after = volumes.get(slice_ids[i + 1]) if i < len(slice_ids) - 1 else None
- _overall, metrics = assess_slice_quality(vol, vol_before, vol_after, args.sample_depth, xy_roi=args.roi_size)
- metrics["is_calibration"] = slice_id in calibration_slices
- metrics["exclude_first"] = slice_id in slice_ids[: args.exclude_first]
- metrics["min_threshold"] = args.min_quality
- return slice_id, metrics
-
- tasks = list(enumerate(slice_ids))
- with ThreadPoolExecutor(max_workers=args.processes) as executor:
- futures = {executor.submit(_assess_one, t): t for t in tasks}
- for future in tqdm(as_completed(futures), total=len(futures), desc="Assessing quality"):
- slice_id, metrics = future.result()
- quality_results[slice_id] = metrics
- if args.min_quality > 0 and metrics.get("overall", 0.0) < args.min_quality:
- exclude_ids.add(slice_id)
-
- print("\n" + "=" * 70)
- print(f"SLICE QUALITY REPORT{' (GPU-accelerated)' if use_gpu else ' (CPU)'}")
- print("=" * 70)
- print(f"{'Slice':<8} {'Quality':<10} {'SSIM':<10} {'Edge':<10} {'Var':<10} {'Depth':<8} {'Status':<15}")
- print("-" * 70)
-
- for slice_id in slice_ids:
- q = quality_results.get(slice_id, {})
- status = []
- if slice_id in exclude_ids:
- if q.get("is_calibration"):
- status.append("CALIBRATION")
- elif q.get("exclude_first"):
- status.append("FIRST_SLICE")
- elif q.get("overall", 1.0) < args.min_quality:
- status.append("LOW_QUALITY")
- else:
- status.append("EXCLUDED")
- else:
- status.append("OK")
-
- status_str = ",".join(status)
- print(
- f"{slice_id:02d} {q.get('overall', 0):.3f} "
- f"{q.get('ssim_mean', 0):.3f} {q.get('edge_score', 0):.3f} "
- f"{q.get('variance_score', 0):.3f} {q.get('depth', 0):<8} {status_str}"
- )
-
- print("-" * 70)
- print(f"Total slices: {len(slice_ids)}")
- print(f"Excluded: {len(exclude_ids)}")
- print(f"Included: {len(slice_ids) - len(exclude_ids)}")
-
- if args.min_quality > 0:
- low_quality = [s for s in slice_ids if quality_results.get(s, {}).get("overall", 1.0) < args.min_quality]
- if low_quality:
- print(f"Low quality slices (< {args.min_quality}): {low_quality}")
-
- if not args.report_only:
- write_slice_config_with_quality(output_file, slice_ids, quality_results, list(exclude_ids), existing_config)
- print(f"\nSlice configuration written to: {output_file}")
-
- if exclude_ids:
- print(f"\nExcluded slice IDs: {sorted(exclude_ids)}")
-
-
-if __name__ == "__main__":
- main()
diff --git a/scripts/linum_generate_pipeline_report.py b/scripts/linum_generate_pipeline_report.py
deleted file mode 100644
index b00c23df..00000000
--- a/scripts/linum_generate_pipeline_report.py
+++ /dev/null
@@ -1,2028 +0,0 @@
-#!/usr/bin/env python3
-"""
-Generate a quality report from pipeline metrics.
-
-This script aggregates metrics from various pipeline steps and generates
-a comprehensive report in HTML or text format to help identify potential
-issues in the 3D reconstruction pipeline.
-"""
-
-# Configure thread limits before numpy/scipy imports
-import linumpy.config.threads # noqa: F401
-
-import argparse
-import base64
-import io as _io
-import json
-import re
-import zipfile
-from collections import defaultdict
-from datetime import datetime
-from pathlib import Path
-
-try:
- from PIL import Image as _PILImage
-
- _PIL_AVAILABLE = True
-except ImportError:
- _PIL_AVAILABLE = False
-
-from typing import Any
-
-import numpy as np
-
-from linumpy.metrics import aggregate_metrics, compute_summary_statistics
-
-# Logical pipeline step ordering
-STEP_ORDER = [
- "stitch_3d",
- "xy_transform_estimation",
- "normalize_intensities",
- "psf_compensation",
- "crop_interface",
- "pairwise_registration",
- "stack_slices",
-]
-
-# Human-readable display names (step_name → display label)
-STEP_DISPLAY_NAMES = {
- "stitch_3d": "Stitch 3D",
- "xy_transform_estimation": "XY Transform Estimation",
- "normalize_intensities": "Normalize Intensities",
- "psf_compensation": "PSF Compensation",
- "crop_interface": "Crop Interface",
- "pairwise_registration": "Pairwise Registration",
- "stack_slices": "Stack Slices",
-}
-
-# Human-readable descriptions for pipeline steps
-STEP_DESCRIPTIONS = {
- "stitch_3d": "Stitches individual mosaic tiles into a single 2D slice.",
- "xy_transform_estimation": "Estimates the affine transformation for tile overlap correction.",
- "normalize_intensities": "Normalizes per-slice intensities using agarose background.",
- "psf_compensation": "Compensates for beam profile / PSF attenuation along the optical axis.",
- "crop_interface": "Detects and crops the tissue-agarose interface.",
- "pairwise_registration": "Registers consecutive serial sections to align the 3D volume.",
- "stack_slices": "Stacks registered slices into the final 3D volume.",
-}
-
-# Maps pipeline step_name → image category shown in that step section
-STEP_PREVIEW_CATEGORY = {
- "stitch_3d": "stitch_preview",
- "pairwise_registration": "common_space_preview",
-}
-
-
-def _build_arg_parser() -> argparse.ArgumentParser:
- p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
- p.add_argument("input_dir", help="Input directory containing pipeline output with metrics files.")
- p.add_argument("output_report", help="Output report file path (.html, .zip, or .txt)")
- p.add_argument(
- "--format",
- choices=["html", "text", "zip", "auto"],
- default="auto",
- help="Output format. 'auto' infers from extension. [%(default)s]",
- )
- p.add_argument("--title", default="Pipeline Quality Report", help="Report title. [%(default)s]")
- p.add_argument("--verbose", action="store_true", help="Include all metric details in the report.")
- p.add_argument(
- "--overview_png", type=Path, default=None, help="Path to the main volume PNG screenshot (embedded in summary)."
- )
- p.add_argument(
- "--annotated_png", type=Path, default=None, help="Path to the annotated volume PNG screenshot (embedded in summary)."
- )
- p.add_argument("--max_overview_width", type=int, default=900, help="Max pixel width for overview images. [%(default)s]")
- p.add_argument("--max_thumb_width", type=int, default=380, help="Max pixel width for gallery thumbnails. [%(default)s]")
- p.add_argument("--no_images", action="store_true", help="Disable image discovery for zip bundles.")
- return p
-
-
-def get_status_color(status: str) -> str:
- """Get HTML color for status."""
- colors = {
- "ok": "#28a745", # green
- "warning": "#ffc107", # yellow/amber
- "error": "#dc3545", # red
- "info": "#17a2b8", # blue
- "unknown": "#6c757d", # gray
- }
- return colors.get(status, colors["unknown"])
-
-
-def get_status_emoji(status: str) -> str:
- """Get emoji for status in text format."""
- emojis = {"ok": "✓", "warning": "⚠", "error": "✗", "info": "ℹ", "unknown": "?"}
- return emojis.get(status, "?")
-
-
-def format_value(value: float, precision: int = 4) -> str:
- """Format a value for display."""
- if isinstance(value, float):
- if abs(value) < 0.0001 or abs(value) > 10000:
- return f"{value:.{precision}e}"
- return f"{value:.{precision}f}"
- elif isinstance(value, list) and len(value) > 5:
- return f"[{len(value)} items]"
- return str(value)
-
-
-def sort_steps(aggregated: dict) -> dict:
- """Sort pipeline steps in logical execution order."""
-
- def step_key(step_name: str) -> Any:
- try:
- return (0, STEP_ORDER.index(step_name))
- except ValueError:
- return (1, step_name)
-
- return dict(sorted(aggregated.items(), key=lambda x: step_key(x[0])))
-
-
-def extract_slice_id(source_file: str) -> str:
- """Extract a meaningful slice identifier from a source file path."""
- path = Path(source_file)
- # Search path components for a slice pattern like z01, z002, slice_3
- for part in reversed(path.parts):
- m = re.search(r"(z\d+|slice_z?\d+)", part, re.IGNORECASE)
- if m:
- return m.group(1)
- return path.stem
-
-
-def parse_issue(issue_str: str) -> dict:
- """Parse an issue string of the form 'source: metric: value op threshold (level)'."""
- parts = issue_str.split(": ", 2)
- if len(parts) < 3:
- return {"source": parts[0] if parts else "", "metric": "", "raw": issue_str, "value": None, "threshold": None}
- source, metric, rest = parts[0], parts[1], parts[2]
- m = re.match(r"([+-]?[\d.e+-]+)\s*([><]=?)\s*([+-]?[\d.e+-]+)", rest)
- if m:
- return {
- "source": source,
- "metric": metric,
- "raw": issue_str,
- "value": float(m.group(1)),
- "op": m.group(2),
- "threshold": float(m.group(3)),
- }
- return {"source": source, "metric": metric, "raw": issue_str, "value": None, "threshold": None}
-
-
-def group_issues(issues: list[str]) -> list[dict]:
- """
- Group issues by metric name.
-
- Returns a list of dicts with keys: metric, count, values, threshold, details.
- """
- groups = defaultdict(list)
- for issue in issues:
- parsed = parse_issue(issue)
- key = parsed["metric"] if parsed["metric"] else "__other__"
- groups[key].append(parsed)
-
- result = []
- for metric, items in groups.items():
- values = [i["value"] for i in items if i.get("value") is not None]
- threshold = items[0].get("threshold") if items else None
- op = items[0].get("op", ">") if items else ">"
- result.append(
- {
- "metric": metric if metric != "__other__" else "",
- "count": len(items),
- "values": values,
- "threshold": threshold,
- "op": op,
- "details": [i["raw"] for i in items],
- }
- )
- return result
-
-
-def separate_metrics_by_type(metrics_list: list[dict]) -> tuple[dict, dict]:
- """
- Separate metrics into quality metrics and info/parameter fields.
-
- Returns
- -------
- tuple
- quality_metrics: {name: {'entries': [{value, status}], 'unit': str}}
- info_fields: {name: {'values': [v], 'description': str, 'is_constant': bool, 'display_value': any}}
- """
- quality_metrics: dict = {}
- info_fields: dict = {}
-
- for m in metrics_list:
- for name, data in m.get("metrics", {}).items():
- if not isinstance(data, dict):
- continue
- status = data.get("status", "ok")
- value = data.get("value")
- unit = data.get("unit") or ""
- desc = data.get("description") or ""
-
- if status == "info":
- if name not in info_fields:
- info_fields[name] = {"values": [], "description": desc, "unit": unit}
- info_fields[name]["values"].append(value)
- else:
- if name not in quality_metrics:
- quality_metrics[name] = {"entries": [], "unit": unit, "description": desc}
- quality_metrics[name]["entries"].append({"value": value, "status": status})
-
- # Determine if each info field is constant across all files
- for info in info_fields.values():
- vals = info["values"]
- try:
- numeric = [v for v in vals if isinstance(v, (int, float))]
- if numeric and len(numeric) == len(vals):
- is_const = float(np.std(numeric)) < 1e-10
- else:
- is_const = len({str(v) for v in vals}) <= 1
- except Exception:
- is_const = len({str(v) for v in vals}) <= 1
- info["is_constant"] = is_const
- info["display_value"] = vals[0] if vals else None
-
- return quality_metrics, info_fields
-
-
-def generate_sparkline_svg(values: list, statuses: list[str] | None = None, width: int = 160, height: int = 36) -> str:
- """Generate an inline SVG bar-chart sparkline for a list of values."""
- numeric = [(i, v) for i, v in enumerate(values) if isinstance(v, (int, float))]
- if len(numeric) < 2:
- return ""
-
- all_vals = [v for _, v in numeric]
- min_val, max_val = min(all_vals), max(all_vals)
- val_range = max_val - min_val or 1.0
-
- if statuses is None:
- statuses = ["ok"] * len(values)
-
- n = len(values)
- bar_w = width / n
- rects = []
- for i, v in numeric:
- h = max(2.0, (v - min_val) / val_range * (height - 4))
- y = height - h
- color = get_status_color(statuses[i]) if i < len(statuses) else get_status_color("ok")
- rects.append(
- f' '
- )
-
- title = f"Min: {min_val:.3g} Max: {max_val:.3g} n={len(numeric)}"
- return (
- f'' + "".join(rects) + " "
- )
-
-
-def generate_trend_line_svg(
- values: list,
- _labels: list[str] | None = None,
- width: int = 420,
- height: int = 90,
- show_trend: bool = True,
- color: str = "#4a90d9",
-) -> str:
- """Generate an inline SVG line chart for cross-slice trend visualisation."""
- numeric = [(i, float(v)) for i, v in enumerate(values) if isinstance(v, (int, float))]
- if len(numeric) < 2:
- return ""
-
- xs = [p[0] for p in numeric]
- ys = [p[1] for p in numeric]
- min_y, max_y = min(ys), max(ys)
- y_range = max_y - min_y or 1.0
- pad_x, pad_y = 30, 10
-
- def to_svg_x(i: Any) -> Any:
- return pad_x + (i / (len(values) - 1)) * (width - 2 * pad_x)
-
- def to_svg_y(v: Any) -> Any:
- return height - pad_y - ((v - min_y) / y_range) * (height - 2 * pad_y)
-
- # Build polyline points
- pts = " ".join(f"{to_svg_x(i):.1f},{to_svg_y(v):.1f}" for i, v in numeric)
-
- elements = [
- f' ',
- ]
-
- # Dots at each data point
- for i, v in numeric:
- elements.append(f' ')
-
- # Trend line (least squares)
- if show_trend and len(xs) >= 3:
- x_arr = np.array(xs, dtype=float)
- y_arr = np.array(ys, dtype=float)
- slope = (np.mean(x_arr * y_arr) - np.mean(x_arr) * np.mean(y_arr)) / (np.mean(x_arr**2) - np.mean(x_arr) ** 2 + 1e-12)
- intercept = np.mean(y_arr) - slope * np.mean(x_arr)
- x0, x1 = xs[0], xs[-1]
- y0, y1 = slope * x0 + intercept, slope * x1 + intercept
- elements.append(
- f' '
- )
-
- # Y-axis labels
- elements.append(
- f'{max_y:.3g} '
- )
- elements.append(
- f'{min_y:.3g} '
- )
-
- title_text = f"n={len(numeric)}, range [{min_y:.3g}, {max_y:.3g}]"
- return (
- f'' + "".join(elements) + " "
- )
-
-
-def compute_cross_slice_trends(aggregated: dict[str, list[dict]]) -> dict:
- """
- Compute cross-slice aggregate trends from aggregated metrics.
-
- Returns a dict with trend groups, each containing:
- 'label', 'description', 'series': [{name, values, unit}]
- """
- trends = {}
-
- def _extract(metrics_list: Any, key: str) -> list:
- """Extract sorted numerical values for a given metric key."""
- pairs = []
- for m in metrics_list:
- src = m.get("source_file", "")
- val = m.get("metrics", {}).get(key, {}).get("value")
- if isinstance(val, (int, float)):
- pairs.append((src, val))
- pairs.sort(key=lambda p: p[0]) # sort by source file path
- return [v for _, v in pairs]
-
- # XY tile transform: scale and shear across slices
- if "xy_transform_estimation" in aggregated:
- ml = aggregated["xy_transform_estimation"]
- t00 = _extract(ml, "transform_00")
- t11 = _extract(ml, "transform_11")
- rms = _extract(ml, "rms_residual")
- acc_sys = _extract(ml, "accumulated_systematic_error_px")
- acc_rnd = _extract(ml, "accumulated_random_error_px")
- series = []
- if t00:
- series.append({"name": "Step Y (px)", "values": t00, "unit": "px"})
- if t11:
- series.append({"name": "Step X (px)", "values": t11, "unit": "px"})
- if rms:
- series.append({"name": "RMS residual (px)", "values": rms, "unit": "px"})
- if acc_sys:
- series.append({"name": "Accum. systematic error (px)", "values": acc_sys, "unit": "px"})
- if acc_rnd:
- series.append({"name": "Accum. random error (px)", "values": acc_rnd, "unit": "px"})
- if series:
- trends["xy_transform"] = {
- "label": "XY Tile Transform Consistency",
- "description": (
- "Tile step sizes and fitting residuals across slices. Large variation indicates unstable tile positioning."
- ),
- "series": series,
- }
-
- # Pairwise registration: cumulative drift
- if "pairwise_registration" in aggregated:
- ml = aggregated["pairwise_registration"]
- tx = _extract(ml, "translation_x")
- ty = _extract(ml, "translation_y")
- rot = _extract(ml, "rotation")
- series = []
- if tx:
- cum_tx = list(np.cumsum(tx))
- series.append({"name": "Cumulative tx (px)", "values": cum_tx, "unit": "px"})
- if ty:
- cum_ty = list(np.cumsum(ty))
- series.append({"name": "Cumulative ty (px)", "values": cum_ty, "unit": "px"})
- if rot:
- cum_rot = list(np.cumsum(rot))
- series.append({"name": "Cumulative rotation (deg)", "values": cum_rot, "unit": "deg"})
- if series:
- trends["registration_drift"] = {
- "label": "Cumulative Registration Drift",
- "description": (
- "Accumulated translation and rotation across all slices. "
- "A large net drift indicates systematic 3D volume distortion."
- ),
- "series": series,
- }
-
- # Interface depth trend
- if "crop_interface" in aggregated:
- ml = aggregated["crop_interface"]
- depth = _extract(ml, "detected_interface_depth_um")
- if depth:
- trends["interface_depth"] = {
- "label": "Interface Depth Trend",
- "description": (
- "Detected tissue-agarose interface depth across slices. "
- "A systematic slope may indicate progressive tissue deformation."
- ),
- "series": [{"name": "Interface depth (µm)", "values": depth, "unit": "µm"}],
- }
-
- # Background normalization drift
- if "normalize_intensities" in aggregated:
- ml = aggregated["normalize_intensities"]
- bg = _extract(ml, "mean_background")
- if bg:
- trends["background_drift"] = {
- "label": "Background Level Trend",
- "description": (
- "Mean agarose background level across slices. "
- "A strong trend indicates illumination drift during acquisition."
- ),
- "series": [{"name": "Mean background", "values": bg, "unit": ""}],
- }
-
- return trends
-
-
-# =============================================================================
-# Diagnostic data discovery
-# =============================================================================
-
-
-def discover_interpolation_data(input_dir: Path) -> dict | None:
- """
- Discover slice-interpolation outputs.
-
- Reads per-slice diagnostic JSONs written by ``linum_interpolate_missing_slice.py``
- (``slice_z*_interpolated_diagnostics.json``) and the preview PNGs.
- ``slice_config_final.csv`` (produced by ``finalise_interpolation``) is
- read via :mod:`linumpy.io.slice_config` to enrich the rows with the
- per-slice trace fields (``interpolated``, ``interpolation_method_used``,
- ``interpolation_fallback_reason``, ``use``, ``auto_excluded``, ...).
-
- Returns
- -------
- dict or None
- ``None`` when no interpolation happened. Otherwise a dict with keys
- ``rows`` (list of per-slice dicts), ``images`` (list of preview
- PNG paths), ``slice_config_final`` (path or None) and
- ``summary`` (aggregated stats).
- """
- from linumpy.io import slice_config as slice_config_io
-
- interp_dir = input_dir / "interpolate_missing_slice"
- if not interp_dir.is_dir():
- return None
-
- diag_files = sorted(interp_dir.glob("slice_z*_interpolated_diagnostics.json"))
- if not diag_files:
- return None
-
- rows: list[dict] = []
- for path in diag_files:
- try:
- with path.open() as fh:
- data = json.load(fh)
- except Exception:
- continue
- rows.append(
- {
- "slice_id": str(data.get("slice_id") or "").strip(),
- "method": str(data.get("method") or "unknown"),
- "method_used": (
- ""
- if data.get("interpolation_failed") is True
- else str(data.get("method_used") or data.get("method") or "unknown")
- ),
- "fallback_reason": str(data.get("fallback_reason") or ""),
- "interpolation_failed": bool(data.get("interpolation_failed", False)),
- "pre_reg_ncc": data.get("pre_reg_ncc"),
- "post_reg_ncc": data.get("post_reg_ncc"),
- "ncc_improvement": data.get("ncc_improvement"),
- "affine_determinant": data.get("affine_determinant"),
- "output_path": str(data.get("output_path") or ""),
- "diagnostics_path": str(path),
- }
- )
-
- if not rows:
- return None
-
- # Enrich from slice_config_final.csv when available (single source of truth).
- slice_config_final = input_dir / "slice_config_final.csv"
- if slice_config_final.exists():
- try:
- sc_rows = slice_config_io.read(slice_config_final)
- for r in rows:
- sid = slice_config_io.normalize_slice_id(r["slice_id"])
- sc_row = sc_rows.get(sid)
- if sc_row is not None:
- r["slice_config_use"] = sc_row.get("use", "")
- r["slice_config_interpolated"] = sc_row.get("interpolated", "")
- r["slice_config_interpolation_failed"] = sc_row.get("interpolation_failed", "")
- r["slice_config_auto_excluded"] = sc_row.get("auto_excluded", "")
- r["slice_config_notes"] = sc_row.get("notes", "")
- except Exception:
- slice_config_final = None
-
- images: list[Path] = sorted(interp_dir.glob("slice_z*_interpolated_preview.png"))
-
- method_counts: dict[str, int] = {}
- method_used_counts: dict[str, int] = {}
- fallback_counts: dict[str, int] = {}
- pre_nccs: list[float] = []
- post_nccs: list[float] = []
- improvements: list[float] = []
-
- def _to_float(value: object) -> float | None:
- if not isinstance(value, (int, float, str, bytes, bytearray)):
- return None
- try:
- return float(value)
- except (TypeError, ValueError):
- return None
-
- for r in rows:
- method = (r.get("method") or "unknown").strip() or "unknown"
- method_used = (r.get("method_used") or method).strip() or method
- fallback = (r.get("fallback_reason") or "").strip()
- method_counts[method] = method_counts.get(method, 0) + 1
- method_used_counts[method_used] = method_used_counts.get(method_used, 0) + 1
- if fallback:
- fallback_counts[fallback] = fallback_counts.get(fallback, 0) + 1
- pre = _to_float(r.get("pre_reg_ncc"))
- post = _to_float(r.get("post_reg_ncc"))
- imp = _to_float(r.get("ncc_improvement"))
- if pre is not None:
- pre_nccs.append(pre)
- if post is not None:
- post_nccs.append(post)
- if imp is not None:
- improvements.append(imp)
-
- n_failed = sum(1 for r in rows if r.get("interpolation_failed"))
-
- summary = {
- "count": len(rows),
- "n_succeeded": len(rows) - n_failed,
- "n_failed": n_failed,
- "method_counts": method_counts,
- "method_used_counts": method_used_counts,
- "fallback_counts": fallback_counts,
- "n_with_fallback": sum(fallback_counts.values()),
- "pre_reg_ncc_mean": float(np.mean(pre_nccs)) if pre_nccs else None,
- "post_reg_ncc_mean": float(np.mean(post_nccs)) if post_nccs else None,
- "ncc_improvement_mean": float(np.mean(improvements)) if improvements else None,
- }
-
- return {
- "rows": rows,
- "images": images,
- "slice_config_final": slice_config_final if (slice_config_final and slice_config_final.exists()) else None,
- "summary": summary,
- }
-
-
-def discover_diagnostic_data(input_dir: Path) -> dict[str, dict]:
- """
- Discover diagnostic outputs in the pipeline output directory.
-
- Looks for known diagnostic subdirectories and reads their JSON data.
-
- Returns
- -------
- dict
- Maps diagnostic_name → {'label', 'description', 'json_data': [...], 'images': [Path]}
- """
- import json as _json
-
- diagnostics: dict[str, dict] = {}
-
- diag_dir = input_dir / "diagnostics"
- if not diag_dir.exists():
- return diagnostics
-
- # Define known diagnostics: (subdir, label, description)
- known = [
- ("dilation_analysis", "Tile Dilation Analysis", "Per-slice scale factors and mosaic positioning accuracy."),
- ("aggregated_dilation", "Aggregated Dilation Analysis", "Cross-slice tile dilation summary."),
- ("rotation_analysis", "Rotation Drift Analysis", "Rotation angle drift across slices."),
- ("acquisition_rotation", "Acquisition Rotation Analysis", "In-plane rotation estimated from acquisition metadata."),
- (
- "motor_only_stitch",
- "Motor-Only Stitching (comparison)",
- "Stitched mosaic using motor positions only (no registration correction).",
- ),
- (
- "motor_only_stack",
- "Motor-Only Stack (comparison)",
- "Volume stacked without pairwise registration (motor positions only).",
- ),
- (
- "stitch_comparison",
- "Stitching Comparison",
- "Side-by-side comparison of registration-based vs motor-based stitching.",
- ),
- ]
-
- for subdir_name, label, description in known:
- subdir = diag_dir / subdir_name
- if not subdir.exists():
- continue
-
- json_data = []
- images = []
-
- # Collect all JSON files (recursively for per-slice diagnostics)
- for json_file in sorted(subdir.rglob("*.json")):
- try:
- with Path(json_file).open() as f:
- data = _json.load(f)
- data["_source"] = str(json_file)
- json_data.append(data)
- except Exception:
- pass
-
- # Collect PNG images
- images.extend(sorted(subdir.rglob("*.png")))
-
- if json_data or images:
- diagnostics[subdir_name] = {
- "label": label,
- "description": description,
- "json_data": json_data,
- "images": images,
- }
-
- return diagnostics
-
-
-def discover_images(
- input_dir: Path, overview_png: Path | None = None, annotated_png: Path | None = None
-) -> dict[str, list[Path]]:
- """
- Discover preview images in the pipeline output directory.
-
- Returns a dict mapping category → sorted list of image paths:
- 'overview' – main volume screenshots (up to 2)
- 'stitch_preview' – per-slice stitched previews
- 'common_space_preview' – common-space alignment previews
- 'diag_*' – images found in diagnostics/ subdirs
- """
- images: dict[str, list[Path]] = {
- "overview": [],
- "stitch_preview": [],
- "common_space_preview": [],
- }
-
- # Overview images from CLI (staged in Nextflow work dir)
- for p in [overview_png, annotated_png]:
- if p and Path(p).exists():
- images["overview"].append(Path(p))
-
- # Stitched slice previews
- stitch_dir = input_dir / "previews" / "stitched_slices"
- if stitch_dir.exists():
- images["stitch_preview"] = sorted(stitch_dir.glob("*.png"))
-
- # Common-space alignment previews
- cs_dir = input_dir / "common_space_previews"
- if cs_dir.exists():
- images["common_space_preview"] = sorted(cs_dir.glob("*.png"))
-
- # Auto-detect overview from stack output directories if not provided via CLI
- if not images["overview"]:
- for stack_dir_name in ("stack_motor", "stack", "normalize_z_intensity"):
- d = input_dir / stack_dir_name
- if d.exists():
- pngs = sorted(d.glob("*.png"))
- if pngs:
- images["overview"] = pngs[:2] # at most overview + annotated
- break
-
- # Diagnostic images: add one category per diagnostics subdir
- diag_dir = input_dir / "diagnostics"
- if diag_dir.exists():
- for subdir in sorted(diag_dir.iterdir()):
- if subdir.is_dir():
- pngs = sorted(subdir.rglob("*.png"))
- if pngs:
- cat_key = f"diag_{subdir.name}"
- images[cat_key] = pngs
-
- return images
-
-
-def image_to_data_uri(path: Path, max_width: int | None = None) -> str:
- """Encode a PNG image as a base64 data URI, optionally resizing."""
- if max_width and _PIL_AVAILABLE:
- with _PILImage.open(path) as img:
- if img.width > max_width:
- ratio = max_width / img.width
- new_size = (max_width, int(img.height * ratio))
- img = img.resize(new_size, _PILImage.Resampling.LANCZOS)
- buf = _io.BytesIO()
- img.save(buf, format="PNG", optimize=True)
- data_bytes = buf.getvalue()
- else:
- data_bytes = path.read_bytes()
- b64 = base64.b64encode(data_bytes).decode("ascii")
- return f"data:image/png;base64,{b64}"
-
-
-def render_image_gallery_html(
- images: list[Path], mode: str = "embed", category: str = "images", _label: str = "Preview Images", max_width: int = 380
-) -> str:
- """
- Render a collapsible image gallery section.
-
- Parameters
- ----------
- images : list of Path
- Image file paths to include in the gallery.
- mode : str
- Embedding mode: 'embed' (base64 in HTML) or 'link' (relative path for zip mode).
- category : str
- Image category name, used as subfolder in zip mode.
- max_width : int
- Maximum image width in pixels for embedded previews.
- """
- if not images:
- return ""
-
- items = []
- for p in images:
- src = image_to_data_uri(p, max_width=max_width) if mode == "embed" else f"previews/{category}/{p.name}"
- name = p.stem
- items.append(
- f''
- f''
- f' '
- f" "
- f"{name} "
- f" "
- )
-
- return f"""
-
- Preview Images ({len(images)})
-
- {"".join(items)}
-
-
-"""
-
-
-def generate_zip_bundle(html: str, images: dict[str, list[Path]], output_path: Path) -> None:
- """Bundle the HTML report and all image files into a zip archive."""
- with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zf:
- zf.writestr("index.html", html)
- for category, paths in images.items():
- for p in paths:
- zf.write(p, f"previews/{category}/{p.name}")
-
-
-def compute_overall_status(aggregated: dict[str, list[dict]]) -> tuple:
- """
- Compute overall status counts from aggregated metrics.
-
- Returns
- -------
- tuple
- (all_statuses, error_count, warning_count, ok_count)
- """
- all_statuses = [m.get("overall_status", "unknown") for step_metrics in aggregated.values() for m in step_metrics]
-
- error_count = all_statuses.count("error")
- warning_count = all_statuses.count("warning")
- ok_count = all_statuses.count("ok")
-
- return all_statuses, error_count, warning_count, ok_count
-
-
-def get_step_status(metrics_list: list[dict]) -> str:
- """Get the overall status for a step based on its metrics."""
- step_statuses = [m.get("overall_status", "unknown") for m in metrics_list]
- if "error" in step_statuses:
- return "error"
- elif "warning" in step_statuses:
- return "warning"
- return "ok"
-
-
-def collect_issues(metrics_list: list[dict]) -> tuple:
- """
- Collect all warnings and errors from a metrics list.
-
- Returns
- -------
- tuple
- (all_warnings, all_errors)
- """
- all_warnings = []
- all_errors = []
- for m in metrics_list:
- source = Path(m.get("source_file", "unknown")).stem
- all_warnings.extend(f"{source}: {w}" for w in m.get("warnings", []))
- all_errors.extend(f"{source}: {e}" for e in m.get("errors", []))
- return all_warnings, all_errors
-
-
-def _render_grouped_issues_html(grouped: list[dict], color_class: str, label: str) -> str:
- """Render a collapsible grouped-issues section in HTML."""
- total = sum(g["count"] for g in grouped)
- html = f"""
-
-
- {label}
- {total}
-
-
-"""
- for g in grouped:
- if g["count"] == 1:
- html += f'
{g["details"][0]}
\n'
- else:
- vals = g["values"]
- val_str = f"range {min(vals):.3g} – {max(vals):.3g}" if vals else f"{g['count']} occurrences"
- thresh_str = f", threshold: {g['threshold']:.3g}" if g["threshold"] is not None else ""
- summary_line = f"
{g['metric']} : {g['count']} slices affected ({val_str}{thresh_str})"
- html += '
\n'
- html += f' {summary_line} \n'
- html += ' \n'
- for detail in g["details"]:
- html += f'
{detail}
\n'
- html += "
\n"
- html += " \n"
- html += "
\n \n"
- return html
-
-
-def _render_interpolation_section_html(
- interpolation: dict,
- image_mode: str = "link",
- max_thumb_width: int = 380,
-) -> str:
- """Render the slice-interpolation section of the HTML report."""
- summary = interpolation.get("summary", {})
- rows = interpolation.get("rows", [])
- images = interpolation.get("images", [])
- slice_config_final = interpolation.get("slice_config_final")
-
- count = summary.get("count", 0)
- n_failed = summary.get("n_failed", 0)
- n_succeeded = summary.get("n_succeeded", count - n_failed)
- method_counts = summary.get("method_counts", {})
- method_used_counts = summary.get("method_used_counts", {})
- fallback_counts = summary.get("fallback_counts", {})
- pre_mean = summary.get("pre_reg_ncc_mean")
- post_mean = summary.get("post_reg_ncc_mean")
- imp_mean = summary.get("ncc_improvement_mean")
-
- status = "ok"
- if n_failed > 0 and count > 0:
- status = "warning" if n_failed < count else "error"
-
- html = '\n
\n'
- html += "
Slice Interpolation \n"
- html += (
- '
'
- "Missing slices reconstructed from their neighbours via zmorph. "
- "Successful interpolations stamp interpolated=true and are flagged "
- "reliable=0 in downstream pairwise registration. When quality gates "
- "fail the slice is hard-skipped (interpolation_failed=true) "
- "and the slot stays a genuine gap in the stacked volume \u2014 no blended volume is "
- "written. See docs/SLICE_INTERPOLATION_FEATURE.md.
\n"
- )
-
- html += '
\n'
- html += f'
{count}
'
- html += '
Gaps Detected
\n'
- ok_color = get_status_color("ok")
- html += (
- f'
'
- f'{n_succeeded}
Successfully Interpolated
\n'
- )
- html += (
- f'
'
- f'{n_failed}
Hard-Skipped (Gap)
\n'
- )
- if pre_mean is not None:
- html += f'
{pre_mean:.3f}
'
- html += '
Mean Pre-Reg NCC
\n'
- if post_mean is not None:
- html += f'
{post_mean:.3f}
'
- html += '
Mean Post-Reg NCC
\n'
- if imp_mean is not None:
- html += f'
{imp_mean:+.3f}
'
- html += '
Mean NCC Improvement
\n'
- html += "
\n"
-
- # Method breakdown
- html += '
\n'
- html += '
Methods
\n'
- html += '
\n'
- html += " Method requested "
- html += ", ".join(f"{k}: {v}" for k, v in sorted(method_counts.items())) or "(none)"
- html += " \n"
- html += " Method actually used "
- html += ", ".join(f"{k}: {v}" for k, v in sorted(method_used_counts.items())) or "(none)"
- html += " \n"
- if fallback_counts:
- html += " Hard-skip reasons "
- html += ", ".join(f"{k}: {v}" for k, v in sorted(fallback_counts.items()))
- html += " \n"
- if slice_config_final is not None:
- html += " Per-slice trace file "
- html += f"{slice_config_final.name} \n"
- html += "
\n"
- html += "
\n"
-
- # Per-slice table (cap to 50 rows; more than that is rare)
- if rows:
- html += '
\n'
- html += ' '
- html += f"Per-slice interpolation diagnostics ({len(rows)} slice(s)) \n"
- html += ' \n'
- html += (
- " "
- "Slice Status Method Used "
- "Reason Pre NCC Post NCC "
- "ΔNCC |det| "
- " \n"
- )
- for r in rows[:50]:
- sid = r.get("slice_id", "") or "?"
- failed = bool(r.get("interpolation_failed"))
- status_label = "SKIPPED" if failed else "OK"
- method_used = r.get("method_used", "") or ("—" if failed else "")
- fb = r.get("fallback_reason", "") or ""
- pre = r.get("pre_reg_ncc", "")
- post = r.get("post_reg_ncc", "")
- imp = r.get("ncc_improvement", "")
- det = r.get("affine_determinant", "")
-
- pre_fmt = f"{float(pre):.3f}" if pre not in ("", None) else "-"
- post_fmt = f"{float(post):.3f}" if post not in ("", None) else "-"
- imp_fmt = f"{float(imp):+.3f}" if imp not in ("", None) else "-"
- det_fmt = f"{float(det):.3f}" if det not in ("", None) else "-"
-
- if failed:
- row_style = ' style="background:#ffe5e5;"'
- elif fb:
- row_style = ' style="background:#fff8e1;"'
- else:
- row_style = ""
- html += (
- f" "
- f"{sid} {status_label} {method_used} "
- f"{fb} {pre_fmt} {post_fmt} "
- f"{imp_fmt} {det_fmt} "
- " \n"
- )
- if len(rows) > 50:
- html += (
- f' (showing first 50 of {len(rows)} rows) \n'
- )
- html += "
\n"
- html += " \n"
-
- # Preview image gallery (shown in zip/link mode only; embed mode skips images)
- if images:
- gallery = render_image_gallery_html(
- images,
- mode=image_mode,
- category="diag_interpolate_missing_slice",
- _label="Interpolation Previews",
- max_width=max_thumb_width,
- )
- html += gallery
-
- html += "
\n"
- return html
-
-
-def _render_interpolation_section_text(interpolation: dict) -> str:
- """Render the slice-interpolation section of the text report."""
- summary = interpolation.get("summary", {})
- rows = interpolation.get("rows", [])
- count = summary.get("count", 0)
- n_failed = summary.get("n_failed", 0)
- n_succeeded = summary.get("n_succeeded", count - n_failed)
- pre_mean = summary.get("pre_reg_ncc_mean")
- post_mean = summary.get("post_reg_ncc_mean")
- imp_mean = summary.get("ncc_improvement_mean")
-
- lines = []
- lines.append("")
- lines.append(f"{get_status_emoji('info')} SLICE INTERPOLATION")
- lines.append("-" * 70)
- lines.append(f" Gaps detected : {count}")
- lines.append(f" Successfully interp'd : {n_succeeded}")
- lines.append(f" Hard-skipped (gap) : {n_failed}")
- if pre_mean is not None:
- lines.append(f" Mean pre-reg NCC : {pre_mean:.3f}")
- if post_mean is not None:
- lines.append(f" Mean post-reg NCC : {post_mean:.3f}")
- if imp_mean is not None:
- lines.append(f" Mean NCC improvement : {imp_mean:+.3f}")
-
- method_used_counts = summary.get("method_used_counts", {})
- if method_used_counts:
- mu_parts = ", ".join(f"{k}: {v}" for k, v in sorted(method_used_counts.items()))
- lines.append(f" Methods used : {mu_parts}")
- fallback_counts = summary.get("fallback_counts", {})
- if fallback_counts:
- fb_parts = ", ".join(f"{k}: {v}" for k, v in sorted(fallback_counts.items()))
- lines.append(f" Hard-skip reasons : {fb_parts}")
-
- if rows:
- lines.append("")
- lines.append(f" {'Slice':<6} {'Status':<8} {'Used':<14} {'Reason':<28} {'PreNCC':>7} {'PostNCC':>7}")
- lines.append(" " + "-" * 80)
- for r in rows[:50]:
- sid = (r.get("slice_id", "") or "?")[:6]
- failed = bool(r.get("interpolation_failed"))
- status = "SKIP" if failed else "OK"
- method_used = (r.get("method_used", "") or ("—" if failed else ""))[:14]
- fb = (r.get("fallback_reason", "") or "")[:28]
- pre = r.get("pre_reg_ncc", "")
- post = r.get("post_reg_ncc", "")
- try:
- pre_fmt = f"{float(pre):.3f}" if pre not in ("", None) else "-"
- except (TypeError, ValueError):
- pre_fmt = "-"
- try:
- post_fmt = f"{float(post):.3f}" if post not in ("", None) else "-"
- except (TypeError, ValueError):
- post_fmt = "-"
- lines.append(f" {sid:<6} {status:<8} {method_used:<14} {fb:<28} {pre_fmt:>7} {post_fmt:>7}")
- if len(rows) > 50:
- lines.append(f" ... ({len(rows) - 50} more row(s) not shown)")
-
- return "\n".join(lines)
-
-
-def generate_html_report(
- aggregated: dict[str, list[dict]],
- title: str,
- verbose: bool = False,
- images: dict[str, list[Path]] | None = None,
- image_mode: str = "embed",
- max_overview_width: int = 900,
- max_thumb_width: int = 380,
- trends: dict | None = None,
- diagnostics: dict | None = None,
- interpolation: dict | None = None,
-) -> str:
- """Generate an HTML report from aggregated metrics."""
- aggregated = sort_steps(aggregated)
- images = images or {}
-
- _, error_count, warning_count, ok_count = compute_overall_status(aggregated)
-
- if error_count > 0:
- overall_status = "error"
- overall_message = f"{error_count} error(s), {warning_count} warning(s)"
- elif warning_count > 0:
- overall_status = "warning"
- overall_message = f"{warning_count} warning(s)"
- else:
- overall_status = "ok"
- overall_message = "All checks passed"
-
- html = f"""
-
-
-
-
- {title}
-
-
-
-
-
-
-
Summary
-
- {overall_message}
-
-
-
-
{len(aggregated)}
-
Pipeline Steps
-
-
-
{sum(len(v) for v in aggregated.values())}
-
Total Metrics Files
-
-
-
-
{warning_count}
-
Warnings
-
-
-
{error_count}
-
Errors
-
-
-
-"""
-
- # Overview images in the summary section
- overview_imgs = images.get("overview", [])
- if overview_imgs:
- html += ' \n'
- html += '
Volume Overview
\n'
- html += '
\n'
- for p in overview_imgs:
- if image_mode == "embed":
- src = image_to_data_uri(p, max_width=max_overview_width)
- else:
- src = f"previews/overview/{p.name}"
- html += (
- f"
"
- f''
- f' '
- f"{p.stem} \n"
- )
- html += "
\n
\n"
-
- # Cross-slice trends section
- if trends:
- colors = ["#4a90d9", "#e67e22", "#27ae60", "#8e44ad", "#c0392b"]
- html += '\n \n'
- html += "
Cross-Slice Trends \n"
- html += (
- '
'
- "Aggregate quality indicators computed across all slices. "
- "Red dashed lines show the linear trend.
\n"
- )
- html += '
\n'
- for trend in trends.values():
- html += '
\n'
- html += f' \n'
- html += f'
{trend["description"]}
\n'
- for ci, series in enumerate(trend["series"]):
- col = colors[ci % len(colors)]
- svg = generate_trend_line_svg(series["values"], color=col)
- html += '
\n'
- html += f'
{series["name"]}
\n'
- html += f" {svg}\n"
- html += "
\n"
- html += "
\n"
- html += "
\n
\n"
-
- # Generate section for each step
- for step_name, metrics_list in aggregated.items():
- summary = compute_summary_statistics(metrics_list)
- step_status = get_step_status(metrics_list)
- description = STEP_DESCRIPTIONS.get(step_name, "")
-
- # Separate quality metrics from info/parameter fields
- quality_metrics, info_fields = separate_metrics_by_type(metrics_list)
-
- html += f"""
-
-
-"""
- if description:
- html += f'
{description}
\n'
-
- # --- Quality metrics stats table with sparklines ---
- if quality_metrics:
- html += '
Quality Metrics
\n'
- html += """
-
- Metric
- Mean
- Median
- Std
- Min
- Max
- Distribution
-
-"""
- for metric_name, mdata in quality_metrics.items():
- entries = mdata["entries"]
- numeric_vals = [e["value"] for e in entries if isinstance(e.get("value"), (int, float))]
- if not numeric_vals:
- continue
- arr = np.array(numeric_vals)
- mean_v = float(np.mean(arr))
- median_v = float(np.median(arr))
- std_v = float(np.std(arr))
- min_v = float(np.min(arr))
- max_v = float(np.max(arr))
- statuses = [e.get("status", "ok") for e in entries]
- unit = mdata.get("unit", "")
- unit_str = f" {unit}" if unit else ""
-
- # Worst status in this metric
- if "error" in statuses:
- metric_status = "error"
- elif "warning" in statuses:
- metric_status = "warning"
- else:
- metric_status = "ok"
-
- sparkline = generate_sparkline_svg([e.get("value") for e in entries], statuses)
-
- html += f"""
-
-
- {metric_name}{unit_str}
-
- {format_value(mean_v)}
- {format_value(median_v)}
- {format_value(std_v)}
- {format_value(min_v)}
- {format_value(max_v)}
- {sparkline}
-
-"""
- html += "
\n"
-
- # --- Errors and warnings (grouped, collapsible) ---
- all_warnings, all_errors = collect_issues(metrics_list)
-
- if all_errors:
- grouped_errors = group_issues(all_errors)
- html += _render_grouped_issues_html(grouped_errors, "error", "Errors")
-
- if all_warnings:
- grouped_warnings = group_issues(all_warnings)
- html += _render_grouped_issues_html(grouped_warnings, "warning", "Warnings")
-
- # --- Info / parameter fields (collapsed) ---
- if info_fields:
- constant_params = {k: v for k, v in info_fields.items() if v["is_constant"]}
- variable_infos = {k: v for k, v in info_fields.items() if not v["is_constant"]}
-
- if constant_params:
- html += """
-
- Pipeline Parameters
-
-"""
- for name, info in constant_params.items():
- val = info["display_value"]
- unit = info.get("unit", "")
- unit_str = f" {unit}" if unit else ""
- html += f"""
- {name}
- {format_value(val)}{unit_str}
-
-"""
- html += "
\n \n"
-
- if variable_infos:
- html += """
-
- Variable Info Fields (per-slice)
-
-
- Field
- Mean
- Std
- Min
- Max
-
-"""
- for name, info in variable_infos.items():
- numeric = [v for v in info["values"] if isinstance(v, (int, float))]
- if not numeric:
- continue
- arr = np.array(numeric)
- unit = info.get("unit", "")
- unit_str = f" {unit}" if unit else ""
- html += f"""
- {name}{unit_str}
- {format_value(float(np.mean(arr)))}
- {format_value(float(np.std(arr)))}
- {format_value(float(np.min(arr)))}
- {format_value(float(np.max(arr)))}
-
-"""
- html += "
\n \n"
-
- # --- Verbose: individual per-slice results (collapsible as a unit) ---
- if verbose:
- n_items = len(metrics_list)
- html += f"""
-
- Individual Results ({n_items} slices)
-"""
- for m in metrics_list:
- source = extract_slice_id(m.get("source_file", "unknown"))
- m_status = m.get("overall_status", "unknown")
- html += f"""
-
-
-
- {source}
-
-
-"""
- for name, data in m.get("metrics", {}).items():
- if isinstance(data, dict):
- value = data.get("value", "N/A")
- unit = data.get("unit", "") or ""
- status = data.get("status", "info")
- html += f"""
-
-
-
- {name}
-
- {format_value(value)}{(" " + unit) if unit else ""}
-
-"""
- html += """
-
-"""
- html += " \n"
-
- # --- Per-step preview image gallery ---
- preview_category = STEP_PREVIEW_CATEGORY.get(step_name)
- if preview_category:
- step_imgs = images.get(preview_category, [])
- if step_imgs:
- html += render_image_gallery_html(
- step_imgs, mode=image_mode, category=preview_category, max_width=max_thumb_width
- )
-
- html += "
\n"
-
- # Slice interpolation section (only if interpolation happened)
- if interpolation:
- html += _render_interpolation_section_html(interpolation, image_mode=image_mode, max_thumb_width=max_thumb_width)
-
- # Diagnostics section (only if diagnostic data was found)
- if diagnostics:
- html += '\n \n'
- html += "
Diagnostic Outputs \n"
- html += (
- '
'
- "Additional diagnostic analyses enabled in the pipeline configuration.
\n"
- )
- for diag_key, diag in diagnostics.items():
- label = diag["label"]
- description = diag["description"]
- json_data = diag.get("json_data", [])
- diag_images = diag.get("images", [])
-
- html += '
\n'
- html += f' \n'
- html += f'
{description}
\n'
-
- # Render key JSON fields
- if json_data:
- # Collect interesting numeric/scalar fields from first entry
- first = json_data[0]
- numeric_fields = {}
- for k, v in first.items():
- if k.startswith("_") or k == "slice_id":
- continue
- if isinstance(v, (int, float, str, bool)):
- numeric_fields[k] = v
- elif isinstance(v, dict):
- # like scale_factors / residuals / distortions sub-dicts
- for sk, sv in v.items():
- if isinstance(sv, (int, float, str, bool)):
- numeric_fields[f"{k}.{sk}"] = sv
-
- if numeric_fields:
- html += '
\n'
- for k, v in list(numeric_fields.items())[:20]:
- html += (
- f" {k} "
- f"{format_value(v) if isinstance(v, (int, float)) else v} \n"
- )
- html += "
\n"
-
- # Render diagnostic image gallery
- if diag_images:
- # In zip mode images are referenced via relative paths; in embed mode as data URIs
- cat_key = f"diag_{diag_key}"
- gallery = render_image_gallery_html(
- diag_images, mode=image_mode, category=cat_key, _label=f"{label} Images", max_width=max_thumb_width
- )
- html += gallery
-
- html += "
\n"
- html += "
\n"
-
- html += """
-
-
-"""
- return html
-
-
-def generate_text_report(
- aggregated: dict[str, list[dict]],
- title: str,
- verbose: bool = False,
- interpolation: dict | None = None,
-) -> str:
- """Generate a plain text report from aggregated metrics."""
- aggregated = sort_steps(aggregated)
-
- lines = []
- lines.append("=" * 70)
- lines.append(title.center(70))
- lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}".center(70))
- lines.append("=" * 70)
- lines.append("")
-
- _, error_count, warning_count, ok_count = compute_overall_status(aggregated)
-
- lines.append("SUMMARY")
- lines.append("-" * 70)
- lines.append(f" Pipeline Steps: {len(aggregated)}")
- lines.append(f" Total Metrics Files: {sum(len(v) for v in aggregated.values())}")
- lines.append(
- f" Status: {get_status_emoji('ok')} OK: {ok_count} "
- f"{get_status_emoji('warning')} Warnings: {warning_count} "
- f"{get_status_emoji('error')} Errors: {error_count}"
- )
- lines.append("")
-
- for step_name, metrics_list in aggregated.items():
- summary = compute_summary_statistics(metrics_list)
- step_status = get_step_status(metrics_list)
-
- lines.append("")
- lines.append(f"{get_status_emoji(step_status)} {step_name.replace('_', ' ').upper()}")
- lines.append("-" * 70)
- lines.append(f" Items: {summary['count']} | Status: {step_status.upper()}")
-
- # Quality metrics stats
- quality_metrics, _ = separate_metrics_by_type(metrics_list)
- if quality_metrics:
- lines.append("")
- lines.append(" Quality Metrics:")
- lines.append(f" {'Metric':<25} {'Mean':>12} {'Median':>12} {'Std':>12} {'Min':>12} {'Max':>12}")
- lines.append(" " + "-" * 77)
- for metric_name, mdata in quality_metrics.items():
- entries = mdata["entries"]
- numeric_vals = [e["value"] for e in entries if isinstance(e.get("value"), (int, float))]
- if not numeric_vals:
- continue
- arr = np.array(numeric_vals)
- name = metric_name[:25]
- lines.append(
- f" {name:<25} {format_value(float(np.mean(arr))):>12} "
- f"{format_value(float(np.median(arr))):>12} "
- f"{format_value(float(np.std(arr))):>12} "
- f"{format_value(float(np.min(arr))):>12} "
- f"{format_value(float(np.max(arr))):>12}"
- )
-
- all_warnings, all_errors = collect_issues(metrics_list)
-
- if all_errors:
- lines.append("")
- lines.append(f" {get_status_emoji('error')} ERRORS:")
- for g in group_issues(all_errors):
- if g["count"] == 1:
- lines.append(f" - {g['details'][0]}")
- else:
- vals = g["values"]
- val_str = f"range {min(vals):.3g}–{max(vals):.3g}" if vals else f"{g['count']} occurrences"
- lines.append(f" - {g['metric']}: {g['count']} slices ({val_str})")
-
- if all_warnings:
- lines.append("")
- lines.append(f" {get_status_emoji('warning')} WARNINGS:")
- for g in group_issues(all_warnings):
- if g["count"] == 1:
- lines.append(f" - {g['details'][0]}")
- else:
- vals = g["values"]
- val_str = f"range {min(vals):.3g}–{max(vals):.3g}" if vals else f"{g['count']} occurrences"
- lines.append(f" - {g['metric']}: {g['count']} slices ({val_str})")
-
- if verbose:
- lines.append("")
- lines.append(" Individual Results:")
- for m in metrics_list:
- source = extract_slice_id(m.get("source_file", "unknown"))
- m_status = m.get("overall_status", "unknown")
- lines.append(f" {get_status_emoji(m_status)} {source}")
- for name, data in m.get("metrics", {}).items():
- if isinstance(data, dict):
- value = data.get("value", "N/A")
- unit = data.get("unit", "") or ""
- lines.append(f" {name}: {format_value(value)}{(' ' + unit) if unit else ''}")
-
- if interpolation:
- lines.append(_render_interpolation_section_text(interpolation))
-
- lines.append("")
- lines.append("=" * 70)
- lines.append("End of Report".center(70))
- lines.append("=" * 70)
-
- return "\n".join(lines)
-
-
-def main() -> None:
- """Run function."""
- parser = _build_arg_parser()
- args = parser.parse_args()
-
- input_dir = Path(args.input_dir)
- output_file = Path(args.output_report)
-
- if not input_dir.exists():
- parser.error(f"Input directory does not exist: {input_dir}")
-
- # Determine format
- if args.format == "auto":
- suffix = output_file.suffix.lower()
- if suffix == ".html":
- output_format = "html"
- elif suffix == ".zip":
- output_format = "zip"
- else:
- output_format = "text"
- else:
- output_format = args.format
-
- # Aggregate metrics from all subdirectories
- print(f"Scanning for metrics files in: {input_dir}")
- aggregated = aggregate_metrics(input_dir)
-
- if not aggregated:
- print("No metrics files found. Checking for process subdirectories...")
- for subdir in input_dir.iterdir():
- if subdir.is_dir():
- sub_aggregated = aggregate_metrics(subdir)
- for step, metrics in sub_aggregated.items():
- if step not in aggregated:
- aggregated[step] = []
- aggregated[step].extend(metrics)
-
- if not aggregated:
- print("Warning: No metrics files found in the input directory.")
- print("Make sure the pipeline has been run with metrics collection enabled.")
- aggregated = {}
-
- print(f"Found {sum(len(v) for v in aggregated.values())} metrics files across {len(aggregated)} pipeline steps")
-
- # Discover preview images — only for zip bundles; HTML is always image-free
- images: dict[str, list[Path]] = {}
- if output_format == "zip" and not args.no_images:
- images = discover_images(input_dir, overview_png=args.overview_png, annotated_png=args.annotated_png)
- total_imgs = sum(len(v) for v in images.values())
- if total_imgs:
- print(f"Found {total_imgs} preview image(s) to bundle in zip")
-
- # Zip bundles use relative image links; standalone HTML has no images
- image_mode = "link"
-
- # Compute cross-slice aggregate trends
- trends = compute_cross_slice_trends(aggregated)
- if trends:
- n_trend_groups = len(trends)
- print(f"Computed {n_trend_groups} cross-slice trend group(s)")
-
- # Discover slice-interpolation outputs
- interpolation = discover_interpolation_data(input_dir)
- if interpolation:
- s = interpolation["summary"]
- print(f"Found interpolation output(s): {s['count']} slice(s), {s['n_with_fallback']} with fallback")
- if output_format == "zip" and not args.no_images and interpolation.get("images"):
- images["diag_interpolate_missing_slice"] = list(interpolation["images"])
-
- # Discover diagnostic outputs
- diagnostics = discover_diagnostic_data(input_dir)
- if diagnostics:
- print(f"Found {len(diagnostics)} diagnostic output(s): {', '.join(diagnostics.keys())}")
- # In zip mode, include diagnostic images in the bundle
- if output_format == "zip" and not args.no_images:
- for diag_key, diag in diagnostics.items():
- cat_key = f"diag_{diag_key}"
- diag_imgs = diag.get("images", [])
- if diag_imgs:
- images[cat_key] = diag_imgs
-
- # Generate report
- output_file.parent.mkdir(parents=True, exist_ok=True)
- if output_format in ("html", "zip"):
- report = generate_html_report(
- aggregated,
- args.title,
- args.verbose,
- images=images,
- image_mode=image_mode,
- max_overview_width=args.max_overview_width,
- max_thumb_width=args.max_thumb_width,
- trends=trends if trends else None,
- diagnostics=diagnostics if diagnostics else None,
- interpolation=interpolation,
- )
- if output_format == "zip":
- if output_file.suffix.lower() != ".zip":
- output_file = output_file.with_suffix(".zip")
- generate_zip_bundle(report, images, output_file)
- else:
- with Path(output_file).open("w") as f:
- f.write(report)
- else:
- report = generate_text_report(aggregated, args.title, args.verbose, interpolation=interpolation)
- with Path(output_file).open("w") as f:
- f.write(report)
-
- print(f"Report saved to: {output_file}")
-
- _, error_count, warning_count, _ = compute_overall_status(aggregated)
-
- if error_count > 0:
- print(f"\n{get_status_emoji('error')} {error_count} error(s) found - please review the report")
- elif warning_count > 0:
- print(f"\n{get_status_emoji('warning')} {warning_count} warning(s) found - please review the report")
- else:
- print(f"\n{get_status_emoji('ok')} All checks passed")
-
-
-if __name__ == "__main__":
- main()
diff --git a/scripts/tests/test_align_to_ras.py b/scripts/tests/test_align_to_ras.py
new file mode 100644
index 00000000..05626cbd
--- /dev/null
+++ b/scripts/tests/test_align_to_ras.py
@@ -0,0 +1,238 @@
+#!/usr/bin/env python3
+"""Tests for ``scripts/linum_align_to_ras.py``.
+
+The script is loaded via :mod:`importlib` so we can test its pure-Python helper
+functions (no ``zarr`` I/O) without relying on the console entry point.
+"""
+
+from __future__ import annotations
+
+import importlib.util
+from pathlib import Path
+
+import numpy as np
+import pytest
+import SimpleITK as sitk
+
+SCRIPT_PATH = Path(__file__).resolve().parents[1] / "linum_align_to_ras.py"
+
+
+@pytest.fixture(scope="module")
+def align_module():
+ """Load ``linum_align_to_ras.py`` as a module."""
+ spec = importlib.util.spec_from_file_location("linum_align_to_ras", SCRIPT_PATH)
+ assert spec is not None
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ return module
+
+
+# ---------------------------------------------------------------------------
+# CLI help
+# ---------------------------------------------------------------------------
+
+
+def test_help(script_runner):
+ ret = script_runner.run(["linum_align_to_ras.py", "--help"])
+ assert ret.success
+
+
+# ---------------------------------------------------------------------------
+# sitk_transform_to_affine_matrix
+# ---------------------------------------------------------------------------
+
+
+class TestSitkTransformToAffine:
+ def test_identity_transform_yields_identity_matrix(self, align_module):
+ t = sitk.Euler3DTransform()
+ mat = align_module.sitk_transform_to_affine_matrix(t)
+ assert mat.shape == (4, 4)
+ np.testing.assert_allclose(mat, np.eye(4), atol=1e-12)
+
+ def test_pure_translation_is_permuted_to_zyx(self, align_module):
+ t = sitk.Euler3DTransform()
+ # SITK translation in (X, Y, Z) = (1, 2, 3)
+ t.SetTranslation((1.0, 2.0, 3.0))
+ mat = align_module.sitk_transform_to_affine_matrix(t)
+ # After conversion to NGFF (Z, Y, X) order, translation must be (3, 2, 1).
+ np.testing.assert_allclose(mat[:3, 3], [3.0, 2.0, 1.0], atol=1e-12)
+ np.testing.assert_allclose(mat[:3, :3], np.eye(3), atol=1e-12)
+
+ def test_rotation_is_permuted_to_zyx(self, align_module):
+ """A pure rotation around SITK X (=numpy axis 2) should appear as a
+ rotation around the last axis of the NGFF matrix (axis Z→row 2)."""
+ t = sitk.Euler3DTransform()
+ t.SetRotation(np.pi / 4, 0.0, 0.0) # rotate around SITK X
+ mat = align_module.sitk_transform_to_affine_matrix(t)
+ # Rotation around numpy axis 2 (X in NGFF) leaves column/row 2 unchanged.
+ assert mat.shape == (4, 4)
+ np.testing.assert_allclose(mat[2, 2], 1.0, atol=1e-9)
+ np.testing.assert_allclose(mat[2, :3], [0.0, 0.0, 1.0], atol=1e-9)
+ np.testing.assert_allclose(mat[:3, 2], [0.0, 0.0, 1.0], atol=1e-9)
+
+
+# ---------------------------------------------------------------------------
+# compute_centered_reference_and_transform
+# ---------------------------------------------------------------------------
+
+
+class TestComputeCenteredReferenceAndTransform:
+ @staticmethod
+ def _make_moving(shape=(20, 20, 20), spacing=(0.1, 0.1, 0.1)):
+ """A small ellipsoid brain so the resampled output has known volume."""
+ z, y, x = np.indices(shape, dtype=np.float32)
+ cz, cy, cx = shape[0] / 2, shape[1] / 2, shape[2] / 2
+ rz, ry, rx = shape[0] * 0.3, shape[1] * 0.3, shape[2] * 0.3
+ mask = ((z - cz) / rz) ** 2 + ((y - cy) / ry) ** 2 + ((x - cx) / rx) ** 2 < 1
+ arr = mask.astype(np.float32)
+ img = sitk.GetImageFromArray(arr)
+ img.SetSpacing((spacing[2], spacing[1], spacing[0])) # SITK XYZ
+ img.SetOrigin((0.0, 0.0, 0.0))
+ img.SetDirection((1, 0, 0, 0, 1, 0, 0, 0, 1))
+ return img, int(mask.sum())
+
+ def test_reference_origin_is_zero(self, align_module):
+ moving, _ = self._make_moving()
+ t = sitk.Euler3DTransform()
+ ref, _ = align_module.compute_centered_reference_and_transform(moving, t)
+ assert ref.GetOrigin() == pytest.approx((0.0, 0.0, 0.0))
+
+ def test_reference_spacing_matches_moving_by_default(self, align_module):
+ moving, _ = self._make_moving(spacing=(0.125, 0.1, 0.2))
+ t = sitk.Euler3DTransform()
+ ref, _ = align_module.compute_centered_reference_and_transform(moving, t)
+ assert ref.GetSpacing() == pytest.approx(moving.GetSpacing())
+
+ def test_reference_spacing_override(self, align_module):
+ moving, _ = self._make_moving()
+ t = sitk.Euler3DTransform()
+ ref, _ = align_module.compute_centered_reference_and_transform(moving, t, output_spacing=(0.05, 0.05, 0.05))
+ assert ref.GetSpacing() == pytest.approx((0.05, 0.05, 0.05))
+
+ def test_identity_transform_roundtrip(self, align_module):
+ """For T = identity, resampling through the composite should recover
+ the original brain volume (no information loss)."""
+ moving, brain_voxels = self._make_moving()
+ t = sitk.Euler3DTransform() # identity
+ ref, composite = align_module.compute_centered_reference_and_transform(moving, t)
+
+ resampler = sitk.ResampleImageFilter()
+ resampler.SetReferenceImage(ref)
+ resampler.SetTransform(composite)
+ resampler.SetInterpolator(sitk.sitkLinear)
+ resampler.SetDefaultPixelValue(0.0)
+ out = resampler.Execute(moving)
+ arr = sitk.GetArrayFromImage(out)
+
+ nonzero = (arr > 0.5).sum()
+ assert abs(int(nonzero) - brain_voxels) / brain_voxels < 0.05
+
+ def test_rotation_preserves_brain_volume(self, align_module):
+ """A rigid rotation + translation preserves the brain voxel count."""
+ moving, brain_voxels = self._make_moving()
+
+ t = sitk.Euler3DTransform()
+ t.SetRotation(np.deg2rad(15), np.deg2rad(-10), np.deg2rad(30))
+ t.SetTranslation((0.3, -0.2, 0.1))
+ center_mm = moving.TransformContinuousIndexToPhysicalPoint([s / 2.0 for s in moving.GetSize()])
+ t.SetCenter(center_mm)
+
+ ref, composite = align_module.compute_centered_reference_and_transform(moving, t)
+
+ resampler = sitk.ResampleImageFilter()
+ resampler.SetReferenceImage(ref)
+ resampler.SetTransform(composite)
+ resampler.SetInterpolator(sitk.sitkLinear)
+ resampler.SetDefaultPixelValue(0.0)
+ out = resampler.Execute(moving)
+ arr = sitk.GetArrayFromImage(out)
+
+ nonzero = (arr > 0.5).sum()
+ # Rigid transform preserves volume; allow 5% tolerance for interpolation.
+ assert abs(int(nonzero) - brain_voxels) / brain_voxels < 0.05
+
+ def test_composite_transform_semantics(self, align_module):
+ """The composite must compute T(shift(p)), not shift(T(p)).
+
+ Sample points at the origin of output space (0, 0, 0) and along each
+ axis; verify the composite maps them to the same physical point as the
+ *manually* composed ``T ∘ shift``.
+ """
+ moving, _ = self._make_moving()
+ t = sitk.Euler3DTransform()
+ t.SetRotation(np.deg2rad(20), 0.0, np.deg2rad(-5))
+ t.SetTranslation((0.25, 0.1, -0.15))
+ center_mm = moving.TransformContinuousIndexToPhysicalPoint([s / 2.0 for s in moving.GetSize()])
+ t.SetCenter(center_mm)
+
+ ref, composite = align_module.compute_centered_reference_and_transform(moving, t)
+
+ # Rebuild the shift transform the helper used: offset == pts_min
+ # recovered from the composite (last-added transform is the shift).
+ sample_points = [
+ (0.0, 0.0, 0.0),
+ tuple(ref.GetSpacing()),
+ tuple(np.array(ref.GetSize(), dtype=float) * ref.GetSpacing() / 2),
+ ]
+ for p in sample_points:
+ actual = np.array(composite.TransformPoint(p))
+ # Compose T(shift(p)) manually. Retrieve shift from the 2-member
+ # composite: ITK applies transforms in reverse order, so nth = last
+ # added = shift.
+ shift = composite.GetNthTransform(1)
+ expected = np.array(t.TransformPoint(shift.TransformPoint(p)))
+ np.testing.assert_allclose(actual, expected, atol=1e-9)
+
+ # Sanity check that the *wrong* ordering ``shift(T(p))`` does NOT
+ # match (unless the transform is degenerate).
+ wrong = np.array(shift.TransformPoint(t.TransformPoint(p)))
+ assert not np.allclose(actual, wrong, atol=1e-4), "Composite accidentally matches the buggy order shift(T(p))"
+
+
+# ---------------------------------------------------------------------------
+# store_transform_in_metadata (skipped unless zarr fixture available)
+# ---------------------------------------------------------------------------
+
+
+class TestStoreTransformInMetadata:
+ """Smoke test: ensure the metadata writer builds a valid affine block."""
+
+ def test_affine_block_written_to_zattrs(self, align_module, tmp_path):
+ import json
+
+ # Create a minimal OME-Zarr v0.4 directory with a .zattrs file.
+ store_path = tmp_path / "test.ome.zarr"
+ store_path.mkdir()
+ initial_attrs = {
+ "multiscales": [
+ {
+ "version": "0.4",
+ "axes": [
+ {"name": "z", "type": "space", "unit": "millimeter"},
+ {"name": "y", "type": "space", "unit": "millimeter"},
+ {"name": "x", "type": "space", "unit": "millimeter"},
+ ],
+ "datasets": [{"path": "0", "coordinateTransformations": []}],
+ }
+ ]
+ }
+ (store_path / ".zattrs").write_text(json.dumps(initial_attrs))
+
+ t = sitk.Euler3DTransform()
+ t.SetTranslation((0.5, 1.5, 2.5))
+
+ align_module.store_transform_in_metadata(str(store_path), t)
+
+ with (store_path / ".zattrs").open() as f:
+ metadata = json.load(f)
+
+ ms = metadata["multiscales"][0]
+ ds = ms["datasets"][0]
+ ctfs = ds["coordinateTransformations"]
+ affines = [c for c in ctfs if c.get("type") == "affine"]
+ assert len(affines) == 1
+ mat = np.array(affines[0]["affine"]).reshape(4, 4)
+ assert mat.shape == (4, 4)
+ np.testing.assert_allclose(mat[3], [0, 0, 0, 1], atol=1e-12)
+ # Translation must be permuted to NGFF (Z, Y, X) ordering.
+ np.testing.assert_allclose(mat[:3, 3], [2.5, 1.5, 0.5], atol=1e-12)