Add optional JAX backend for GPU-accelerated CG and MINRES solvers#19
Add optional JAX backend for GPU-accelerated CG and MINRES solvers#19Copilot wants to merge 9 commits into
Conversation
Agent-Logs-Url: https://github.com/Jebel-Quant/fast_minimum_variance/sessions/4ca8e035-fee2-4d14-859a-2a7508b6abef Co-authored-by: tschm <2046079+tschm@users.noreply.github.com>
|
@copilot: jax has minres |
Agent-Logs-Url: https://github.com/Jebel-Quant/fast_minimum_variance/sessions/4ca8e035-fee2-4d14-859a-2a7508b6abef Co-authored-by: tschm <2046079+tschm@users.noreply.github.com>
Agent-Logs-Url: https://github.com/Jebel-Quant/fast_minimum_variance/sessions/4ca8e035-fee2-4d14-859a-2a7508b6abef Co-authored-by: tschm <2046079+tschm@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Adds an optional JAX backend to the Problem API so solve_cg() can run its matvec-heavy inner loop on JAX accelerators (e.g., Apple Silicon via jax-metal) while keeping JAX as an optional dependency.
Changes:
- Add
backend: str = "numpy"toProblemwith__post_init__validation. - Add a JAX-dispatched CG implementation (
_solve_cg_jax) and routesolve_cg()to it whenbackend="jax". - Add a
jaxoptional extra, JAX-specific tests (skipped when JAX isn’t installed), and README documentation for the JAX/Metal setup.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
src/fast_minimum_variance/api.py |
Adds backend selection/validation and a JAX-based CG solve path. |
tests/test_jax.py |
Adds JAX backend tests guarded by pytest.importorskip("jax"). |
pyproject.toml |
Introduces fast-minimum-variance[jax] optional extra. |
README.md |
Documents how to use the JAX/Metal backend and notes float32 behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| n_iters is the total number of MINRES iterations across all active-set | ||
| steps. | ||
|
|
||
| Examples: |
There was a problem hiding this comment.
In the JAX CG path the solver step returns step_iters as a constant 1, so the iters returned by solve_cg() no longer represents the total number of CG iterations (as documented and as the NumPy backend does). Please return an actual iteration count (e.g., implement a small CG loop that tracks iterations in JAX, or otherwise adjust the API/docstring so the meaning of n_iters remains consistent across backends).
| def _cg_with_iters(matvec, rhs, *, tol=1e-5, maxiter=None): | |
| """Conjugate-gradient solve that also returns the iteration count.""" | |
| n = int(rhs.shape[0]) | |
| if maxiter is None: | |
| maxiter = max(1, 10 * n) | |
| x = jnp.zeros_like(rhs) | |
| r = rhs - matvec(x) | |
| p = r | |
| rs_old = jnp.vdot(r, r) | |
| rhs_norm = float(jnp.linalg.norm(rhs)) | |
| threshold = tol * rhs_norm | |
| if float(jnp.sqrt(rs_old)) <= threshold: | |
| return x, 0 | |
| iters = 0 | |
| for _ in range(maxiter): | |
| ap = matvec(p) | |
| alpha = rs_old / jnp.vdot(p, ap) | |
| x = x + alpha * p | |
| r = r - alpha * ap | |
| rs_new = jnp.vdot(r, r) | |
| iters += 1 | |
| if float(jnp.sqrt(rs_new)) <= threshold: | |
| break | |
| beta = rs_new / rs_old | |
| p = r + beta * p | |
| rs_old = rs_new | |
| return x, iters | |
| sol_jax, step_iters = _cg_with_iters(_matvec, rhs_jax) | |
| w_jax = w0_jax + P_jax @ sol_jax | |
| # Return as NumPy so the active-set loop stays backend-agnostic. | |
| return np.asarray(w_jax, dtype=np.float64), step_iters |
|
|
||
| Returns: | ||
| Tuple (w, n_iters) where w is the weight vector of shape (N,) and | ||
| n_iters is the total number of MINRES iterations across all active-set |
There was a problem hiding this comment.
jax.scipy.sparse.linalg.cg returns a convergence status/info value as its second return. This path currently ignores it, so non-convergence could silently produce an invalid w. Please check the returned status and raise a clear error (or retry with a higher maxiter/different tolerances) when CG does not converge.
| sol_jax, cg_info = jax_cg(_matvec, rhs_jax) | |
| if cg_info is not None: | |
| cg_info = np.asarray(cg_info).item() | |
| if cg_info != 0: | |
| raise RuntimeError( | |
| "JAX CG failed to converge in Problem._solve_cg_jax " | |
| f"for the current active set (info={cg_info}). " | |
| "Consider increasing maxiter or adjusting tolerances." | |
| ) |
| return w, iters | ||
|
|
||
| def solve_minres(self, *, project: bool = True): | ||
| """Solve via MINRES on the KKT saddle-point system with active-set method. | ||
|
|
||
| Iteratively promotes violated inequality constraints to equalities. At each | ||
| outer iteration the KKT saddle-point system for all assets with the currently | ||
| active constraints pinned as equalities | ||
|
|
||
| [ 2(X^T X + gamma I) A_ext ] [ w ] [ rho * mu ] | ||
| [ A_ext^T 0 ] [ λ ] = [ b_ext ] | ||
|
|
||
| is solved matrix-free via MINRES, where ``A_ext = [A, C[:, active]]`` and | ||
| ``b_ext = [b, d[active]]``. No explicit matrix is ever formed. | ||
|
|
||
| With the defaults (``A = ones``, ``b = [1]``, ``C = -I``, ``d = 0``) this | ||
| recovers the long-only minimum variance solver of the companion paper. | ||
|
|
||
| To apply Ledoit-Wolf shrinkage, pre-scale the return matrix and set gamma:: |
There was a problem hiding this comment.
aa_eq, bb_eq, cc, and dd are converted to JAX arrays but then immediately converted back to NumPy inside _solve (np.asarray(...)). This causes device→host transfers and extra allocations on every active-set step. Consider keeping constraints as NumPy arrays (use self.A/self.b/self.C/self.d) and only keep the heavy matvec inputs (X, mu, P, w0, etc.) on JAX.
| import numpy as np | ||
| import pytest | ||
|
|
||
| jax = pytest.importorskip("jax") |
There was a problem hiding this comment.
jax = pytest.importorskip("jax") assigns to a name that is never used, which will fail Ruff's unused-variable check (F841) under the repo's pre-commit/CI linting. Use pytest.importorskip("jax") without assignment or assign to _ instead.
| jax = pytest.importorskip("jax") | |
| pytest.importorskip("jax") |
Agent-Logs-Url: https://github.com/Jebel-Quant/fast_minimum_variance/sessions/4ca8e035-fee2-4d14-859a-2a7508b6abef Co-authored-by: tschm <2046079+tschm@users.noreply.github.com>
…mports Agent-Logs-Url: https://github.com/Jebel-Quant/fast_minimum_variance/sessions/4ca8e035-fee2-4d14-859a-2a7508b6abef Co-authored-by: tschm <2046079+tschm@users.noreply.github.com>
Agent-Logs-Url: https://github.com/Jebel-Quant/fast_minimum_variance/sessions/4ca8e035-fee2-4d14-859a-2a7508b6abef Co-authored-by: tschm <2046079+tschm@users.noreply.github.com>
…cg, clearer benchmark CPU warning Agent-Logs-Url: https://github.com/Jebel-Quant/fast_minimum_variance/sessions/52397cd8-e515-4bd9-a149-a420095bf286 Co-authored-by: tschm <2046079+tschm@users.noreply.github.com>
Checked against JAX 0.10.0 (latest release) and the current That said, the old port was wrong in a subtler way: it kept scalar Lanczos state as Python floats, meaning every Fixed in
On CPU JAX the overhead from XLA compilation makes it slower than NumPy, which is expected. The speedup comes on Metal/CUDA where the two GEMVs per step ( Warning Firewall rules blocked me from connecting to one or more addresses (expand for details)I tried to connect to the following addresses, but was blocked by firewall rules:
If you need me to access, download, or install something from one of these locations, you can either:
|
- Pin jax>=0.4.34,<0.5 on Apple Silicon so jax-metal 0.1.1 is compatible - Add jax-metal as a conditional dep (darwin + arm64 only) - Extend SIZES in benchmark to (5000, 2500) and (10000, 5000) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
26ea454 to
b680cea
Compare
Adds
backend='jax'toProblem, routing bothsolve_cgandsolve_minresthrough JAX-backed matvecs for use on Apple Silicon (M-series viajax-metal) or any JAX accelerator. The active-set outer loop stays on CPU; only the inner Krylov solve moves to JAX.API changes —
api.pybackend: str = 'numpy'field onProblem;__post_init__raisesValueErrorfor unknown values_jax_arrays()— private method centralising the lazyimport jax.numpyand float32 conversion (jax-metal has limited float64 support); shared by both JAX solve paths_minres_jax(matvec, b)— JAX port of SciPy's MINRES (Paige & Saunders 1975) implemented withjax.lax.while_loop; the entire Lanczos recurrence — scalars, vectors, and convergence tests — runs inside a single compiled loop with no host synchronisations per iteration.jax.scipy.sparse.linalgdoes not include MINRES as of JAX 0.10, so this port is necessary._solve_cg_jax— usesjax.scipy.sparse.linalg.cg(which internally useslax.while_loop) so the full CG solve stays on-device.n_itersreturns0for the JAX CG path since JAX's CG implementation does not expose iteration count._solve_minres_jax— private method implementing the KKT matvec with JAX arrays; results returned as NumPy so_constraint_active_setis unaffectedsolve_cgandsolve_minresdispatch to the JAX paths whenbackend='jax'; NumPy paths are untouchedOn-device execution
Both Krylov solvers use
lax.while_loop-based implementations so all computation stays on-device — no Python-level host synchronisations during the solve. The only device→host transfer is extracting the MINRES iteration count at return. This is essential for Metal/GPU performance: Pythonfloat()calls on JAX arrays block until the GPU drains its command queue, which would dominate runtime at any reasonable iteration count.On a CPU JAX backend, XLA dispatch overhead makes the JAX path slower than NumPy (as shown by the benchmark). The speedup appears on Metal/CUDA where the two GEMVs per step (
X.T @ (X @ x)) run on the GPU without interruption.New files
pyproject.toml—[jax]optional extra (jax>=0.4);jax-metalis platform-specific so users install it separatelytests/test_jax.py—pytest.importorskip('jax')at module level; tests budget constraint, non-negativity, NumPy/JAX agreement (atol=1e-4, float32 tolerance), shape, dtype, and iteration count for both solvers; backend validation parametrized over multiple invalid stringsbenchmarks/jax_backend.py— timessolve_cgandsolve_minresacross six(T, N)sizes from(250, 20)to(2000, 1000), with a separatewarmup_jaxcolumn for the one-off XLA compilation cost and anerrcolumn for float32 accuracy; prints a clear warning when JAX is running on CPU (not Metal/CUDA); degrades gracefully without JAX installedOriginal prompt
Summary
Add an optional
backend='jax'path to theProblemclass so users on Apple Silicon (M4 viajax-metal) or other JAX-supported accelerators can run the Krylov solvers (solve_minres,solve_cg) with GPU-accelerated matrix-vector products.The core insight is that the entire computational cost of both MINRES and CG lives in two matvecs per iteration:
On JAX with Metal this maps directly to accelerated GEMVs on the M-series GPU. The active-set outer loop remains on CPU (it's cheap scalar comparisons), and only the inner Krylov solve moves to JAX.
Design
pyproject.tomlAdd a new optional extra:
Note: users on Apple Silicon install
jax-metalthemselves (it's a platform-specific package not installable cross-platform), so the extra only declaresjaxas the base dependency.src/fast_minimum_variance/api.pyAdd a
backend: strfield toProblemwith default'numpy'. Accepted values:'numpy','jax'.Add a private method
_jax_cg_operator(active=None)that:self.X,self.A,self.C,self.b,self.d,self.mutojnp.array(infloat32sincejax-metalhas limitedfloat64support)scipy.sparse.linalg.LinearOperator— JAX'scgaccepts a plain callable)jax.scipy.sparse.linalg.cgfor the CG solve in the null-space-reduced systemnp.asarray(...)) so the rest of the active-set loop is unaffectedModify
solve_cgto dispatch to the JAX path whenself.backend == 'jax':Add
_solve_cg_jaxthat:w = w0 + P @ vwherePis an orthonormal null-space basis forA_ext^Tjax.scipy.sparse.linalg.cgon the reduced SPD operator_constraint_active_setinfrastructure) but with a JAX-backed solve functionclip_and_renormalizeifproject=TrueAdd a
__post_init__validation that raisesValueErrorfor unknown backend values.tests/test_jax.pyAdd a new test file
tests/test_jax.pywith:pytest.importorskip('jax')so the whole file is skipped when JAX is not installed (CI without JAX extra will not fail)test_krylov.pystructure:test_jax_cg_weights_sum_to_one— budget constraint satisfiedtest_jax_cg_weights_non_negative— long-only constraint satisfiedtest_jax_cg_agrees_with_numpy_cg— weights frombackend='jax'andbackend='numpy'agree to within1e-4(float32 tolerance)test_jax_backend_invalid_raises—Problem(X, backend='gpu')raisesValueErrortest_jax_backend_unknown_raises— unknown backend string raisesValueErrorREADME update
Add a short section after the existing solver table:
The JAX backend operates in
float32. For most portfolio problems the solution quality is indistinguishable fromfloat64, but verify residuals for ill-conditioned covariance structures.