Skip to content

Add JAX array support to pted and pted_coverage_test#14

Merged
ConnorStoneAstro merged 7 commits intomainfrom
copilot/make-pted-compatible-with-jax
Mar 3, 2026
Merged

Add JAX array support to pted and pted_coverage_test#14
ConnorStoneAstro merged 7 commits intomainfrom
copilot/make-pted-compatible-with-jax

Conversation

Copy link
Contributor

Copilot AI commented Mar 2, 2026

pted and pted_coverage_test accepted NumPy and PyTorch inputs but not JAX arrays. JAX is now supported as an optional backend — no crash if JAX is not installed.

Changes

src/pted/utils.py

  • Optional jax/jax.numpy import with None fallback
  • is_jax_array() — detects jax.Array instances; returns False when JAX is absent
  • _jax_cdist() — memory-efficient pairwise distance: uses the squared-norm identity (||xi - yj||² = ||xi||² + ||yj||² - 2·xi·yj) for Euclidean (p=2.0) to avoid materializing the (nx, ny, d) diff tensor; uses jax.vmap over rows for general p-norms
  • _energy_distance_jax() / _energy_distance_estimate_jax() — JAX-native energy distance (full and chunked)
  • pted_jax() / pted_chunk_jax() — permutation test implementations mirroring the existing torch equivalents

src/pted/pted.py

  • Imports new JAX functions; adds is_jax_array routing branches between the torch and numpy paths in pted()
  • Updated pted() and pted_coverage_test() type annotations to Union[np.ndarray, "Tensor", "jax.Array"]
  • Updated metric parameter docstrings to document JAX-specific semantics: "euclidean" uses the squared-norm identity (p=2); float p uses jnp.linalg.norm(ord=p); non-euclidean string metrics are not supported for JAX

tests/test_pted.py

  • test_pted_jax, test_pted_chunk_jax, test_pted_coverage_jax — mirror existing torch tests; all skip gracefully when JAX is absent
  • test_is_jax_array_with_jax, test_is_jax_array_no_jax — unit tests for is_jax_array including mocked JAX-absent case
  • test_jax_cdist, test_jax_cdist_non_euclidean, test_energy_distance_jax, test_energy_distance_estimate_jax — direct unit tests for JAX utility functions covering both the L2 (matmul) and general p-norm (vmap) paths
  • test_pted_jax_no_jax, test_pted_chunk_jax_no_jax — mocked tests verifying correct error when JAX is not installed
  • test_pted_torch_no_torch, test_pted_chunk_torch_no_torch — mocked tests verifying correct error when torch is not installed
  • test_jax_cdist_matches_scipy — verifies _jax_cdist (L2 path) produces the same pairwise distances as scipy.spatial.distance.cdist
  • test_energy_distance_numpy_torch_jax_agree — verifies _energy_distance_{numpy,torch,jax} return the same value for identical inputs across all three backends
  • test_energy_distance_estimate_numpy_torch_jax_agree — same cross-backend consistency check for the chunked estimators using a fixed random seed

pyproject.toml

  • Added jax>=0.4,<1 to [project.optional-dependencies] dev so CI installs JAX and runs the JAX tests
  • Added jax = ["jax>=0.4,<1"] optional-dependency extra (parallel to the existing torch extra) so users can pip install pted[jax]

README.md

  • Updated pted() and pted_coverage_test() function signatures and parameter descriptions to include jax.Array alongside np.ndarray and Tensor
  • Expanded the GPU Compatibility section with a JAX example
  • Added pip install jax note in the Install section

Usage

import jax.numpy as jnp
from pted import pted, pted_coverage_test

x = jnp.array(np.random.normal(size=(100, 10)))
y = jnp.array(np.random.normal(size=(100, 10)))

p = pted(x, y)

g = jnp.array(np.random.normal(size=(100, 10)))
s = jnp.array(np.random.normal(size=(50, 100, 10)))
p = pted_coverage_test(g, s)

JAX can be installed as an optional extra:

pip install pted[jax]
Original prompt

This section details on the original issue you should resolve

<issue_title>make PTED compatible with JAX</issue_title>
<issue_description>I would like the PTED functions pted and pted_coverage_test to accept JAX arrays. Currently accepts numpy and pytorch inputs so JAX should be an option too. I don't want JAX to be added as a new required dependency, so it should not crash if the user doesn't have JAX in their environment.</issue_description>

Comments on the Issue (you are @copilot in this section)


🔒 GitHub Advanced Security automatically protects Copilot coding agent pull requests. You can protect all pull requests by enabling Advanced Security for your repositories. Learn more about Advanced Security.

@codecov-commenter
Copy link

codecov-commenter commented Mar 2, 2026

Codecov Report

❌ Patch coverage is 95.52239% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 97.87%. Comparing base (a5a45a8) to head (ed525f4).

Files with missing lines Patch % Lines
src/pted/utils.py 95.23% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #14      +/-   ##
==========================================
- Coverage   98.47%   97.87%   -0.61%     
==========================================
  Files           4        4              
  Lines         262      329      +67     
==========================================
+ Hits          258      322      +64     
- Misses          4        7       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
Copilot AI changed the title [WIP] Make PTED functions accept JAX arrays Add JAX array support to pted and pted_coverage_test Mar 2, 2026
@ConnorStoneAstro
Copy link
Owner

@copilot you will need to add jax as a dev dependency so that the jax tests will run in the CI.

Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Mar 2, 2026

@copilot you will need to add jax as a dev dependency so that the jax tests will run in the CI.

Added jax>=0.4,<1 to the dev optional-dependencies in pyproject.toml in commit 5e17977.

@ConnorStoneAstro
Copy link
Owner

@copilot this looks great! Could you update the README.md file to now also state that JAX arrays may be input (it already says Torch Tensors are allowed).

Further you will need to add unit tests to cover all cases in the newly added utils functions. This may require using pytest mocking to cover the case where jax is not installed. If you could also have the test check the case where torch is not installed that would be great!

Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Mar 2, 2026

@copilot this looks great! Could you update the README.md file to now also state that JAX arrays may be input (it already says Torch Tensors are allowed).

Further you will need t...

Done in commit 761cb4c:

README.md — updated both function signatures (pted and pted_coverage_test) to show jax.Array alongside np.ndarray and Tensor, updated the parameter descriptions accordingly, and expanded the GPU Compatibility section with a JAX example alongside the PyTorch one. Also added a pip install jax note in the Install section.

tests/test_pted.py — added 9 new unit tests covering all newly-added utils functions:

  • test_is_jax_array_with_jax — True for a real jax.Array, False for numpy/int
  • test_is_jax_array_no_jax — monkeypatches pted.utils.jax = None, verifies False
  • test_jax_cdist — checks pairwise distances for known inputs
  • test_energy_distance_jax — energy distance is ~0 for identical arrays
  • test_energy_distance_estimate_jax — chunked estimator returns a finite value
  • test_pted_jax_no_jax — monkeypatches pted.utils.jax = None, expects AssertionError
  • test_pted_chunk_jax_no_jax — same for pted_chunk_jax
  • test_pted_torch_no_torch — monkeypatches torch.__version__ = "null", expects AssertionError
  • test_pted_chunk_torch_no_torch — same for pted_chunk_torch

@ConnorStoneAstro ConnorStoneAstro marked this pull request as ready for review March 3, 2026 03:34
Copilot AI review requested due to automatic review settings March 3, 2026 03:34
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds optional JAX support to the PTED two-sample and coverage permutation tests so jax.Array inputs are accepted without making JAX a required dependency.

Changes:

  • Introduces JAX detection and JAX-native energy distance / permutation-test implementations in utilities.
  • Routes pted() execution to the JAX backend when inputs are JAX arrays.
  • Adds JAX-focused tests (skipping when JAX is unavailable) and includes JAX in dev dependencies; updates README to document JAX usage.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
src/pted/utils.py Adds optional JAX import, is_jax_array, JAX distance/energy distance helpers, and JAX PTED implementations.
src/pted/pted.py Routes JAX array inputs through the new JAX backend functions.
tests/test_pted.py Adds JAX integration/unit tests and mocked “backend missing” tests.
pyproject.toml Adds JAX to the dev optional dependency set.
README.md Documents JAX array support and provides installation/GPU usage notes.
Comments suppressed due to low confidence (4)

src/pted/utils.py:278

  • metric is typed as Union[str, float], but the JAX path only special-cases the string 'euclidean'; any other string (e.g. 'cityblock') will be passed to jnp.linalg.norm(ord=...) and fail with a confusing error. Please add explicit validation (reject non-'euclidean' strings with a clear message) or implement support for additional string metrics to match the NumPy behavior.
    if metric == "euclidean":
        metric = 2.0
    dmatrix = _jax_cdist(z, z, p=metric)

src/pted/pted.py:157

  • The public pted() signature/docstring still only mentions NumPy/Torch (Union[np.ndarray, "Tensor"] and the metric docs only describe SciPy/Torch), but the implementation now supports JAX arrays. Please update the function annotations and docstring parameter docs to include JAX (and clarify the JAX metric semantics) so help(pted.pted) matches the README and actual behavior.
        )
    elif is_torch_tensor(x):
        test, permute = pted_torch(x, y, permutations=permutations, metric=metric, prog_bar=prog_bar)
    elif is_jax_array(x) and chunk_size is not None:
        test, permute = pted_chunk_jax(
            x,
            y,
            permutations=permutations,
            metric=metric,
            chunk_size=int(chunk_size),
            chunk_iter=int(chunk_iter),
            prog_bar=prog_bar,
        )
    elif is_jax_array(x):
        test, permute = pted_jax(x, y, permutations=permutations, metric=metric, prog_bar=prog_bar)

tests/test_pted.py:189

  • test_pted_jax runs 20 trials using the default permutations=1000 with D=300, which is likely to make the test suite significantly slower (and JAX’s current _jax_cdist implementation is particularly memory/compute heavy). Consider reducing permutations, D, and/or the loop count for this test while keeping the same assertion intent.
def test_pted_jax():
    if jax is None:
        pytest.skip("jax not installed")

    # Set the random seed for reproducibility
    np.random.seed(42)

    # example 2 sample test
    D = 300
    for _ in range(20):
        x = jnp.array(np.random.normal(size=(100, D)))
        y = jnp.array(np.random.normal(size=(100, D)))
        p = pted.pted(x, y)
        assert p > 1e-2 and p < 0.99, f"p-value {p} is not in the expected range (U(0,1))"

    x = jnp.array(np.random.normal(size=(100, D)))
    y = jnp.array(np.random.uniform(size=(100, D)))
    p = pted.pted(x, y, two_tailed=False)
    assert p < 1e-2, f"p-value {p} is not in the expected range (~0)"

    x = jnp.array(np.random.normal(size=(100, D)))
    t, p, _ = pted.pted(x, x, return_all=True)
    q = 2 * min(np.sum(p > t), np.sum(p < t))
    p = (1 + q) / (len(p) + 1)
    assert p < 1e-2, f"p-value {p} is not in the expected range (~0)"

pyproject.toml:49

  • Since the project already exposes a torch extra under [project.optional-dependencies], consider adding a parallel jax = ["jax>=0.4,<1"] extra (in addition to dev) so users can do pip install pted[jax] for JAX support, matching the existing installation pattern.
[project.optional-dependencies]
dev = [
    "pytest>=8.0,<9",
    "pytest-cov>=4.1,<5",
    "pytest-mock>=3.12,<4",
    "torch>=2.0,<3",
    "jax>=0.4,<1",
    "matplotlib",
]
torch = [
    "torch>=2.0,<3",
]

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +130 to +133
def _jax_cdist(x, y, p: float = 2.0):
diff = x[:, None, :] - y[None, :, :]
return jnp.linalg.norm(diff, ord=p, axis=-1)

Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_jax_cdist materializes diff = x[:, None, :] - y[None, :, :] (shape (nx, ny, d)), which can be drastically larger than the output distance matrix and can OOM for moderate sizes. Consider computing pairwise distances without allocating the full (nx, ny, d) tensor (e.g., using the squared-norm identity / matmul approach or vmap over rows) so memory scales primarily with (nx, ny).

Copilot uses AI. Check for mistakes.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit be230a4. _jax_cdist now uses the squared-norm identity (||x_i - y_j||² = ||x_i||² + ||y_j||² - 2·x_i·y_j) for the default Euclidean (p=2.0) case, which avoids materializing the (nx, ny, d) diff tensor entirely — memory scales as O(nx·ny) instead. For general p-norms, it falls back to jax.vmap over rows, which avoids the full broadcast intermediate. A test for the non-Euclidean vmap path was also added.

@ConnorStoneAstro
Copy link
Owner

@copilot The public pted() signature/docstring still only mentions NumPy/Torch (Union[np.ndarray, "Tensor"] and the metric docs only describe SciPy/Torch), but the implementation now supports JAX arrays. Please update the function annotations and docstring parameter docs to include JAX (and clarify the JAX metric semantics) so help(pted.pted) matches the README and actual behavior.

also Since the project already exposes a torch extra under [project.optional-dependencies], consider adding a parallel jax = ["jax>=0.4,<1"] extra (in addition to dev) so users can do pip install pted[jax] for JAX support, matching the existing installation pattern.

Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
…xtra in pyproject.toml

Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Mar 3, 2026

@copilot The public pted() signature/docstring still only mentions NumPy/Torch (Union[np.ndarray, "Tensor"] and the metric docs only describe SciPy/Torch), but the implementation now supports JAX arra...

Done in commit b601c34:

  • pted() and pted_coverage_test() signatures updated to Union[np.ndarray, "Tensor", "jax.Array"]
  • metric parameter docs in both docstrings now describe all three backends: NumPy (scipy.spatial.distance.cdist), PyTorch (torch.cdist p-norm), and JAX ("euclidean" uses the squared-norm identity; float p uses jnp.linalg.norm(ord=p); non-euclidean strings are not supported)
  • Added jax = ["jax>=0.4,<1"] to [project.optional-dependencies] in pyproject.toml, parallel to the existing torch extra — users can now pip install pted[jax]

@ConnorStoneAstro
Copy link
Owner

@copilot can you add unit tests to check that the utils for the numpy torch and Jax versions all produce the same output where appropriate?

…ce utils

Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Mar 3, 2026

@copilot can you add unit tests to check that the utils for the numpy torch and Jax versions all produce the same output where appropriate?

Added in commit ed525f4 — three new tests at the end of tests/test_pted.py:

  • test_jax_cdist_matches_scipy — verifies _jax_cdist (L2 path) produces the same pairwise distances as scipy.spatial.distance.cdist for identical inputs
  • test_energy_distance_numpy_torch_jax_agree — verifies _energy_distance_{numpy,torch,jax} return the same value for identical float32 inputs (tolerance rel=1e-4 to account for scipy cdist accumulating in float64 vs float32 in torch/JAX)
  • test_energy_distance_estimate_numpy_torch_jax_agree — same for the chunked estimators, using a fixed np.random.seed before each call so all three backends draw the same chunks

All tests skip if torch or JAX is not installed.

@ConnorStoneAstro ConnorStoneAstro merged commit 728bd53 into main Mar 3, 2026
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

make PTED compatible with JAX

4 participants