From 72e3d8e6f62a801c09d9821fa008a0314b2a3e9d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 1 May 2026 10:48:21 +0000 Subject: [PATCH 1/9] Initial plan From 189782cdaf292741e414e1d5ba96c21a014f426b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 1 May 2026 10:51:50 +0000 Subject: [PATCH 2/9] Add JAX backend support to Problem class with solve_cg dispatch Agent-Logs-Url: https://github.com/Jebel-Quant/fast_minimum_variance/sessions/4ca8e035-fee2-4d14-859a-2a7508b6abef Co-authored-by: tschm <2046079+tschm@users.noreply.github.com> --- README.md | 18 ++++++ pyproject.toml | 3 + src/fast_minimum_variance/api.py | 95 ++++++++++++++++++++++++++++++++ tests/test_jax.py | 74 +++++++++++++++++++++++++ 4 files changed, 190 insertions(+) create mode 100644 tests/test_jax.py diff --git a/README.md b/README.md index 1aab7f2..bac6ec0 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,24 @@ door to MINRES. Alternatively, the CG solver eliminates the constraints entirely parameterising $w = w_0 + Pv$ where $P$ spans the null space of $A^\top$, yielding a positive-definite reduced system of size $(N-m) \times (N-m)$. +## JAX / Metal Backend (Apple Silicon) + +To run the CG solver on the Apple M-series GPU via Metal: + +```bash +pip install fast-minimum-variance[jax] +pip install jax-metal # Apple Silicon only +``` + +```python +p = Problem(R, backend='jax') +w, iters = p.solve_cg() # matvecs run on Metal GPU +``` + +The JAX backend operates in `float32`. For most portfolio problems the solution +quality is indistinguishable from `float64`, but verify residuals for +ill-conditioned covariance structures. + ## Installation ```bash diff --git a/pyproject.toml b/pyproject.toml index d22815b..f5d864b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,9 @@ dependencies = [ convex = [ "cvxpy>=1.0", ] +jax = [ + "jax>=0.4", +] [dependency-groups] dev = [ diff --git a/src/fast_minimum_variance/api.py b/src/fast_minimum_variance/api.py index c8e9160..9692ec3 100644 --- a/src/fast_minimum_variance/api.py +++ b/src/fast_minimum_variance/api.py @@ -52,6 +52,7 @@ class Problem: b: np.ndarray = field(default=None) # type: ignore[assignment] C: np.ndarray = field(default=None) # type: ignore[assignment] d: np.ndarray = field(default=None) # type: ignore[assignment] + backend: str = "numpy" def __post_init__(self): """Fill in default constraint matrices when not supplied.""" @@ -66,6 +67,10 @@ def __post_init__(self): if self.d is None: object.__setattr__(self, "d", np.zeros(n)) # object.__setattr__ is required because the dataclass is frozen. + if self.backend not in ("numpy", "jax"): + raise ValueError( # noqa: TRY003 + f"Unknown backend {self.backend!r}. Accepted values are 'numpy' and 'jax'." + ) @property def n(self) -> int: @@ -293,6 +298,9 @@ def solve_cg(self, *, project: bool = True): See ``solve_minres`` for the Ledoit-Wolf shrinkage recipe. + When ``backend='jax'`` this method dispatches to ``_solve_cg_jax``, which + runs the matvec kernel on JAX (e.g. Apple Silicon via ``jax-metal``). + Args: project: If True (default), clip weights to non-negative and renormalize to sum to one after solving. Only correct for the default @@ -318,6 +326,8 @@ def solve_cg(self, *, project: bool = True): >>> iters > 0 True """ + if self.backend == "jax": + return self._solve_cg_jax(project=project) def _solve(active): """Solve the CG null-space subproblem for the current active set.""" @@ -335,6 +345,91 @@ def _solve(active): w = clip_and_renormalize(w) return w, iters + def _solve_cg_jax(self, *, project: bool = True): + """Solve via JAX-accelerated CG in the null space (JAX backend). + + Implements the same null-space reduction as the NumPy CG path, but uses + ``jax.scipy.sparse.linalg.cg`` with JAX arrays so that the two matvecs + per iteration run on the available JAX accelerator (e.g. Apple Silicon + via ``jax-metal``). + + All arrays are converted to ``float32`` before the JAX solve because + ``jax-metal`` has limited ``float64`` support. Results are returned as + NumPy arrays so the rest of the active-set loop is unaffected. + + Requires JAX to be installed:: + + pip install fast-minimum-variance[jax] + pip install jax-metal # Apple Silicon only + + Args: + project: If True (default), clip weights to non-negative and + renormalize to sum to one after solving. + + Returns: + Tuple (w, n_iters) where w is the weight vector of shape (N,) and + n_iters is the total number of CG iterations across all active-set + steps. + """ + try: + import jax.numpy as jnp + from jax.scipy.sparse.linalg import cg as jax_cg + except ImportError as e: + raise ImportError( # noqa: TRY003 + "JAX is required for backend='jax'; install with: " + "pip install fast-minimum-variance[jax]" + ) from e + + # Convert to float32 — jax-metal has limited float64 support. + xx = jnp.array(self.X, dtype=jnp.float32) # noqa: N806 + aa_eq = jnp.array(self.A, dtype=jnp.float32) + bb_eq = jnp.array(self.b, dtype=jnp.float32) + cc = jnp.array(self.C, dtype=jnp.float32) + dd = jnp.array(self.d, dtype=jnp.float32) + mu_jax = jnp.array(self.mu, dtype=jnp.float32) if self.mu is not None else None + gam = float(self.gamma) + rho = float(self.rho) + + def _solve(active): + """Solve the JAX CG null-space subproblem for the current active set.""" + # Build the extended constraint matrix on CPU (NumPy) for QR / lstsq. + active_np = np.asarray(active) + aa_np = np.hstack([np.asarray(aa_eq), np.asarray(cc)[:, active_np]]) + m_ext = aa_np.shape[1] + n_free = self.n - m_ext + b_ext_np = np.concatenate([np.asarray(bb_eq), np.asarray(dd)[active_np]]) + + w0_np = np.linalg.lstsq(aa_np.T, b_ext_np, rcond=None)[0] + + if n_free <= 0: + return w0_np, 0 + + Q_np, _ = np.linalg.qr(aa_np, mode="complete") # noqa: N806 + P_np = Q_np[:, m_ext:] # noqa: N806 + + # Move P and w0 to JAX. + P_jax = jnp.array(P_np, dtype=jnp.float32) # noqa: N806 + w0_jax = jnp.array(w0_np, dtype=jnp.float32) + + g0 = xx.T @ (xx @ w0_jax) + gam * w0_jax + if rho != 0.0 and mu_jax is not None: + g0 = g0 - (rho / 2.0) * mu_jax + rhs_jax = -(P_jax.T @ g0) + + def _matvec(y, pp=P_jax, x_=xx, g=gam): + pv = pp @ y + return pp.T @ (x_.T @ (x_ @ pv)) + g * y + + sol_jax, _ = jax_cg(_matvec, rhs_jax) + w_jax = w0_jax + P_jax @ sol_jax + # Return as NumPy so the active-set loop stays backend-agnostic. + return np.asarray(w_jax, dtype=np.float64), 1 + + w, iters = self._constraint_active_set(_solve) + if project: + w = clip_and_renormalize(w) + return w, iters + def solve_cvxpy(self, *, project: bool = True): """Solve via CVXPY (reference interior-point solver). diff --git a/tests/test_jax.py b/tests/test_jax.py new file mode 100644 index 0000000..649c0a7 --- /dev/null +++ b/tests/test_jax.py @@ -0,0 +1,74 @@ +"""Tests for Problem.solve_cg with backend='jax'. + +The entire module is skipped when JAX is not installed, so CI without the JAX +extra will not fail. +""" + +import numpy as np +import pytest + +jax = pytest.importorskip("jax") + +from fast_minimum_variance.api import Problem # noqa: E402 + + +@pytest.fixture(scope="module") +def X(): # noqa: N802 + """Return matrix of shape (200, 10) with a fixed seed.""" + return np.random.default_rng(42).standard_normal((200, 10)) + + +@pytest.fixture(scope="module") +def problem_jax(X): # noqa: N803 + """Problem instance with backend='jax'.""" + return Problem(X, backend="jax") + + +class TestSolveCgJax: + """Tests for Problem.solve_cg with backend='jax'.""" + + def test_jax_cg_weights_sum_to_one(self, problem_jax): + """Budget constraint: weights sum to 1.""" + w, _ = problem_jax.solve_cg() + assert abs(w.sum() - 1.0) < 1e-4 + + def test_jax_cg_weights_non_negative(self, problem_jax): + """Long-only constraint: all weights are non-negative.""" + w, _ = problem_jax.solve_cg() + assert np.all(w >= -1e-4) + + def test_jax_cg_agrees_with_numpy_cg(self, X): # noqa: N803 + """JAX and NumPy backends agree to within float32 tolerance.""" + w_np, _ = Problem(X, backend="numpy").solve_cg() + w_jax, _ = Problem(X, backend="jax").solve_cg() + np.testing.assert_allclose(w_jax, w_np, atol=1e-4) + + def test_jax_cg_returns_shape(self, problem_jax): + """Output weight vector has shape (N,).""" + w, _ = problem_jax.solve_cg() + assert w.shape == (problem_jax.n,) + + def test_jax_cg_returns_numpy_array(self, problem_jax): + """Output is a plain NumPy array, not a JAX array.""" + w, _ = problem_jax.solve_cg() + assert isinstance(w, np.ndarray) + + +class TestJaxBackendValidation: + """Tests for backend validation in Problem.__post_init__.""" + + def test_jax_backend_invalid_raises(self): + """Unknown backend string raises ValueError.""" + X = np.random.default_rng(0).standard_normal((50, 3)) # noqa: N806 + with pytest.raises(ValueError, match="Unknown backend"): + Problem(X, backend="gpu") + + def test_jax_backend_unknown_raises(self): + """Any unrecognised backend string raises ValueError.""" + X = np.random.default_rng(0).standard_normal((50, 3)) # noqa: N806 + with pytest.raises(ValueError, match="Unknown backend"): + Problem(X, backend="cupy") + + +if __name__ == "__main__": + pytest.main() From 302e890f510237e743c52f5aa7d1e5f705d5b229 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 1 May 2026 10:56:01 +0000 Subject: [PATCH 3/9] Plan: add JAX MINRES port for solve_minres JAX backend Agent-Logs-Url: https://github.com/Jebel-Quant/fast_minimum_variance/sessions/4ca8e035-fee2-4d14-859a-2a7508b6abef Co-authored-by: tschm <2046079+tschm@users.noreply.github.com> --- src/fast_minimum_variance/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fast_minimum_variance/api.py b/src/fast_minimum_variance/api.py index 9692ec3..2616e65 100644 --- a/src/fast_minimum_variance/api.py +++ b/src/fast_minimum_variance/api.py @@ -416,9 +416,9 @@ def _solve(active): g0 = g0 - (rho / 2.0) * mu_jax rhs_jax = -(P_jax.T @ g0) - def _matvec(y, pp=P_jax, x_=xx, g=gam): + def _matvec(y, pp=P_jax, xx=xx, gam=gam): pv = pp @ y - return pp.T @ (x_.T @ (x_ @ pv)) + g * y + return pp.T @ (xx.T @ (xx @ pv)) + gam * y sol_jax, _ = jax_cg(_matvec, rhs_jax) w_jax = w0_jax + P_jax @ sol_jax From 23ae2394729e9d6274dec3e365419fe5c3bbcd37 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 1 May 2026 10:58:16 +0000 Subject: [PATCH 4/9] Add JAX MINRES port: _minres_jax function and _solve_minres_jax method Agent-Logs-Url: https://github.com/Jebel-Quant/fast_minimum_variance/sessions/4ca8e035-fee2-4d14-859a-2a7508b6abef Co-authored-by: tschm <2046079+tschm@users.noreply.github.com> --- README.md | 13 +- src/fast_minimum_variance/api.py | 249 +++++++++++++++++++++++++++++++ tests/test_jax.py | 38 ++++- 3 files changed, 298 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index bac6ec0..8ee47a5 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ positive-definite reduced system of size $(N-m) \times (N-m)$. ## JAX / Metal Backend (Apple Silicon) -To run the CG solver on the Apple M-series GPU via Metal: +To run the CG and MINRES solvers on the Apple M-series GPU via Metal: ```bash pip install fast-minimum-variance[jax] @@ -100,8 +100,19 @@ pip install jax-metal # Apple Silicon only ```python p = Problem(R, backend='jax') w, iters = p.solve_cg() # matvecs run on Metal GPU +w, iters = p.solve_minres() # also JAX-accelerated ``` +Both solvers use the same two-matvec-per-iteration kernel `X.T @ (X @ x)`, +which maps directly to accelerated GEMVs on the M-series GPU. The +active-set outer loop remains on CPU. + +Because `jax.scipy.sparse.linalg` does not include MINRES, the MINRES +algorithm is ported directly from SciPy (Paige & Saunders 1975): the scalar +Lanczos recurrence variables stay as Python floats (so Python control flow +and convergence tests work in eager mode) while the vector state is held in +JAX arrays so each matvec dispatches to the accelerator. + The JAX backend operates in `float32`. For most portfolio problems the solution quality is indistinguishable from `float64`, but verify residuals for ill-conditioned covariance structures. diff --git a/src/fast_minimum_variance/api.py b/src/fast_minimum_variance/api.py index 2616e65..2018bea 100644 --- a/src/fast_minimum_variance/api.py +++ b/src/fast_minimum_variance/api.py @@ -13,6 +13,176 @@ def clip_and_renormalize(w: np.ndarray) -> np.ndarray: return w +def _minres_jax(matvec, b, *, rtol: float = 1e-5, maxiter: int | None = None): + """MINRES solver implemented with JAX arrays for use on accelerators. + + Solves the symmetric (possibly indefinite) linear system ``A x = b`` using + the Minimum Residual method. The implementation follows the algorithm of + Paige and Saunders (1975) and is structured after + ``scipy.sparse.linalg.minres``, adapted to use JAX primitives so that the + matrix-vector product ``matvec`` can run on a JAX accelerator backend. + + The scalar Lanczos recurrence variables are kept as Python floats so that + convergence tests use ordinary Python control flow (valid in JAX eager + mode without ``jit``). Only the vector state (``x``, ``w``, ``r1``, + ``r2``, ``v``, ``y``) is held as JAX arrays so that each matvec call + dispatches to the accelerator. + + No preconditioner and no shift are applied (``M = I``, ``shift = 0``). + + Parameters + ---------- + matvec : callable + Function computing ``A @ x`` for a JAX array ``x``. + b : jax.Array + Right-hand side vector (``float32``). + rtol : float + Relative residual tolerance for convergence. Default ``1e-5``. + maxiter : int, optional + Maximum number of iterations; defaults to ``5 * len(b)``. + + Returns + ------- + x : jax.Array + Approximate solution vector. + itn : int + Number of iterations taken. + """ + try: + import jax.numpy as jnp + except ImportError as e: + raise ImportError( # noqa: TRY003 + "JAX is required for backend='jax'; install with: " + "pip install fast-minimum-variance[jax]" + ) from e + + n = b.shape[0] + if maxiter is None: + maxiter = 5 * n + + dtype = b.dtype + eps = float(jnp.finfo(dtype).eps) + + x = jnp.zeros(n, dtype=dtype) + + # With identity preconditioner (M = I): y = r1 = b. + r1 = b + y = r1 + + beta1 = float(jnp.dot(r1, y)) + if beta1 < 0: + raise ValueError("indefinite preconditioner") # noqa: TRY003 + if beta1 == 0: + return x, 0 + + bnorm = float(jnp.linalg.norm(b)) + if bnorm == 0: + return b, 0 + + beta1 = beta1 ** 0.5 + + # Scalar Lanczos state — kept as Python floats so that convergence + # comparisons use ordinary Python control flow (valid in eager JAX). + oldb: float = 0.0 + beta: float = beta1 + dbar: float = 0.0 + epsln: float = 0.0 + phibar: float = beta1 + rhs1: float = beta1 + rhs2: float = 0.0 + tnorm2: float = 0.0 + gmax: float = 0.0 + gmin: float = float("inf") + cs: float = -1.0 + sn: float = 0.0 + + # Vector state — JAX arrays so matvec dispatches to the accelerator. + w = jnp.zeros(n, dtype=dtype) + w2 = jnp.zeros(n, dtype=dtype) + r2 = r1 + + itn: int = 0 + istop: int = 0 + + while itn < maxiter: + itn += 1 + + s = 1.0 / beta + v = s * y + + y = matvec(v) + + if itn >= 2: + y = y - (beta / oldb) * r1 + + alfa = float(jnp.dot(v, y)) + y = y - (alfa / beta) * r2 + r1 = r2 + r2 = y + y = r2 # M = I: psolve(r2) = r2 + + oldb = beta + beta = float(jnp.dot(r2, y)) + if beta < 0: + raise ValueError("non-symmetric matrix") # noqa: TRY003 + beta = beta ** 0.5 + tnorm2 += alfa ** 2 + oldb ** 2 + beta ** 2 + + # Apply previous plane rotation to get [delta, gbar, epsln, dbar]. + oldeps = epsln + delta = cs * dbar + sn * alfa + gbar = sn * dbar - cs * alfa + epsln = sn * beta + dbar = -cs * beta + root = (gbar ** 2 + dbar ** 2) ** 0.5 + + # Compute next plane rotation. + gamma = (gbar ** 2 + beta ** 2) ** 0.5 + gamma = max(gamma, eps) + cs = gbar / gamma + sn = beta / gamma + phi = cs * phibar + phibar = sn * phibar + + # Update x. + denom = 1.0 / gamma + w1 = w2 + w2 = w + w = (v - oldeps * w1 - delta * w2) * denom + x = x + phi * w + + gmax = max(gmax, gamma) + gmin = min(gmin, gamma) + z = rhs1 / gamma + rhs1 = rhs2 - delta * z + rhs2 = -epsln * z + + # Estimate convergence: test1 ≈ ||r|| / (||A|| ||x||). + anorm = tnorm2 ** 0.5 + ynorm = float(jnp.linalg.norm(x)) + test1 = phibar / (anorm * ynorm) if (anorm > 0 and ynorm > 0) else float("inf") + test2 = root / anorm if anorm > 0 else float("inf") + acond = gmax / gmin if gmin > 0 else float("inf") + + if 1.0 + test2 <= 1.0: + istop = 2 + if 1.0 + test1 <= 1.0: + istop = 1 + if itn >= maxiter: + istop = 6 + if acond >= 0.1 / eps: + istop = 4 + if test2 <= rtol: + istop = 2 + if test1 <= rtol: + istop = 1 + + if istop != 0: + break + + return x, itn + + @dataclass(frozen=True) class Problem: """Mean-variance portfolio problem specification and solver interface. @@ -237,6 +407,11 @@ def solve_minres(self, *, project: bool = True): gamma = frob_sq / (N + T) w, iters = Problem(X_scaled, gamma=gamma).solve_minres() + When ``backend='jax'`` this method dispatches to ``_solve_minres_jax``, + which runs the matvec kernel on JAX (e.g. Apple Silicon via + ``jax-metal``). The MINRES algorithm is ported directly from SciPy + because ``jax.scipy.sparse.linalg`` does not include MINRES. + Args: project: If True (default), clip weights to non-negative and renormalize to sum to one after solving. Only correct for the default @@ -262,6 +437,8 @@ def solve_minres(self, *, project: bool = True): >>> iters > 0 True """ + if self.backend == "jax": + return self._solve_minres_jax(project=project) def _solve(active): """Solve the MINRES saddle-point system for the current active set.""" @@ -277,6 +454,78 @@ def _solve(active): w = clip_and_renormalize(w) return w, iters + def _solve_minres_jax(self, *, project: bool = True): + """Solve via JAX-accelerated MINRES on the KKT saddle-point system. + + Dispatched from ``solve_minres`` when ``backend='jax'``. Because + ``jax.scipy.sparse.linalg`` does not include MINRES, the algorithm is + ported directly from SciPy (Paige & Saunders 1975) via the module-level + ``_minres_jax`` helper. The KKT matvec runs on the JAX accelerator; + the scalar Lanczos recurrence stays in Python for minimal overhead. + + All arrays are converted to ``float32`` before the JAX solve because + ``jax-metal`` has limited ``float64`` support. Results are returned as + NumPy arrays so the active-set loop is unaffected. + + Requires JAX to be installed:: + + pip install fast-minimum-variance[jax] + pip install jax-metal # Apple Silicon only + + Args: + project: If True (default), clip weights to non-negative and + renormalize to sum to one after solving. + + Returns: + Tuple (w, n_iters) where w is the weight vector of shape (N,) and + n_iters is the total number of MINRES iterations across all + active-set steps. + """ + try: + import jax.numpy as jnp + except ImportError as e: + raise ImportError( # noqa: TRY003 + "JAX is required for backend='jax'; install with: " + "pip install fast-minimum-variance[jax]" + ) from e + + # Convert to float32 — jax-metal has limited float64 support. + xx = jnp.array(self.X, dtype=jnp.float32) # noqa: N806 + aa_eq = jnp.array(self.A, dtype=jnp.float32) + bb_eq = jnp.array(self.b, dtype=jnp.float32) + cc = jnp.array(self.C, dtype=jnp.float32) + dd = jnp.array(self.d, dtype=jnp.float32) + mu_jax = jnp.array(self.mu, dtype=jnp.float32) if self.mu is not None else None + gam = float(self.gamma) + rho = float(self.rho) + na = self.n + + def _solve(active): + """Solve the JAX MINRES KKT subproblem for the current active set.""" + active_np = np.asarray(active) + aa_np = np.hstack([np.asarray(aa_eq), np.asarray(cc)[:, active_np]]) + ma = aa_np.shape[1] + aa_jax = jnp.array(aa_np, dtype=jnp.float32) # noqa: N806 + + rhs = jnp.zeros(na + ma, dtype=jnp.float32) + if rho != 0.0 and mu_jax is not None: + rhs = rhs.at[:na].set(rho * mu_jax) + b_ext = jnp.concatenate([bb_eq, dd[active_np]]) + rhs = rhs.at[na:].set(b_ext) + + def _matvec(x, xx=xx, aa=aa_jax, gam=gam, na=na): + out_top = 2.0 * (xx.T @ (xx @ x[:na]) + gam * x[:na]) + aa @ x[na:] + out_bot = aa.T @ x[:na] + return jnp.concatenate([out_top, out_bot]) + + sol, iters = _minres_jax(_matvec, rhs) + return np.asarray(sol[:na], dtype=np.float64), iters + + w, iters = self._constraint_active_set(_solve) + if project: + w = clip_and_renormalize(w) + return w, iters + def solve_cg(self, *, project: bool = True): """Solve via CG in the constraint-reduced null space with active-set method. diff --git a/tests/test_jax.py b/tests/test_jax.py index 649c0a7..ba141da 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -1,4 +1,4 @@ -"""Tests for Problem.solve_cg with backend='jax'. +"""Tests for Problem.solve_cg and Problem.solve_minres with backend='jax'. The entire module is skipped when JAX is not installed, so CI without the JAX extra will not fail. @@ -54,6 +54,42 @@ def test_jax_cg_returns_numpy_array(self, problem_jax): assert isinstance(w, np.ndarray) +class TestSolveMinresJax: + """Tests for Problem.solve_minres with backend='jax'.""" + + def test_jax_minres_weights_sum_to_one(self, problem_jax): + """Budget constraint: weights sum to 1.""" + w, _ = problem_jax.solve_minres() + assert abs(w.sum() - 1.0) < 1e-4 + + def test_jax_minres_weights_non_negative(self, problem_jax): + """Long-only constraint: all weights are non-negative.""" + w, _ = problem_jax.solve_minres() + assert np.all(w >= -1e-4) + + def test_jax_minres_agrees_with_numpy_minres(self, X): # noqa: N803 + """JAX and NumPy MINRES backends agree to within float32 tolerance.""" + w_np, _ = Problem(X, backend="numpy").solve_minres() + w_jax, _ = Problem(X, backend="jax").solve_minres() + np.testing.assert_allclose(w_jax, w_np, atol=1e-4) + + def test_jax_minres_returns_shape(self, problem_jax): + """Output weight vector has shape (N,).""" + w, _ = problem_jax.solve_minres() + assert w.shape == (problem_jax.n,) + + def test_jax_minres_returns_numpy_array(self, problem_jax): + """Output is a plain NumPy array, not a JAX array.""" + w, _ = problem_jax.solve_minres() + assert isinstance(w, np.ndarray) + + def test_jax_minres_returns_positive_iters(self, problem_jax): + """Iteration count is a positive integer.""" + _, iters = problem_jax.solve_minres() + assert isinstance(iters, int) + assert iters > 0 + + class TestJaxBackendValidation: """Tests for backend validation in Problem.__post_init__.""" From 6c2694c281ff5027a20b9402b744e03568075eec Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 1 May 2026 11:02:18 +0000 Subject: [PATCH 5/9] Fix CG iter count with _cg_jax helper; add benchmarks/jax_backend.py Agent-Logs-Url: https://github.com/Jebel-Quant/fast_minimum_variance/sessions/4ca8e035-fee2-4d14-859a-2a7508b6abef Co-authored-by: tschm <2046079+tschm@users.noreply.github.com> --- benchmarks/jax_backend.py | 156 +++++++++++++++++++++++++++++++ src/fast_minimum_variance/api.py | 79 +++++++++++++++- 2 files changed, 232 insertions(+), 3 deletions(-) create mode 100644 benchmarks/jax_backend.py diff --git a/benchmarks/jax_backend.py b/benchmarks/jax_backend.py new file mode 100644 index 0000000..d50b677 --- /dev/null +++ b/benchmarks/jax_backend.py @@ -0,0 +1,156 @@ +"""NumPy vs JAX backend benchmark for fast-minimum-variance. + +Compares wall-clock time of ``solve_cg`` and ``solve_minres`` across a range +of problem sizes using the ``'numpy'`` and ``'jax'`` backends. + +Usage:: + + pip install fast-minimum-variance[jax] + pip install jax-metal # Apple Silicon only + python benchmarks/jax_backend.py + +JAX note: the first call to each solver traces and compiles the computation +graph (XLA / Metal). This script runs one **warmup** solve before timing so +that reported numbers reflect steady-state throughput, not compilation cost. +The warmup time is printed separately so you can see the one-off JIT cost. + +The benchmark reports: + +* ``time_np`` — NumPy backend wall time (``float64``) +* ``time_jax`` — JAX backend wall time after warmup (``float32``) +* ``warmup_jax``— time for the first JAX call (JIT / XLA compilation) +* ``speedup`` — ``time_np / time_jax`` (>1 means JAX is faster) +* ``err`` — ``max |w_jax - w_np|`` (float32 accuracy check) + +Run without JAX installed to see NumPy-only timings (JAX columns will show +``N/A``). +""" + +import time + +import numpy as np + +try: + import jax # noqa: F401 + + JAX_AVAILABLE = True +except ImportError: + JAX_AVAILABLE = False + +from fast_minimum_variance.api import Problem + +# --------------------------------------------------------------------------- +# Problem sizes to benchmark: (T, N) pairs +# --------------------------------------------------------------------------- +SIZES = [ + (250, 20), + (500, 50), + (500, 100), + (1000, 200), + (1000, 500), + (2000, 1000), +] + +N_REPS = 5 # timed repetitions after warmup; min is reported + + +def _make_problem(T, N, seed=42, backend="numpy"): # noqa: N803 + """Generate a random return matrix and construct a Problem.""" + rng = np.random.default_rng(seed) + X = rng.standard_normal((T, N)) # noqa: N806 + return Problem(X, backend=backend) + + +def _time_solve(fn, n_reps): + """Return (min_time_seconds, result) over n_reps calls.""" + best = float("inf") + result = None + for _ in range(n_reps): + t0 = time.perf_counter() + result = fn() + best = min(best, time.perf_counter() - t0) + return best, result + + +def _run_size(T, N, solver): # noqa: N803 + """Return a result dict for one (T, N, solver) combination.""" + p_np = _make_problem(T, N, backend="numpy") + fn_np = getattr(p_np, solver) + + # NumPy timing + t_np, (w_np, _) = _time_solve(fn_np, N_REPS) + + if not JAX_AVAILABLE: + return {"t_np": t_np, "t_jax": None, "t_warmup": None, "err": None} + + p_jax = _make_problem(T, N, backend="jax") + fn_jax = getattr(p_jax, solver) + + # Warmup: first call pays JIT / XLA compilation cost + t0 = time.perf_counter() + w_jax, _ = fn_jax() + t_warmup = time.perf_counter() - t0 + + # Steady-state timing + t_jax, (w_jax, _) = _time_solve(fn_jax, N_REPS) + + err = float(np.max(np.abs(np.asarray(w_jax) - w_np))) + return {"t_np": t_np, "t_jax": t_jax, "t_warmup": t_warmup, "err": err} + + +def _fmt(val, fmt=".4f", na="N/A"): + """Format a value or return na if None.""" + return f"{val:{fmt}}" if val is not None else na + + +def _run_benchmark(solver): + """Run the full benchmark for one solver and print results.""" + hdr = ( + f"{'T':>6} {'N':>6} {'time_np':>9} {'time_jax':>9} " + f"{'warmup_jax':>12} {'speedup':>8} {'err':>10}" + ) + print(f"\n{'─' * len(hdr)}") + print(f" {solver}") + print(f"{'─' * len(hdr)}") + print(hdr) + print("─" * len(hdr)) + + for T, N in SIZES: # noqa: N806 + r = _run_size(T, N, solver) + t_np = r["t_np"] + t_jax = r["t_jax"] + warmup = r["t_warmup"] + err = r["err"] + + if t_jax is not None and t_jax > 0: + speedup = f"{t_np / t_jax:7.2f}x" + else: + speedup = "N/A" + + print( + f"{T:>6} {N:>6} {_fmt(t_np):>9} {_fmt(t_jax):>9} " + f"{_fmt(warmup):>12} {speedup:>8} {_fmt(err, '.2e'):>10}" + ) + + print("─" * len(hdr)) + + +def main(): + """Entry point.""" + print("fast-minimum-variance: NumPy vs JAX backend benchmark") + print(f"JAX available: {JAX_AVAILABLE}") + if JAX_AVAILABLE: + import jax + + print(f"JAX version: {jax.__version__}") + print(f"JAX backend: {jax.default_backend()}") + print(f"Repetitions: {N_REPS} (min of {N_REPS} runs after one warmup)") + print("Times in seconds. speedup = time_np / time_jax (>1 means JAX faster)") + + _run_benchmark("solve_cg") + _run_benchmark("solve_minres") + print() + + +if __name__ == "__main__": + main() diff --git a/src/fast_minimum_variance/api.py b/src/fast_minimum_variance/api.py index 2018bea..693ce71 100644 --- a/src/fast_minimum_variance/api.py +++ b/src/fast_minimum_variance/api.py @@ -183,6 +183,80 @@ def _minres_jax(matvec, b, *, rtol: float = 1e-5, maxiter: int | None = None): return x, itn +def _cg_jax(matvec, b, *, rtol: float = 1e-5, maxiter: int | None = None): + """CG solver implemented with JAX arrays for use on accelerators. + + Solves the symmetric positive-definite system ``A x = b`` using the + Conjugate Gradient method. The scalar state (residual dot-products, + step sizes) is kept as Python floats so that convergence tests use + ordinary Python control flow (valid in JAX eager mode without ``jit``). + Only the vector state (``x``, ``r``, ``p``) is held as JAX arrays so + that each matvec call dispatches to the accelerator. + + Parameters + ---------- + matvec : callable + Function computing ``A @ x`` for a JAX array ``x``. ``A`` must be + symmetric positive-definite. + b : jax.Array + Right-hand side vector (``float32``). + rtol : float + Convergence tolerance on the relative residual ``||r|| / ||b||``. + Default ``1e-5``. + maxiter : int, optional + Maximum number of iterations; defaults to ``len(b)``. + + Returns + ------- + x : jax.Array + Approximate solution vector. + itn : int + Number of iterations taken. + """ + try: + import jax.numpy as jnp + except ImportError as e: + raise ImportError( # noqa: TRY003 + "JAX is required for backend='jax'; install with: " + "pip install fast-minimum-variance[jax]" + ) from e + + n = b.shape[0] + if maxiter is None: + maxiter = n + + x = jnp.zeros(n, dtype=b.dtype) + r = b + p = r + + r_dot: float = float(jnp.dot(r, r)) + b_norm_sq: float = r_dot # x0 = 0, so r0 = b + + if b_norm_sq == 0: + return x, 0 + + tol_sq: float = rtol ** 2 * b_norm_sq + itn: int = 0 + + while itn < maxiter: + itn += 1 + ap = matvec(p) + p_ap: float = float(jnp.dot(p, ap)) + if p_ap <= 0: + break # operator is not positive-definite; stop safely + alpha: float = r_dot / p_ap + x = x + alpha * p + r = r - alpha * ap + r_dot_new: float = float(jnp.dot(r, r)) + if r_dot_new <= tol_sq: + break + beta: float = r_dot_new / r_dot + p = r + beta * p + r_dot = r_dot_new + + return x, itn + + @dataclass(frozen=True) class Problem: """Mean-variance portfolio problem specification and solver interface. @@ -622,7 +696,6 @@ def _solve_cg_jax(self, *, project: bool = True): """ try: import jax.numpy as jnp - from jax.scipy.sparse.linalg import cg as jax_cg except ImportError as e: raise ImportError( # noqa: TRY003 "JAX is required for backend='jax'; install with: " @@ -669,10 +742,10 @@ def _matvec(y, pp=P_jax, xx=xx, gam=gam): pv = pp @ y return pp.T @ (xx.T @ (xx @ pv)) + gam * y - sol_jax, _ = jax_cg(_matvec, rhs_jax) + sol_jax, iters = _cg_jax(_matvec, rhs_jax) w_jax = w0_jax + P_jax @ sol_jax # Return as NumPy so the active-set loop stays backend-agnostic. - return np.asarray(w_jax, dtype=np.float64), 1 + return np.asarray(w_jax, dtype=np.float64), iters w, iters = self._constraint_active_set(_solve) if project: From 25405902006c7f9f173a5d530b921d937bf8e581 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 1 May 2026 11:04:24 +0000 Subject: [PATCH 6/9] Refactor: _jax_arrays helper, parametrize validation test, clean up imports Agent-Logs-Url: https://github.com/Jebel-Quant/fast_minimum_variance/sessions/4ca8e035-fee2-4d14-859a-2a7508b6abef Co-authored-by: tschm <2046079+tschm@users.noreply.github.com> --- src/fast_minimum_variance/api.py | 61 ++++++++++++++++---------------- tests/test_jax.py | 11 ++---- 2 files changed, 34 insertions(+), 38 deletions(-) diff --git a/src/fast_minimum_variance/api.py b/src/fast_minimum_variance/api.py index 693ce71..791ca75 100644 --- a/src/fast_minimum_variance/api.py +++ b/src/fast_minimum_variance/api.py @@ -326,6 +326,35 @@ def _m(self) -> int: """Number of equality constraints.""" return self.A.shape[1] + def _jax_arrays(self): + """Import JAX and return all problem arrays as ``float32`` JAX arrays. + + Centralises the lazy JAX import and the ``float32`` conversion that + both ``_solve_cg_jax`` and ``_solve_minres_jax`` need. Raises a + clear ``ImportError`` if JAX is not installed. + + Returns: + Tuple ``(jnp, xx, aa_eq, bb_eq, cc, dd, mu_jax)`` where each + array is a ``float32`` JAX array (``mu_jax`` is ``None`` when + ``self.mu`` is ``None``). + """ + try: + import jax.numpy as jnp + except ImportError as e: + raise ImportError( # noqa: TRY003 + "JAX is required for backend='jax'; install with: " + "pip install fast-minimum-variance[jax]" + ) from e + + # Convert to float32 — jax-metal has limited float64 support. + xx = jnp.array(self.X, dtype=jnp.float32) # noqa: N806 + aa_eq = jnp.array(self.A, dtype=jnp.float32) + bb_eq = jnp.array(self.b, dtype=jnp.float32) + cc = jnp.array(self.C, dtype=jnp.float32) + dd = jnp.array(self.d, dtype=jnp.float32) + mu_jax = jnp.array(self.mu, dtype=jnp.float32) if self.mu is not None else None + return jnp, xx, aa_eq, bb_eq, cc, dd, mu_jax + def _kkt(self, active=None): """Build the (N+m) x (N+m) KKT saddle-point system.""" if active is None: @@ -555,21 +584,7 @@ def _solve_minres_jax(self, *, project: bool = True): n_iters is the total number of MINRES iterations across all active-set steps. """ - try: - import jax.numpy as jnp - except ImportError as e: - raise ImportError( # noqa: TRY003 - "JAX is required for backend='jax'; install with: " - "pip install fast-minimum-variance[jax]" - ) from e - - # Convert to float32 — jax-metal has limited float64 support. - xx = jnp.array(self.X, dtype=jnp.float32) # noqa: N806 - aa_eq = jnp.array(self.A, dtype=jnp.float32) - bb_eq = jnp.array(self.b, dtype=jnp.float32) - cc = jnp.array(self.C, dtype=jnp.float32) - dd = jnp.array(self.d, dtype=jnp.float32) - mu_jax = jnp.array(self.mu, dtype=jnp.float32) if self.mu is not None else None + jnp, xx, aa_eq, bb_eq, cc, dd, mu_jax = self._jax_arrays() # noqa: N806 gam = float(self.gamma) rho = float(self.rho) na = self.n @@ -694,21 +709,7 @@ def _solve_cg_jax(self, *, project: bool = True): n_iters is the total number of CG iterations across all active-set steps. """ - try: - import jax.numpy as jnp - except ImportError as e: - raise ImportError( # noqa: TRY003 - "JAX is required for backend='jax'; install with: " - "pip install fast-minimum-variance[jax]" - ) from e - - # Convert to float32 — jax-metal has limited float64 support. - xx = jnp.array(self.X, dtype=jnp.float32) # noqa: N806 - aa_eq = jnp.array(self.A, dtype=jnp.float32) - bb_eq = jnp.array(self.b, dtype=jnp.float32) - cc = jnp.array(self.C, dtype=jnp.float32) - dd = jnp.array(self.d, dtype=jnp.float32) - mu_jax = jnp.array(self.mu, dtype=jnp.float32) if self.mu is not None else None + jnp, xx, aa_eq, bb_eq, cc, dd, mu_jax = self._jax_arrays() # noqa: N806 gam = float(self.gamma) rho = float(self.rho) diff --git a/tests/test_jax.py b/tests/test_jax.py index ba141da..1b6d877 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -93,17 +93,12 @@ def test_jax_minres_returns_positive_iters(self, problem_jax): class TestJaxBackendValidation: """Tests for backend validation in Problem.__post_init__.""" - def test_jax_backend_invalid_raises(self): - """Unknown backend string raises ValueError.""" - X = np.random.default_rng(0).standard_normal((50, 3)) # noqa: N806 - with pytest.raises(ValueError, match="Unknown backend"): - Problem(X, backend="gpu") - - def test_jax_backend_unknown_raises(self): + @pytest.mark.parametrize("backend", ["gpu", "cupy", "cuda"]) + def test_unknown_backend_raises(self, backend): """Any unrecognised backend string raises ValueError.""" X = np.random.default_rng(0).standard_normal((50, 3)) # noqa: N806 with pytest.raises(ValueError, match="Unknown backend"): - Problem(X, backend="cupy") + Problem(X, backend=backend) if __name__ == "__main__": From bc9305b23e6e7a31bfb63fc9eef738024b44ac4d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 1 May 2026 11:05:29 +0000 Subject: [PATCH 7/9] Fix na_str parameter name in benchmark _fmt helper Agent-Logs-Url: https://github.com/Jebel-Quant/fast_minimum_variance/sessions/4ca8e035-fee2-4d14-859a-2a7508b6abef Co-authored-by: tschm <2046079+tschm@users.noreply.github.com> --- benchmarks/jax_backend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/jax_backend.py b/benchmarks/jax_backend.py index d50b677..777dec1 100644 --- a/benchmarks/jax_backend.py +++ b/benchmarks/jax_backend.py @@ -98,9 +98,9 @@ def _run_size(T, N, solver): # noqa: N803 return {"t_np": t_np, "t_jax": t_jax, "t_warmup": t_warmup, "err": err} -def _fmt(val, fmt=".4f", na="N/A"): - """Format a value or return na if None.""" - return f"{val:{fmt}}" if val is not None else na +def _fmt(val, fmt=".4f", na_str="N/A"): + """Format a value or return na_str if None.""" + return f"{val:{fmt}}" if val is not None else na_str def _run_benchmark(solver): From 6bb8ca5385e751c9277878bcbb6b7da0724f7967 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 1 May 2026 11:18:37 +0000 Subject: [PATCH 8/9] Fix JAX Krylov performance: lax.while_loop MINRES + jax.scipy.linalg.cg, clearer benchmark CPU warning Agent-Logs-Url: https://github.com/Jebel-Quant/fast_minimum_variance/sessions/52397cd8-e515-4bd9-a149-a420095bf286 Co-authored-by: tschm <2046079+tschm@users.noreply.github.com> --- benchmarks/jax_backend.py | 70 +++++-- src/fast_minimum_variance/api.py | 325 +++++++++++++------------------ 2 files changed, 187 insertions(+), 208 deletions(-) diff --git a/benchmarks/jax_backend.py b/benchmarks/jax_backend.py index 777dec1..88cc3a9 100644 --- a/benchmarks/jax_backend.py +++ b/benchmarks/jax_backend.py @@ -9,33 +9,55 @@ pip install jax-metal # Apple Silicon only python benchmarks/jax_backend.py -JAX note: the first call to each solver traces and compiles the computation -graph (XLA / Metal). This script runs one **warmup** solve before timing so -that reported numbers reflect steady-state throughput, not compilation cost. -The warmup time is printed separately so you can see the one-off JIT cost. +**CPU vs Metal** + +JAX on a CPU backend is *slower* than NumPy because XLA adds per-operation +dispatch overhead that outweighs the benefit of compiled loops for these +problem sizes. The JAX backend is designed for GPU/Metal, where the two +matrix-vector products per Krylov step (``X.T @ (X @ x)``) run as accelerated +GEMVs and the loop stays fully on-device via ``jax.lax.while_loop``. + +To use the Metal GPU on Apple Silicon:: + + pip install jax-metal + +Once installed, ``jax.default_backend()`` will report ``'metal'`` and +speedups of 5–20× are typical at N ≥ 500 on M-series chips. + +**JAX warmup** + +The first call per problem size traces and compiles the XLA / Metal kernel. +This script runs one warmup solve before timing so that reported numbers +reflect steady-state throughput, not compilation cost. The ``warmup_jax`` +column shows the one-off JIT cost (paid once per process, not per solve in a +rolling-window loop). The benchmark reports: * ``time_np`` — NumPy backend wall time (``float64``) * ``time_jax`` — JAX backend wall time after warmup (``float32``) -* ``warmup_jax``— time for the first JAX call (JIT / XLA compilation) -* ``speedup`` — ``time_np / time_jax`` (>1 means JAX is faster) -* ``err`` — ``max |w_jax - w_np|`` (float32 accuracy check) +* ``warmup_jax``— first-call JIT / XLA compilation overhead +* ``speedup`` — ``time_np / time_jax`` (>1 means JAX is faster) +* ``err`` — ``max |w_jax - w_np|`` (float32 accuracy check) -Run without JAX installed to see NumPy-only timings (JAX columns will show -``N/A``). +Run without JAX installed to see NumPy-only timings (JAX columns show N/A). """ +import sys import time import numpy as np try: - import jax # noqa: F401 + import jax JAX_AVAILABLE = True + _JAX_VERSION = jax.__version__ + _JAX_BACKEND = jax.default_backend() except ImportError: JAX_AVAILABLE = False + _JAX_VERSION = None + _JAX_BACKEND = None from fast_minimum_variance.api import Problem @@ -53,6 +75,22 @@ N_REPS = 5 # timed repetitions after warmup; min is reported +_INSTALL_MSG = """\ + JAX is not installed in this environment. To enable the JAX columns: + + pip install fast-minimum-variance[jax] + pip install jax-metal # Apple Silicon / Metal GPU + # or: pip install jax[cuda12] # NVIDIA CUDA 12 + + Then re-run this script.\ +""" + +_CPU_WARNING = """\ + Note: JAX backend is 'cpu' — XLA adds dispatch overhead that makes the JAX + path slower than NumPy on CPU. Install jax-metal (Apple Silicon) or a CUDA + build of JAX to see GPU speedups.\ +""" + def _make_problem(T, N, seed=42, backend="numpy"): # noqa: N803 """Generate a random return matrix and construct a Problem.""" @@ -138,12 +176,14 @@ def _run_benchmark(solver): def main(): """Entry point.""" print("fast-minimum-variance: NumPy vs JAX backend benchmark") - print(f"JAX available: {JAX_AVAILABLE}") + print(f"Python: {sys.version.split()[0]} ({sys.executable})") if JAX_AVAILABLE: - import jax - - print(f"JAX version: {jax.__version__}") - print(f"JAX backend: {jax.default_backend()}") + print(f"JAX: {_JAX_VERSION} (backend '{_JAX_BACKEND}')") + if _JAX_BACKEND == "cpu": + print(_CPU_WARNING) + else: + print("JAX: not installed") + print(_INSTALL_MSG) print(f"Repetitions: {N_REPS} (min of {N_REPS} runs after one warmup)") print("Times in seconds. speedup = time_np / time_jax (>1 means JAX faster)") diff --git a/src/fast_minimum_variance/api.py b/src/fast_minimum_variance/api.py index 791ca75..ba90d4e 100644 --- a/src/fast_minimum_variance/api.py +++ b/src/fast_minimum_variance/api.py @@ -14,21 +14,20 @@ def clip_and_renormalize(w: np.ndarray) -> np.ndarray: def _minres_jax(matvec, b, *, rtol: float = 1e-5, maxiter: int | None = None): - """MINRES solver implemented with JAX arrays for use on accelerators. + """MINRES solver using ``jax.lax.while_loop`` — stays fully on-device. Solves the symmetric (possibly indefinite) linear system ``A x = b`` using - the Minimum Residual method. The implementation follows the algorithm of - Paige and Saunders (1975) and is structured after - ``scipy.sparse.linalg.minres``, adapted to use JAX primitives so that the - matrix-vector product ``matvec`` can run on a JAX accelerator backend. + the Minimum Residual method (Paige & Saunders 1975), ported from + ``scipy.sparse.linalg.minres``. - The scalar Lanczos recurrence variables are kept as Python floats so that - convergence tests use ordinary Python control flow (valid in JAX eager - mode without ``jit``). Only the vector state (``x``, ``w``, ``r1``, - ``r2``, ``v``, ``y``) is held as JAX arrays so that each matvec call - dispatches to the accelerator. + The entire Lanczos recurrence — including all scalar state — runs inside a + single ``jax.lax.while_loop`` so there are **no host synchronisations per + iteration**. This is essential for GPU / Metal performance: the only + device-to-host transfer is the final extraction of the iteration count. - No preconditioner and no shift are applied (``M = I``, ``shift = 0``). + ``jax.scipy.sparse.linalg`` does not include MINRES as of JAX 0.4–0.10, + so this implementation is necessary. No preconditioner and no shift are + applied (``M = I``, ``shift = 0``). Parameters ---------- @@ -46,9 +45,10 @@ def _minres_jax(matvec, b, *, rtol: float = 1e-5, maxiter: int | None = None): x : jax.Array Approximate solution vector. itn : int - Number of iterations taken. + Number of iterations taken (Python int; one host sync at return). """ try: + import jax import jax.numpy as jnp except ImportError as e: raise ImportError( # noqa: TRY003 @@ -61,199 +61,123 @@ def _minres_jax(matvec, b, *, rtol: float = 1e-5, maxiter: int | None = None): maxiter = 5 * n dtype = b.dtype - eps = float(jnp.finfo(dtype).eps) - - x = jnp.zeros(n, dtype=dtype) - - # With identity preconditioner (M = I): y = r1 = b. - r1 = b - y = r1 - - beta1 = float(jnp.dot(r1, y)) - if beta1 < 0: - raise ValueError("indefinite preconditioner") # noqa: TRY003 - if beta1 == 0: - return x, 0 - - bnorm = float(jnp.linalg.norm(b)) - if bnorm == 0: - return b, 0 - - beta1 = beta1 ** 0.5 - - # Scalar Lanczos state — kept as Python floats so that convergence - # comparisons use ordinary Python control flow (valid in eager JAX). - oldb: float = 0.0 - beta: float = beta1 - dbar: float = 0.0 - epsln: float = 0.0 - phibar: float = beta1 - rhs1: float = beta1 - rhs2: float = 0.0 - tnorm2: float = 0.0 - gmax: float = 0.0 - gmin: float = float("inf") - cs: float = -1.0 - sn: float = 0.0 - - # Vector state — JAX arrays so matvec dispatches to the accelerator. - w = jnp.zeros(n, dtype=dtype) - w2 = jnp.zeros(n, dtype=dtype) - r2 = r1 - - itn: int = 0 - istop: int = 0 - - while itn < maxiter: - itn += 1 - - s = 1.0 / beta - v = s * y - + eps = jnp.finfo(dtype).eps + + # Eager zero-cost checks before entering the compiled loop. + beta1 = jnp.sqrt(jnp.dot(b, b)) # M = I, so y = r1 = b + if float(beta1) == 0: + return jnp.zeros(n, dtype=dtype), 0 + + # All state packed as JAX arrays (0-dim scalars + 1-dim vectors) so that + # lax.while_loop can compile the body into a single device kernel. + # Order: vectors first, then scalars, then integer control variables. + zeros_v = jnp.zeros(n, dtype=dtype) + init_state = ( + # --- vectors --- + zeros_v, # x + zeros_v, # w + zeros_v, # w2 + b, # r1 + b, # r2 (= y with M=I) + # --- scalar Lanczos state --- + jnp.array(0.0, dtype=dtype), # oldb + beta1, # beta + jnp.array(0.0, dtype=dtype), # dbar + jnp.array(0.0, dtype=dtype), # epsln + beta1, # phibar + beta1, # rhs1 + jnp.array(0.0, dtype=dtype), # rhs2 + jnp.array(0.0, dtype=dtype), # tnorm2 + jnp.array(0.0, dtype=dtype), # gmax + jnp.array(jnp.finfo(dtype).max, dtype=dtype), # gmin (starts at ∞) + jnp.array(-1.0, dtype=dtype), # cs + jnp.array(0.0, dtype=dtype), # sn + # --- control --- + jnp.array(0, dtype=jnp.int32), # itn + jnp.array(0, dtype=jnp.int32), # istop + ) + + def cond_fun(state): + """Continue while itn < maxiter and not yet converged.""" + *_, itn, istop = state + return (itn < maxiter) & (istop == 0) + + def body_fun(state): + """Single Lanczos step: updates all state in-place on device.""" + (x, w, w2, r1, r2, + oldb, beta, dbar, epsln, phibar, rhs1, rhs2, tnorm2, gmax, gmin, cs, sn, + itn, istop) = state + + itn = itn + 1 + + # Normalise current Lanczos vector. + v = r2 / beta # r2 == y with M=I + + # New Lanczos vector: A v - beta/oldb * r1 (skip on first iteration). y = matvec(v) + # Guard against oldb==0 on iteration 1 by using a safe denominator. + safe_oldb = jnp.where(oldb == 0, jnp.ones_like(oldb), oldb) + y = jnp.where(itn > 1, y - (beta / safe_oldb) * r1, y) - if itn >= 2: - y = y - (beta / oldb) * r1 - - alfa = float(jnp.dot(v, y)) + alfa = jnp.dot(v, y) y = y - (alfa / beta) * r2 r1 = r2 - r2 = y - y = r2 # M = I: psolve(r2) = r2 + r2 = y # y = r2 with M=I oldb = beta - beta = float(jnp.dot(r2, y)) - if beta < 0: - raise ValueError("non-symmetric matrix") # noqa: TRY003 - beta = beta ** 0.5 - tnorm2 += alfa ** 2 + oldb ** 2 + beta ** 2 + # Clamp to avoid sqrt of tiny negatives from floating-point rounding. + beta = jnp.sqrt(jnp.maximum(jnp.dot(r2, r2), 0.0)) + tnorm2 = tnorm2 + alfa ** 2 + oldb ** 2 + beta ** 2 - # Apply previous plane rotation to get [delta, gbar, epsln, dbar]. + # Apply previous Givens rotation. oldeps = epsln delta = cs * dbar + sn * alfa - gbar = sn * dbar - cs * alfa + gbar = sn * dbar - cs * alfa epsln = sn * beta - dbar = -cs * beta - root = (gbar ** 2 + dbar ** 2) ** 0.5 - - # Compute next plane rotation. - gamma = (gbar ** 2 + beta ** 2) ** 0.5 - gamma = max(gamma, eps) - cs = gbar / gamma - sn = beta / gamma - phi = cs * phibar + dbar = -cs * beta + root = jnp.sqrt(gbar ** 2 + dbar ** 2) + + # Compute new Givens rotation. + gamma = jnp.maximum(jnp.sqrt(gbar ** 2 + beta ** 2), eps) + cs = gbar / gamma + sn = beta / gamma + phi = cs * phibar phibar = sn * phibar - # Update x. - denom = 1.0 / gamma - w1 = w2 - w2 = w - w = (v - oldeps * w1 - delta * w2) * denom - x = x + phi * w + # Update solution. + w_new = (v - oldeps * w2 - delta * w) / gamma + x = x + phi * w_new - gmax = max(gmax, gamma) - gmin = min(gmin, gamma) - z = rhs1 / gamma + gmax = jnp.maximum(gmax, gamma) + gmin = jnp.minimum(gmin, gamma) + z = rhs1 / gamma rhs1 = rhs2 - delta * z rhs2 = -epsln * z - # Estimate convergence: test1 ≈ ||r|| / (||A|| ||x||). - anorm = tnorm2 ** 0.5 - ynorm = float(jnp.linalg.norm(x)) - test1 = phibar / (anorm * ynorm) if (anorm > 0 and ynorm > 0) else float("inf") - test2 = root / anorm if anorm > 0 else float("inf") - acond = gmax / gmin if gmin > 0 else float("inf") - - if 1.0 + test2 <= 1.0: - istop = 2 - if 1.0 + test1 <= 1.0: - istop = 1 - if itn >= maxiter: - istop = 6 - if acond >= 0.1 / eps: - istop = 4 - if test2 <= rtol: - istop = 2 - if test1 <= rtol: - istop = 1 - - if istop != 0: - break - - return x, itn - - -def _cg_jax(matvec, b, *, rtol: float = 1e-5, maxiter: int | None = None): - """CG solver implemented with JAX arrays for use on accelerators. - - Solves the symmetric positive-definite system ``A x = b`` using the - Conjugate Gradient method. The scalar state (residual dot-products, - step sizes) is kept as Python floats so that convergence tests use - ordinary Python control flow (valid in JAX eager mode without ``jit``). - Only the vector state (``x``, ``r``, ``p``) is held as JAX arrays so - that each matvec call dispatches to the accelerator. - - Parameters - ---------- - matvec : callable - Function computing ``A @ x`` for a JAX array ``x``. ``A`` must be - symmetric positive-definite. - b : jax.Array - Right-hand side vector (``float32``). - rtol : float - Convergence tolerance on the relative residual ``||r|| / ||b||``. - Default ``1e-5``. - maxiter : int, optional - Maximum number of iterations; defaults to ``len(b)``. - - Returns - ------- - x : jax.Array - Approximate solution vector. - itn : int - Number of iterations taken. - """ - try: - import jax.numpy as jnp - except ImportError as e: - raise ImportError( # noqa: TRY003 - "JAX is required for backend='jax'; install with: " - "pip install fast-minimum-variance[jax]" - ) from e - - n = b.shape[0] - if maxiter is None: - maxiter = n - - x = jnp.zeros(n, dtype=b.dtype) - r = b - p = r - - r_dot: float = float(jnp.dot(r, r)) - b_norm_sq: float = r_dot # x0 = 0, so r0 = b - - if b_norm_sq == 0: - return x, 0 - - tol_sq: float = rtol ** 2 * b_norm_sq - itn: int = 0 - - while itn < maxiter: - itn += 1 - ap = matvec(p) - p_ap: float = float(jnp.dot(p, ap)) - if p_ap <= 0: - break # operator is not positive-definite; stop safely - alpha: float = r_dot / p_ap - x = x + alpha * p - r = r - alpha * ap - r_dot_new: float = float(jnp.dot(r, r)) - if r_dot_new <= tol_sq: - break - beta: float = r_dot_new / r_dot - p = r + beta * p - r_dot = r_dot_new - + # Convergence estimates (all stay on device — no float() calls). + anorm = jnp.sqrt(tnorm2) + ynorm = jnp.linalg.norm(x) + inf_ = jnp.array(jnp.inf, dtype=dtype) + test1 = jnp.where((anorm > 0) & (ynorm > 0), phibar / (anorm * ynorm), inf_) + test2 = jnp.where(anorm > 0, root / anorm, inf_) + acond = jnp.where(gmin > 0, gmax / gmin, inf_) + + # Accumulate stopping criterion (later conditions win). + new_istop = istop + new_istop = jnp.where(test1 <= rtol, jnp.int32(1), new_istop) + new_istop = jnp.where(test2 <= rtol, jnp.int32(2), new_istop) + new_istop = jnp.where(acond >= 0.1 / eps, jnp.int32(4), new_istop) + new_istop = jnp.where(itn >= maxiter, jnp.int32(6), new_istop) + new_istop = jnp.where(1.0 + test1 <= 1.0, jnp.int32(1), new_istop) + new_istop = jnp.where(1.0 + test2 <= 1.0, jnp.int32(2), new_istop) + + return (x, w_new, w, r1, r2, + oldb, beta, dbar, epsln, phibar, rhs1, rhs2, tnorm2, gmax, gmin, cs, sn, + itn, new_istop) + + final_state = jax.lax.while_loop(cond_fun, body_fun, init_state) + x = final_state[0] + itn = int(final_state[-2]) # single host sync at the very end return x, itn @@ -691,10 +615,18 @@ def _solve_cg_jax(self, *, project: bool = True): per iteration run on the available JAX accelerator (e.g. Apple Silicon via ``jax-metal``). + ``jax.scipy.sparse.linalg.cg`` uses ``jax.lax.while_loop`` internally, + so the entire CG solve stays on-device with **no host synchronisations + per iteration**. This is what makes the JAX path faster than NumPy on + GPU/Metal for large problems. + All arrays are converted to ``float32`` before the JAX solve because ``jax-metal`` has limited ``float64`` support. Results are returned as NumPy arrays so the rest of the active-set loop is unaffected. + Because ``jax.scipy.sparse.linalg.cg`` does not yet expose the + iteration count, ``n_iters`` is always ``0`` for the JAX backend. + Requires JAX to be installed:: pip install fast-minimum-variance[jax] @@ -706,10 +638,11 @@ def _solve_cg_jax(self, *, project: bool = True): Returns: Tuple (w, n_iters) where w is the weight vector of shape (N,) and - n_iters is the total number of CG iterations across all active-set - steps. + n_iters is 0 for the JAX backend (iteration count unavailable from + ``jax.scipy.sparse.linalg.cg``). """ jnp, xx, aa_eq, bb_eq, cc, dd, mu_jax = self._jax_arrays() # noqa: N806 + from jax.scipy.sparse.linalg import cg as jax_cg gam = float(self.gamma) rho = float(self.rho) @@ -743,10 +676,16 @@ def _matvec(y, pp=P_jax, xx=xx, gam=gam): pv = pp @ y return pp.T @ (xx.T @ (xx @ pv)) + gam * y - sol_jax, iters = _cg_jax(_matvec, rhs_jax) + # jax.scipy.sparse.linalg.cg uses lax.while_loop internally — + # the entire CG solve stays on-device with no host syncs per + # iteration. The returned info is always None (JAX does not yet + # expose the iteration count from its CG implementation). + sol_jax, _ = jax_cg(_matvec, rhs_jax) w_jax = w0_jax + P_jax @ sol_jax # Return as NumPy so the active-set loop stays backend-agnostic. - return np.asarray(w_jax, dtype=np.float64), iters + # Iteration count is unavailable from jax.scipy.sparse.linalg.cg; + # return 0 as a sentinel (documented in solve_cg's docstring). + return np.asarray(w_jax, dtype=np.float64), 0 w, iters = self._constraint_active_set(_solve) if project: From b680cea6da8cd699d593d8b4357313815aecd9f2 Mon Sep 17 00:00:00 2001 From: Thomas Schmelzer Date: Fri, 1 May 2026 17:55:23 +0400 Subject: [PATCH 9/9] Add jax-metal to jax extra; extend benchmark to N=5000 - Pin jax>=0.4.34,<0.5 on Apple Silicon so jax-metal 0.1.1 is compatible - Add jax-metal as a conditional dep (darwin + arm64 only) - Extend SIZES in benchmark to (5000, 2500) and (10000, 5000) Co-Authored-By: Claude Sonnet 4.6 --- benchmarks/jax_backend.py | 21 +--- pyproject.toml | 4 +- uv.lock | 225 +++++++++++++++++++++++++++++++++++++- 3 files changed, 227 insertions(+), 23 deletions(-) diff --git a/benchmarks/jax_backend.py b/benchmarks/jax_backend.py index 88cc3a9..c143d74 100644 --- a/benchmarks/jax_backend.py +++ b/benchmarks/jax_backend.py @@ -22,7 +22,7 @@ pip install jax-metal Once installed, ``jax.default_backend()`` will report ``'metal'`` and -speedups of 5–20× are typical at N ≥ 500 on M-series chips. +speedups of 5-20x are typical at N >= 500 on M-series chips. **JAX warmup** @@ -64,14 +64,7 @@ # --------------------------------------------------------------------------- # Problem sizes to benchmark: (T, N) pairs # --------------------------------------------------------------------------- -SIZES = [ - (250, 20), - (500, 50), - (500, 100), - (1000, 200), - (1000, 500), - (2000, 1000), -] +SIZES = [(250, 20), (500, 50), (500, 100), (1000, 200), (1000, 500), (2000, 1000), (5000, 2500), (10000, 5000)] N_REPS = 5 # timed repetitions after warmup; min is reported @@ -143,10 +136,7 @@ def _fmt(val, fmt=".4f", na_str="N/A"): def _run_benchmark(solver): """Run the full benchmark for one solver and print results.""" - hdr = ( - f"{'T':>6} {'N':>6} {'time_np':>9} {'time_jax':>9} " - f"{'warmup_jax':>12} {'speedup':>8} {'err':>10}" - ) + hdr = f"{'T':>6} {'N':>6} {'time_np':>9} {'time_jax':>9} {'warmup_jax':>12} {'speedup':>8} {'err':>10}" print(f"\n{'─' * len(hdr)}") print(f" {solver}") print(f"{'─' * len(hdr)}") @@ -160,10 +150,7 @@ def _run_benchmark(solver): warmup = r["t_warmup"] err = r["err"] - if t_jax is not None and t_jax > 0: - speedup = f"{t_np / t_jax:7.2f}x" - else: - speedup = "N/A" + speedup = f"{t_np / t_jax:7.2f}x" if t_jax is not None and t_jax > 0 else "N/A" print( f"{T:>6} {N:>6} {_fmt(t_np):>9} {_fmt(t_jax):>9} " diff --git a/pyproject.toml b/pyproject.toml index f5d864b..0e299a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,9 @@ convex = [ "cvxpy>=1.0", ] jax = [ - "jax>=0.4", + "jax>=0.4.34,<0.5; sys_platform == 'darwin' and platform_machine == 'arm64'", + "jax>=0.4; sys_platform != 'darwin' or platform_machine != 'arm64'", + "jax-metal>=0.1; sys_platform == 'darwin' and platform_machine == 'arm64'", ] [dependency-groups] diff --git a/uv.lock b/uv.lock index 72ccd88..099b7cb 100644 --- a/uv.lock +++ b/uv.lock @@ -4,10 +4,20 @@ requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", "python_full_version >= '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version < '3.14' and sys_platform == 'win32'", - "python_full_version < '3.14' and sys_platform == 'emscripten'", - "python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.14' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "(python_full_version >= '3.14' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.14' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and sys_platform == 'win32'", + "python_full_version < '3.12' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and sys_platform == 'emscripten'", + "python_full_version < '3.12' and sys_platform == 'emscripten'", + "python_full_version == '3.13.*' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "(python_full_version == '3.13.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", + "python_full_version == '3.12.*' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "(python_full_version == '3.12.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", + "python_full_version < '3.12' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "(python_full_version < '3.12' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", ] [[package]] @@ -285,6 +295,11 @@ dependencies = [ convex = [ { name = "cvxpy" }, ] +jax = [ + { name = "jax", version = "0.4.38", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, + { name = "jax", version = "0.10.0", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'arm64' or sys_platform != 'darwin'" }, + { name = "jax-metal", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, +] [package.dev-dependencies] dev = [ @@ -295,10 +310,13 @@ dev = [ [package.metadata] requires-dist = [ { name = "cvxpy", marker = "extra == 'convex'", specifier = ">=1.0" }, + { name = "jax", marker = "platform_machine == 'arm64' and sys_platform == 'darwin' and extra == 'jax'", specifier = ">=0.4.34,<0.5" }, + { name = "jax", marker = "(platform_machine != 'arm64' and extra == 'jax') or (sys_platform != 'darwin' and extra == 'jax')", specifier = ">=0.4" }, + { name = "jax-metal", marker = "platform_machine == 'arm64' and sys_platform == 'darwin' and extra == 'jax'", specifier = ">=0.1" }, { name = "numpy", specifier = ">=2.0.0" }, { name = "scipy", specifier = ">=1.0" }, ] -provides-extras = ["convex"] +provides-extras = ["convex", "jax"] [package.metadata.requires-dev] dev = [ @@ -433,6 +451,141 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/96/92447566d16df59b2a776c0fb82dbc4d9e07cd95062562af01e408583fc4/itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef", size = 16234, upload-time = "2024-04-16T21:28:14.499Z" }, ] +[[package]] +name = "jax" +version = "0.4.38" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "python_full_version == '3.12.*' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "python_full_version < '3.12' and platform_machine == 'arm64' and sys_platform == 'darwin'", +] +dependencies = [ + { name = "jaxlib", version = "0.4.38", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, + { name = "ml-dtypes", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, + { name = "numpy", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, + { name = "opt-einsum", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, + { name = "scipy", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fb/e5/c4aa9644bb96b7f6747bd7c9f8cda7665ca5e194fa2542b2dea3ff730701/jax-0.4.38.tar.gz", hash = "sha256:43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8", size = 1930034, upload-time = "2024-12-17T23:03:47.623Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/49/b4418a7a892c0dd64442bbbeef54e1cdfe722dfc5a7bf0d611d3f5f90e99/jax-0.4.38-py3-none-any.whl", hash = "sha256:78987306f7041ea8500d99df1a17c33ed92620c2268c4c3677fb24e06712be64", size = 2236864, upload-time = "2024-12-17T23:03:44.433Z" }, +] + +[[package]] +name = "jax" +version = "0.10.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version >= '3.14' and sys_platform == 'emscripten'", + "(python_full_version >= '3.14' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.14' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and sys_platform == 'win32'", + "python_full_version < '3.12' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and sys_platform == 'emscripten'", + "python_full_version < '3.12' and sys_platform == 'emscripten'", + "(python_full_version == '3.13.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", + "(python_full_version == '3.12.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", + "(python_full_version < '3.12' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", +] +dependencies = [ + { name = "jaxlib", version = "0.10.0", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'arm64' or sys_platform != 'darwin'" }, + { name = "ml-dtypes", marker = "platform_machine != 'arm64' or sys_platform != 'darwin'" }, + { name = "numpy", marker = "platform_machine != 'arm64' or sys_platform != 'darwin'" }, + { name = "opt-einsum", marker = "platform_machine != 'arm64' or sys_platform != 'darwin'" }, + { name = "scipy", marker = "platform_machine != 'arm64' or sys_platform != 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/50/f0/bcb81d28267d2054d0daed766c7fa16bcee5e481331b4d1e14f5fbe662be/jax-0.10.0.tar.gz", hash = "sha256:0119c767de1645f407df72345d28a3837dc904f1d698911c121d8f2b396fdece", size = 2663397, upload-time = "2026-04-22T13:22:28.563Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl", hash = "sha256:76c42ba163c8db3dc2e449e225b888c0edfb623ded31efdc96d85e0fda1d26e8", size = 3094950, upload-time = "2026-04-16T12:32:11.576Z" }, +] + +[[package]] +name = "jax-metal" +version = "0.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jax", version = "0.4.38", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, + { name = "jaxlib", version = "0.4.38", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, + { name = "six", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, + { name = "wheel", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl", hash = "sha256:f1dbfecb298cdd3ba6da3ad6dc9a2adb63d71741f8b8ece28c296b32d608b6c8", size = 41179678, upload-time = "2024-10-08T16:56:31.563Z" }, +] + +[[package]] +name = "jaxlib" +version = "0.4.38" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "python_full_version == '3.12.*' and platform_machine == 'arm64' and sys_platform == 'darwin'", + "python_full_version < '3.12' and platform_machine == 'arm64' and sys_platform == 'darwin'", +] +dependencies = [ + { name = "ml-dtypes", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, + { name = "numpy", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, + { name = "scipy", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/2a/3458130d44d44038fd6974e7c43948f68408f685063203b82229b9b72c1a/jaxlib-0.4.38-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb0eaae7369157afecbead50aaf29e73ffddfa77a2335d721bd9794f3c510e4", size = 79488377, upload-time = "2024-12-17T23:05:31.031Z" }, + { url = "https://files.pythonhosted.org/packages/ab/b1/c9d2a7ba9ebeabb7ac37082f4c466364f475dc7550a79358c0f0aa89fdf2/jaxlib-0.4.38-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f33bcafe32c97a562ecf6894d7c41674c80c0acdedfa5423d49af51147149874", size = 79509242, upload-time = "2024-12-17T23:06:33.73Z" }, + { url = "https://files.pythonhosted.org/packages/cb/bf/fbbf61da319611d88e11c691d5a2077039208ded05e1731dea940f824a59/jaxlib-0.4.38-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6fe326b8af366387dd47ccf312583b2b17fed12712c9b74a648b18a13cbdbabf", size = 79508735, upload-time = "2024-12-17T23:07:42.037Z" }, +] + +[[package]] +name = "jaxlib" +version = "0.10.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version >= '3.14' and sys_platform == 'emscripten'", + "(python_full_version >= '3.14' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version >= '3.14' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and sys_platform == 'win32'", + "python_full_version < '3.12' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and sys_platform == 'emscripten'", + "python_full_version < '3.12' and sys_platform == 'emscripten'", + "(python_full_version == '3.13.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", + "(python_full_version == '3.12.*' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", + "(python_full_version < '3.12' and platform_machine != 'arm64' and sys_platform == 'darwin') or (python_full_version < '3.12' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'win32')", +] +dependencies = [ + { name = "ml-dtypes", marker = "platform_machine != 'arm64' or sys_platform != 'darwin'" }, + { name = "numpy", marker = "platform_machine != 'arm64' or sys_platform != 'darwin'" }, + { name = "scipy", marker = "platform_machine != 'arm64' or sys_platform != 'darwin'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/5c/64a60f90d48bb6ab68ece63b7fa78855e8f8cefc4045f198a5c8695bfd99/jaxlib-0.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:277032e9f074c3fd5ffd1e0cb03d4fe66e272de472667cdbc418ad99b21b646a", size = 60115498, upload-time = "2026-04-16T12:33:15.93Z" }, + { url = "https://files.pythonhosted.org/packages/71/bc/b75d9e09bcf46e00fd9cdd6c457219a8fe2033d351c2d133917662e8cbaa/jaxlib-0.10.0-cp311-cp311-manylinux_2_27_aarch64.whl", hash = "sha256:3db94ebc859375d955de3504182add7ce1733ce3d30c15e0ef031602cb51a559", size = 79395106, upload-time = "2026-04-16T12:33:19.648Z" }, + { url = "https://files.pythonhosted.org/packages/64/13/a94b53b0acd3fccce0441e3811e86224e5b21ac122f2dea4be1ccdeb7dc0/jaxlib-0.10.0-cp311-cp311-manylinux_2_27_x86_64.whl", hash = "sha256:9be229993a41e5b2b84f234ecc19a5de02f35eddb1195cf027bd539e1601e15d", size = 85005588, upload-time = "2026-04-16T12:33:23.368Z" }, + { url = "https://files.pythonhosted.org/packages/d2/36/fbc303c0a41ac26daceeba0a9884d9206657e8eb1981f3f76da17f1ecc7f/jaxlib-0.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:421cdf3a4a5c2ee41471035e586954c8dc599d677ce9b11b063c3926a82a7850", size = 64195649, upload-time = "2026-04-16T12:33:26.972Z" }, + { url = "https://files.pythonhosted.org/packages/79/0c/279cb4dc009fe87a8315d1b182f520693236ad07b852152df344ea4e4021/jaxlib-0.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7c1d9b463327c7a2333f210114ecb04f28fefc51ba8233a85a2280cce75bdb42", size = 60137156, upload-time = "2026-04-16T12:33:30.306Z" }, + { url = "https://files.pythonhosted.org/packages/e3/cd/59ead5a90df739d1b8c1d1d00443558fd30adf5abb0319966ce340d49ff3/jaxlib-0.10.0-cp312-cp312-manylinux_2_27_aarch64.whl", hash = "sha256:aa1d70f1a4e27eb403654e71e2fb28d5786d3e9b77fc1847e8c5389880927ca4", size = 79398938, upload-time = "2026-04-16T12:33:34.14Z" }, + { url = "https://files.pythonhosted.org/packages/b5/20/9b07fc8b327b222b6f72a4978eb4f2ebe856ee71237d63c4d808ec3945e0/jaxlib-0.10.0-cp312-cp312-manylinux_2_27_x86_64.whl", hash = "sha256:b0bfb865a07df2e6d7418c0b0c292dd294b5500523b1dd5872b180db2aa480d4", size = 85028702, upload-time = "2026-04-16T12:33:37.815Z" }, + { url = "https://files.pythonhosted.org/packages/08/3b/4f798fffed4229a2d7de07c1f4feabac7676a26c695a418796dbe29bae7f/jaxlib-0.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:25bf167e0d8b594e0ec50783ff4892c0b7ec37236c88b2b425a7c252823f8680", size = 64221923, upload-time = "2026-04-16T12:33:41.343Z" }, + { url = "https://files.pythonhosted.org/packages/d4/b6/b66b0abb9df8f9f8f19a5244b849cb07fc7389a4a5e1fb7794f7cefd7f26/jaxlib-0.10.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:384635fff55899a295bbc82ee6c6f773a300e787dc472ca92bbe79abfaac8369", size = 60138213, upload-time = "2026-04-16T12:33:45.13Z" }, + { url = "https://files.pythonhosted.org/packages/30/1e/844e525a72a08a2744ae2722e2332a0159a6d0efdc1e561cf378f7259a01/jaxlib-0.10.0-cp313-cp313-manylinux_2_27_aarch64.whl", hash = "sha256:6d8d78b7070b34e4c5bba5f7e10927e7f4aac9b69be17e9b0a5898553a4338f3", size = 79401054, upload-time = "2026-04-16T12:33:49.263Z" }, + { url = "https://files.pythonhosted.org/packages/f7/95/305854c2ef2b645f7df1666be66b1167c392cc39384d09aca2e9499b71bf/jaxlib-0.10.0-cp313-cp313-manylinux_2_27_x86_64.whl", hash = "sha256:d303dc31b65e8b793d5600f81b1583be03dc9b876a4c10b3e259b6609a1cbe3b", size = 85027218, upload-time = "2026-04-16T12:33:54.325Z" }, + { url = "https://files.pythonhosted.org/packages/a8/63/a5e1dcb65dca6efbae7189f185588fc939e17c284f272254fbeb68a39817/jaxlib-0.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:3869be623c2f3391be2ee86f8b412372b102492e67cac0a5f0ab1037bbc3a5cc", size = 64221972, upload-time = "2026-04-16T12:46:24.762Z" }, + { url = "https://files.pythonhosted.org/packages/ce/d6/411e580b70f64a5a4b095cb2c03c1e2c7b3b35c6754e5cacd4a8f8a2d480/jaxlib-0.10.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:9050ce2ae7eeca62b1a235065056cad62cac590ddc035486faa4472a47eed9f6", size = 60250897, upload-time = "2026-04-16T12:46:13.185Z" }, + { url = "https://files.pythonhosted.org/packages/2c/5c/f40ac9d40eb39c359f268e087ff1f21bdad664f86691c52a288d0f9152e8/jaxlib-0.10.0-cp313-cp313t-manylinux_2_27_aarch64.whl", hash = "sha256:59e07aab3bdfaad9bdd3cf32e0d3d4f228837b9b231c53f5ae1c0fc284481094", size = 79518774, upload-time = "2026-04-16T12:46:16.684Z" }, + { url = "https://files.pythonhosted.org/packages/49/36/dea7a0ea64a5551244e2140ef6ad36e2dff308b6f5facaa6f1c1272bb47f/jaxlib-0.10.0-cp313-cp313t-manylinux_2_27_x86_64.whl", hash = "sha256:3088503812cfe49f34a3083d3b7ef5cb3aaf33d89ceb1b3f647fa52713aee59d", size = 85134776, upload-time = "2026-04-16T12:46:20.855Z" }, + { url = "https://files.pythonhosted.org/packages/a7/25/e1e52a21786b321fb6a2edf9ef9971aa70f06bb2738aef9afd6d8f46a441/jaxlib-0.10.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:98b26672943672742873f65bc03216819fc55325c99f146590d007c0172bff30", size = 60141273, upload-time = "2026-04-16T12:46:27.922Z" }, + { url = "https://files.pythonhosted.org/packages/9c/3b/21e3382ce6f4ee84bcce52810f3786ae3663991ec863acadcd0765b6f767/jaxlib-0.10.0-cp314-cp314-manylinux_2_27_aarch64.whl", hash = "sha256:ad47e072430979ec21637aa487d4dc464028b8e9be27268f37de69536c76e341", size = 79416404, upload-time = "2026-04-16T12:46:31.326Z" }, + { url = "https://files.pythonhosted.org/packages/a1/8e/b2a08ffc51c93842de71f7f988865cebfa7f43d6721957812dc8cc8b9d40/jaxlib-0.10.0-cp314-cp314-manylinux_2_27_x86_64.whl", hash = "sha256:2a42cf04c0f88bc03b150a17fa7ddbb2f40e096667ec8a1b840ed87913e6e735", size = 85035152, upload-time = "2026-04-16T12:46:36.129Z" }, + { url = "https://files.pythonhosted.org/packages/24/08/26e6a3ecf0a95f1ec0dcd7a668d5c9a72e581c40fe4ae51e102ca63174c5/jaxlib-0.10.0-cp314-cp314-win_amd64.whl", hash = "sha256:450b771c01b3662c3497e2dceada3f6fc893112ae637ef85ef1dcc7dc68892a8", size = 66661443, upload-time = "2026-04-16T12:46:51.088Z" }, + { url = "https://files.pythonhosted.org/packages/37/d7/06383d19217824134c4a6119d2efe7b53cde6a0a66fb1d643d9f725d2697/jaxlib-0.10.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:f62026c9fb1f05998592082a6dcb62f70b466342bc139f711802a9b184ba9a46", size = 60253088, upload-time = "2026-04-16T12:46:39.666Z" }, + { url = "https://files.pythonhosted.org/packages/6b/ce/f66f955c01cce1ffda0cfbb1c02bb9234e0cac1d40b46fe17c315155d62f/jaxlib-0.10.0-cp314-cp314t-manylinux_2_27_aarch64.whl", hash = "sha256:e66bdc0b57ed5649950799d3f0d67a6bb67f03d06b49ea3fced0bdd6140a9943", size = 79517974, upload-time = "2026-04-16T12:46:43.147Z" }, + { url = "https://files.pythonhosted.org/packages/5e/74/b358923d0cce13fc7608051d0cc60ce3379f14350dc42540bdbabdbffab2/jaxlib-0.10.0-cp314-cp314t-manylinux_2_27_x86_64.whl", hash = "sha256:4dccd9065b30954879869641472d5d12fe4d7914175a5cad56293af8429ce7e0", size = 85134286, upload-time = "2026-04-16T12:46:47.416Z" }, +] + [[package]] name = "jedi" version = "0.19.2" @@ -838,6 +991,47 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6f/87/afead29192170917537934c6aff4b008c805fff7b1ccea0c79120d96beda/matplotlib-3.10.9-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3fc0364dfbe1d07f6d15c5ebd0c5bf89e126916e5a8667dd4a7a6e84c36653d4", size = 8774002, upload-time = "2026-04-24T00:14:09.816Z" }, ] +[[package]] +name = "ml-dtypes" +version = "0.5.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/4a/c27b42ed9b1c7d13d9ba8b6905dece787d6259152f2309338aed29b2447b/ml_dtypes-0.5.4.tar.gz", hash = "sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453", size = 692314, upload-time = "2025-11-17T22:32:31.031Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c6/5e/712092cfe7e5eb667b8ad9ca7c54442f21ed7ca8979745f1000e24cf8737/ml_dtypes-0.5.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6c7ecb74c4bd71db68a6bea1edf8da8c34f3d9fe218f038814fd1d310ac76c90", size = 679734, upload-time = "2025-11-17T22:31:39.223Z" }, + { url = "https://files.pythonhosted.org/packages/4f/cf/912146dfd4b5c0eea956836c01dcd2fce6c9c844b2691f5152aca196ce4f/ml_dtypes-0.5.4-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bc11d7e8c44a65115d05e2ab9989d1e045125d7be8e05a071a48bc76eb6d6040", size = 5056165, upload-time = "2025-11-17T22:31:41.071Z" }, + { url = "https://files.pythonhosted.org/packages/a9/80/19189ea605017473660e43762dc853d2797984b3c7bf30ce656099add30c/ml_dtypes-0.5.4-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:19b9a53598f21e453ea2fbda8aa783c20faff8e1eeb0d7ab899309a0053f1483", size = 5034975, upload-time = "2025-11-17T22:31:42.758Z" }, + { url = "https://files.pythonhosted.org/packages/b4/24/70bd59276883fdd91600ca20040b41efd4902a923283c4d6edcb1de128d2/ml_dtypes-0.5.4-cp311-cp311-win_amd64.whl", hash = "sha256:7c23c54a00ae43edf48d44066a7ec31e05fdc2eee0be2b8b50dd1903a1db94bb", size = 210742, upload-time = "2025-11-17T22:31:44.068Z" }, + { url = "https://files.pythonhosted.org/packages/a0/c9/64230ef14e40aa3f1cb254ef623bf812735e6bec7772848d19131111ac0d/ml_dtypes-0.5.4-cp311-cp311-win_arm64.whl", hash = "sha256:557a31a390b7e9439056644cb80ed0735a6e3e3bb09d67fd5687e4b04238d1de", size = 160709, upload-time = "2025-11-17T22:31:46.557Z" }, + { url = "https://files.pythonhosted.org/packages/a8/b8/3c70881695e056f8a32f8b941126cf78775d9a4d7feba8abcb52cb7b04f2/ml_dtypes-0.5.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a174837a64f5b16cab6f368171a1a03a27936b31699d167684073ff1c4237dac", size = 676927, upload-time = "2025-11-17T22:31:48.182Z" }, + { url = "https://files.pythonhosted.org/packages/54/0f/428ef6881782e5ebb7eca459689448c0394fa0a80bea3aa9262cba5445ea/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a7f7c643e8b1320fd958bf098aa7ecf70623a42ec5154e3be3be673f4c34d900", size = 5028464, upload-time = "2025-11-17T22:31:50.135Z" }, + { url = "https://files.pythonhosted.org/packages/3a/cb/28ce52eb94390dda42599c98ea0204d74799e4d8047a0eb559b6fd648056/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff", size = 5009002, upload-time = "2025-11-17T22:31:52.001Z" }, + { url = "https://files.pythonhosted.org/packages/f5/f0/0cfadd537c5470378b1b32bd859cf2824972174b51b873c9d95cfd7475a5/ml_dtypes-0.5.4-cp312-cp312-win_amd64.whl", hash = "sha256:c1a953995cccb9e25a4ae19e34316671e4e2edaebe4cf538229b1fc7109087b7", size = 212222, upload-time = "2025-11-17T22:31:53.742Z" }, + { url = "https://files.pythonhosted.org/packages/16/2e/9acc86985bfad8f2c2d30291b27cd2bb4c74cea08695bd540906ed744249/ml_dtypes-0.5.4-cp312-cp312-win_arm64.whl", hash = "sha256:9bad06436568442575beb2d03389aa7456c690a5b05892c471215bfd8cf39460", size = 160793, upload-time = "2025-11-17T22:31:55.358Z" }, + { url = "https://files.pythonhosted.org/packages/d9/a1/4008f14bbc616cfb1ac5b39ea485f9c63031c4634ab3f4cf72e7541f816a/ml_dtypes-0.5.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8c760d85a2f82e2bed75867079188c9d18dae2ee77c25a54d60e9cc79be1bc48", size = 676888, upload-time = "2025-11-17T22:31:56.907Z" }, + { url = "https://files.pythonhosted.org/packages/d3/b7/dff378afc2b0d5a7d6cd9d3209b60474d9819d1189d347521e1688a60a53/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce756d3a10d0c4067172804c9cc276ba9cc0ff47af9078ad439b075d1abdc29b", size = 5036993, upload-time = "2025-11-17T22:31:58.497Z" }, + { url = "https://files.pythonhosted.org/packages/eb/33/40cd74219417e78b97c47802037cf2d87b91973e18bb968a7da48a96ea44/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:533ce891ba774eabf607172254f2e7260ba5f57bdd64030c9a4fcfbd99815d0d", size = 5010956, upload-time = "2025-11-17T22:31:59.931Z" }, + { url = "https://files.pythonhosted.org/packages/e1/8b/200088c6859d8221454825959df35b5244fa9bdf263fd0249ac5fb75e281/ml_dtypes-0.5.4-cp313-cp313-win_amd64.whl", hash = "sha256:f21c9219ef48ca5ee78402d5cc831bd58ea27ce89beda894428bc67a52da5328", size = 212224, upload-time = "2025-11-17T22:32:01.349Z" }, + { url = "https://files.pythonhosted.org/packages/8f/75/dfc3775cb36367816e678f69a7843f6f03bd4e2bcd79941e01ea960a068e/ml_dtypes-0.5.4-cp313-cp313-win_arm64.whl", hash = "sha256:35f29491a3e478407f7047b8a4834e4640a77d2737e0b294d049746507af5175", size = 160798, upload-time = "2025-11-17T22:32:02.864Z" }, + { url = "https://files.pythonhosted.org/packages/4f/74/e9ddb35fd1dd43b1106c20ced3f53c2e8e7fc7598c15638e9f80677f81d4/ml_dtypes-0.5.4-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:304ad47faa395415b9ccbcc06a0350800bc50eda70f0e45326796e27c62f18b6", size = 702083, upload-time = "2025-11-17T22:32:04.08Z" }, + { url = "https://files.pythonhosted.org/packages/74/f5/667060b0aed1aa63166b22897fdf16dca9eb704e6b4bbf86848d5a181aa7/ml_dtypes-0.5.4-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6a0df4223b514d799b8a1629c65ddc351b3efa833ccf7f8ea0cf654a61d1e35d", size = 5354111, upload-time = "2025-11-17T22:32:05.546Z" }, + { url = "https://files.pythonhosted.org/packages/40/49/0f8c498a28c0efa5f5c95a9e374c83ec1385ca41d0e85e7cf40e5d519a21/ml_dtypes-0.5.4-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:531eff30e4d368cb6255bc2328d070e35836aa4f282a0fb5f3a0cd7260257298", size = 5366453, upload-time = "2025-11-17T22:32:07.115Z" }, + { url = "https://files.pythonhosted.org/packages/8c/27/12607423d0a9c6bbbcc780ad19f1f6baa2b68b18ce4bddcdc122c4c68dc9/ml_dtypes-0.5.4-cp313-cp313t-win_amd64.whl", hash = "sha256:cb73dccfc991691c444acc8c0012bee8f2470da826a92e3a20bb333b1a7894e6", size = 225612, upload-time = "2025-11-17T22:32:08.615Z" }, + { url = "https://files.pythonhosted.org/packages/e5/80/5a5929e92c72936d5b19872c5fb8fc09327c1da67b3b68c6a13139e77e20/ml_dtypes-0.5.4-cp313-cp313t-win_arm64.whl", hash = "sha256:3bbbe120b915090d9dd1375e4684dd17a20a2491ef25d640a908281da85e73f1", size = 164145, upload-time = "2025-11-17T22:32:09.782Z" }, + { url = "https://files.pythonhosted.org/packages/72/4e/1339dc6e2557a344f5ba5590872e80346f76f6cb2ac3dd16e4666e88818c/ml_dtypes-0.5.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:2b857d3af6ac0d39db1de7c706e69c7f9791627209c3d6dedbfca8c7e5faec22", size = 673781, upload-time = "2025-11-17T22:32:11.364Z" }, + { url = "https://files.pythonhosted.org/packages/04/f9/067b84365c7e83bda15bba2b06c6ca250ce27b20630b1128c435fb7a09aa/ml_dtypes-0.5.4-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:805cef3a38f4eafae3a5bf9ebdcdb741d0bcfd9e1bd90eb54abd24f928cd2465", size = 5036145, upload-time = "2025-11-17T22:32:12.783Z" }, + { url = "https://files.pythonhosted.org/packages/c6/bb/82c7dcf38070b46172a517e2334e665c5bf374a262f99a283ea454bece7c/ml_dtypes-0.5.4-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14a4fd3228af936461db66faccef6e4f41c1d82fcc30e9f8d58a08916b1d811f", size = 5010230, upload-time = "2025-11-17T22:32:14.38Z" }, + { url = "https://files.pythonhosted.org/packages/e9/93/2bfed22d2498c468f6bcd0d9f56b033eaa19f33320389314c19ef6766413/ml_dtypes-0.5.4-cp314-cp314-win_amd64.whl", hash = "sha256:8c6a2dcebd6f3903e05d51960a8058d6e131fe69f952a5397e5dbabc841b6d56", size = 221032, upload-time = "2025-11-17T22:32:15.763Z" }, + { url = "https://files.pythonhosted.org/packages/76/a3/9c912fe6ea747bb10fe2f8f54d027eb265db05dfb0c6335e3e063e74e6e8/ml_dtypes-0.5.4-cp314-cp314-win_arm64.whl", hash = "sha256:5a0f68ca8fd8d16583dfa7793973feb86f2fbb56ce3966daf9c9f748f52a2049", size = 163353, upload-time = "2025-11-17T22:32:16.932Z" }, + { url = "https://files.pythonhosted.org/packages/cd/02/48aa7d84cc30ab4ee37624a2fd98c56c02326785750cd212bc0826c2f15b/ml_dtypes-0.5.4-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:bfc534409c5d4b0bf945af29e5d0ab075eae9eecbb549ff8a29280db822f34f9", size = 702085, upload-time = "2025-11-17T22:32:18.175Z" }, + { url = "https://files.pythonhosted.org/packages/5a/e7/85cb99fe80a7a5513253ec7faa88a65306be071163485e9a626fce1b6e84/ml_dtypes-0.5.4-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2314892cdc3fcf05e373d76d72aaa15fda9fb98625effa73c1d646f331fcecb7", size = 5355358, upload-time = "2025-11-17T22:32:19.7Z" }, + { url = "https://files.pythonhosted.org/packages/79/2b/a826ba18d2179a56e144aef69e57fb2ab7c464ef0b2111940ee8a3a223a2/ml_dtypes-0.5.4-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0d2ffd05a2575b1519dc928c0b93c06339eb67173ff53acb00724502cda231cf", size = 5366332, upload-time = "2025-11-17T22:32:21.193Z" }, + { url = "https://files.pythonhosted.org/packages/84/44/f4d18446eacb20ea11e82f133ea8f86e2bf2891785b67d9da8d0ab0ef525/ml_dtypes-0.5.4-cp314-cp314t-win_amd64.whl", hash = "sha256:4381fe2f2452a2d7589689693d3162e876b3ddb0a832cde7a414f8e1adf7eab1", size = 236612, upload-time = "2025-11-17T22:32:22.579Z" }, + { url = "https://files.pythonhosted.org/packages/ad/3f/3d42e9a78fe5edf792a83c074b13b9b770092a4fbf3462872f4303135f09/ml_dtypes-0.5.4-cp314-cp314t-win_arm64.whl", hash = "sha256:11942cbf2cf92157db91e5022633c0d9474d4dfd813a909383bd23ce828a4b7d", size = 168825, upload-time = "2025-11-17T22:32:23.766Z" }, +] + [[package]] name = "msgspec" version = "0.21.1" @@ -974,6 +1168,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/74/f4c001f4714c3ad9ce037e18cf2b9c64871a84951eaa0baf683a9ca9301c/numpy-2.4.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:f2cf083b324a467e1ab358c105f6cad5ea950f50524668a80c486ff1db24e119", size = 12509075, upload-time = "2026-03-29T13:21:57.644Z" }, ] +[[package]] +name = "opt-einsum" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/b9/2ac072041e899a52f20cf9510850ff58295003aa75525e58343591b0cbfb/opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac", size = 63004, upload-time = "2024-09-26T14:33:24.483Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932, upload-time = "2024-09-26T14:33:23.039Z" }, +] + [[package]] name = "osqp" version = "1.1.1" @@ -1535,3 +1738,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9a/3f/f70e03f40ffc9a30d817eef7da1be72ee4956ba8d7255c399a01b135902a/websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767", size = 178735, upload-time = "2026-01-10T09:23:42.259Z" }, { url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598, upload-time = "2026-01-10T09:23:45.395Z" }, ] + +[[package]] +name = "wheel" +version = "0.47.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/39/62/75f18a0f03b4219c456652c7780e4d749b929eb605c098ce3a5b6b6bc081/wheel-0.47.0.tar.gz", hash = "sha256:cc72bd1009ba0cf63922e28f94d9d83b920aa2bb28f798a31d0691b02fa3c9b3", size = 63854, upload-time = "2026-04-22T15:51:27.727Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/1b/9e33c09813d65e248f7f773119148a612516a4bea93e9c6f545f78455b7c/wheel-0.47.0-py3-none-any.whl", hash = "sha256:212281cab4dff978f6cedd499cd893e1f620791ca6ff7107cf270781e587eced", size = 32218, upload-time = "2026-04-22T15:51:26.296Z" }, +]