From d2e76a27472ca045f7dcdcb48fd7f132e26605c4 Mon Sep 17 00:00:00 2001 From: Garrett Bischof Date: Tue, 9 Jun 2026 12:47:24 -0400 Subject: [PATCH] feat(preprocess): add spatially_diverse_sample for representative scan subsets Bucket a 2-D scan's bounding box into a ~sqrt(n_target) grid and pick one random point per occupied cell, so a subset covers the whole scan area instead of clustering (a uniform random sample can clump where coincidental symmetry would mis-score an orientation). Returns sorted indices; optional seeded rng for reproducibility. Extracted from holoptycho/scripts/detect_orientation.py (_spatially_diverse_sample) so the orientation-detection CLI imports it; it's a generic position-sampling utility, not orientation-specific. Tested in tests/test_preprocess.py: all-when-target>=n, sorted/in-range, one-per-occupied-cell, spread-not-cluster (far corners always chosen, dense cluster contributes few), and seeded determinism. Co-authored-by: Himanshu Goel <4122621+himanshugoel2797@users.noreply.github.com> --- README.md | 1 + ptychoml/__init__.py | 2 ++ ptychoml/preprocess.py | 36 +++++++++++++++++++++++++ tests/test_preprocess.py | 58 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+) diff --git a/README.md b/README.md index 0fff0fe..eefb629 100644 --- a/README.md +++ b/README.md @@ -155,6 +155,7 @@ Change the spatial extent of frames. Three variants by use case. | `fourier_shift(images, shifts)` | Sub-pixel shift each `(H, W)` plane by `shifts[i] = (dy, dx)` via FFT phase-ramp multiplication. | | `compute_sample_pixel_size(wavelength_m, detector_distance_m, ccd_pixel_size_m, n_pixels)` | Far-field pixel size at the sample plane: `λ z / (N · dx_detector)`. | | `inner_crop_from_probe(probe, threshold=0.5)` | Derive a ViT-patch `inner_crop` from the reconstruction probe: for a circular probe of radius `R` (at `threshold` × peak amplitude) the inscribed square gives `floor(patch/2 − R/√2)`, clamped to `[0, patch//4]`. Pass a complex or real 2-D probe (e.g. an ONNX-baked probe or `PtychoViTInference.baked_probe`). `None` if the probe is empty. | +| `spatially_diverse_sample(x, y, n_target, rng=None)` | Select up to `n_target` spatially-spread indices from 2-D scan positions by bucketing the bounding box into a `√n_target × √n_target` grid and picking one point per occupied cell (so a subset covers the scan area rather than clustering). Returns sorted indices; pass a seeded `rng` for reproducibility. | ## Stitching (`ptychoml.stitch`) diff --git a/ptychoml/__init__.py b/ptychoml/__init__.py index 633d88a..f63b689 100644 --- a/ptychoml/__init__.py +++ b/ptychoml/__init__.py @@ -26,6 +26,7 @@ preprocess_diffraction, remap_positions, resize_diffraction_patterns, + spatially_diverse_sample, zero_pad_to_target, ) from .stitch import ( @@ -71,6 +72,7 @@ "remap_positions", "resize_diffraction_patterns", "save_engine", + "spatially_diverse_sample", "stitch_batch_into", "stitch_batch_livestitch_into", "stitch_batch_nearest", diff --git a/ptychoml/preprocess.py b/ptychoml/preprocess.py index 8fd36a2..d46b1d0 100644 --- a/ptychoml/preprocess.py +++ b/ptychoml/preprocess.py @@ -648,6 +648,42 @@ def remap_positions( return out +def spatially_diverse_sample(x, y, n_target, rng=None): + """Select up to ``n_target`` spatially-spread indices from a 2-D point set. + + Buckets the bounding box of ``(x, y)`` into a roughly ``sqrt(n_target)`` × + ``sqrt(n_target)`` grid and picks one random point per occupied cell, so + the selection covers the whole scan area instead of clustering. (A uniform + random sample can clump in one region, where coincidentally symmetric + structure could make a wrong orientation score well — this guarantees + coverage.) + + ``x`` and ``y`` are 1-D position arrays of equal length. ``rng`` is an + optional ``numpy.random.Generator`` (a fresh default is used if omitted); + pass a seeded one for reproducible selection. Returns a **sorted** array of + selected indices into ``x``/``y``. If ``n_target >= len(x)`` all indices are + returned. The result has one index per *occupied* grid cell, so its length + is approximately — not exactly — ``n_target`` (fewer if points cluster). + + Source: holoptycho/scripts/detect_orientation.py ``_spatially_diverse_sample``. + """ + x = np.asarray(x) + y = np.asarray(y) + n = len(x) + if n_target >= n: + return np.arange(n) + if rng is None: + rng = np.random.default_rng() + grid_n = max(1, int(np.ceil(np.sqrt(n_target)))) + x_edges = np.linspace(x.min(), x.max() + 1e-12, grid_n + 1) + y_edges = np.linspace(y.min(), y.max() + 1e-12, grid_n + 1) + bx = np.clip(np.searchsorted(x_edges, x, side='right') - 1, 0, grid_n - 1) + by = np.clip(np.searchsorted(y_edges, y, side='right') - 1, 0, grid_n - 1) + bucket = by * grid_n + bx + chosen = [rng.choice(np.where(bucket == b)[0]) for b in np.unique(bucket)] + return np.sort(np.array(chosen)) + + def compute_sample_pixel_size( wavelength_m: float, detector_distance_m: float, diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index c18934f..4b988bb 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -20,6 +20,7 @@ preprocess_diffraction, remap_positions, resize_diffraction_patterns, + spatially_diverse_sample, zero_pad_to_target, ) @@ -698,3 +699,60 @@ def test_inner_crop_from_probe_matches_reference_geometry(): support = r[amp >= thr * amp.max()] ref = max(0, min(int(np.floor(n / 2.0 - float(support.max()) / np.sqrt(2))), n // 4)) assert inner_crop_from_probe(amp, thr) == ref + + +# ----- spatially_diverse_sample --------------------------------------------- + +def test_spatially_diverse_sample_returns_all_when_target_exceeds_n(): + x = np.array([0.0, 1.0, 2.0]) + y = np.array([0.0, 1.0, 2.0]) + out = spatially_diverse_sample(x, y, n_target=10, rng=np.random.default_rng(0)) + np.testing.assert_array_equal(out, np.arange(3)) + + +def test_spatially_diverse_sample_is_sorted_and_in_range(): + rng = np.random.default_rng(0) + x = rng.uniform(0, 10, size=200) + y = rng.uniform(0, 10, size=200) + out = spatially_diverse_sample(x, y, n_target=16, rng=np.random.default_rng(1)) + assert np.all(np.diff(out) > 0) # strictly increasing (sorted, unique) + assert out.min() >= 0 and out.max() < 200 + + +def test_spatially_diverse_sample_one_per_occupied_cell(): + # 16 points, one dead-centre in each cell of a 4x4 grid over [0,4]x[0,4]. + # n_target=16 -> grid_n=4 -> every cell occupied -> all 16 selected. + xs, ys = np.meshgrid(np.arange(4) + 0.5, np.arange(4) + 0.5) + x = xs.ravel() + y = ys.ravel() + out = spatially_diverse_sample(x, y, n_target=16, rng=np.random.default_rng(0)) + assert len(out) == 16 + np.testing.assert_array_equal(out, np.arange(16)) + + +def test_spatially_diverse_sample_covers_spread_not_cluster(): + # 100 points jammed in one corner + 4 lone points in the other corners. + # The selection must include the spread-out corners, not just the cluster. + rng = np.random.default_rng(0) + cluster_x = rng.uniform(0, 0.5, size=100) + cluster_y = rng.uniform(0, 0.5, size=100) + corners_x = np.array([9.5, 0.0, 9.5]) + corners_y = np.array([9.5, 9.5, 0.0]) + x = np.concatenate([cluster_x, corners_x]) + y = np.concatenate([cluster_y, corners_y]) + + out = spatially_diverse_sample(x, y, n_target=16, rng=np.random.default_rng(2)) + + # the three far corners (indices 100,101,102) must all be picked + assert set([100, 101, 102]).issubset(set(out.tolist())) + # and the 100-point cluster contributes at most a handful (its few cells) + assert sum(i < 100 for i in out) <= 4 + + +def test_spatially_diverse_sample_deterministic_with_seed(): + rng = np.random.default_rng(0) + x = rng.uniform(0, 10, size=300) + y = rng.uniform(0, 10, size=300) + a = spatially_diverse_sample(x, y, 25, rng=np.random.default_rng(7)) + b = spatially_diverse_sample(x, y, 25, rng=np.random.default_rng(7)) + np.testing.assert_array_equal(a, b)