diff --git a/AGENTS.md b/AGENTS.md index 92213e3..be6c5db 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -89,11 +89,13 @@ original names (zero behavior change). Three strategies: shift), `stitch_batch_livestitch_into` (nearest-integer + touched bbox), and `stitch_batch_nearest` (plain nearest-integer, edge-clamped). -`canvas` and `counts` accumulate **in place**; callers normalize as -`canvas / np.maximum(counts, 1)` — these functions never normalize. The -Fourier-shift and livestitch paths flip each patch up-down before -placement; `stitch_batch_nearest` does not. All edge handling clamps to -the canvas (no wrap-around). Pure numpy + `scipy.fft`; depends only on +The accumulators (`stitch_batch_*`) update `canvas`/`counts` **in place** +and never normalize; `normalize_mosaic(canvas, counts, min_overlap=0.5)` is +the companion that averages them into a display mosaic (`canvas / counts`, +under-covered pixels → `NaN`, plus a median `fill_value` for `NaN`-blind +renderers). The Fourier-shift and livestitch paths flip each patch up-down +before placement; `stitch_batch_nearest` does not. All edge handling clamps +to the canvas (no wrap-around). Pure numpy + `scipy.fft`; depends only on `fourier_shift` from `preprocess.py`. **Known gotcha — the three strategies are not pixel-interchangeable** diff --git a/README.md b/README.md index 06047ae..243ec47 100644 --- a/README.md +++ b/README.md @@ -158,11 +158,11 @@ Change the spatial extent of frames. Three variants by use case. ## Stitching (`ptychoml.stitch`) Patch-placement helpers that accumulate a batch of reconstructed ViT -patches into a running `(canvas, counts)` mosaic. Both arrays accumulate -in place; the displayed/written mosaic is `canvas / np.maximum(counts, 1)` -(no normalization happens inside these functions — the caller picks the -min-overlap threshold). `positions_px` is `(N, 2)` in canvas pixel -coordinates `(y, x)` pointing at patch centers. +patches into a running `(canvas, counts)` mosaic. The accumulators don't +normalize — `normalize_mosaic(canvas, counts)` turns the pair into the +displayed/written mosaic (`canvas / counts`, under-covered pixels masked to +`NaN`). `positions_px` is `(N, 2)` in canvas pixel coordinates `(y, x)` +pointing at patch centers. | Function | Purpose | |---|---| @@ -170,8 +170,9 @@ coordinates `(y, x)` pointing at patch centers. | `stitch_batch_into(canvas, counts, patches, positions_px, *, pad=1)` | Accumulate one batch into `(canvas, counts)` using the Fourier-shift path. Scatter-add is associative, so per-batch accumulation matches one-shot stitching (up to FFT noise). | | `stitch_batch_livestitch_into(canvas, counts, patches, positions_px)` | Nearest-integer accumulation that also returns the `(y0, y1, x0, x1)` bounding box touched this batch — lets a live writer repaint only the changed sub-rectangle. Returns `(0, 0, 0, 0)` when nothing overlapped. | | `stitch_batch_nearest(canvas, counts, patches, positions_px)` | Plain nearest-integer scatter-add; clamps at canvas edges (no wrap). Simplest variant, handy as a JIT/cache warm-up kernel. | +| `normalize_mosaic(canvas, counts, min_overlap=0.5)` | Average a `(canvas, counts)` pair into a display mosaic: covered pixels (`counts >= min_overlap`) become `canvas / counts`, under-covered pixels become `NaN`. Returns `(fill_value, mosaic)` where `fill_value` is the median of covered pixels (a neutral background for renderers that treat `NaN` as zero). The normalization companion to the `stitch_batch_*` accumulators. | -**The three strategies are not pixel-interchangeable.** The Fourier-shift +**The three placement strategies are not pixel-interchangeable.** The Fourier-shift and livestitch paths flip each patch up-down before placement (matching the ptycho-vit convention) while `stitch_batch_nearest` does not, and the three use slightly different center-rounding conventions (so a patch @@ -192,7 +193,7 @@ Allocate the mosaic once, then call the same function each batch ```python import numpy as np -from ptychoml.stitch import stitch_batch_livestitch_into +from ptychoml.stitch import stitch_batch_livestitch_into, normalize_mosaic H, W = 2048, 2048 canvas = np.zeros((H, W), dtype=np.float32) # running sum of patch values @@ -203,12 +204,12 @@ for patches, positions_px in stream: # patches (B, ph, pw); positions (B canvas, counts, (y0, y1, x0, x1) = stitch_batch_livestitch_into( canvas, counts, patches, positions_px, ) - mosaic = canvas / np.maximum(counts, 1) # normalize for display/write - repaint(mosaic[y0:y1, x0:x1]) # bbox = only the region that changed this batch + fill, mosaic = normalize_mosaic(canvas, counts) # under-covered pixels -> NaN + repaint(np.where(np.isnan(mosaic), fill, mosaic)[y0:y1, x0:x1]) # bbox = changed region # --- Offline / batch: a single call with every patch --- canvas, counts, _ = stitch_batch_livestitch_into(canvas, counts, all_patches, all_positions) -mosaic = canvas / np.maximum(counts, 1) +fill, mosaic = normalize_mosaic(canvas, counts) ``` For sub-pixel accuracy use `stitch_batch_into` (drop the bbox return); diff --git a/ptychoml/__init__.py b/ptychoml/__init__.py index 38dda66..778cd0c 100644 --- a/ptychoml/__init__.py +++ b/ptychoml/__init__.py @@ -28,6 +28,7 @@ zero_pad_to_target, ) from .stitch import ( + normalize_mosaic, place_patches_fourier_shift, stitch_batch_into, stitch_batch_livestitch_into, @@ -62,6 +63,7 @@ "mask_hot_pixels", "mask_hot_pixels_by_count", "normalize_intensity", + "normalize_mosaic", "place_patches_fourier_shift", "preprocess_diffraction", "remap_positions", diff --git a/ptychoml/stitch.py b/ptychoml/stitch.py index cb2b4d0..85cf5b3 100644 --- a/ptychoml/stitch.py +++ b/ptychoml/stitch.py @@ -349,3 +349,36 @@ def stitch_batch_nearest( canvas[sy0:sy1, sx0:sx1] += patches[i, py0:py1, px0:px1] counts[sy0:sy1, sx0:sx1] += 1.0 return canvas, counts + + +def normalize_mosaic(canvas, counts, min_overlap=0.5): + """Average a stitched ``(canvas, counts)`` pair into a display mosaic. + + ``canvas`` is the running sum of placed patch values and ``counts`` the + running occupancy, both produced by the ``stitch_batch_*`` functions (which + deliberately leave normalization to the caller). A pixel is "covered" when + ``counts >= min_overlap``; covered pixels become ``canvas / counts`` and + under-covered pixels become ``NaN``. + + Returns ``(fill_value, mosaic)`` where: + + * ``mosaic`` is ``float32`` with ``NaN`` in the under-covered regions. + * ``fill_value`` is the median of the covered pixels — a neutral background + for renderers that treat ``NaN`` as zero (paint ``NaN`` with + ``fill_value`` before display). ``0.0`` when nothing is covered. + + With the default ``min_overlap=0.5`` and integer-valued ``counts``, a pixel + counts as covered once it has been written at least once (count ``>= 1``). + Raise ``min_overlap`` to require multiple overlapping patches. + + Source: holoptycho/vit_inference.py ``MosaicWriterOp._normalise_full``. + """ + canvas = np.asarray(canvas) + counts = np.asarray(counts) + valid = counts >= min_overlap + if valid.any(): + avg = canvas / np.where(valid, counts, 1.0) + fill = float(np.median(avg[valid])) + mosaic = np.where(valid, avg, np.nan).astype(np.float32) + return fill, mosaic + return 0.0, np.full(canvas.shape, np.nan, dtype=np.float32) diff --git a/tests/test_stitch.py b/tests/test_stitch.py index 020783a..186b9a0 100644 --- a/tests/test_stitch.py +++ b/tests/test_stitch.py @@ -3,6 +3,7 @@ import pytest from ptychoml.stitch import ( + normalize_mosaic, place_patches_fourier_shift, stitch_batch_into, stitch_batch_livestitch_into, @@ -228,3 +229,110 @@ def test_place_patches_fourier_shift_boundary_padding(): assert out.shape == (20, 20) assert out[10:, 10:].sum() == 0.0 # nothing leaked to the far corner assert out.sum() > 0.0 # the in-bounds portion was placed + + +# ----- normalize_mosaic ----------------------------------------------------- + +def test_normalize_mosaic_averages_covered_pixels(): + canvas = np.array([[6.0, 0.0], [9.0, 4.0]], dtype=np.float32) + counts = np.array([[3.0, 0.0], [3.0, 2.0]], dtype=np.float32) + + fill, mosaic = normalize_mosaic(canvas, counts, min_overlap=0.5) + + # covered pixels = canvas / counts + assert mosaic[0, 0] == pytest.approx(2.0) + assert mosaic[1, 0] == pytest.approx(3.0) + assert mosaic[1, 1] == pytest.approx(2.0) + # the count==0 pixel is under-covered -> NaN + assert np.isnan(mosaic[0, 1]) + + +def test_normalize_mosaic_masks_undercovered_to_nan(): + canvas = np.array([[5.0, 5.0]], dtype=np.float32) + counts = np.array([[5.0, 0.0]], dtype=np.float32) + + _, mosaic = normalize_mosaic(canvas, counts) + + assert mosaic[0, 0] == pytest.approx(1.0) + assert np.isnan(mosaic[0, 1]) + + +def test_normalize_mosaic_fill_is_median_of_covered(): + # covered averaged values are 1, 2, 3 -> median 2; the NaN pixel is excluded + canvas = np.array([[1.0, 2.0, 3.0, 7.0]], dtype=np.float32) + counts = np.array([[1.0, 1.0, 1.0, 0.0]], dtype=np.float32) + + fill, _ = normalize_mosaic(canvas, counts) + + assert fill == pytest.approx(2.0) + + +def test_normalize_mosaic_threshold_is_inclusive(): + # count exactly equal to the threshold counts as covered + canvas = np.array([[4.0]], dtype=np.float32) + counts = np.array([[2.0]], dtype=np.float32) + + _, mosaic = normalize_mosaic(canvas, counts, min_overlap=2.0) + + assert mosaic[0, 0] == pytest.approx(2.0) + assert not np.isnan(mosaic[0, 0]) + + +def test_normalize_mosaic_higher_threshold_excludes_thin_coverage(): + canvas = np.array([[2.0, 6.0]], dtype=np.float32) + counts = np.array([[1.0, 3.0]], dtype=np.float32) + + _, mosaic = normalize_mosaic(canvas, counts, min_overlap=2.0) + + assert np.isnan(mosaic[0, 0]) # count 1 < 2 -> dropped + assert mosaic[0, 1] == pytest.approx(2.0) + + +def test_normalize_mosaic_empty_returns_zero_and_all_nan(): + canvas = np.zeros((3, 3), dtype=np.float32) + counts = np.zeros((3, 3), dtype=np.float32) + + fill, mosaic = normalize_mosaic(canvas, counts) + + assert fill == 0.0 + assert mosaic.shape == (3, 3) + assert np.isnan(mosaic).all() + + +def test_normalize_mosaic_returns_float32(): + canvas = np.ones((2, 2), dtype=np.float64) + counts = np.ones((2, 2), dtype=np.float64) + + _, mosaic = normalize_mosaic(canvas, counts) + + assert mosaic.dtype == np.float32 + + +def test_normalize_mosaic_does_not_mutate_inputs(): + canvas = np.array([[4.0, 0.0]], dtype=np.float32) + counts = np.array([[2.0, 0.0]], dtype=np.float32) + canvas_ref = canvas.copy() + counts_ref = counts.copy() + + normalize_mosaic(canvas, counts) + + np.testing.assert_array_equal(canvas, canvas_ref) + np.testing.assert_array_equal(counts, counts_ref) + + +def test_normalize_mosaic_after_livestitch_recovers_patch_value(): + # stitch two non-overlapping all-ones patches, then normalize: covered + # region should read back the patch value (1.0), rest NaN. + canvas = np.zeros((40, 40), dtype=np.float32) + counts = np.zeros((40, 40), dtype=np.float32) + patches = np.ones((2, 6, 6), dtype=np.float32) + positions = np.array([[10.0, 10.0], [10.0, 30.0]]) + + canvas, counts, _ = stitch_batch_livestitch_into(canvas, counts, patches, positions) + fill, mosaic = normalize_mosaic(canvas, counts) + + covered = counts >= 0.5 + assert covered.sum() == 2 * 6 * 6 + np.testing.assert_allclose(mosaic[covered], 1.0, atol=1e-6) + assert np.isnan(mosaic[~covered]).all() + assert fill == pytest.approx(1.0)