Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,13 @@ original names (zero behavior change). Three strategies:
shift), `stitch_batch_livestitch_into` (nearest-integer + touched bbox),
and `stitch_batch_nearest` (plain nearest-integer, edge-clamped).

`canvas` and `counts` accumulate **in place**; callers normalize as
`canvas / np.maximum(counts, 1)` — these functions never normalize. The
Fourier-shift and livestitch paths flip each patch up-down before
placement; `stitch_batch_nearest` does not. All edge handling clamps to
the canvas (no wrap-around). Pure numpy + `scipy.fft`; depends only on
The accumulators (`stitch_batch_*`) update `canvas`/`counts` **in place**
and never normalize; `normalize_mosaic(canvas, counts, min_overlap=0.5)` is
the companion that averages them into a display mosaic (`canvas / counts`,
under-covered pixels → `NaN`, plus a median `fill_value` for `NaN`-blind
renderers). The Fourier-shift and livestitch paths flip each patch up-down
before placement; `stitch_batch_nearest` does not. All edge handling clamps
to the canvas (no wrap-around). Pure numpy + `scipy.fft`; depends only on
`fourier_shift` from `preprocess.py`.

**Known gotcha — the three strategies are not pixel-interchangeable**
Expand Down
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,21 @@ Change the spatial extent of frames. Three variants by use case.
## Stitching (`ptychoml.stitch`)

Patch-placement helpers that accumulate a batch of reconstructed ViT
patches into a running `(canvas, counts)` mosaic. Both arrays accumulate
in place; the displayed/written mosaic is `canvas / np.maximum(counts, 1)`
(no normalization happens inside these functions — the caller picks the
min-overlap threshold). `positions_px` is `(N, 2)` in canvas pixel
coordinates `(y, x)` pointing at patch centers.
patches into a running `(canvas, counts)` mosaic. The accumulators don't
normalize — `normalize_mosaic(canvas, counts)` turns the pair into the
displayed/written mosaic (`canvas / counts`, under-covered pixels masked to
`NaN`). `positions_px` is `(N, 2)` in canvas pixel coordinates `(y, x)`
pointing at patch centers.

| Function | Purpose |
|---|---|
| `place_patches_fourier_shift(image, positions, patches, pad=1)` | Add patches into `image` with sub-pixel Fourier shifts: over-extract by `pad`, phase-ramp shift by the fractional position, center-crop, scatter-add. Highest placement accuracy. |
| `stitch_batch_into(canvas, counts, patches, positions_px, *, pad=1)` | Accumulate one batch into `(canvas, counts)` using the Fourier-shift path. Scatter-add is associative, so per-batch accumulation matches one-shot stitching (up to FFT noise). |
| `stitch_batch_livestitch_into(canvas, counts, patches, positions_px)` | Nearest-integer accumulation that also returns the `(y0, y1, x0, x1)` bounding box touched this batch — lets a live writer repaint only the changed sub-rectangle. Returns `(0, 0, 0, 0)` when nothing overlapped. |
| `stitch_batch_nearest(canvas, counts, patches, positions_px)` | Plain nearest-integer scatter-add; clamps at canvas edges (no wrap). Simplest variant, handy as a JIT/cache warm-up kernel. |
| `normalize_mosaic(canvas, counts, min_overlap=0.5)` | Average a `(canvas, counts)` pair into a display mosaic: covered pixels (`counts >= min_overlap`) become `canvas / counts`, under-covered pixels become `NaN`. Returns `(fill_value, mosaic)` where `fill_value` is the median of covered pixels (a neutral background for renderers that treat `NaN` as zero). The normalization companion to the `stitch_batch_*` accumulators. |

**The three strategies are not pixel-interchangeable.** The Fourier-shift
**The three placement strategies are not pixel-interchangeable.** The Fourier-shift
and livestitch paths flip each patch up-down before placement (matching
the ptycho-vit convention) while `stitch_batch_nearest` does not, and the
three use slightly different center-rounding conventions (so a patch
Expand All @@ -192,7 +193,7 @@ Allocate the mosaic once, then call the same function each batch

```python
import numpy as np
from ptychoml.stitch import stitch_batch_livestitch_into
from ptychoml.stitch import stitch_batch_livestitch_into, normalize_mosaic

H, W = 2048, 2048
canvas = np.zeros((H, W), dtype=np.float32) # running sum of patch values
Expand All @@ -203,12 +204,12 @@ for patches, positions_px in stream: # patches (B, ph, pw); positions (B
canvas, counts, (y0, y1, x0, x1) = stitch_batch_livestitch_into(
canvas, counts, patches, positions_px,
)
mosaic = canvas / np.maximum(counts, 1) # normalize for display/write
repaint(mosaic[y0:y1, x0:x1]) # bbox = only the region that changed this batch
fill, mosaic = normalize_mosaic(canvas, counts) # under-covered pixels -> NaN
repaint(np.where(np.isnan(mosaic), fill, mosaic)[y0:y1, x0:x1]) # bbox = changed region

# --- Offline / batch: a single call with every patch ---
canvas, counts, _ = stitch_batch_livestitch_into(canvas, counts, all_patches, all_positions)
mosaic = canvas / np.maximum(counts, 1)
fill, mosaic = normalize_mosaic(canvas, counts)
```

For sub-pixel accuracy use `stitch_batch_into` (drop the bbox return);
Expand Down
2 changes: 2 additions & 0 deletions ptychoml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
zero_pad_to_target,
)
from .stitch import (
normalize_mosaic,
place_patches_fourier_shift,
stitch_batch_into,
stitch_batch_livestitch_into,
Expand Down Expand Up @@ -62,6 +63,7 @@
"mask_hot_pixels",
"mask_hot_pixels_by_count",
"normalize_intensity",
"normalize_mosaic",
"place_patches_fourier_shift",
"preprocess_diffraction",
"remap_positions",
Expand Down
33 changes: 33 additions & 0 deletions ptychoml/stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,36 @@ def stitch_batch_nearest(
canvas[sy0:sy1, sx0:sx1] += patches[i, py0:py1, px0:px1]
counts[sy0:sy1, sx0:sx1] += 1.0
return canvas, counts


def normalize_mosaic(canvas, counts, min_overlap=0.5):
"""Average a stitched ``(canvas, counts)`` pair into a display mosaic.

``canvas`` is the running sum of placed patch values and ``counts`` the
running occupancy, both produced by the ``stitch_batch_*`` functions (which
deliberately leave normalization to the caller). A pixel is "covered" when
``counts >= min_overlap``; covered pixels become ``canvas / counts`` and
under-covered pixels become ``NaN``.

Returns ``(fill_value, mosaic)`` where:

* ``mosaic`` is ``float32`` with ``NaN`` in the under-covered regions.
* ``fill_value`` is the median of the covered pixels — a neutral background
for renderers that treat ``NaN`` as zero (paint ``NaN`` with
``fill_value`` before display). ``0.0`` when nothing is covered.

With the default ``min_overlap=0.5`` and integer-valued ``counts``, a pixel
counts as covered once it has been written at least once (count ``>= 1``).
Raise ``min_overlap`` to require multiple overlapping patches.

Source: holoptycho/vit_inference.py ``MosaicWriterOp._normalise_full``.
"""
canvas = np.asarray(canvas)
counts = np.asarray(counts)
valid = counts >= min_overlap
if valid.any():
avg = canvas / np.where(valid, counts, 1.0)
fill = float(np.median(avg[valid]))
mosaic = np.where(valid, avg, np.nan).astype(np.float32)
return fill, mosaic
return 0.0, np.full(canvas.shape, np.nan, dtype=np.float32)
108 changes: 108 additions & 0 deletions tests/test_stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from ptychoml.stitch import (
normalize_mosaic,
place_patches_fourier_shift,
stitch_batch_into,
stitch_batch_livestitch_into,
Expand Down Expand Up @@ -228,3 +229,110 @@ def test_place_patches_fourier_shift_boundary_padding():
assert out.shape == (20, 20)
assert out[10:, 10:].sum() == 0.0 # nothing leaked to the far corner
assert out.sum() > 0.0 # the in-bounds portion was placed


# ----- normalize_mosaic -----------------------------------------------------

def test_normalize_mosaic_averages_covered_pixels():
canvas = np.array([[6.0, 0.0], [9.0, 4.0]], dtype=np.float32)
counts = np.array([[3.0, 0.0], [3.0, 2.0]], dtype=np.float32)

fill, mosaic = normalize_mosaic(canvas, counts, min_overlap=0.5)

# covered pixels = canvas / counts
assert mosaic[0, 0] == pytest.approx(2.0)
assert mosaic[1, 0] == pytest.approx(3.0)
assert mosaic[1, 1] == pytest.approx(2.0)
# the count==0 pixel is under-covered -> NaN
assert np.isnan(mosaic[0, 1])


def test_normalize_mosaic_masks_undercovered_to_nan():
canvas = np.array([[5.0, 5.0]], dtype=np.float32)
counts = np.array([[5.0, 0.0]], dtype=np.float32)

_, mosaic = normalize_mosaic(canvas, counts)

assert mosaic[0, 0] == pytest.approx(1.0)
assert np.isnan(mosaic[0, 1])


def test_normalize_mosaic_fill_is_median_of_covered():
# covered averaged values are 1, 2, 3 -> median 2; the NaN pixel is excluded
canvas = np.array([[1.0, 2.0, 3.0, 7.0]], dtype=np.float32)
counts = np.array([[1.0, 1.0, 1.0, 0.0]], dtype=np.float32)

fill, _ = normalize_mosaic(canvas, counts)

assert fill == pytest.approx(2.0)


def test_normalize_mosaic_threshold_is_inclusive():
# count exactly equal to the threshold counts as covered
canvas = np.array([[4.0]], dtype=np.float32)
counts = np.array([[2.0]], dtype=np.float32)

_, mosaic = normalize_mosaic(canvas, counts, min_overlap=2.0)

assert mosaic[0, 0] == pytest.approx(2.0)
assert not np.isnan(mosaic[0, 0])


def test_normalize_mosaic_higher_threshold_excludes_thin_coverage():
canvas = np.array([[2.0, 6.0]], dtype=np.float32)
counts = np.array([[1.0, 3.0]], dtype=np.float32)

_, mosaic = normalize_mosaic(canvas, counts, min_overlap=2.0)

assert np.isnan(mosaic[0, 0]) # count 1 < 2 -> dropped
assert mosaic[0, 1] == pytest.approx(2.0)


def test_normalize_mosaic_empty_returns_zero_and_all_nan():
canvas = np.zeros((3, 3), dtype=np.float32)
counts = np.zeros((3, 3), dtype=np.float32)

fill, mosaic = normalize_mosaic(canvas, counts)

assert fill == 0.0
assert mosaic.shape == (3, 3)
assert np.isnan(mosaic).all()


def test_normalize_mosaic_returns_float32():
canvas = np.ones((2, 2), dtype=np.float64)
counts = np.ones((2, 2), dtype=np.float64)

_, mosaic = normalize_mosaic(canvas, counts)

assert mosaic.dtype == np.float32


def test_normalize_mosaic_does_not_mutate_inputs():
canvas = np.array([[4.0, 0.0]], dtype=np.float32)
counts = np.array([[2.0, 0.0]], dtype=np.float32)
canvas_ref = canvas.copy()
counts_ref = counts.copy()

normalize_mosaic(canvas, counts)

np.testing.assert_array_equal(canvas, canvas_ref)
np.testing.assert_array_equal(counts, counts_ref)


def test_normalize_mosaic_after_livestitch_recovers_patch_value():
# stitch two non-overlapping all-ones patches, then normalize: covered
# region should read back the patch value (1.0), rest NaN.
canvas = np.zeros((40, 40), dtype=np.float32)
counts = np.zeros((40, 40), dtype=np.float32)
patches = np.ones((2, 6, 6), dtype=np.float32)
positions = np.array([[10.0, 10.0], [10.0, 30.0]])

canvas, counts, _ = stitch_batch_livestitch_into(canvas, counts, patches, positions)
fill, mosaic = normalize_mosaic(canvas, counts)

covered = counts >= 0.5
assert covered.sum() == 2 * 6 * 6
np.testing.assert_allclose(mosaic[covered], 1.0, atol=1e-6)
assert np.isnan(mosaic[~covered]).all()
assert fill == pytest.approx(1.0)
Loading