Add JAX array support to pted and pted_coverage_test#14
Add JAX array support to pted and pted_coverage_test#14ConnorStoneAstro merged 7 commits intomainfrom
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #14 +/- ##
==========================================
- Coverage 98.47% 97.87% -0.61%
==========================================
Files 4 4
Lines 262 329 +67
==========================================
+ Hits 258 322 +64
- Misses 4 7 +3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
|
@copilot you will need to add jax as a dev dependency so that the jax tests will run in the CI. |
Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
|
@copilot this looks great! Could you update the README.md file to now also state that JAX arrays may be input (it already says Torch Tensors are allowed). Further you will need to add unit tests to cover all cases in the newly added utils functions. This may require using pytest mocking to cover the case where jax is not installed. If you could also have the test check the case where torch is not installed that would be great! |
Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
Done in commit 761cb4c: README.md — updated both function signatures ( tests/test_pted.py — added 9 new unit tests covering all newly-added utils functions:
|
There was a problem hiding this comment.
Pull request overview
Adds optional JAX support to the PTED two-sample and coverage permutation tests so jax.Array inputs are accepted without making JAX a required dependency.
Changes:
- Introduces JAX detection and JAX-native energy distance / permutation-test implementations in utilities.
- Routes
pted()execution to the JAX backend when inputs are JAX arrays. - Adds JAX-focused tests (skipping when JAX is unavailable) and includes JAX in dev dependencies; updates README to document JAX usage.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
src/pted/utils.py |
Adds optional JAX import, is_jax_array, JAX distance/energy distance helpers, and JAX PTED implementations. |
src/pted/pted.py |
Routes JAX array inputs through the new JAX backend functions. |
tests/test_pted.py |
Adds JAX integration/unit tests and mocked “backend missing” tests. |
pyproject.toml |
Adds JAX to the dev optional dependency set. |
README.md |
Documents JAX array support and provides installation/GPU usage notes. |
Comments suppressed due to low confidence (4)
src/pted/utils.py:278
metricis typed asUnion[str, float], but the JAX path only special-cases the string'euclidean'; any other string (e.g.'cityblock') will be passed tojnp.linalg.norm(ord=...)and fail with a confusing error. Please add explicit validation (reject non-'euclidean'strings with a clear message) or implement support for additional string metrics to match the NumPy behavior.
if metric == "euclidean":
metric = 2.0
dmatrix = _jax_cdist(z, z, p=metric)
src/pted/pted.py:157
- The public
pted()signature/docstring still only mentions NumPy/Torch (Union[np.ndarray, "Tensor"]and the metric docs only describe SciPy/Torch), but the implementation now supports JAX arrays. Please update the function annotations and docstring parameter docs to include JAX (and clarify the JAX metric semantics) sohelp(pted.pted)matches the README and actual behavior.
)
elif is_torch_tensor(x):
test, permute = pted_torch(x, y, permutations=permutations, metric=metric, prog_bar=prog_bar)
elif is_jax_array(x) and chunk_size is not None:
test, permute = pted_chunk_jax(
x,
y,
permutations=permutations,
metric=metric,
chunk_size=int(chunk_size),
chunk_iter=int(chunk_iter),
prog_bar=prog_bar,
)
elif is_jax_array(x):
test, permute = pted_jax(x, y, permutations=permutations, metric=metric, prog_bar=prog_bar)
tests/test_pted.py:189
test_pted_jaxruns 20 trials using the defaultpermutations=1000withD=300, which is likely to make the test suite significantly slower (and JAX’s current_jax_cdistimplementation is particularly memory/compute heavy). Consider reducingpermutations,D, and/or the loop count for this test while keeping the same assertion intent.
def test_pted_jax():
if jax is None:
pytest.skip("jax not installed")
# Set the random seed for reproducibility
np.random.seed(42)
# example 2 sample test
D = 300
for _ in range(20):
x = jnp.array(np.random.normal(size=(100, D)))
y = jnp.array(np.random.normal(size=(100, D)))
p = pted.pted(x, y)
assert p > 1e-2 and p < 0.99, f"p-value {p} is not in the expected range (U(0,1))"
x = jnp.array(np.random.normal(size=(100, D)))
y = jnp.array(np.random.uniform(size=(100, D)))
p = pted.pted(x, y, two_tailed=False)
assert p < 1e-2, f"p-value {p} is not in the expected range (~0)"
x = jnp.array(np.random.normal(size=(100, D)))
t, p, _ = pted.pted(x, x, return_all=True)
q = 2 * min(np.sum(p > t), np.sum(p < t))
p = (1 + q) / (len(p) + 1)
assert p < 1e-2, f"p-value {p} is not in the expected range (~0)"
pyproject.toml:49
- Since the project already exposes a
torchextra under[project.optional-dependencies], consider adding a paralleljax = ["jax>=0.4,<1"]extra (in addition todev) so users can dopip install pted[jax]for JAX support, matching the existing installation pattern.
[project.optional-dependencies]
dev = [
"pytest>=8.0,<9",
"pytest-cov>=4.1,<5",
"pytest-mock>=3.12,<4",
"torch>=2.0,<3",
"jax>=0.4,<1",
"matplotlib",
]
torch = [
"torch>=2.0,<3",
]
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _jax_cdist(x, y, p: float = 2.0): | ||
| diff = x[:, None, :] - y[None, :, :] | ||
| return jnp.linalg.norm(diff, ord=p, axis=-1) | ||
|
|
There was a problem hiding this comment.
_jax_cdist materializes diff = x[:, None, :] - y[None, :, :] (shape (nx, ny, d)), which can be drastically larger than the output distance matrix and can OOM for moderate sizes. Consider computing pairwise distances without allocating the full (nx, ny, d) tensor (e.g., using the squared-norm identity / matmul approach or vmap over rows) so memory scales primarily with (nx, ny).
There was a problem hiding this comment.
Fixed in commit be230a4. _jax_cdist now uses the squared-norm identity (||x_i - y_j||² = ||x_i||² + ||y_j||² - 2·x_i·y_j) for the default Euclidean (p=2.0) case, which avoids materializing the (nx, ny, d) diff tensor entirely — memory scales as O(nx·ny) instead. For general p-norms, it falls back to jax.vmap over rows, which avoids the full broadcast intermediate. A test for the non-Euclidean vmap path was also added.
|
@copilot The public pted() signature/docstring still only mentions NumPy/Torch (Union[np.ndarray, "Tensor"] and the metric docs only describe SciPy/Torch), but the implementation now supports JAX arrays. Please update the function annotations and docstring parameter docs to include JAX (and clarify the JAX metric semantics) so help(pted.pted) matches the README and actual behavior. also Since the project already exposes a torch extra under [project.optional-dependencies], consider adding a parallel jax = ["jax>=0.4,<1"] extra (in addition to dev) so users can do pip install pted[jax] for JAX support, matching the existing installation pattern. |
Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
…xtra in pyproject.toml Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
Done in commit b601c34:
|
|
@copilot can you add unit tests to check that the utils for the numpy torch and Jax versions all produce the same output where appropriate? |
…ce utils Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
Added in commit ed525f4 — three new tests at the end of
All tests skip if torch or JAX is not installed. |
ptedandpted_coverage_testaccepted NumPy and PyTorch inputs but not JAX arrays. JAX is now supported as an optional backend — no crash if JAX is not installed.Changes
src/pted/utils.pyjax/jax.numpyimport withNonefallbackis_jax_array()— detectsjax.Arrayinstances; returnsFalsewhen JAX is absent_jax_cdist()— memory-efficient pairwise distance: uses the squared-norm identity (||xi - yj||² = ||xi||² + ||yj||² - 2·xi·yj) for Euclidean (p=2.0) to avoid materializing the(nx, ny, d)diff tensor; usesjax.vmapover rows for general p-norms_energy_distance_jax()/_energy_distance_estimate_jax()— JAX-native energy distance (full and chunked)pted_jax()/pted_chunk_jax()— permutation test implementations mirroring the existing torch equivalentssrc/pted/pted.pyis_jax_arrayrouting branches between the torch and numpy paths inpted()pted()andpted_coverage_test()type annotations toUnion[np.ndarray, "Tensor", "jax.Array"]metricparameter docstrings to document JAX-specific semantics:"euclidean"uses the squared-norm identity (p=2); floatpusesjnp.linalg.norm(ord=p); non-euclidean string metrics are not supported for JAXtests/test_pted.pytest_pted_jax,test_pted_chunk_jax,test_pted_coverage_jax— mirror existing torch tests; all skip gracefully when JAX is absenttest_is_jax_array_with_jax,test_is_jax_array_no_jax— unit tests foris_jax_arrayincluding mocked JAX-absent casetest_jax_cdist,test_jax_cdist_non_euclidean,test_energy_distance_jax,test_energy_distance_estimate_jax— direct unit tests for JAX utility functions covering both the L2 (matmul) and general p-norm (vmap) pathstest_pted_jax_no_jax,test_pted_chunk_jax_no_jax— mocked tests verifying correct error when JAX is not installedtest_pted_torch_no_torch,test_pted_chunk_torch_no_torch— mocked tests verifying correct error when torch is not installedtest_jax_cdist_matches_scipy— verifies_jax_cdist(L2 path) produces the same pairwise distances asscipy.spatial.distance.cdisttest_energy_distance_numpy_torch_jax_agree— verifies_energy_distance_{numpy,torch,jax}return the same value for identical inputs across all three backendstest_energy_distance_estimate_numpy_torch_jax_agree— same cross-backend consistency check for the chunked estimators using a fixed random seedpyproject.tomljax>=0.4,<1to[project.optional-dependencies] devso CI installs JAX and runs the JAX testsjax = ["jax>=0.4,<1"]optional-dependency extra (parallel to the existingtorchextra) so users canpip install pted[jax]README.mdpted()andpted_coverage_test()function signatures and parameter descriptions to includejax.Arrayalongsidenp.ndarrayandTensorpip install jaxnote in the Install sectionUsage
JAX can be installed as an optional extra:
Original prompt
🔒 GitHub Advanced Security automatically protects Copilot coding agent pull requests. You can protect all pull requests by enabling Advanced Security for your repositories. Learn more about Advanced Security.