Skip to content

Rootfind use different rewrite func for lstsq solvers (#220)#1

Open
jpbrodrick89 wants to merge 61 commits intomainfrom
claude/second-order-methods-research-xOyOu
Open

Rootfind use different rewrite func for lstsq solvers (#220)#1
jpbrodrick89 wants to merge 61 commits intomainfrom
claude/second-order-methods-research-xOyOu

Conversation

@jpbrodrick89
Copy link
Copy Markdown
Owner

  • Rootfind use different rewrite func for lstsq solvers

  • root rewrite for lstsq test

jurasic-pf and others added 30 commits March 9, 2026 19:09
…#220)

* Rootfind use different rewrite func for lstsq solvers

* root rewrite for lstsq test
`minimise()` now accepts `AbstractRootFinder` solvers (e.g. `Newton`,
`Chord`) in addition to `AbstractMinimiser`. When a root finder is
supplied, minimisation is performed by finding the roots of the gradient
∇f(y) = 0 using `jax.value_and_grad`, enabling true second-order methods
via exact Hessian-vector products through JAX AD.

Key design points:
- `_to_grad_fn`: wraps the objective as a root-finding target, returning
  `(grad, (grad, aux))` so the gradient is available in aux for
  termination checking (mirrors `_to_minimise_fn` in `_root_find.py`)
- `_RootToMinimise` / `_ConcreteRootToMinimise`: wrapper analogous to
  `_MinimToRoot`, adds `||∇f(y)|| < atol` as an extra termination
  condition on top of the root finder's own convergence check
- Circular import broken with a local import inside `minimise()` plus a
  `TYPE_CHECKING` guard for the type annotation
- `tags` passes straight through: the Hessian of f equals the Jacobian
  of ∇f, so the semantics are consistent
- ImplicitAdjoint / JVP works correctly: root_find's rewrite_fn
  evaluates ∇f(y*)=0, which is the correct optimality condition

Tests: Newton and Chord added to `_root_find_minimisers` in helpers.py
and included in the `minimisers` parametrize fixture.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
- docs/api/minimise.md: adds a dedicated section explaining that any
  AbstractRootFinder can be passed to minimise(), which rewrites the
  problem as ∇f(y)=0. Includes worked examples for:
  - Newton--Krylov (Hessian-free, GMRES inner solver) with a note that
    MINRES is the preferred solver for symmetric systems and is planned
    for Lineax
  - Direct solver (default lx.AutoLinearSolver) for small/medium scale,
    noting modified LDL^T as the forthcoming robust option

- docs/how-to-choose.md: adds second-order methods to the minimisation
  guidance, pointing readers to the new section for worked examples

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Implements three new classes in optimistix/_solver/second_order.py:

- `SteihaugCGDescent`: Steihaug-Toint truncated CG for trust-region
  subproblems. Handles indefinite Hessians via negative-curvature detection
  and trust-region boundary projection; suitable for non-convex problems.

- `AbstractNewtonMinimiser`: Base class for exact second-order minimisers.
  Computes the true Hessian via jax.hessian at each accepted step (as a
  materialised PyTreeLinearOperator with symmetric tag), rather than
  maintaining a quasi-Newton approximation.

- `LineSearchNewton`: AbstractNewtonMinimiser + NewtonDescent +
  BacktrackingArmijo. Good for convex / near-convex problems.

- `TrustNewton`: AbstractNewtonMinimiser + ClassicalTrustRegion, with either
  NewtonDescent (default) or SteihaugCGDescent (use_steihaug=True). The
  Steihaug variant is recommended for non-convex problems.

All four are exported from the top-level optimistix namespace.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Instead of materialising the O(n²) Hessian with jax.hessian, represent it as
lx.JacobianLinearOperator(λ _y, _: jax.grad(fn_scalar)(_y), y_eval, tags=symmetric).
This computes Hessian-vector products lazily via jax.jvp on the gradient,
meaning the full matrix is never built unless the linear solver asks for it.

SteihaugCGDescent now uses entirely matrix-free HVPs (O(k * cost_of_grad) per
outer step for k CG iterations). NewtonDescent/Cholesky will still materialise
on demand via lineax's as_matrix() path.

Also applies the same static-structure normalisation trick as AbstractGaussNewton
uses for its FunctionLinearOperator, ensuring filter_cond sees matching treedefs
in the accepted/rejected branches.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Instead of constructing a new JacobianLinearOperator (and running
filter_closure_convert → tracing fn again) on every accepted step,
just update the existing operator's x field with eqx.tree_at.

The fn closure (jaxpr + captured arrays) is built once during init()
via filter_eval_shape and never changes. Updating x is sufficient
since JacobianLinearOperator.mv uses self.x as the jvp point directly.
This removes the normalization trick entirely — filter_cond sees
identical treedefs in both branches by construction.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…nearize

Mirrors the Newton root-finder pattern exactly:

  _HessianGradFn (like _NoAux) is a stable eqx.Module stored in
  _NewtonMinimiserState.  In accepted(), jax.linearize(state.hessian_grad_fn,
  y_eval) produces hess_mv_fn whose mv(v) replays the already-computed primal
  residuals with v as tangent — the primal is shared across all mv calls.
  This matters for direct solvers (Cholesky, LU) which call mv n times to
  materialise the matrix: FunctionLinearOperator is O(n * tangent-only) vs
  JacobianLinearOperator which would be O(n * full-JVP).

The normalisation trick (same as AbstractGaussNewton) keeps filter_cond happy
by ensuring both accepted/rejected branches produce the same static jaxpr.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…rad_fn)

jax.linearize(state.hessian_grad_fn, y_eval) returns (grad, hess_mv_fn) in
one forward-over-reverse pass.  The primal IS the gradient — the same reason
Newton root finder uses jax.linearize(fn, y) to get f_eval alongside lin_fn.

Since the gradient comes for free from the Hessian linearize, we no longer
need the separate jax.linearize(fn, y) + lin_to_grad path.  Replace it with
a direct fn(y_eval, args) call for the scalar the search step needs.  Removes
the autodiff_mode option and the lin_to_grad import.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
- Remove superfluous banners, docstrings, and inline notes that restate
  the code or document design decisions already in git history
- Fix E501: shorten module-level comparison table to fit 88-char limit
- Fix pyright reportGeneralTypeIssues: _CGState inner class used Generic[Y]
  shadowing the outer TypeVar; replaced with Any on field annotations
- Fix pyright reportOperatorIssue: cast tree_dot() results to Array where
  used in scalar arithmetic (tree_dot is typed as Array | tuple[Array, ...])

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
The module wrapper existed to give jax.linearize a stable Python object
so filter_cond would see matching treedefs.  But the normalisation trick
already discards the new jaxpr and reuses the old one from state — so
caching the callable in state buys nothing.

Now jax.linearize receives a fresh lambda each step (fn/args come directly
from step() parameters, no closed-over tracers), identical to how Newton
root finder constructs its lambda.  Removes _HessianGradFn class and the
hessian_grad_fn field from _NewtonMinimiserState.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
jax.grad(fn, has_aux=True) already takes (y, args) since fn does.
_NoAuxIn fixes args so the callable only takes y; _NoAuxOut strips
the (grad, aux) tuple down to just grad.  Both are eqx.Modules so
equinox sees args as a dynamic pytree leaf.  jax.grad(...) itself
is a static callable field — fine, its closure over fn's arrays is
only accessed when called within the same JAX trace.

Replaces the double-lambda with a clean module chain and restores
the cached hessian_grad_fn field in _NewtonMinimiserState.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
_NoAux(fn)(y, args) -> scalar already strips aux before jax.grad,
so the chain becomes _NoAuxIn(jax.grad(_NoAux(fn)), args) — no
_NoAuxOut wrapper needed.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Introduces AbstractNewtonBase (eqx.Module mixin) in quasi_newton.py
with the six shared AbstractVar declarations (rtol, atol, norm,
descent, search, verbose) and _NewtonBaseState with the nine common
state fields.

AbstractQuasiNewton and AbstractNewtonMinimiser both inherit from
AbstractNewtonBase alongside their own AbstractMinimiser specialisation,
removing the duplicated AbstractVar declarations. Each concrete state
(_QuasiNewtonState, _NewtonMinimiserState) subclasses _NewtonBaseState
and adds only its extra field (hessian_update_state / hessian_grad_fn).

terminate() and postprocess() remain in each subclass to avoid a
pyright MRO conflict: AbstractMinimiser[..., ConcreteState] declares
them with the concrete state type, and a mixin definition with
_NewtonBaseState is flagged as an incompatible override regardless of
LSP direction.

AbstractNewtonBase is exported from optimistix public API.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…ewton

Minimises the diff: the two new classes now sit directly above the
abstract quasi-Newton section where a reviewer would naturally look
for shared abstractions, rather than at the top of the file.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…hooks

Makes step concrete in the shared base class. The ~80% of step logic that
was duplicated between AbstractQuasiNewton and AbstractNewtonMinimiser now
lives in AbstractNewtonBase.step, which is driven by two abstract hooks:

- _prepare_step: evaluates fn at state.y_eval and returns (f_eval, aux_eval,
  accepted_fn, hus_for_rejected). The accepted closure captures all
  solver-specific variables (lin_fn for QN, hessian_grad_fn for Newton).
  Newton returns None for hus_for_rejected, matching the pattern already
  used by BFGS/SR1 where HessianUpdateState=None.

- _build_new_state: constructs the concrete solver state from the common
  post-step values.

AbstractNewtonBase is made Generic[Y, Aux, _StateT] so step carries the
proper return type tuple[Y, _StateT, Aux] through to each subclass.
Base class order is swapped (AbstractNewtonBase before AbstractMinimiser)
so Python's MRO resolves the concrete step before the abstract one.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…tractNewtonBase

AbstractNewtonBase now inherits from AbstractMinimiser[Y, Aux, _BoundNewtonState]
directly, so AbstractQuasiNewton and AbstractNewtonMinimiser no longer need to
list AbstractMinimiser explicitly. The MRO ordering hack is also no longer needed
since AbstractNewtonBase is the single concrete base that carries both the
solver interface and the step implementation.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…order.py

filter_cond is now called inside AbstractNewtonBase.step (quasi_newton.py)
so second_order.py no longer needs to import it directly.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Reverts "Document second-order minimisation via root finders" (d4ab779)
and "Allow root finders to be used as minimisers via ∇f(y) = 0" (995c117).

The root-finder wrapping approach is superseded by the new AbstractNewtonBase
design that implements second-order methods directly as proper AbstractMinimiser
subclasses, which is cleaner and more capable.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Both AbstractQuasiNewton and AbstractNewtonMinimiser had identical
terminate/postprocess implementations that only read state.terminate
and state.result — fields on _NewtonBaseState, the bound for
_BoundNewtonState. Moving them to the base class removes ~45 lines of
duplication.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
TruncatedCG is an optimistix-local CG solver for potentially indefinite
operators, implementing the truncation strategy used by scipy's Newton-CG:

- Exits early on negative curvature (d^T H d <= 0), returning the current
  partial CG iterate as a valid descent direction.
- On the first step (iterate still zero), falls back to the initial residual
  direction (-g), matching scipy's steepest-descent fallback.
- Accepts an optional `"delta"` option (trust-region radius) so it can also
  exit when ||p|| >= delta, returning the pre-crossing iterate and the search
  direction in stats["direction"] for downstream boundary projection.
- Accepts an optional `"rtol"` option per-call, providing the hook needed for
  Eisenstat-Walker tolerance scheduling without reinitialising the solver.

Uses a fori_loop with done/result_p/result_d freeze pattern (consistent with
SteihaugCGDescent) so all three exit conditions share a single loop body.

Also registers LineSearchNewton, LineSearchNewton+TruncatedCG, TrustNewton,
and TrustNewton(use_steihaug=True) in tests/helpers.py _minim_only, covering
both PSD and indefinite Hessian test cases.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Previous version used fori_loop (Steihaug pattern) which always runs
max_steps iterations as no-ops after convergence. This version correctly
follows lineax.CG's while_loop structure, making only the targeted changes:

- init(): remove PSD/NSD operator check (TruncatedCG accepts indefinite)
- body_fun(): replace NaN injection with explicit neg_curv detection;
  add hit_boundary exit; compute result_y/result_d at each step so the
  correct exit value is available regardless of which condition fires
- cond_fun(): add ~done flag so while_loop exits immediately on neg_curv
  or boundary, rather than running extra no-op iterations
- Options: rtol/atol/delta overrides now correctly fed through

Also kept: stabilise_every periodic residual recomputation, atol+rtol
convergence check via not_converged, norm field — all from lineax.CG.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
TruncatedCG now handles both the Newton-CG (line-search) and Steihaug
(trust-region) cases uniformly:

- For both neg_curv and hit_boundary exits: returns y_before so the caller
  can project onto the trust-region boundary via _find_boundary_tau.
- Exception: when delta=inf (line-search mode) and neg_curv fires at step 0
  (y=0), returns d (= -g) as the steepest-descent fallback so Armijo has a
  non-trivial direction.
- Adds hit_boundary to stats alongside negative_curvature.

SteihaugCGDescent.step() is now ~20 lines: call TruncatedCG with delta,
read the two flags from stats, apply _find_boundary_tau if either is set.
The bespoke fori_loop CG and nested _find_boundary_tau are removed.

_find_boundary_tau is promoted to module level (accepts delta_sq as a
parameter) and the unused lax import is dropped from second_order.py.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
TruncatedCG.compute() now handles boundary projection internally when
options["delta"] is finite: both the negative-curvature and boundary-
crossing early exits solve ‖y_before + τd‖ = Δ and return the projected
step as .value, so callers get the correct Steihaug step directly.

_find_boundary_tau is now private to truncated_cg.py.  SteihaugCGDescent
.step() drops to a single lx.linear_solve call with no post-processing.
Remove the now-unused cast/tree_dot imports from second_order.py.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
NewtonDescent.query() now computes ew_rtol = min(0.5, sqrt(||g||_2)) for
EvalGradHessian f_info and passes it as options["rtol"] to lx.linear_solve.
This is forwarded to TruncatedCG (and any future iterative lineax solver
that honours options["rtol"]); direct solvers (Cholesky, LU, Auto) ignore
it silently.

SteihaugCGDescent.step() computes the same adaptive tolerance using the
pre-computed grad_norm in state and passes it alongside options["delta"].
The rtol field on SteihaugCGDescent becomes the E-W cap (default 0.5),
matching scipy's trust-ncg formula: eta = min(0.5, sqrt(||g||)) * ||g||.

newton_step() gains an optional options parameter (default None) that is
forwarded to lx.linear_solve; all existing callers (dogleg, LM) are
unaffected.

two_norm imported into gauss_newton.py for the E-W calculation.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Previously TruncatedCG always returned the forward (positive) root when
projecting a negative-curvature exit onto the trust-region sphere.  Scipy's
trust-ncg evaluates both intersections and returns the one that minimises the
quadratic model m(p) = g^T p + 0.5 p^T H p.

Add _find_neg_curv_boundary_tau() which computes both roots ta <= tb and
picks the one minimising Δm(τ) = τγ + ½τ²(d^T H d), using the CG conjugacy
identity r·d = -‖r‖² = -γ to avoid an extra dot product.  The boundary-
crossing exit still uses the positive root via the existing _find_boundary_tau.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
_find_boundary_tau now takes neg_curv/gamma/inner_prod and handles both
cases internally: positive root for boundary-crossing exits, best-of-two
roots for negative-curvature exits.  The two-function API is gone.

TruncatedCG.max_steps default raised from 10*n to 20*n to match scipy's
cg_maxiter = 20*len(x0).  SteihaugCGDescent.max_steps and
TrustNewton.steihaug_max_steps now default to None, propagating to the
TruncatedCG default rather than imposing a hard cap of 100.  Users can
still set an explicit integer to impose a cap on large problems.

Outer Newton iterations remain independently controlled via
minimise(..., max_steps=256) through iterative_solve.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Previously, when negative curvature fired on the first CG step (y=0) in
line-search mode, TruncatedCG returned the bare direction d = -g, leaving
Armijo to find the scale by backtracking from alpha=1 against a unit step.

Scipy instead returns dri0 / (-curv) * b = ||g||^2 / |d^T H d| * (-g),
the Cauchy-scaled descent step.  This matches the curvature magnitude and
lets the Armijo search accept alpha=1 immediately rather than backtracking.

Implemented as gamma / max(-inner_prod, eps) * d, using gamma = ||r||^2
and inner_prod = d^T H d which are both already in scope in body_fun.
The eps guard avoids a huge step when inner_prod is only marginally negative.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Cover the previously untested path where `tags=frozenset({lx.positive_semidefinite_tag})`
is passed to `minimise`, causing the Hessian operator in second_order.py to
carry `{symmetric_tag, positive_semidefinite_tag}`.  Without the PSD tag, both
`lx.Cholesky()` and `lx.CG()` raise a hard `ValueError` at init time
(operator only gets `symmetric_tag` by default); with it, both solvers
converge correctly on a strictly-convex quadratic.

Two parametrised tests across LineSearchNewton × TrustNewton × {Cholesky, CG}:
- test_newton_psd_linear_solver: verifies convergence with tag present
- test_newton_psd_linear_solver_no_tag_raises: verifies error without tag

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Replace the pure quadratic test function (which Cholesky solves in one step)
with f(y) = sum(cosh(y_i) - 1): globally convex, H = diag(cosh(y)) > 0
everywhere, minimum at y=0, f*=0.  From y0=[4,-3,2] the solvers take 33-39
Newton steps, exercising the full Cholesky/CG path non-trivially.

Assert on f(y*) rather than |y* - 0|: TrustNewton terminates on |Δf| < atol
once f ≈ 0, so f(y*) converges to machine precision while |y*| is only ~3e-4.
This reflects the semantics of the Cauchy stopping criterion correctly.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
claude added 30 commits March 16, 2026 15:12
The Newton minimiser previously stored jax.grad(_NoAux(fn)) inside a
_NoAuxIn equinox Module, which was then held in _NewtonMinimiserState.
Newer equinox raises _JaxTransformException when an eqx.Module stores a
JAX-transformed function (whose __wrapped__ points back to an eqx.Module)
and that state is processed inside a filter_custom_jvp trace context —
exactly what ImplicitAdjoint triggers via _implicit_impl.

Fix: remove _NoAuxIn and _NoAux usage entirely. Instead build the gradient
function inline as a plain lambda (`lambda y: fn(y, args)[0]`) in both
_make_hessian_f_info and _prepare_step.accepted. A plain lambda's
__wrapped__ attribute is a plain Python function, not an eqx.Module, so
equinox's transform-function check no longer fires.

Also fix TruncatedCG: wrap the 'delta' option with jnp.asarray so that
delta_sq is a JAX scalar rather than Python float('inf'), preventing a
jaxtyping TypeCheckError in _find_boundary_tau when no trust-region radius
is provided.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
- Rename _NewtonBaseState → NewtonMinimiserState (public name); remove the
  thin _NewtonMinimiserState pass-through subclass and use NewtonMinimiserState
  directly in AbstractNewtonMinimiser and its concrete methods.
- Mirror NewtonChord's jax.linearize style: inline the lambda and pass
  has_aux=True instead of a separate grad_fn variable, in both
  _make_hessian_f_info and AbstractNewtonMinimiser._prepare_step.
- Move LineSearchNewton / TrustNewton from _minim_only to _general_minimisers
  in tests/helpers.py so they are exercised by least_squares_optimisers too.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
- Rename NewtonMinimiserState -> _NewtonMinimiserState (private, consistent
  with _QuasiNewtonState)
- Add _NoAux wrapper in second_order.py so Hessian linearize calls use
  jax.grad on a no-aux fn rather than has_aux=True + discarded _

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…lpers

- second_order.py: remove duplicate _NoAux class, import from newton_chord,
  inline _NoAux(fn) calls (no intermediate fn_no_aux variable).
  Change default linear_solver for LineSearchNewton and TrustNewton from
  AutoLinearSolver(well_posed=None) to lx.Cholesky(); update docstrings.

- helpers.py: add globally_convex test function (cosh(y)-1, globally SPD),
  add it to minimisation_fn_minima_init_args, define _spd_minimisation_fns
  set (bowl, matyas, square_minus_one, globally_convex), add _newton_needs_psd
  predicate, add lx.CG vanilla variants for LineSearchNewton and TrustNewton.

- test_minimise.py: remove test_newton_psd_linear_solver* tests and the
  associated _globally_convex / _psd_solver_linear_solver_pairs definitions
  (superseded by parametrised coverage via minimisation_fn_minima_init_args).
  Add skip/tags logic to test_minimise, test_minimise_jvp,
  test_forward_minimisation: skip when solver requires PSD Hessian but
  problem is not globally convex; otherwise pass positive_semidefinite_tag.

- test_least_squares.py: same skip/tags logic for test_least_squares and
  test_least_squares_jvp (no least-squares problems are globally SPD so all
  Newton+Cholesky/CG combinations are skipped there).

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…odiff_mode in Newton minimisers

TruncatedCG._find_boundary_tau had prefer_ta = +gamma + ... > 0 when it
should be -gamma + ... > 0.  For the common first-step case (p=0, ta+tb=0)
this always evaluated to True, picking the root that moves *away* from the
minimum (in direction +g) instead of toward it (direction -g).

Fix: negate gamma in the preference test so the solver correctly extends to
the trust-region boundary in the descent direction under negative curvature.

AbstractNewtonMinimiser now also respects the autodiff_mode='fwd' option
(mirroring quasi-Newton's lin_to_grad pattern) by routing through
jax.jacfwd instead of jax.grad for the gradient computation, enabling
Newton methods to work with dfx.ForwardMode() ODE solves.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…x.LU()

- Change LineSearchNewton/TrustNewton default linear_solver from lx.Cholesky()
  to lx.LU(), which works for any non-singular Hessian without requiring PSD.

- Replace _newton_needs_psd with two predicates:
  - _uses_vanilla_cg: True only for lx.CG (explicit PSD tag requirement).
    Used in test_minimise/test_least_squares: only CG is skipped on non-SPD
    problems; all other solvers run, with PSD tags added problem-dependently.
  - _newton_needs_convex: True for Newton+LU/Cholesky/CG without indefinite-
    capable solver. Used in test_forward_minimisation where forward_only_ode
    has a negative Hessian at the starting point (non-convex region), making
    Newton+LU produce ascent directions that the Armijo search can never accept.

- Tags are now always problem-dependent (not solver-dependent): pass
  positive_semidefinite_tag iff _fn in _spd_minimisation_fns.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Selects Cholesky when the Hessian carries positive_semidefinite_tag
(passed for SPD problems) and falls back to LU otherwise, giving the
best of both: efficiency on convex problems, correctness on general ones.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
- Change default linear_solver for LineSearchNewton and TrustNewton from
  lx.Cholesky() to lx.AutoLinearSolver(well_posed=True), which dispatches
  to Cholesky for SPD-tagged operators and LU otherwise.

- Replace _newton_needs_psd with two predicates:
  - _uses_vanilla_cg: True only when Newton uses lx.CG; used in convergence
    tests to skip non-SPD problems (CG requires positive_semidefinite_tag)
  - _newton_needs_convex: True for Newton with direct linear solver (not
    SteihaugCG/TruncatedCG); used in JVP tests to avoid XLA crash from
    cache exhaustion on non-convex problems

- Make tags problem-dependent (not solver-dependent):
  tags = frozenset({lx.positive_semidefinite_tag}) if fn in _spd_minimisation_fns

- Add scale args to globally_convex for consistency with bowl pattern;
  minimum y*=0 is independent of scale, so ImplicitAdjoint returns zero tangent

- Add NelderMead to loose-tolerance group in test_minimise_jvp

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
ImplicitAdjoint only uses the optimality condition at y* (not the solver
internals), so Newton+LU on non-convex problems produces correct JVP
results. The only reason to skip is when lx.CG requires
positive_semidefinite_tag but the problem is not SPD.

Remove _newton_needs_convex and replace all remaining uses with
_uses_vanilla_cg in test_minimise_jvp, test_forward_minimisation, and
test_least_squares_jvp.

The LLVM OOM crash (Failed to materialize symbols) in combined runs is a
JIT compilation cache exhaustion issue on constrained machines, not a
code or pytree structure bug -- each test passes in isolation.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…ubproblem

TrustNewton previously used NewtonDescent + ClassicalTrustRegion. When the
Hessian is negative definite, NewtonDescent produces an ascent direction that
neither the trust region nor the line search can correct; the trust region only
scales the step size, not the direction. After ~54 halvings the step size
reaches floating-point zero, Cauchy termination fires, and the solver
falsely converges at the starting point.

Fix: replace NewtonDescent with IndirectDampedNewtonDescent, which solves the
true trust-region subproblem (Conn, Gould, Toint §7.3) by root-finding for the
Levenberg-Marquardt parameter λ such that ‖(H + λI)⁻¹g‖ = Δ. This guarantees
H + λI is positive definite at the solution and that the step is a genuine
descent direction.

Also fix damped_newton_step: H + λI is always symmetric when
isinstance(f_info, EvalGradHessian), so always propagate symmetric_tag in
addition to the existing conditional positive_semidefinite_tag. This allows
AutoLinearSolver(well_posed=True) to use LDL for indefinite problems and
Cholesky when the Hessian is tagged positive_semidefinite.

Fixes forward_only_ode-solver13 (TrustNewton) in test_forward_minimisation.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…haug=True

Previously, passing linear_solver alongside use_steihaug=True silently
discarded the solver with no indication to the user. This was confusing
because the parameter appeared to be accepted but had no effect.

Use a sentinel default for linear_solver so we can distinguish an
explicit user-supplied value from the default. Raise a clear ValueError
explaining that SteihaugCGDescent constructs its own TruncatedCG
internally and does not accept an external linear_solver.

Also updates the linear_solver docstring to document the error contract
instead of the previous "Ignored when use_steihaug=True" note.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
When the exact linear solver (Cholesky, LU, etc.) fails on an
indefinite or near-singular Hessian, fall back to the gradient
direction so the Armijo line search can still make progress.

This mirrors scipy's newton-cg behaviour: on the first CG step
with negative curvature (y=0), scipy returns the Cauchy-scaled
steepest-descent direction rather than aborting.  For exact solvers
we don't have curvature information, so we fall back to the raw
gradient and let Armijo find the right step size.

Fixes the forward_only_ode test for LineSearchNewton.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…ctions

The previous fallback only fired when the linear solve returned a
non-successful result.  This missed the case where the solve succeeds
but the Hessian is indefinite: LU on a negative-definite H returns a
direction with g^T(H^{-1}g) < 0, which ascends rather than descends,
causing the Armijo line search to fail silently.

Extend the condition to also fall back when the solved direction is not
a descent direction (g^T newton <= 0), using the dot product check that
costs no extra HVP.  This fixes the forward_only_ode test where the
Hessian is negative at the initial point k=0.6.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Let the non-successful result propagate rather than silently falling
back to steepest descent. Only fall back on non-descent directions.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Flatten _NewtonMinimiserState into _QuasiNewtonState directly.
AbstractNewtonMinimiser now uses _QuasiNewtonState with
hessian_update_state=None, eliminating the separate base class.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
…mments

Revert LM operator to original (lineax AddLinearOperator preserves
symmetry already). Revert two _QuasiNewtonState comments to match main.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Cauchy termination needs three iterations for quadratic problems where
the initial point is not the minimum: one to evaluate at init (y_diff=0
but f_diff large due to zeroed f_info), one to accept the Newton step
(y_diff large = full step to minimum), and one to confirm no further
movement.

Add a gradient-norm termination criterion alongside cauchy: if
||grad|| < atol + rtol * ||y|| (element-wise), the iterate is at a
stationary point. For exact Newton on a quadratic this fires as soon as
we arrive at the minimum, removing the extra confirmation step and
reducing iteration count from 3 to 2 for bowl, matyas, diag_bowl, and
glob_convex.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
The gradient_termination shortcut is out of scope for this PR.
The underlying one-step timing delay vs newton_chord is a deeper
architectural question that warrants separate investigation.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Start from [10, -8, 6] instead of [0.4, -0.3, 0.2] to stress-test
convergence behaviour from a harder starting point.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
The function was incorrectly defined as a plain quadratic. Use the
actual test definition: f(y) = Σ cosh(scale·y) − 1.

Add glob_convex_far with init [10, -8, 6] alongside the original
near-start init [0.4, -0.3, 0.2] from tests/helpers.py.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
- square_minus_one: use x^2-1 (matches helpers.py) instead of (x^2-1)^2
- bowl: use PRNGKey(17) instead of PRNGKey(0) to match test matrix
- diagonal_quadratic_bowl: use tree_map(x^2*(0.1+w)) residual with
  random squared-normal args, matching the helpers.py definition

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
- SteihaugCGDescent and TruncatedCG added to descents.md
- LineSearchNewton, TrustNewton, AbstractNewtonMinimiser added to minimise.md

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Remove the _uses_vanilla_cg helper and replace each call site with the
equivalent one-liner: isinstance(getattr(solver.descent, "linear_solver", None), lx.CG).
Condense the surrounding comments in helpers.py.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
The default linear solver for LineSearchNewton and TrustNewton is
AutoLinearSolver(well_posed=True), not Cholesky.

https://claude.ai/code/session_01XNA4GHKg4a2ixxKTfKixW6
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants