Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ If you want to run PTED on GPUs using PyTorch, then also install torch:
pip install torch
```

If you want to use JAX arrays as inputs, then also install jax:

```bash
pip install jax
```

The two functions are ``pted.pted`` and ``pted.pted_coverage_test``. For
information about each argument, just use ``help(pted.pted)`` or
``help(pted.pted_coverage_test)``.
Expand Down Expand Up @@ -261,8 +267,8 @@ results you are getting!

```python
def pted(
x: Union[np.ndarray, "Tensor"],
y: Union[np.ndarray, "Tensor"],
x: Union[np.ndarray, "Tensor", "jax.Array"],
y: Union[np.ndarray, "Tensor", "jax.Array"],
permutations: int = 1000,
metric: Union[str, float] = "euclidean",
return_all: bool = False,
Expand All @@ -273,10 +279,10 @@ def pted(
) -> Union[float, tuple[float, np.ndarray, float]]:
```

* **x** *(Union[np.ndarray, Tensor])*: first set of samples. Shape (N, *D)
* **y** *(Union[np.ndarray, Tensor])*: second set of samples. Shape (M, *D)
* **x** *(Union[np.ndarray, Tensor, jax.Array])*: first set of samples. Shape (N, *D)
* **y** *(Union[np.ndarray, Tensor, jax.Array])*: second set of samples. Shape (M, *D)
* **permutations** *(int)*: number of permutations to run. This determines how accurately the p-value is computed.
* **metric** *(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf.
* **metric** *(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf. When using JAX arrays, the metric is passed as the "ord" for jnp.linalg.norm and therefore is also a float.
* **return_all** *(bool)*: if True, return the test statistic and the permuted statistics with the p-value. If False, just return the p-value. bool (default: False)
* **chunk_size** *(Optional[int])*: if not None, use chunked energy distance estimation. This is useful for large datasets. The chunk size is the number of samples to use for each chunk. If None, use the full dataset.
* **chunk_iter** *(Optional[int])*: The chunk iter is the number of iterations to use with the given chunk size.
Expand All @@ -287,8 +293,8 @@ def pted(

```python
def pted_coverage_test(
g: Union[np.ndarray, "Tensor"],
s: Union[np.ndarray, "Tensor"],
g: Union[np.ndarray, "Tensor", "jax.Array"],
s: Union[np.ndarray, "Tensor", "jax.Array"],
permutations: int = 1000,
metric: Union[str, float] = "euclidean",
warn_confidence: Optional[float] = 1e-3,
Expand All @@ -301,10 +307,10 @@ def pted_coverage_test(
) -> Union[float, tuple[np.ndarray, np.ndarray, float]]:
```

* **g** *(Union[np.ndarray, Tensor])*: Ground truth samples. Shape (n_sims, *D)
* **s** *(Union[np.ndarray, Tensor])*: Posterior samples. Shape (n_samples, n_sims, *D)
* **g** *(Union[np.ndarray, Tensor, jax.Array])*: Ground truth samples. Shape (n_sims, *D)
* **s** *(Union[np.ndarray, Tensor, jax.Array])*: Posterior samples. Shape (n_samples, n_sims, *D)
* **permutations** *(int)*: number of permutations to run. This determines how accurately the p-value is computed.
* **metric** *(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf.
* **metric** *(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf. When using JAX arrays, the metric is passed as the "ord" for jnp.linalg.norm and therefore is also a float.
* **return_all** *(bool)*: if True, return the test statistic and the permuted statistics with the p-value. If False, just return the p-value. bool (default: False)
* **chunk_size** *(Optional[int])*: if not None, use chunked energy distance estimation. This is useful for large datasets. The chunk size is the number of samples to use for each chunk. If None, use the full dataset.
* **chunk_iter** *(Optional[int])*: The chunk iter is the number of iterations to use with the given chunk size.
Expand All @@ -315,9 +321,9 @@ def pted_coverage_test(
## GPU Compatibility

PTED works on both CPU and GPU. All that is needed is to pass the `x` and `y` as
PyTorch Tensors on the appropriate device.
PyTorch Tensors or JAX Arrays on the appropriate device.

Example:
Example with PyTorch:
```python
from pted import pted
import numpy as np
Expand All @@ -330,6 +336,19 @@ p_value = pted(torch.tensor(x), torch.tensor(y))
print(f"p-value: {p_value:.3f}") # expect uniform random from 0-1
```

Example with JAX:
```python
from pted import pted
import numpy as np
import jax.numpy as jnp

x = np.random.normal(size = (500, 10)) # (n_samples_x, n_dimensions)
y = np.random.normal(size = (400, 10)) # (n_samples_y, n_dimensions)

p_value = pted(jnp.array(x), jnp.array(y))
print(f"p-value: {p_value:.3f}") # expect uniform random from 0-1
```

## Memory and Compute limitations

If a GPU isn't enough to get PTED running fast enough for you, or if you are
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ dev = [
"pytest-cov>=4.1,<5",
"pytest-mock>=3.12,<4",
"torch>=2.0,<3",
"jax>=0.4,<1",
"matplotlib",
]
torch = [
"torch>=2.0,<3",
]
jax = [
"jax>=0.4,<1",
]

[tool.hatch.metadata.hooks.requirements_txt]
files = ["requirements.txt"]
Expand Down
53 changes: 37 additions & 16 deletions src/pted/pted.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

from .utils import (
is_torch_tensor,
is_jax_array,
pted_torch,
pted_numpy,
pted_chunk_torch,
pted_chunk_numpy,
pted_jax,
pted_chunk_jax,
two_tailed_p,
confidence_alert,
simulation_based_calibration_histogram,
Expand All @@ -17,8 +20,8 @@


def pted(
x: Union[np.ndarray, "Tensor"],
y: Union[np.ndarray, "Tensor"],
x: Union[np.ndarray, "Tensor", "jax.Array"],
y: Union[np.ndarray, "Tensor", "jax.Array"],
permutations: int = 1000,
metric: Union[str, float] = "euclidean",
return_all: bool = False,
Expand Down Expand Up @@ -72,14 +75,17 @@ def pted(

Parameters
----------
x (Union[np.ndarray, Tensor]): first set of samples. Shape (N, *D)
y (Union[np.ndarray, Tensor]): second set of samples. Shape (M, *D)
x (Union[np.ndarray, Tensor, jax.Array]): first set of samples. Shape (N, *D)
y (Union[np.ndarray, Tensor, jax.Array]): second set of samples. Shape (M, *D)
permutations (int): number of permutations to run. This determines how
accurately the p-value is computed.
metric (Union[str, float]): distance metric to use. See scipy.spatial.distance.cdist
for the list of available metrics with numpy. See torch.cdist when
using PyTorch, note that the metric is passed as the "p" for
torch.cdist and therefore is a float from 0 to inf.
metric (Union[str, float]): distance metric to use. For NumPy inputs,
see scipy.spatial.distance.cdist for available metrics. For PyTorch
inputs, the metric is passed as the "p" argument to torch.cdist and
therefore is a float from 0 to inf. For JAX inputs, "euclidean" uses
the squared-norm identity (p=2), and any float p uses
jnp.linalg.norm with ord=p; string metrics other than "euclidean"
are not supported for JAX.
return_all (bool): if True, return the test statistic and the permuted
statistics with the p-value. If False, just return the p-value.
bool (default: False)
Expand Down Expand Up @@ -140,6 +146,18 @@ def pted(
)
elif is_torch_tensor(x):
test, permute = pted_torch(x, y, permutations=permutations, metric=metric, prog_bar=prog_bar)
elif is_jax_array(x) and chunk_size is not None:
test, permute = pted_chunk_jax(
x,
y,
permutations=permutations,
metric=metric,
chunk_size=int(chunk_size),
chunk_iter=int(chunk_iter),
prog_bar=prog_bar,
)
elif is_jax_array(x):
test, permute = pted_jax(x, y, permutations=permutations, metric=metric, prog_bar=prog_bar)
elif chunk_size is not None:
test, permute = pted_chunk_numpy(
x,
Expand Down Expand Up @@ -170,8 +188,8 @@ def pted(


def pted_coverage_test(
g: Union[np.ndarray, "Tensor"],
s: Union[np.ndarray, "Tensor"],
g: Union[np.ndarray, "Tensor", "jax.Array"],
s: Union[np.ndarray, "Tensor", "jax.Array"],
permutations: int = 1000,
metric: Union[str, float] = "euclidean",
warn_confidence: Optional[float] = 1e-3,
Expand Down Expand Up @@ -228,14 +246,17 @@ def pted_coverage_test(

Parameters
----------
g (Union[np.ndarray, Tensor]): Ground truth samples. Shape (n_sims, *D)
s (Union[np.ndarray, Tensor]): Posterior samples. Shape (n_samples, n_sims, *D)
g (Union[np.ndarray, Tensor, jax.Array]): Ground truth samples. Shape (n_sims, *D)
s (Union[np.ndarray, Tensor, jax.Array]): Posterior samples. Shape (n_samples, n_sims, *D)
permutations (int): number of permutations to run. This determines how
accurately the p-value is computed.
metric (Union[str, float]): distance metric to use. See scipy.spatial.distance.cdist
for the list of available metrics with numpy. See torch.cdist when using
PyTorch, note that the metric is passed as the "p" for torch.cdist and
therefore is a float from 0 to inf.
metric (Union[str, float]): distance metric to use. For NumPy inputs,
see scipy.spatial.distance.cdist for available metrics. For PyTorch
inputs, the metric is passed as the "p" argument to torch.cdist and
therefore is a float from 0 to inf. For JAX inputs, "euclidean" uses
the squared-norm identity (p=2), and any float p uses
jnp.linalg.norm with ord=p; string metrics other than "euclidean"
are not supported for JAX.
return_all (bool): if True, return the test statistic and the permuted
statistics with the p-value. If False, just return the p-value. bool
(default: False)
Expand Down
113 changes: 113 additions & 0 deletions src/pted/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,23 @@ class torch:
Tensor = np.ndarray


try:
import jax
import jax.numpy as jnp
except ImportError:
jax = None
jnp = None


__all__ = (
"is_torch_tensor",
"is_jax_array",
"pted_numpy",
"pted_chunk_numpy",
"pted_torch",
"pted_chunk_torch",
"pted_jax",
"pted_chunk_jax",
"two_tailed_p",
"confidence_alert",
"simulation_based_calibration_histogram",
Expand All @@ -39,6 +50,12 @@ def is_torch_tensor(o):
)


def is_jax_array(o):
if jax is None:
return False
return isinstance(o, jax.Array)


def _energy_distance_precompute(
D: Union[np.ndarray, torch.Tensor], nx: int, ny: int
) -> Union[float, torch.Tensor]:
Expand Down Expand Up @@ -110,6 +127,49 @@ def _energy_distance_estimate_torch(
return np.mean(E_est)


def _jax_cdist(x, y, p: float = 2.0):
if p == 2.0:
# Squared-norm identity avoids materializing the (nx, ny, d) diff tensor.
# ||x_i - y_j||^2 = ||x_i||^2 + ||y_j||^2 - 2 * x_i . y_j
x_sq = jnp.sum(x ** 2, axis=-1) # (nx,)
y_sq = jnp.sum(y ** 2, axis=-1) # (ny,)
sq_dist = x_sq[:, None] + y_sq[None, :] - 2.0 * (x @ y.T)
return jnp.sqrt(jnp.maximum(sq_dist, 0.0))
# For general p-norms use vmap to avoid the (nx, ny, d) intermediate.
return jax.vmap(lambda xi: jnp.linalg.norm(xi - y, ord=p, axis=-1))(x)

Comment on lines +130 to +140
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_jax_cdist materializes diff = x[:, None, :] - y[None, :, :] (shape (nx, ny, d)), which can be drastically larger than the output distance matrix and can OOM for moderate sizes. Consider computing pairwise distances without allocating the full (nx, ny, d) tensor (e.g., using the squared-norm identity / matmul approach or vmap over rows) so memory scales primarily with (nx, ny).

Copilot uses AI. Check for mistakes.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit be230a4. _jax_cdist now uses the squared-norm identity (||x_i - y_j||² = ||x_i||² + ||y_j||² - 2·x_i·y_j) for the default Euclidean (p=2.0) case, which avoids materializing the (nx, ny, d) diff tensor entirely — memory scales as O(nx·ny) instead. For general p-norms, it falls back to jax.vmap over rows, which avoids the full broadcast intermediate. A test for the non-Euclidean vmap path was also added.


def _energy_distance_jax(x, y, metric: Union[str, float] = "euclidean") -> float:
nx = len(x)
ny = len(y)
z = jnp.concatenate([x, y], axis=0)
if metric == "euclidean":
metric = 2.0
D = _jax_cdist(z, z, p=metric)
return float(_energy_distance_precompute(D, nx, ny))


def _energy_distance_estimate_jax(
x,
y,
chunk_size: int,
chunk_iter: int,
metric: Union[str, float] = "euclidean",
) -> float:

E_est = []
for _ in range(chunk_iter):
# Randomly sample a chunk of data
idx = np.random.choice(len(x), size=min(len(x), chunk_size), replace=False)
x_chunk = x[idx]
idy = np.random.choice(len(y), size=min(len(y), chunk_size), replace=False)
y_chunk = y[idy]

# Compute the energy distance
E_est.append(_energy_distance_jax(x_chunk, y_chunk, metric=metric))
return np.mean(E_est)


def pted_chunk_numpy(
x: np.ndarray,
y: np.ndarray,
Expand Down Expand Up @@ -210,6 +270,59 @@ def pted_torch(
return test_stat, permute_stats


def pted_jax(
x,
y,
permutations: int = 100,
metric: Union[str, float] = "euclidean",
prog_bar: bool = False,
) -> tuple[float, list[float]]:
assert jax is not None, "JAX is not installed! try: `pip install jax`"
z = jnp.concatenate([x, y], axis=0)
assert jnp.all(jnp.isfinite(z)), "Input contains NaN or Inf!"
if metric == "euclidean":
metric = 2.0
dmatrix = _jax_cdist(z, z, p=metric)
assert jnp.all(
jnp.isfinite(dmatrix)
), "Distance matrix contains NaN or Inf! Consider using a different metric or normalizing values to be more stable (i.e. z-score norm)."
nx = len(x)
ny = len(y)

test_stat = float(_energy_distance_precompute(dmatrix, nx, ny))
permute_stats = []
for _ in trange(permutations, disable=not prog_bar):
I = np.random.permutation(len(z))
dmatrix = dmatrix[I][:, I]
permute_stats.append(float(_energy_distance_precompute(dmatrix, nx, ny)))
return test_stat, permute_stats


def pted_chunk_jax(
x,
y,
permutations: int = 100,
metric: Union[str, float] = "euclidean",
chunk_size: int = 100,
chunk_iter: int = 10,
prog_bar: bool = False,
) -> tuple[float, list[float]]:
assert jax is not None, "JAX is not installed! try: `pip install jax`"
assert jnp.all(jnp.isfinite(x)) and jnp.all(jnp.isfinite(y)), "Input contains NaN or Inf!"
nx = len(x)

test_stat = _energy_distance_estimate_jax(x, y, chunk_size, chunk_iter, metric=metric)
permute_stats = []
for _ in trange(permutations, disable=not prog_bar):
z = jnp.concatenate([x, y], axis=0)
z = z[np.random.permutation(len(z))]
x, y = z[:nx], z[nx:]
permute_stats.append(
_energy_distance_estimate_jax(x, y, chunk_size, chunk_iter, metric=metric)
)
return test_stat, permute_stats


def two_tailed_p(chi2, df):
assert df > 2, "Degrees of freedom must be greater than 2 for two-tailed p-value calculation."
alpha = chi2_dist.pdf(chi2, df)
Expand Down
Loading
Loading