Skip to content

Add optional JAX backend for GPU-accelerated CG and MINRES solvers#19

Open
Copilot wants to merge 9 commits into
mainfrom
copilot/add-jax-backend-support
Open

Add optional JAX backend for GPU-accelerated CG and MINRES solvers#19
Copilot wants to merge 9 commits into
mainfrom
copilot/add-jax-backend-support

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented May 1, 2026

Adds backend='jax' to Problem, routing both solve_cg and solve_minres through JAX-backed matvecs for use on Apple Silicon (M-series via jax-metal) or any JAX accelerator. The active-set outer loop stays on CPU; only the inner Krylov solve moves to JAX.

pip install fast-minimum-variance[jax]
pip install jax-metal  # Apple Silicon only

p = Problem(R, backend='jax')
w, iters = p.solve_cg()      # matvecs on Metal GPU
w, iters = p.solve_minres()  # also JAX-accelerated

API changes — api.py

  • backend: str = 'numpy' field on Problem; __post_init__ raises ValueError for unknown values
  • _jax_arrays() — private method centralising the lazy import jax.numpy and float32 conversion (jax-metal has limited float64 support); shared by both JAX solve paths
  • _minres_jax(matvec, b) — JAX port of SciPy's MINRES (Paige & Saunders 1975) implemented with jax.lax.while_loop; the entire Lanczos recurrence — scalars, vectors, and convergence tests — runs inside a single compiled loop with no host synchronisations per iteration. jax.scipy.sparse.linalg does not include MINRES as of JAX 0.10, so this port is necessary.
  • _solve_cg_jax — uses jax.scipy.sparse.linalg.cg (which internally uses lax.while_loop) so the full CG solve stays on-device. n_iters returns 0 for the JAX CG path since JAX's CG implementation does not expose iteration count.
  • _solve_minres_jax — private method implementing the KKT matvec with JAX arrays; results returned as NumPy so _constraint_active_set is unaffected
  • solve_cg and solve_minres dispatch to the JAX paths when backend='jax'; NumPy paths are untouched

On-device execution

Both Krylov solvers use lax.while_loop-based implementations so all computation stays on-device — no Python-level host synchronisations during the solve. The only device→host transfer is extracting the MINRES iteration count at return. This is essential for Metal/GPU performance: Python float() calls on JAX arrays block until the GPU drains its command queue, which would dominate runtime at any reasonable iteration count.

On a CPU JAX backend, XLA dispatch overhead makes the JAX path slower than NumPy (as shown by the benchmark). The speedup appears on Metal/CUDA where the two GEMVs per step (X.T @ (X @ x)) run on the GPU without interruption.

New files

  • pyproject.toml[jax] optional extra (jax>=0.4); jax-metal is platform-specific so users install it separately
  • tests/test_jax.pypytest.importorskip('jax') at module level; tests budget constraint, non-negativity, NumPy/JAX agreement (atol=1e-4, float32 tolerance), shape, dtype, and iteration count for both solvers; backend validation parametrized over multiple invalid strings
  • benchmarks/jax_backend.py — times solve_cg and solve_minres across six (T, N) sizes from (250, 20) to (2000, 1000), with a separate warmup_jax column for the one-off XLA compilation cost and an err column for float32 accuracy; prints a clear warning when JAX is running on CPU (not Metal/CUDA); degrades gracefully without JAX installed
Original prompt

Summary

Add an optional backend='jax' path to the Problem class so users on Apple Silicon (M4 via jax-metal) or other JAX-supported accelerators can run the Krylov solvers (solve_minres, solve_cg) with GPU-accelerated matrix-vector products.

The core insight is that the entire computational cost of both MINRES and CG lives in two matvecs per iteration:

X.T @ (X @ x)

On JAX with Metal this maps directly to accelerated GEMVs on the M-series GPU. The active-set outer loop remains on CPU (it's cheap scalar comparisons), and only the inner Krylov solve moves to JAX.


Design

pyproject.toml

Add a new optional extra:

[project.optional-dependencies]
jax = ["jax>=0.4"]

Note: users on Apple Silicon install jax-metal themselves (it's a platform-specific package not installable cross-platform), so the extra only declares jax as the base dependency.

src/fast_minimum_variance/api.py

  1. Add a backend: str field to Problem with default 'numpy'. Accepted values: 'numpy', 'jax'.

  2. Add a private method _jax_cg_operator(active=None) that:

    • Converts self.X, self.A, self.C, self.b, self.d, self.mu to jnp.array (in float32 since jax-metal has limited float64 support)
    • Builds a JAX-compatible matvec callable (not a scipy.sparse.linalg.LinearOperator — JAX's cg accepts a plain callable)
    • Uses jax.scipy.sparse.linalg.cg for the CG solve in the null-space-reduced system
    • Returns results as NumPy arrays (via np.asarray(...)) so the rest of the active-set loop is unaffected
  3. Modify solve_cg to dispatch to the JAX path when self.backend == 'jax':

    def solve_cg(self, *, project: bool = True):
        if self.backend == 'jax':
            return self._solve_cg_jax(project=project)
        # existing numpy path unchanged
        ...
  4. Add _solve_cg_jax that:

    • Implements the same null-space reduction as the NumPy path: w = w0 + P @ v where P is an orthonormal null-space basis for A_ext^T
    • Uses jax.scipy.sparse.linalg.cg on the reduced SPD operator
    • Runs the active-set loop (same _constraint_active_set infrastructure) but with a JAX-backed solve function
    • Applies clip_and_renormalize if project=True
    • Handles the import error for missing JAX gracefully with a clear message
  5. Add a __post_init__ validation that raises ValueError for unknown backend values.

tests/test_jax.py

Add a new test file tests/test_jax.py with:

  • A module-level pytest.importorskip('jax') so the whole file is skipped when JAX is not installed (CI without JAX extra will not fail)
  • Tests that mirror the existing test_krylov.py structure:
    • test_jax_cg_weights_sum_to_one — budget constraint satisfied
    • test_jax_cg_weights_non_negative — long-only constraint satisfied
    • test_jax_cg_agrees_with_numpy_cg — weights from backend='jax' and backend='numpy' agree to within 1e-4 (float32 tolerance)
    • test_jax_backend_invalid_raisesProblem(X, backend='gpu') raises ValueError
    • test_jax_backend_unknown_raises — unknown backend string raises ValueError

README update

Add a short section after the existing solver table:

## JAX / Metal Backend (Apple Silicon)

To run the CG solver on the Apple M-series GPU via Metal:

```bash
pip install fast-minimum-variance[jax]
pip install jax-metal          # Apple Silicon only
p = Problem(R, backend='jax')
w, iters = p.solve_cg()        # matvecs run on Metal GPU

The JAX backend operates in float32. For most portfolio problems the solution quality is indistinguishable from float64, but verify residuals for ill-conditioned covariance structures.


---

## Important constraints

- The **NumPy path must be completely unchanged** — no regressions, no new imports at module level
- JAX must be a **lazy import** (imported only inside the JAX-specific methods), so the package works without JAX installed
- The `solve_minres` JAX path is **out of scope** for this PR — `jax.scipy.sparse.linalg` does not have `minres` as of JAX 0.4; only `cg` and `gmres` are available. Do not add a broken stub.
- Do not add `jax` to the core `dependencies` in `pyproject.toml` — it must remain optional
- Follow the existing code style: `N806` noqa comments for uppercase variable names, docstrings on all public methods, doctests in the style of the existing solvers


The following is the prior conversation context from the user's chat exploration (may be truncated):

User: Which problems can you now solve 

Assistant: Here's what I can help you with for **[Jebel-Quant/fast_minimum_variance](https://github.com/Jebel-Quant/fast_minimum_variance)** — a Python library for fast computation of minimum variance portfolios (created ~3 days ago):

---

### 🔍 Code Understanding & Exploration
- Explain how the minimum variance portfolio algorithm is implemented in `src/`
- Trace data flow from inputs...

</details>



<!-- START COPILOT CODING AGENT SUFFIX -->

*This pull request was created from Copilot chat.*
>

@tschm
Copy link
Copy Markdown
Member

tschm commented May 1, 2026

@copilot: jax has minres

@tschm tschm marked this pull request as ready for review May 1, 2026 10:54
Copilot AI review requested due to automatic review settings May 1, 2026 10:54
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an optional JAX backend to the Problem API so solve_cg() can run its matvec-heavy inner loop on JAX accelerators (e.g., Apple Silicon via jax-metal) while keeping JAX as an optional dependency.

Changes:

  • Add backend: str = "numpy" to Problem with __post_init__ validation.
  • Add a JAX-dispatched CG implementation (_solve_cg_jax) and route solve_cg() to it when backend="jax".
  • Add a jax optional extra, JAX-specific tests (skipped when JAX isn’t installed), and README documentation for the JAX/Metal setup.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
src/fast_minimum_variance/api.py Adds backend selection/validation and a JAX-based CG solve path.
tests/test_jax.py Adds JAX backend tests guarded by pytest.importorskip("jax").
pyproject.toml Introduces fast-minimum-variance[jax] optional extra.
README.md Documents how to use the JAX/Metal backend and notes float32 behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 423 to 426
n_iters is the total number of MINRES iterations across all active-set
steps.

Examples:
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the JAX CG path the solver step returns step_iters as a constant 1, so the iters returned by solve_cg() no longer represents the total number of CG iterations (as documented and as the NumPy backend does). Please return an actual iteration count (e.g., implement a small CG loop that tracks iterations in JAX, or otherwise adjust the API/docstring so the meaning of n_iters remains consistent across backends).

Suggested change
def _cg_with_iters(matvec, rhs, *, tol=1e-5, maxiter=None):
"""Conjugate-gradient solve that also returns the iteration count."""
n = int(rhs.shape[0])
if maxiter is None:
maxiter = max(1, 10 * n)
x = jnp.zeros_like(rhs)
r = rhs - matvec(x)
p = r
rs_old = jnp.vdot(r, r)
rhs_norm = float(jnp.linalg.norm(rhs))
threshold = tol * rhs_norm
if float(jnp.sqrt(rs_old)) <= threshold:
return x, 0
iters = 0
for _ in range(maxiter):
ap = matvec(p)
alpha = rs_old / jnp.vdot(p, ap)
x = x + alpha * p
r = r - alpha * ap
rs_new = jnp.vdot(r, r)
iters += 1
if float(jnp.sqrt(rs_new)) <= threshold:
break
beta = rs_new / rs_old
p = r + beta * p
rs_old = rs_new
return x, iters
sol_jax, step_iters = _cg_with_iters(_matvec, rhs_jax)
w_jax = w0_jax + P_jax @ sol_jax
# Return as NumPy so the active-set loop stays backend-agnostic.
return np.asarray(w_jax, dtype=np.float64), step_iters

Copilot uses AI. Check for mistakes.

Returns:
Tuple (w, n_iters) where w is the weight vector of shape (N,) and
n_iters is the total number of MINRES iterations across all active-set
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.scipy.sparse.linalg.cg returns a convergence status/info value as its second return. This path currently ignores it, so non-convergence could silently produce an invalid w. Please check the returned status and raise a clear error (or retry with a higher maxiter/different tolerances) when CG does not converge.

Suggested change
sol_jax, cg_info = jax_cg(_matvec, rhs_jax)
if cg_info is not None:
cg_info = np.asarray(cg_info).item()
if cg_info != 0:
raise RuntimeError(
"JAX CG failed to converge in Problem._solve_cg_jax "
f"for the current active set (info={cg_info}). "
"Consider increasing maxiter or adjusting tolerances."
)

Copilot uses AI. Check for mistakes.
Comment on lines 383 to 401
return w, iters

def solve_minres(self, *, project: bool = True):
"""Solve via MINRES on the KKT saddle-point system with active-set method.

Iteratively promotes violated inequality constraints to equalities. At each
outer iteration the KKT saddle-point system for all assets with the currently
active constraints pinned as equalities

[ 2(X^T X + gamma I) A_ext ] [ w ] [ rho * mu ]
[ A_ext^T 0 ] [ λ ] = [ b_ext ]

is solved matrix-free via MINRES, where ``A_ext = [A, C[:, active]]`` and
``b_ext = [b, d[active]]``. No explicit matrix is ever formed.

With the defaults (``A = ones``, ``b = [1]``, ``C = -I``, ``d = 0``) this
recovers the long-only minimum variance solver of the companion paper.

To apply Ledoit-Wolf shrinkage, pre-scale the return matrix and set gamma::
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aa_eq, bb_eq, cc, and dd are converted to JAX arrays but then immediately converted back to NumPy inside _solve (np.asarray(...)). This causes device→host transfers and extra allocations on every active-set step. Consider keeping constraints as NumPy arrays (use self.A/self.b/self.C/self.d) and only keep the heavy matvec inputs (X, mu, P, w0, etc.) on JAX.

Copilot uses AI. Check for mistakes.
Comment thread tests/test_jax.py
import numpy as np
import pytest

jax = pytest.importorskip("jax")
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax = pytest.importorskip("jax") assigns to a name that is never used, which will fail Ruff's unused-variable check (F841) under the repo's pre-commit/CI linting. Use pytest.importorskip("jax") without assignment or assign to _ instead.

Suggested change
jax = pytest.importorskip("jax")
pytest.importorskip("jax")

Copilot uses AI. Check for mistakes.
Copilot AI changed the title [WIP] Add optional backend support for jax in Problem class Add optional JAX backend for GPU-accelerated CG and MINRES solvers May 1, 2026
Copilot finished work on behalf of tschm May 1, 2026 11:06
Copilot AI requested a review from tschm May 1, 2026 11:06
…cg, clearer benchmark CPU warning

Agent-Logs-Url: https://github.com/Jebel-Quant/fast_minimum_variance/sessions/52397cd8-e515-4bd9-a149-a420095bf286

Co-authored-by: tschm <2046079+tschm@users.noreply.github.com>
Copy link
Copy Markdown
Contributor Author

Copilot AI commented May 1, 2026

@copilot: jax has minres

Checked against JAX 0.10.0 (latest release) and the current main branch: jax.scipy.sparse.linalg exports only cg, gmres, and bicgstab — no minres. So a custom port is necessary.

That said, the old port was wrong in a subtler way: it kept scalar Lanczos state as Python floats, meaning every float(jnp.dot(...)) call forced a device→host synchronisation per iteration. On your M4 benchmark (commit dc6560f) that showed up as 0.07–0.24× speedup — JAX was slower than NumPy even on CPU.

Fixed in dc6560f:

  • MINRES: rewritten with jax.lax.while_loop — the entire recurrence (scalars, vectors, convergence tests) runs inside a single compiled loop on-device with only one host sync at the end to extract the iteration count
  • CG: switched to jax.scipy.sparse.linalg.cg which already uses lax.while_loop internally

On CPU JAX the overhead from XLA compilation makes it slower than NumPy, which is expected. The speedup comes on Metal/CUDA where the two GEMVs per step (X.T @ (X @ x)) run on the GPU without Python interrupting the loop at each iteration.

Warning

Firewall rules blocked me from connecting to one or more addresses (expand for details)

I tried to connect to the following addresses, but was blocked by firewall rules:

  • jax.readthedocs.io
    • Triggering command: /home/REDACTED/work/_temp/ghcca-node/node/bin/node /home/REDACTED/work/_temp/ghcca-node/node/bin/node --enable-source-maps /home/REDACTED/work/_temp/copilot-developer-action-main/dist/index.js (dns block)

If you need me to access, download, or install something from one of these locations, you can either:

Copilot finished work on behalf of tschm May 1, 2026 11:20
- Pin jax>=0.4.34,<0.5 on Apple Silicon so jax-metal 0.1.1 is compatible
- Add jax-metal as a conditional dep (darwin + arm64 only)
- Extend SIZES in benchmark to (5000, 2500) and (10000, 5000)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@tschm tschm force-pushed the copilot/add-jax-backend-support branch from 26ea454 to b680cea Compare June 5, 2026 13:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants