Skip to content

Add JAX operators#4

Merged
andycasey merged 9 commits intomainfrom
jax
Jun 3, 2025
Merged

Add JAX operators#4
andycasey merged 9 commits intomainfrom
jax

Conversation

@andycasey
Copy link
Copy Markdown
Owner

This PR introduces JAX operators thanks to @TomHilder.

  • Added optional installation version for the JAX operators uv add nifty-solve[jax]
  • Added unit tests for JAX operators
  • Updated README instructions

@andycasey andycasey requested a review from Copilot June 3, 2025 22:20
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR integrates JAX-based FINUFFT operators into the project, updates tests to cover them, and adjusts installation and CI settings.

  • Introduce nifty_solve/jax_operators.py with JAX operator implementations.
  • Add comprehensive unit tests for the new JAX operators in tests/test_jax_operators.py and update existing tests to use tolerance parameters.
  • Update installation instructions in README.md and modify the CI workflow to install JAX extras.

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tests/test_operators.py Added DOTTEST_KWDS and applied to dottest calls
tests/test_jax_operators.py New JAX operator tests covering 1D/2D/3D and edge cases
src/nifty_solve/jax_operators.py Wrapped JAX FINUFFT calls and adjusted vector operations
README.md Documented nifty-solve[jax] extra and install steps
.github/workflows/ci.yml Updated CI matrix and commented out the test-run step
Comments suppressed due to low confidence (4)

src/nifty_solve/jax_operators.py:8

  • The module calls warnings.warn in the import error block but never imports the warnings module. Please add import warnings at the top.
try: # pragma: no cover

.github/workflows/ci.yml:44

  • The Run tests step is commented out, so no tests actually run in CI. Please uncomment these lines to ensure tests are executed.
#- name: "Run tests"

tests/test_jax_operators.py:348

  • This test name is duplicated earlier at line 345, which will override the first definition. Rename or remove one of them to avoid silent test collisions.
test_3d_real_operator_dottest_N_even_gt_P_oee = partial(dottest_3d_real_operator, 27, (11, 10, 8))

tests/test_jax_operators.py:218

  • [nitpick] There are hundreds of nearly identical partial-based tests. Consider using pytest.mark.parametrize to reduce duplication and improve readability.
# 1D Operator

@andycasey andycasey merged commit aad64be into main Jun 3, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants