From 913aa4fa2bb38d2b5b9a107ea2f8eaf120a356c2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 20:43:13 +0000 Subject: [PATCH 1/7] Initial plan From ab20c0fdfeae72dfbc41180bfbc420bebfbd95a0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 20:55:12 +0000 Subject: [PATCH 2/7] Add JAX compatibility to pted and pted_coverage_test Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> --- src/pted/pted.py | 15 +++++++ src/pted/utils.py | 106 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_pted.py | 62 ++++++++++++++++++++++++++ 3 files changed, 183 insertions(+) diff --git a/src/pted/pted.py b/src/pted/pted.py index 9d401ee..b2fc1ef 100644 --- a/src/pted/pted.py +++ b/src/pted/pted.py @@ -4,10 +4,13 @@ from .utils import ( is_torch_tensor, + is_jax_array, pted_torch, pted_numpy, pted_chunk_torch, pted_chunk_numpy, + pted_jax, + pted_chunk_jax, two_tailed_p, confidence_alert, simulation_based_calibration_histogram, @@ -140,6 +143,18 @@ def pted( ) elif is_torch_tensor(x): test, permute = pted_torch(x, y, permutations=permutations, metric=metric, prog_bar=prog_bar) + elif is_jax_array(x) and chunk_size is not None: + test, permute = pted_chunk_jax( + x, + y, + permutations=permutations, + metric=metric, + chunk_size=int(chunk_size), + chunk_iter=int(chunk_iter), + prog_bar=prog_bar, + ) + elif is_jax_array(x): + test, permute = pted_jax(x, y, permutations=permutations, metric=metric, prog_bar=prog_bar) elif chunk_size is not None: test, permute = pted_chunk_numpy( x, diff --git a/src/pted/utils.py b/src/pted/utils.py index d00e0ec..5e7ee54 100644 --- a/src/pted/utils.py +++ b/src/pted/utils.py @@ -16,12 +16,23 @@ class torch: Tensor = np.ndarray +try: + import jax + import jax.numpy as jnp +except ImportError: + jax = None + jnp = None + + __all__ = ( "is_torch_tensor", + "is_jax_array", "pted_numpy", "pted_chunk_numpy", "pted_torch", "pted_chunk_torch", + "pted_jax", + "pted_chunk_jax", "two_tailed_p", "confidence_alert", "simulation_based_calibration_histogram", @@ -39,6 +50,12 @@ def is_torch_tensor(o): ) +def is_jax_array(o): + if jax is None: + return False + return isinstance(o, jax.Array) + + def _energy_distance_precompute( D: Union[np.ndarray, torch.Tensor], nx: int, ny: int ) -> Union[float, torch.Tensor]: @@ -110,6 +127,42 @@ def _energy_distance_estimate_torch( return np.mean(E_est) +def _jax_cdist(x, y, p: float = 2.0): + diff = x[:, None, :] - y[None, :, :] + return jnp.linalg.norm(diff, ord=p, axis=-1) + + +def _energy_distance_jax(x, y, metric: Union[str, float] = "euclidean") -> float: + nx = len(x) + ny = len(y) + z = jnp.concatenate([x, y], axis=0) + if metric == "euclidean": + metric = 2.0 + D = _jax_cdist(z, z, p=metric) + return float(_energy_distance_precompute(D, nx, ny)) + + +def _energy_distance_estimate_jax( + x, + y, + chunk_size: int, + chunk_iter: int, + metric: Union[str, float] = "euclidean", +) -> float: + + E_est = [] + for _ in range(chunk_iter): + # Randomly sample a chunk of data + idx = np.random.choice(len(x), size=min(len(x), chunk_size), replace=False) + x_chunk = x[idx] + idy = np.random.choice(len(y), size=min(len(y), chunk_size), replace=False) + y_chunk = y[idy] + + # Compute the energy distance + E_est.append(_energy_distance_jax(x_chunk, y_chunk, metric=metric)) + return np.mean(E_est) + + def pted_chunk_numpy( x: np.ndarray, y: np.ndarray, @@ -210,6 +263,59 @@ def pted_torch( return test_stat, permute_stats +def pted_jax( + x, + y, + permutations: int = 100, + metric: Union[str, float] = "euclidean", + prog_bar: bool = False, +) -> tuple[float, list[float]]: + assert jax is not None, "JAX is not installed! try: `pip install jax`" + z = jnp.concatenate([x, y], axis=0) + assert jnp.all(jnp.isfinite(z)), "Input contains NaN or Inf!" + if metric == "euclidean": + metric = 2.0 + dmatrix = _jax_cdist(z, z, p=metric) + assert jnp.all( + jnp.isfinite(dmatrix) + ), "Distance matrix contains NaN or Inf! Consider using a different metric or normalizing values to be more stable (i.e. z-score norm)." + nx = len(x) + ny = len(y) + + test_stat = float(_energy_distance_precompute(dmatrix, nx, ny)) + permute_stats = [] + for _ in trange(permutations, disable=not prog_bar): + I = np.random.permutation(len(z)) + dmatrix = dmatrix[I][:, I] + permute_stats.append(float(_energy_distance_precompute(dmatrix, nx, ny))) + return test_stat, permute_stats + + +def pted_chunk_jax( + x, + y, + permutations: int = 100, + metric: Union[str, float] = "euclidean", + chunk_size: int = 100, + chunk_iter: int = 10, + prog_bar: bool = False, +) -> tuple[float, list[float]]: + assert jax is not None, "JAX is not installed! try: `pip install jax`" + assert jnp.all(jnp.isfinite(x)) and jnp.all(jnp.isfinite(y)), "Input contains NaN or Inf!" + nx = len(x) + + test_stat = _energy_distance_estimate_jax(x, y, chunk_size, chunk_iter, metric=metric) + permute_stats = [] + for _ in trange(permutations, disable=not prog_bar): + z = jnp.concatenate([x, y], axis=0) + z = z[np.random.permutation(len(z))] + x, y = z[:nx], z[nx:] + permute_stats.append( + _energy_distance_estimate_jax(x, y, chunk_size, chunk_iter, metric=metric) + ) + return test_stat, permute_stats + + def two_tailed_p(chi2, df): assert df > 2, "Degrees of freedom must be greater than 2 for two-tailed p-value calculation." alpha = chi2_dist.pdf(chi2, df) diff --git a/tests/test_pted.py b/tests/test_pted.py index abaaa04..4153716 100644 --- a/tests/test_pted.py +++ b/tests/test_pted.py @@ -6,6 +6,14 @@ import torch except ImportError: torch = None + +try: + import jax + import jax.numpy as jnp +except ImportError: + jax = None + jnp = None + import numpy as np import pytest @@ -151,3 +159,57 @@ def test_sbc_histogram(): pted.pted_coverage_test(g, s, permutations=100, sbc_histogram="sbc_hist.pdf") os.remove("sbc_hist.pdf") + + +def test_pted_jax(): + if jax is None: + pytest.skip("jax not installed") + + # Set the random seed for reproducibility + np.random.seed(42) + + # example 2 sample test + D = 300 + for _ in range(20): + x = jnp.array(np.random.normal(size=(100, D))) + y = jnp.array(np.random.normal(size=(100, D))) + p = pted.pted(x, y) + assert p > 1e-2 and p < 0.99, f"p-value {p} is not in the expected range (U(0,1))" + + x = jnp.array(np.random.normal(size=(100, D))) + y = jnp.array(np.random.uniform(size=(100, D))) + p = pted.pted(x, y, two_tailed=False) + assert p < 1e-2, f"p-value {p} is not in the expected range (~0)" + + x = jnp.array(np.random.normal(size=(100, D))) + t, p, _ = pted.pted(x, x, return_all=True) + q = 2 * min(np.sum(p > t), np.sum(p < t)) + p = (1 + q) / (len(p) + 1) + assert p < 1e-2, f"p-value {p} is not in the expected range (~0)" + + +def test_pted_chunk_jax(): + if jax is None: + pytest.skip("jax not installed") + np.random.seed(42) + + # example 2 sample test + D = 10 + x = jnp.array(np.random.normal(size=(1000, D))) + y = jnp.array(np.random.normal(size=(1000, D))) + p = pted.pted(x, y, chunk_size=100, chunk_iter=10) + assert p > 1e-2 and p < 0.99, f"p-value {p} is not in the expected range (U(0,1))" + + y = jnp.array(np.random.uniform(size=(1000, D))) + p = pted.pted(x, y, chunk_size=100, chunk_iter=10) + assert p < 1e-2, f"p-value {p} is not in the expected range (~0)" + + +def test_pted_coverage_jax(): + if jax is None: + pytest.skip("jax not installed") + + g = jnp.array(np.random.normal(size=(100, 10))) + s = jnp.array(np.random.normal(size=(50, 100, 10))) + p = pted.pted_coverage_test(g, s) + assert p > 1e-2 and p < 0.99, f"p-value {p} is not in the expected range (U(0,1))" From 5e17977dc352b38f08d6bac47d5f8cdad6732529 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 21:10:11 +0000 Subject: [PATCH 3/7] Add jax as a dev dependency Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 7e8a87c..b1ea425 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dev = [ "pytest-cov>=4.1,<5", "pytest-mock>=3.12,<4", "torch>=2.0,<3", + "jax>=0.4,<1", "matplotlib", ] torch = [ From 761cb4c7d1c8adc2a81b64fdccc3e5dc68f1da6c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 21:43:06 +0000 Subject: [PATCH 4/7] Update README for JAX support and add utils unit tests with mocking Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> --- README.md | 43 +++++++++++++++++------- tests/test_pted.py | 84 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index fcddd20..50f660b 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,12 @@ If you want to run PTED on GPUs using PyTorch, then also install torch: pip install torch ``` +If you want to use JAX arrays as inputs, then also install jax: + +```bash +pip install jax +``` + The two functions are ``pted.pted`` and ``pted.pted_coverage_test``. For information about each argument, just use ``help(pted.pted)`` or ``help(pted.pted_coverage_test)``. @@ -261,8 +267,8 @@ results you are getting! ```python def pted( - x: Union[np.ndarray, "Tensor"], - y: Union[np.ndarray, "Tensor"], + x: Union[np.ndarray, "Tensor", "jax.Array"], + y: Union[np.ndarray, "Tensor", "jax.Array"], permutations: int = 1000, metric: Union[str, float] = "euclidean", return_all: bool = False, @@ -273,10 +279,10 @@ def pted( ) -> Union[float, tuple[float, np.ndarray, float]]: ``` -* **x** *(Union[np.ndarray, Tensor])*: first set of samples. Shape (N, *D) -* **y** *(Union[np.ndarray, Tensor])*: second set of samples. Shape (M, *D) +* **x** *(Union[np.ndarray, Tensor, jax.Array])*: first set of samples. Shape (N, *D) +* **y** *(Union[np.ndarray, Tensor, jax.Array])*: second set of samples. Shape (M, *D) * **permutations** *(int)*: number of permutations to run. This determines how accurately the p-value is computed. -* **metric** *(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf. +* **metric** *(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf. When using JAX arrays, the metric is passed as the "ord" for jnp.linalg.norm and therefore is also a float. * **return_all** *(bool)*: if True, return the test statistic and the permuted statistics with the p-value. If False, just return the p-value. bool (default: False) * **chunk_size** *(Optional[int])*: if not None, use chunked energy distance estimation. This is useful for large datasets. The chunk size is the number of samples to use for each chunk. If None, use the full dataset. * **chunk_iter** *(Optional[int])*: The chunk iter is the number of iterations to use with the given chunk size. @@ -287,8 +293,8 @@ def pted( ```python def pted_coverage_test( - g: Union[np.ndarray, "Tensor"], - s: Union[np.ndarray, "Tensor"], + g: Union[np.ndarray, "Tensor", "jax.Array"], + s: Union[np.ndarray, "Tensor", "jax.Array"], permutations: int = 1000, metric: Union[str, float] = "euclidean", warn_confidence: Optional[float] = 1e-3, @@ -301,10 +307,10 @@ def pted_coverage_test( ) -> Union[float, tuple[np.ndarray, np.ndarray, float]]: ``` -* **g** *(Union[np.ndarray, Tensor])*: Ground truth samples. Shape (n_sims, *D) -* **s** *(Union[np.ndarray, Tensor])*: Posterior samples. Shape (n_samples, n_sims, *D) +* **g** *(Union[np.ndarray, Tensor, jax.Array])*: Ground truth samples. Shape (n_sims, *D) +* **s** *(Union[np.ndarray, Tensor, jax.Array])*: Posterior samples. Shape (n_samples, n_sims, *D) * **permutations** *(int)*: number of permutations to run. This determines how accurately the p-value is computed. -* **metric** *(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf. +* **metric** *(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf. When using JAX arrays, the metric is passed as the "ord" for jnp.linalg.norm and therefore is also a float. * **return_all** *(bool)*: if True, return the test statistic and the permuted statistics with the p-value. If False, just return the p-value. bool (default: False) * **chunk_size** *(Optional[int])*: if not None, use chunked energy distance estimation. This is useful for large datasets. The chunk size is the number of samples to use for each chunk. If None, use the full dataset. * **chunk_iter** *(Optional[int])*: The chunk iter is the number of iterations to use with the given chunk size. @@ -315,9 +321,9 @@ def pted_coverage_test( ## GPU Compatibility PTED works on both CPU and GPU. All that is needed is to pass the `x` and `y` as -PyTorch Tensors on the appropriate device. +PyTorch Tensors or JAX Arrays on the appropriate device. -Example: +Example with PyTorch: ```python from pted import pted import numpy as np @@ -330,6 +336,19 @@ p_value = pted(torch.tensor(x), torch.tensor(y)) print(f"p-value: {p_value:.3f}") # expect uniform random from 0-1 ``` +Example with JAX: +```python +from pted import pted +import numpy as np +import jax.numpy as jnp + +x = np.random.normal(size = (500, 10)) # (n_samples_x, n_dimensions) +y = np.random.normal(size = (400, 10)) # (n_samples_y, n_dimensions) + +p_value = pted(jnp.array(x), jnp.array(y)) +print(f"p-value: {p_value:.3f}") # expect uniform random from 0-1 +``` + ## Memory and Compute limitations If a GPU isn't enough to get PTED running fast enough for you, or if you are diff --git a/tests/test_pted.py b/tests/test_pted.py index 4153716..915e046 100644 --- a/tests/test_pted.py +++ b/tests/test_pted.py @@ -1,4 +1,5 @@ import os +import types import pted @@ -213,3 +214,86 @@ def test_pted_coverage_jax(): s = jnp.array(np.random.normal(size=(50, 100, 10))) p = pted.pted_coverage_test(g, s) assert p > 1e-2 and p < 0.99, f"p-value {p} is not in the expected range (U(0,1))" + + +# --------------------------------------------------------------------------- +# Unit tests for newly-added utils functions +# --------------------------------------------------------------------------- + + +def test_is_jax_array_with_jax(): + """is_jax_array returns True for a real JAX array and False for other types.""" + if jax is None: + pytest.skip("jax not installed") + assert pted.utils.is_jax_array(jnp.zeros(3)) is True + assert pted.utils.is_jax_array(np.zeros(3)) is False + assert pted.utils.is_jax_array(42) is False + + +def test_is_jax_array_no_jax(monkeypatch): + """is_jax_array returns False when JAX is not installed.""" + monkeypatch.setattr("pted.utils.jax", None) + assert pted.utils.is_jax_array(42) is False + + +def test_jax_cdist(): + """_jax_cdist produces correct pairwise Euclidean distances.""" + if jax is None: + pytest.skip("jax not installed") + x = jnp.array([[0.0, 0.0], [3.0, 4.0]]) + y = jnp.array([[0.0, 0.0], [1.0, 0.0]]) + D = pted.utils._jax_cdist(x, y) + assert D.shape == (2, 2) + assert float(D[0, 1]) == pytest.approx(1.0) + assert float(D[1, 0]) == pytest.approx(5.0) + + +def test_energy_distance_jax(): + """_energy_distance_jax returns 0 when x and y are identical.""" + if jax is None: + pytest.skip("jax not installed") + x = jnp.array(np.random.normal(size=(50, 5))) + # Identical samples → energy distance should be ~0 + ed = pted.utils._energy_distance_jax(x, x) + assert abs(ed) < 1e-6 + + +def test_energy_distance_estimate_jax(): + """_energy_distance_estimate_jax returns a finite scalar.""" + if jax is None: + pytest.skip("jax not installed") + np.random.seed(0) + x = jnp.array(np.random.normal(size=(100, 4))) + y = jnp.array(np.random.normal(size=(100, 4))) + ed = pted.utils._energy_distance_estimate_jax(x, y, chunk_size=20, chunk_iter=5) + assert np.isfinite(ed) + + +def test_pted_jax_no_jax(monkeypatch): + """pted_jax raises AssertionError when JAX is not installed.""" + monkeypatch.setattr("pted.utils.jax", None) + with pytest.raises(AssertionError, match="JAX is not installed"): + pted.utils.pted_jax(np.zeros((5, 2)), np.zeros((5, 2))) + + +def test_pted_chunk_jax_no_jax(monkeypatch): + """pted_chunk_jax raises AssertionError when JAX is not installed.""" + monkeypatch.setattr("pted.utils.jax", None) + with pytest.raises(AssertionError, match="JAX is not installed"): + pted.utils.pted_chunk_jax(np.zeros((5, 2)), np.zeros((5, 2))) + + +def test_pted_torch_no_torch(monkeypatch): + """pted_torch raises AssertionError when torch is not installed.""" + fake_torch = types.SimpleNamespace(__version__="null") + monkeypatch.setattr("pted.utils.torch", fake_torch) + with pytest.raises(AssertionError, match="PyTorch is not installed"): + pted.utils.pted_torch(np.zeros((5, 2)), np.zeros((5, 2))) + + +def test_pted_chunk_torch_no_torch(monkeypatch): + """pted_chunk_torch raises AssertionError when torch is not installed.""" + fake_torch = types.SimpleNamespace(__version__="null") + monkeypatch.setattr("pted.utils.torch", fake_torch) + with pytest.raises(AssertionError, match="PyTorch is not installed"): + pted.utils.pted_chunk_torch(np.zeros((5, 2)), np.zeros((5, 2))) From be230a46fad811b5460f6d1ab729cee9002b5b92 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 03:52:17 +0000 Subject: [PATCH 5/7] Optimize _jax_cdist: squared-norm identity for L2, vmap for general p Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> --- src/pted/utils.py | 11 +++++++++-- tests/test_pted.py | 13 +++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/pted/utils.py b/src/pted/utils.py index 5e7ee54..a450a61 100644 --- a/src/pted/utils.py +++ b/src/pted/utils.py @@ -128,8 +128,15 @@ def _energy_distance_estimate_torch( def _jax_cdist(x, y, p: float = 2.0): - diff = x[:, None, :] - y[None, :, :] - return jnp.linalg.norm(diff, ord=p, axis=-1) + if p == 2.0: + # Squared-norm identity avoids materializing the (nx, ny, d) diff tensor. + # ||x_i - y_j||^2 = ||x_i||^2 + ||y_j||^2 - 2 * x_i . y_j + x_sq = jnp.sum(x ** 2, axis=-1) # (nx,) + y_sq = jnp.sum(y ** 2, axis=-1) # (ny,) + sq_dist = x_sq[:, None] + y_sq[None, :] - 2.0 * (x @ y.T) + return jnp.sqrt(jnp.maximum(sq_dist, 0.0)) + # For general p-norms use vmap to avoid the (nx, ny, d) intermediate. + return jax.vmap(lambda xi: jnp.linalg.norm(xi - y, ord=p, axis=-1))(x) def _energy_distance_jax(x, y, metric: Union[str, float] = "euclidean") -> float: diff --git a/tests/test_pted.py b/tests/test_pted.py index 915e046..d356607 100644 --- a/tests/test_pted.py +++ b/tests/test_pted.py @@ -248,6 +248,19 @@ def test_jax_cdist(): assert float(D[1, 0]) == pytest.approx(5.0) +def test_jax_cdist_non_euclidean(): + """_jax_cdist produces correct pairwise distances for p != 2 (vmap path).""" + if jax is None: + pytest.skip("jax not installed") + x = jnp.array([[0.0, 0.0], [3.0, 4.0]]) + y = jnp.array([[0.0, 0.0], [1.0, 0.0]]) + # L1: d(x[0], y[1]) = |0-1| + |0-0| = 1; d(x[1], y[0]) = |3-0| + |4-0| = 7 + D = pted.utils._jax_cdist(x, y, p=1.0) + assert D.shape == (2, 2) + assert float(D[0, 1]) == pytest.approx(1.0) + assert float(D[1, 0]) == pytest.approx(7.0) + + def test_energy_distance_jax(): """_energy_distance_jax returns 0 when x and y are identical.""" if jax is None: From b601c3471b3ea300a9578be0f94a95e4131a20e3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 03:57:57 +0000 Subject: [PATCH 6/7] Update pted() signatures/docstrings for jax.Array; add jax optional extra in pyproject.toml Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> --- pyproject.toml | 3 +++ src/pted/pted.py | 38 ++++++++++++++++++++++---------------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b1ea425..0b3aee1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,9 @@ dev = [ torch = [ "torch>=2.0,<3", ] +jax = [ + "jax>=0.4,<1", +] [tool.hatch.metadata.hooks.requirements_txt] files = ["requirements.txt"] diff --git a/src/pted/pted.py b/src/pted/pted.py index b2fc1ef..59441ee 100644 --- a/src/pted/pted.py +++ b/src/pted/pted.py @@ -20,8 +20,8 @@ def pted( - x: Union[np.ndarray, "Tensor"], - y: Union[np.ndarray, "Tensor"], + x: Union[np.ndarray, "Tensor", "jax.Array"], + y: Union[np.ndarray, "Tensor", "jax.Array"], permutations: int = 1000, metric: Union[str, float] = "euclidean", return_all: bool = False, @@ -75,14 +75,17 @@ def pted( Parameters ---------- - x (Union[np.ndarray, Tensor]): first set of samples. Shape (N, *D) - y (Union[np.ndarray, Tensor]): second set of samples. Shape (M, *D) + x (Union[np.ndarray, Tensor, jax.Array]): first set of samples. Shape (N, *D) + y (Union[np.ndarray, Tensor, jax.Array]): second set of samples. Shape (M, *D) permutations (int): number of permutations to run. This determines how accurately the p-value is computed. - metric (Union[str, float]): distance metric to use. See scipy.spatial.distance.cdist - for the list of available metrics with numpy. See torch.cdist when - using PyTorch, note that the metric is passed as the "p" for - torch.cdist and therefore is a float from 0 to inf. + metric (Union[str, float]): distance metric to use. For NumPy inputs, + see scipy.spatial.distance.cdist for available metrics. For PyTorch + inputs, the metric is passed as the "p" argument to torch.cdist and + therefore is a float from 0 to inf. For JAX inputs, "euclidean" uses + the squared-norm identity (p=2), and any float p uses + jnp.linalg.norm with ord=p; string metrics other than "euclidean" + are not supported for JAX. return_all (bool): if True, return the test statistic and the permuted statistics with the p-value. If False, just return the p-value. bool (default: False) @@ -185,8 +188,8 @@ def pted( def pted_coverage_test( - g: Union[np.ndarray, "Tensor"], - s: Union[np.ndarray, "Tensor"], + g: Union[np.ndarray, "Tensor", "jax.Array"], + s: Union[np.ndarray, "Tensor", "jax.Array"], permutations: int = 1000, metric: Union[str, float] = "euclidean", warn_confidence: Optional[float] = 1e-3, @@ -243,14 +246,17 @@ def pted_coverage_test( Parameters ---------- - g (Union[np.ndarray, Tensor]): Ground truth samples. Shape (n_sims, *D) - s (Union[np.ndarray, Tensor]): Posterior samples. Shape (n_samples, n_sims, *D) + g (Union[np.ndarray, Tensor, jax.Array]): Ground truth samples. Shape (n_sims, *D) + s (Union[np.ndarray, Tensor, jax.Array]): Posterior samples. Shape (n_samples, n_sims, *D) permutations (int): number of permutations to run. This determines how accurately the p-value is computed. - metric (Union[str, float]): distance metric to use. See scipy.spatial.distance.cdist - for the list of available metrics with numpy. See torch.cdist when using - PyTorch, note that the metric is passed as the "p" for torch.cdist and - therefore is a float from 0 to inf. + metric (Union[str, float]): distance metric to use. For NumPy inputs, + see scipy.spatial.distance.cdist for available metrics. For PyTorch + inputs, the metric is passed as the "p" argument to torch.cdist and + therefore is a float from 0 to inf. For JAX inputs, "euclidean" uses + the squared-norm identity (p=2), and any float p uses + jnp.linalg.norm with ord=p; string metrics other than "euclidean" + are not supported for JAX. return_all (bool): if True, return the test statistic and the permuted statistics with the p-value. If False, just return the p-value. bool (default: False) From ed525f40b1bc3c7457a43a13e0584cba5b2317f5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 04:37:15 +0000 Subject: [PATCH 7/7] Add cross-backend consistency tests for numpy/torch/JAX energy distance utils Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> --- tests/test_pted.py | 79 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/tests/test_pted.py b/tests/test_pted.py index d356607..13b172b 100644 --- a/tests/test_pted.py +++ b/tests/test_pted.py @@ -310,3 +310,82 @@ def test_pted_chunk_torch_no_torch(monkeypatch): monkeypatch.setattr("pted.utils.torch", fake_torch) with pytest.raises(AssertionError, match="PyTorch is not installed"): pted.utils.pted_chunk_torch(np.zeros((5, 2)), np.zeros((5, 2))) + + +# --------------------------------------------------------------------------- +# Cross-backend consistency tests +# --------------------------------------------------------------------------- + + +def test_jax_cdist_matches_scipy(): + """_jax_cdist (L2) and scipy cdist produce the same pairwise distances.""" + if jax is None: + pytest.skip("jax not installed") + from scipy.spatial.distance import cdist as scipy_cdist + + np.random.seed(7) + x_np = np.random.normal(size=(10, 4)).astype(np.float32) + y_np = np.random.normal(size=(8, 4)).astype(np.float32) + + expected = scipy_cdist(x_np, y_np, metric="euclidean") + got = np.array(pted.utils._jax_cdist(jnp.array(x_np), jnp.array(y_np))) + np.testing.assert_allclose(got, expected, rtol=1e-5) + + +def test_energy_distance_numpy_torch_jax_agree(): + """_energy_distance_{numpy,torch,jax} return the same value for identical inputs.""" + if torch is None: + pytest.skip("torch not installed") + if jax is None: + pytest.skip("jax not installed") + + np.random.seed(99) + # Use float32 so all backends operate at the same precision + # (JAX uses float32 by default) + x_np = np.random.normal(size=(30, 5)).astype(np.float32) + y_np = np.random.normal(size=(30, 5)).astype(np.float32) + + ed_numpy = pted.utils._energy_distance_numpy(x_np, y_np) + ed_torch = pted.utils._energy_distance_torch( + torch.tensor(x_np), torch.tensor(y_np) + ) + ed_jax = pted.utils._energy_distance_jax(jnp.array(x_np), jnp.array(y_np)) + + assert ed_numpy == pytest.approx(ed_torch, rel=1e-4), ( + f"numpy ({ed_numpy}) and torch ({ed_torch}) energy distances differ" + ) + assert ed_numpy == pytest.approx(ed_jax, rel=1e-4), ( + f"numpy ({ed_numpy}) and jax ({ed_jax}) energy distances differ" + ) + + +def test_energy_distance_estimate_numpy_torch_jax_agree(): + """_energy_distance_estimate_{numpy,torch,jax} return close values for the same seed/data.""" + if torch is None: + pytest.skip("torch not installed") + if jax is None: + pytest.skip("jax not installed") + + np.random.seed(123) + # Use float32 so all backends operate at the same precision + x_np = np.random.normal(size=(200, 5)).astype(np.float32) + y_np = np.random.normal(size=(200, 5)).astype(np.float32) + + # Run with the same seed so the same chunks are sampled + np.random.seed(0) + ed_numpy = pted.utils._energy_distance_estimate_numpy(x_np, y_np, chunk_size=50, chunk_iter=5) + np.random.seed(0) + ed_torch = pted.utils._energy_distance_estimate_torch( + torch.tensor(x_np), torch.tensor(y_np), chunk_size=50, chunk_iter=5 + ) + np.random.seed(0) + ed_jax = pted.utils._energy_distance_estimate_jax( + jnp.array(x_np), jnp.array(y_np), chunk_size=50, chunk_iter=5 + ) + + assert ed_numpy == pytest.approx(ed_torch, rel=1e-4), ( + f"numpy ({ed_numpy}) and torch ({ed_torch}) energy distance estimates differ" + ) + assert ed_numpy == pytest.approx(ed_jax, rel=1e-4), ( + f"numpy ({ed_numpy}) and jax ({ed_jax}) energy distance estimates differ" + )