Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ Change the spatial extent of frames. Three variants by use case.
| `fourier_shift(images, shifts)` | Sub-pixel shift each `(H, W)` plane by `shifts[i] = (dy, dx)` via FFT phase-ramp multiplication. |
| `compute_sample_pixel_size(wavelength_m, detector_distance_m, ccd_pixel_size_m, n_pixels)` | Far-field pixel size at the sample plane: `λ z / (N · dx_detector)`. |
| `inner_crop_from_probe(probe, threshold=0.5)` | Derive a ViT-patch `inner_crop` from the reconstruction probe: for a circular probe of radius `R` (at `threshold` × peak amplitude) the inscribed square gives `floor(patch/2 − R/√2)`, clamped to `[0, patch//4]`. Pass a complex or real 2-D probe (e.g. an ONNX-baked probe or `PtychoViTInference.baked_probe`). `None` if the probe is empty. |
| `spatially_diverse_sample(x, y, n_target, rng=None)` | Select up to `n_target` spatially-spread indices from 2-D scan positions by bucketing the bounding box into a `√n_target × √n_target` grid and picking one point per occupied cell (so a subset covers the scan area rather than clustering). Returns sorted indices; pass a seeded `rng` for reproducibility. |

## Stitching (`ptychoml.stitch`)

Expand Down
2 changes: 2 additions & 0 deletions ptychoml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
preprocess_diffraction,
remap_positions,
resize_diffraction_patterns,
spatially_diverse_sample,
zero_pad_to_target,
)
from .stitch import (
Expand Down Expand Up @@ -71,6 +72,7 @@
"remap_positions",
"resize_diffraction_patterns",
"save_engine",
"spatially_diverse_sample",
"stitch_batch_into",
"stitch_batch_livestitch_into",
"stitch_batch_nearest",
Expand Down
36 changes: 36 additions & 0 deletions ptychoml/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,42 @@ def remap_positions(
return out


def spatially_diverse_sample(x, y, n_target, rng=None):
"""Select up to ``n_target`` spatially-spread indices from a 2-D point set.

Buckets the bounding box of ``(x, y)`` into a roughly ``sqrt(n_target)`` ×
``sqrt(n_target)`` grid and picks one random point per occupied cell, so
the selection covers the whole scan area instead of clustering. (A uniform
random sample can clump in one region, where coincidentally symmetric
structure could make a wrong orientation score well — this guarantees
coverage.)

``x`` and ``y`` are 1-D position arrays of equal length. ``rng`` is an
optional ``numpy.random.Generator`` (a fresh default is used if omitted);
pass a seeded one for reproducible selection. Returns a **sorted** array of
selected indices into ``x``/``y``. If ``n_target >= len(x)`` all indices are
returned. The result has one index per *occupied* grid cell, so its length
is approximately — not exactly — ``n_target`` (fewer if points cluster).

Source: holoptycho/scripts/detect_orientation.py ``_spatially_diverse_sample``.
"""
x = np.asarray(x)
y = np.asarray(y)
n = len(x)
if n_target >= n:
return np.arange(n)
if rng is None:
rng = np.random.default_rng()
grid_n = max(1, int(np.ceil(np.sqrt(n_target))))
x_edges = np.linspace(x.min(), x.max() + 1e-12, grid_n + 1)
y_edges = np.linspace(y.min(), y.max() + 1e-12, grid_n + 1)
bx = np.clip(np.searchsorted(x_edges, x, side='right') - 1, 0, grid_n - 1)
by = np.clip(np.searchsorted(y_edges, y, side='right') - 1, 0, grid_n - 1)
bucket = by * grid_n + bx
chosen = [rng.choice(np.where(bucket == b)[0]) for b in np.unique(bucket)]
return np.sort(np.array(chosen))


def compute_sample_pixel_size(
wavelength_m: float,
detector_distance_m: float,
Expand Down
58 changes: 58 additions & 0 deletions tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
preprocess_diffraction,
remap_positions,
resize_diffraction_patterns,
spatially_diverse_sample,
zero_pad_to_target,
)

Expand Down Expand Up @@ -698,3 +699,60 @@ def test_inner_crop_from_probe_matches_reference_geometry():
support = r[amp >= thr * amp.max()]
ref = max(0, min(int(np.floor(n / 2.0 - float(support.max()) / np.sqrt(2))), n // 4))
assert inner_crop_from_probe(amp, thr) == ref


# ----- spatially_diverse_sample ---------------------------------------------

def test_spatially_diverse_sample_returns_all_when_target_exceeds_n():
x = np.array([0.0, 1.0, 2.0])
y = np.array([0.0, 1.0, 2.0])
out = spatially_diverse_sample(x, y, n_target=10, rng=np.random.default_rng(0))
np.testing.assert_array_equal(out, np.arange(3))


def test_spatially_diverse_sample_is_sorted_and_in_range():
rng = np.random.default_rng(0)
x = rng.uniform(0, 10, size=200)
y = rng.uniform(0, 10, size=200)
out = spatially_diverse_sample(x, y, n_target=16, rng=np.random.default_rng(1))
assert np.all(np.diff(out) > 0) # strictly increasing (sorted, unique)
assert out.min() >= 0 and out.max() < 200


def test_spatially_diverse_sample_one_per_occupied_cell():
# 16 points, one dead-centre in each cell of a 4x4 grid over [0,4]x[0,4].
# n_target=16 -> grid_n=4 -> every cell occupied -> all 16 selected.
xs, ys = np.meshgrid(np.arange(4) + 0.5, np.arange(4) + 0.5)
x = xs.ravel()
y = ys.ravel()
out = spatially_diverse_sample(x, y, n_target=16, rng=np.random.default_rng(0))
assert len(out) == 16
np.testing.assert_array_equal(out, np.arange(16))


def test_spatially_diverse_sample_covers_spread_not_cluster():
# 100 points jammed in one corner + 4 lone points in the other corners.
# The selection must include the spread-out corners, not just the cluster.
rng = np.random.default_rng(0)
cluster_x = rng.uniform(0, 0.5, size=100)
cluster_y = rng.uniform(0, 0.5, size=100)
corners_x = np.array([9.5, 0.0, 9.5])
corners_y = np.array([9.5, 9.5, 0.0])
x = np.concatenate([cluster_x, corners_x])
y = np.concatenate([cluster_y, corners_y])

out = spatially_diverse_sample(x, y, n_target=16, rng=np.random.default_rng(2))

# the three far corners (indices 100,101,102) must all be picked
assert set([100, 101, 102]).issubset(set(out.tolist()))
# and the 100-point cluster contributes at most a handful (its few cells)
assert sum(i < 100 for i in out) <= 4


def test_spatially_diverse_sample_deterministic_with_seed():
rng = np.random.default_rng(0)
x = rng.uniform(0, 10, size=300)
y = rng.uniform(0, 10, size=300)
a = spatially_diverse_sample(x, y, 25, rng=np.random.default_rng(7))
b = spatially_diverse_sample(x, y, 25, rng=np.random.default_rng(7))
np.testing.assert_array_equal(a, b)
Loading