Rootfind use different rewrite func for lstsq solvers (#220)#1
Open
jpbrodrick89 wants to merge 61 commits intomainfrom
Open
Rootfind use different rewrite func for lstsq solvers (#220)#1jpbrodrick89 wants to merge 61 commits intomainfrom
jpbrodrick89 wants to merge 61 commits intomainfrom
Conversation
…#220) * Rootfind use different rewrite func for lstsq solvers * root rewrite for lstsq test
`minimise()` now accepts `AbstractRootFinder` solvers (e.g. `Newton`, `Chord`) in addition to `AbstractMinimiser`. When a root finder is supplied, minimisation is performed by finding the roots of the gradient ∇f(y) = 0 using `jax.value_and_grad`, enabling true second-order methods via exact Hessian-vector products through JAX AD. Key design points: - `_to_grad_fn`: wraps the objective as a root-finding target, returning `(grad, (grad, aux))` so the gradient is available in aux for termination checking (mirrors `_to_minimise_fn` in `_root_find.py`) - `_RootToMinimise` / `_ConcreteRootToMinimise`: wrapper analogous to `_MinimToRoot`, adds `||∇f(y)|| < atol` as an extra termination condition on top of the root finder's own convergence check - Circular import broken with a local import inside `minimise()` plus a `TYPE_CHECKING` guard for the type annotation - `tags` passes straight through: the Hessian of f equals the Jacobian of ∇f, so the semantics are consistent - ImplicitAdjoint / JVP works correctly: root_find's rewrite_fn evaluates ∇f(y*)=0, which is the correct optimality condition Tests: Newton and Chord added to `_root_find_minimisers` in helpers.py and included in the `minimisers` parametrize fixture. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
- docs/api/minimise.md: adds a dedicated section explaining that any
AbstractRootFinder can be passed to minimise(), which rewrites the
problem as ∇f(y)=0. Includes worked examples for:
- Newton--Krylov (Hessian-free, GMRES inner solver) with a note that
MINRES is the preferred solver for symmetric systems and is planned
for Lineax
- Direct solver (default lx.AutoLinearSolver) for small/medium scale,
noting modified LDL^T as the forthcoming robust option
- docs/how-to-choose.md: adds second-order methods to the minimisation
guidance, pointing readers to the new section for worked examples
https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Implements three new classes in optimistix/_solver/second_order.py: - `SteihaugCGDescent`: Steihaug-Toint truncated CG for trust-region subproblems. Handles indefinite Hessians via negative-curvature detection and trust-region boundary projection; suitable for non-convex problems. - `AbstractNewtonMinimiser`: Base class for exact second-order minimisers. Computes the true Hessian via jax.hessian at each accepted step (as a materialised PyTreeLinearOperator with symmetric tag), rather than maintaining a quasi-Newton approximation. - `LineSearchNewton`: AbstractNewtonMinimiser + NewtonDescent + BacktrackingArmijo. Good for convex / near-convex problems. - `TrustNewton`: AbstractNewtonMinimiser + ClassicalTrustRegion, with either NewtonDescent (default) or SteihaugCGDescent (use_steihaug=True). The Steihaug variant is recommended for non-convex problems. All four are exported from the top-level optimistix namespace. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Instead of materialising the O(n²) Hessian with jax.hessian, represent it as lx.JacobianLinearOperator(λ _y, _: jax.grad(fn_scalar)(_y), y_eval, tags=symmetric). This computes Hessian-vector products lazily via jax.jvp on the gradient, meaning the full matrix is never built unless the linear solver asks for it. SteihaugCGDescent now uses entirely matrix-free HVPs (O(k * cost_of_grad) per outer step for k CG iterations). NewtonDescent/Cholesky will still materialise on demand via lineax's as_matrix() path. Also applies the same static-structure normalisation trick as AbstractGaussNewton uses for its FunctionLinearOperator, ensuring filter_cond sees matching treedefs in the accepted/rejected branches. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Instead of constructing a new JacobianLinearOperator (and running filter_closure_convert → tracing fn again) on every accepted step, just update the existing operator's x field with eqx.tree_at. The fn closure (jaxpr + captured arrays) is built once during init() via filter_eval_shape and never changes. Updating x is sufficient since JacobianLinearOperator.mv uses self.x as the jvp point directly. This removes the normalization trick entirely — filter_cond sees identical treedefs in both branches by construction. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…nearize Mirrors the Newton root-finder pattern exactly: _HessianGradFn (like _NoAux) is a stable eqx.Module stored in _NewtonMinimiserState. In accepted(), jax.linearize(state.hessian_grad_fn, y_eval) produces hess_mv_fn whose mv(v) replays the already-computed primal residuals with v as tangent — the primal is shared across all mv calls. This matters for direct solvers (Cholesky, LU) which call mv n times to materialise the matrix: FunctionLinearOperator is O(n * tangent-only) vs JacobianLinearOperator which would be O(n * full-JVP). The normalisation trick (same as AbstractGaussNewton) keeps filter_cond happy by ensuring both accepted/rejected branches produce the same static jaxpr. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…rad_fn) jax.linearize(state.hessian_grad_fn, y_eval) returns (grad, hess_mv_fn) in one forward-over-reverse pass. The primal IS the gradient — the same reason Newton root finder uses jax.linearize(fn, y) to get f_eval alongside lin_fn. Since the gradient comes for free from the Hessian linearize, we no longer need the separate jax.linearize(fn, y) + lin_to_grad path. Replace it with a direct fn(y_eval, args) call for the scalar the search step needs. Removes the autodiff_mode option and the lin_to_grad import. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
- Remove superfluous banners, docstrings, and inline notes that restate the code or document design decisions already in git history - Fix E501: shorten module-level comparison table to fit 88-char limit - Fix pyright reportGeneralTypeIssues: _CGState inner class used Generic[Y] shadowing the outer TypeVar; replaced with Any on field annotations - Fix pyright reportOperatorIssue: cast tree_dot() results to Array where used in scalar arithmetic (tree_dot is typed as Array | tuple[Array, ...]) https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
The module wrapper existed to give jax.linearize a stable Python object so filter_cond would see matching treedefs. But the normalisation trick already discards the new jaxpr and reuses the old one from state — so caching the callable in state buys nothing. Now jax.linearize receives a fresh lambda each step (fn/args come directly from step() parameters, no closed-over tracers), identical to how Newton root finder constructs its lambda. Removes _HessianGradFn class and the hessian_grad_fn field from _NewtonMinimiserState. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
jax.grad(fn, has_aux=True) already takes (y, args) since fn does. _NoAuxIn fixes args so the callable only takes y; _NoAuxOut strips the (grad, aux) tuple down to just grad. Both are eqx.Modules so equinox sees args as a dynamic pytree leaf. jax.grad(...) itself is a static callable field — fine, its closure over fn's arrays is only accessed when called within the same JAX trace. Replaces the double-lambda with a clean module chain and restores the cached hessian_grad_fn field in _NewtonMinimiserState. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
_NoAux(fn)(y, args) -> scalar already strips aux before jax.grad, so the chain becomes _NoAuxIn(jax.grad(_NoAux(fn)), args) — no _NoAuxOut wrapper needed. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Introduces AbstractNewtonBase (eqx.Module mixin) in quasi_newton.py with the six shared AbstractVar declarations (rtol, atol, norm, descent, search, verbose) and _NewtonBaseState with the nine common state fields. AbstractQuasiNewton and AbstractNewtonMinimiser both inherit from AbstractNewtonBase alongside their own AbstractMinimiser specialisation, removing the duplicated AbstractVar declarations. Each concrete state (_QuasiNewtonState, _NewtonMinimiserState) subclasses _NewtonBaseState and adds only its extra field (hessian_update_state / hessian_grad_fn). terminate() and postprocess() remain in each subclass to avoid a pyright MRO conflict: AbstractMinimiser[..., ConcreteState] declares them with the concrete state type, and a mixin definition with _NewtonBaseState is flagged as an incompatible override regardless of LSP direction. AbstractNewtonBase is exported from optimistix public API. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…ewton Minimises the diff: the two new classes now sit directly above the abstract quasi-Newton section where a reviewer would naturally look for shared abstractions, rather than at the top of the file. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…hooks Makes step concrete in the shared base class. The ~80% of step logic that was duplicated between AbstractQuasiNewton and AbstractNewtonMinimiser now lives in AbstractNewtonBase.step, which is driven by two abstract hooks: - _prepare_step: evaluates fn at state.y_eval and returns (f_eval, aux_eval, accepted_fn, hus_for_rejected). The accepted closure captures all solver-specific variables (lin_fn for QN, hessian_grad_fn for Newton). Newton returns None for hus_for_rejected, matching the pattern already used by BFGS/SR1 where HessianUpdateState=None. - _build_new_state: constructs the concrete solver state from the common post-step values. AbstractNewtonBase is made Generic[Y, Aux, _StateT] so step carries the proper return type tuple[Y, _StateT, Aux] through to each subclass. Base class order is swapped (AbstractNewtonBase before AbstractMinimiser) so Python's MRO resolves the concrete step before the abstract one. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…tractNewtonBase AbstractNewtonBase now inherits from AbstractMinimiser[Y, Aux, _BoundNewtonState] directly, so AbstractQuasiNewton and AbstractNewtonMinimiser no longer need to list AbstractMinimiser explicitly. The MRO ordering hack is also no longer needed since AbstractNewtonBase is the single concrete base that carries both the solver interface and the step implementation. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…order.py filter_cond is now called inside AbstractNewtonBase.step (quasi_newton.py) so second_order.py no longer needs to import it directly. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Reverts "Document second-order minimisation via root finders" (d4ab779) and "Allow root finders to be used as minimisers via ∇f(y) = 0" (995c117). The root-finder wrapping approach is superseded by the new AbstractNewtonBase design that implements second-order methods directly as proper AbstractMinimiser subclasses, which is cleaner and more capable. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Both AbstractQuasiNewton and AbstractNewtonMinimiser had identical terminate/postprocess implementations that only read state.terminate and state.result — fields on _NewtonBaseState, the bound for _BoundNewtonState. Moving them to the base class removes ~45 lines of duplication. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
TruncatedCG is an optimistix-local CG solver for potentially indefinite operators, implementing the truncation strategy used by scipy's Newton-CG: - Exits early on negative curvature (d^T H d <= 0), returning the current partial CG iterate as a valid descent direction. - On the first step (iterate still zero), falls back to the initial residual direction (-g), matching scipy's steepest-descent fallback. - Accepts an optional `"delta"` option (trust-region radius) so it can also exit when ||p|| >= delta, returning the pre-crossing iterate and the search direction in stats["direction"] for downstream boundary projection. - Accepts an optional `"rtol"` option per-call, providing the hook needed for Eisenstat-Walker tolerance scheduling without reinitialising the solver. Uses a fori_loop with done/result_p/result_d freeze pattern (consistent with SteihaugCGDescent) so all three exit conditions share a single loop body. Also registers LineSearchNewton, LineSearchNewton+TruncatedCG, TrustNewton, and TrustNewton(use_steihaug=True) in tests/helpers.py _minim_only, covering both PSD and indefinite Hessian test cases. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Previous version used fori_loop (Steihaug pattern) which always runs max_steps iterations as no-ops after convergence. This version correctly follows lineax.CG's while_loop structure, making only the targeted changes: - init(): remove PSD/NSD operator check (TruncatedCG accepts indefinite) - body_fun(): replace NaN injection with explicit neg_curv detection; add hit_boundary exit; compute result_y/result_d at each step so the correct exit value is available regardless of which condition fires - cond_fun(): add ~done flag so while_loop exits immediately on neg_curv or boundary, rather than running extra no-op iterations - Options: rtol/atol/delta overrides now correctly fed through Also kept: stabilise_every periodic residual recomputation, atol+rtol convergence check via not_converged, norm field — all from lineax.CG. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
TruncatedCG now handles both the Newton-CG (line-search) and Steihaug (trust-region) cases uniformly: - For both neg_curv and hit_boundary exits: returns y_before so the caller can project onto the trust-region boundary via _find_boundary_tau. - Exception: when delta=inf (line-search mode) and neg_curv fires at step 0 (y=0), returns d (= -g) as the steepest-descent fallback so Armijo has a non-trivial direction. - Adds hit_boundary to stats alongside negative_curvature. SteihaugCGDescent.step() is now ~20 lines: call TruncatedCG with delta, read the two flags from stats, apply _find_boundary_tau if either is set. The bespoke fori_loop CG and nested _find_boundary_tau are removed. _find_boundary_tau is promoted to module level (accepts delta_sq as a parameter) and the unused lax import is dropped from second_order.py. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
TruncatedCG.compute() now handles boundary projection internally when options["delta"] is finite: both the negative-curvature and boundary- crossing early exits solve ‖y_before + τd‖ = Δ and return the projected step as .value, so callers get the correct Steihaug step directly. _find_boundary_tau is now private to truncated_cg.py. SteihaugCGDescent .step() drops to a single lx.linear_solve call with no post-processing. Remove the now-unused cast/tree_dot imports from second_order.py. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
NewtonDescent.query() now computes ew_rtol = min(0.5, sqrt(||g||_2)) for EvalGradHessian f_info and passes it as options["rtol"] to lx.linear_solve. This is forwarded to TruncatedCG (and any future iterative lineax solver that honours options["rtol"]); direct solvers (Cholesky, LU, Auto) ignore it silently. SteihaugCGDescent.step() computes the same adaptive tolerance using the pre-computed grad_norm in state and passes it alongside options["delta"]. The rtol field on SteihaugCGDescent becomes the E-W cap (default 0.5), matching scipy's trust-ncg formula: eta = min(0.5, sqrt(||g||)) * ||g||. newton_step() gains an optional options parameter (default None) that is forwarded to lx.linear_solve; all existing callers (dogleg, LM) are unaffected. two_norm imported into gauss_newton.py for the E-W calculation. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Previously TruncatedCG always returned the forward (positive) root when projecting a negative-curvature exit onto the trust-region sphere. Scipy's trust-ncg evaluates both intersections and returns the one that minimises the quadratic model m(p) = g^T p + 0.5 p^T H p. Add _find_neg_curv_boundary_tau() which computes both roots ta <= tb and picks the one minimising Δm(τ) = τγ + ½τ²(d^T H d), using the CG conjugacy identity r·d = -‖r‖² = -γ to avoid an extra dot product. The boundary- crossing exit still uses the positive root via the existing _find_boundary_tau. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
_find_boundary_tau now takes neg_curv/gamma/inner_prod and handles both cases internally: positive root for boundary-crossing exits, best-of-two roots for negative-curvature exits. The two-function API is gone. TruncatedCG.max_steps default raised from 10*n to 20*n to match scipy's cg_maxiter = 20*len(x0). SteihaugCGDescent.max_steps and TrustNewton.steihaug_max_steps now default to None, propagating to the TruncatedCG default rather than imposing a hard cap of 100. Users can still set an explicit integer to impose a cap on large problems. Outer Newton iterations remain independently controlled via minimise(..., max_steps=256) through iterative_solve. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Previously, when negative curvature fired on the first CG step (y=0) in line-search mode, TruncatedCG returned the bare direction d = -g, leaving Armijo to find the scale by backtracking from alpha=1 against a unit step. Scipy instead returns dri0 / (-curv) * b = ||g||^2 / |d^T H d| * (-g), the Cauchy-scaled descent step. This matches the curvature magnitude and lets the Armijo search accept alpha=1 immediately rather than backtracking. Implemented as gamma / max(-inner_prod, eps) * d, using gamma = ||r||^2 and inner_prod = d^T H d which are both already in scope in body_fun. The eps guard avoids a huge step when inner_prod is only marginally negative. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Cover the previously untested path where `tags=frozenset({lx.positive_semidefinite_tag})`
is passed to `minimise`, causing the Hessian operator in second_order.py to
carry `{symmetric_tag, positive_semidefinite_tag}`. Without the PSD tag, both
`lx.Cholesky()` and `lx.CG()` raise a hard `ValueError` at init time
(operator only gets `symmetric_tag` by default); with it, both solvers
converge correctly on a strictly-convex quadratic.
Two parametrised tests across LineSearchNewton × TrustNewton × {Cholesky, CG}:
- test_newton_psd_linear_solver: verifies convergence with tag present
- test_newton_psd_linear_solver_no_tag_raises: verifies error without tag
https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Replace the pure quadratic test function (which Cholesky solves in one step) with f(y) = sum(cosh(y_i) - 1): globally convex, H = diag(cosh(y)) > 0 everywhere, minimum at y=0, f*=0. From y0=[4,-3,2] the solvers take 33-39 Newton steps, exercising the full Cholesky/CG path non-trivially. Assert on f(y*) rather than |y* - 0|: TrustNewton terminates on |Δf| < atol once f ≈ 0, so f(y*) converges to machine precision while |y*| is only ~3e-4. This reflects the semantics of the Cauchy stopping criterion correctly. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
The Newton minimiser previously stored jax.grad(_NoAux(fn)) inside a
_NoAuxIn equinox Module, which was then held in _NewtonMinimiserState.
Newer equinox raises _JaxTransformException when an eqx.Module stores a
JAX-transformed function (whose __wrapped__ points back to an eqx.Module)
and that state is processed inside a filter_custom_jvp trace context —
exactly what ImplicitAdjoint triggers via _implicit_impl.
Fix: remove _NoAuxIn and _NoAux usage entirely. Instead build the gradient
function inline as a plain lambda (`lambda y: fn(y, args)[0]`) in both
_make_hessian_f_info and _prepare_step.accepted. A plain lambda's
__wrapped__ attribute is a plain Python function, not an eqx.Module, so
equinox's transform-function check no longer fires.
Also fix TruncatedCG: wrap the 'delta' option with jnp.asarray so that
delta_sq is a JAX scalar rather than Python float('inf'), preventing a
jaxtyping TypeCheckError in _find_boundary_tau when no trust-region radius
is provided.
https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
- Rename _NewtonBaseState → NewtonMinimiserState (public name); remove the thin _NewtonMinimiserState pass-through subclass and use NewtonMinimiserState directly in AbstractNewtonMinimiser and its concrete methods. - Mirror NewtonChord's jax.linearize style: inline the lambda and pass has_aux=True instead of a separate grad_fn variable, in both _make_hessian_f_info and AbstractNewtonMinimiser._prepare_step. - Move LineSearchNewton / TrustNewton from _minim_only to _general_minimisers in tests/helpers.py so they are exercised by least_squares_optimisers too. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
- Rename NewtonMinimiserState -> _NewtonMinimiserState (private, consistent with _QuasiNewtonState) - Add _NoAux wrapper in second_order.py so Hessian linearize calls use jax.grad on a no-aux fn rather than has_aux=True + discarded _ https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…lpers - second_order.py: remove duplicate _NoAux class, import from newton_chord, inline _NoAux(fn) calls (no intermediate fn_no_aux variable). Change default linear_solver for LineSearchNewton and TrustNewton from AutoLinearSolver(well_posed=None) to lx.Cholesky(); update docstrings. - helpers.py: add globally_convex test function (cosh(y)-1, globally SPD), add it to minimisation_fn_minima_init_args, define _spd_minimisation_fns set (bowl, matyas, square_minus_one, globally_convex), add _newton_needs_psd predicate, add lx.CG vanilla variants for LineSearchNewton and TrustNewton. - test_minimise.py: remove test_newton_psd_linear_solver* tests and the associated _globally_convex / _psd_solver_linear_solver_pairs definitions (superseded by parametrised coverage via minimisation_fn_minima_init_args). Add skip/tags logic to test_minimise, test_minimise_jvp, test_forward_minimisation: skip when solver requires PSD Hessian but problem is not globally convex; otherwise pass positive_semidefinite_tag. - test_least_squares.py: same skip/tags logic for test_least_squares and test_least_squares_jvp (no least-squares problems are globally SPD so all Newton+Cholesky/CG combinations are skipped there). https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…odiff_mode in Newton minimisers TruncatedCG._find_boundary_tau had prefer_ta = +gamma + ... > 0 when it should be -gamma + ... > 0. For the common first-step case (p=0, ta+tb=0) this always evaluated to True, picking the root that moves *away* from the minimum (in direction +g) instead of toward it (direction -g). Fix: negate gamma in the preference test so the solver correctly extends to the trust-region boundary in the descent direction under negative curvature. AbstractNewtonMinimiser now also respects the autodiff_mode='fwd' option (mirroring quasi-Newton's lin_to_grad pattern) by routing through jax.jacfwd instead of jax.grad for the gradient computation, enabling Newton methods to work with dfx.ForwardMode() ODE solves. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…x.LU()
- Change LineSearchNewton/TrustNewton default linear_solver from lx.Cholesky()
to lx.LU(), which works for any non-singular Hessian without requiring PSD.
- Replace _newton_needs_psd with two predicates:
- _uses_vanilla_cg: True only for lx.CG (explicit PSD tag requirement).
Used in test_minimise/test_least_squares: only CG is skipped on non-SPD
problems; all other solvers run, with PSD tags added problem-dependently.
- _newton_needs_convex: True for Newton+LU/Cholesky/CG without indefinite-
capable solver. Used in test_forward_minimisation where forward_only_ode
has a negative Hessian at the starting point (non-convex region), making
Newton+LU produce ascent directions that the Armijo search can never accept.
- Tags are now always problem-dependent (not solver-dependent): pass
positive_semidefinite_tag iff _fn in _spd_minimisation_fns.
https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Selects Cholesky when the Hessian carries positive_semidefinite_tag (passed for SPD problems) and falls back to LU otherwise, giving the best of both: efficiency on convex problems, correctness on general ones. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
- Change default linear_solver for LineSearchNewton and TrustNewton from
lx.Cholesky() to lx.AutoLinearSolver(well_posed=True), which dispatches
to Cholesky for SPD-tagged operators and LU otherwise.
- Replace _newton_needs_psd with two predicates:
- _uses_vanilla_cg: True only when Newton uses lx.CG; used in convergence
tests to skip non-SPD problems (CG requires positive_semidefinite_tag)
- _newton_needs_convex: True for Newton with direct linear solver (not
SteihaugCG/TruncatedCG); used in JVP tests to avoid XLA crash from
cache exhaustion on non-convex problems
- Make tags problem-dependent (not solver-dependent):
tags = frozenset({lx.positive_semidefinite_tag}) if fn in _spd_minimisation_fns
- Add scale args to globally_convex for consistency with bowl pattern;
minimum y*=0 is independent of scale, so ImplicitAdjoint returns zero tangent
- Add NelderMead to loose-tolerance group in test_minimise_jvp
https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
ImplicitAdjoint only uses the optimality condition at y* (not the solver internals), so Newton+LU on non-convex problems produces correct JVP results. The only reason to skip is when lx.CG requires positive_semidefinite_tag but the problem is not SPD. Remove _newton_needs_convex and replace all remaining uses with _uses_vanilla_cg in test_minimise_jvp, test_forward_minimisation, and test_least_squares_jvp. The LLVM OOM crash (Failed to materialize symbols) in combined runs is a JIT compilation cache exhaustion issue on constrained machines, not a code or pytree structure bug -- each test passes in isolation. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…ubproblem TrustNewton previously used NewtonDescent + ClassicalTrustRegion. When the Hessian is negative definite, NewtonDescent produces an ascent direction that neither the trust region nor the line search can correct; the trust region only scales the step size, not the direction. After ~54 halvings the step size reaches floating-point zero, Cauchy termination fires, and the solver falsely converges at the starting point. Fix: replace NewtonDescent with IndirectDampedNewtonDescent, which solves the true trust-region subproblem (Conn, Gould, Toint §7.3) by root-finding for the Levenberg-Marquardt parameter λ such that ‖(H + λI)⁻¹g‖ = Δ. This guarantees H + λI is positive definite at the solution and that the step is a genuine descent direction. Also fix damped_newton_step: H + λI is always symmetric when isinstance(f_info, EvalGradHessian), so always propagate symmetric_tag in addition to the existing conditional positive_semidefinite_tag. This allows AutoLinearSolver(well_posed=True) to use LDL for indefinite problems and Cholesky when the Hessian is tagged positive_semidefinite. Fixes forward_only_ode-solver13 (TrustNewton) in test_forward_minimisation. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…haug=True Previously, passing linear_solver alongside use_steihaug=True silently discarded the solver with no indication to the user. This was confusing because the parameter appeared to be accepted but had no effect. Use a sentinel default for linear_solver so we can distinguish an explicit user-supplied value from the default. Raise a clear ValueError explaining that SteihaugCGDescent constructs its own TruncatedCG internally and does not accept an external linear_solver. Also updates the linear_solver docstring to document the error contract instead of the previous "Ignored when use_steihaug=True" note. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
When the exact linear solver (Cholesky, LU, etc.) fails on an indefinite or near-singular Hessian, fall back to the gradient direction so the Armijo line search can still make progress. This mirrors scipy's newton-cg behaviour: on the first CG step with negative curvature (y=0), scipy returns the Cauchy-scaled steepest-descent direction rather than aborting. For exact solvers we don't have curvature information, so we fall back to the raw gradient and let Armijo find the right step size. Fixes the forward_only_ode test for LineSearchNewton. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…ctions
The previous fallback only fired when the linear solve returned a
non-successful result. This missed the case where the solve succeeds
but the Hessian is indefinite: LU on a negative-definite H returns a
direction with g^T(H^{-1}g) < 0, which ascends rather than descends,
causing the Armijo line search to fail silently.
Extend the condition to also fall back when the solved direction is not
a descent direction (g^T newton <= 0), using the dot product check that
costs no extra HVP. This fixes the forward_only_ode test where the
Hessian is negative at the initial point k=0.6.
https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Let the non-successful result propagate rather than silently falling back to steepest descent. Only fall back on non-descent directions. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Flatten _NewtonMinimiserState into _QuasiNewtonState directly. AbstractNewtonMinimiser now uses _QuasiNewtonState with hessian_update_state=None, eliminating the separate base class. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…mments Revert LM operator to original (lineax AddLinearOperator preserves symmetry already). Revert two _QuasiNewtonState comments to match main. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Cauchy termination needs three iterations for quadratic problems where the initial point is not the minimum: one to evaluate at init (y_diff=0 but f_diff large due to zeroed f_info), one to accept the Newton step (y_diff large = full step to minimum), and one to confirm no further movement. Add a gradient-norm termination criterion alongside cauchy: if ||grad|| < atol + rtol * ||y|| (element-wise), the iterate is at a stationary point. For exact Newton on a quadratic this fires as soon as we arrive at the minimum, removing the extra confirmation step and reducing iteration count from 3 to 2 for bowl, matyas, diag_bowl, and glob_convex. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
The gradient_termination shortcut is out of scope for this PR. The underlying one-step timing delay vs newton_chord is a deeper architectural question that warrants separate investigation. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Start from [10, -8, 6] instead of [0.4, -0.3, 0.2] to stress-test convergence behaviour from a harder starting point. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
The function was incorrectly defined as a plain quadratic. Use the actual test definition: f(y) = Σ cosh(scale·y) − 1. Add glob_convex_far with init [10, -8, 6] alongside the original near-start init [0.4, -0.3, 0.2] from tests/helpers.py. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
- square_minus_one: use x^2-1 (matches helpers.py) instead of (x^2-1)^2 - bowl: use PRNGKey(17) instead of PRNGKey(0) to match test matrix - diagonal_quadratic_bowl: use tree_map(x^2*(0.1+w)) residual with random squared-normal args, matching the helpers.py definition https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
- SteihaugCGDescent and TruncatedCG added to descents.md - LineSearchNewton, TrustNewton, AbstractNewtonMinimiser added to minimise.md https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Remove the _uses_vanilla_cg helper and replace each call site with the equivalent one-liner: isinstance(getattr(solver.descent, "linear_solver", None), lx.CG). Condense the surrounding comments in helpers.py. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
The default linear solver for LineSearchNewton and TrustNewton is AutoLinearSolver(well_posed=True), not Cholesky. https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Rootfind use different rewrite func for lstsq solvers
root rewrite for lstsq test