Support true second order methods#225
Support true second order methods#225jpbrodrick89 wants to merge 4 commits intopatrick-kidger:mainfrom
Conversation
|
Note I currently don't support backward-over-backward because that would typically mean you need a custom_vjp of a custom_vjp which would I think be fairly rare. However, if this would actually work for a recursive checkpointing diffrax sim then maybe it's worth supporting. |
|
Oh wow! This is fairly gigantic 😅 I have to be honest I think this is probably out-of-scope here. I think it's much larger than I can commit to maintaining. (Plus, you are currently sending PRs faster than I can find time to review them :p I am very aware of the stack of Lineax PRs I am slowly working through...) This kind of thing might find a good home in some kind of 'optimistixtra' (made up a name). Or possibly we should introduce an Over on the technical side, it's also not immediately clear to me how this differs from the existing WDYT? |
I think that's a pretty fair take 😅 I guess I got a bit carried away showing how much ground this can potentially cover rather than focussing on a digestible incremental approach, but in another sense it's probably healthy seeing a preview of a minimal working API before committing to it blindly before realising it's beyond maintenance appetite. That said, if you do change your mind on this I'd like to allay any pressure on timelines, with my full understanding that this would be a >(>)6 month project that we would just chew slowly through as we have time. (I have no urgent need for this in my own work right now.)
The idea of just passing a root finder to
While I am very keen to eventually create a lineaxtra as I pull together enough useful features, I don't have a burning desire to add an optimistixtra as a separate repo to support all on my own right now (not least because writing this PR made me realise how little I know about optimisation algorithms). One alternative, to As such, before we move on it's probably worthwhile exploring "what would be the MVP to allow power users to use the existing API to more easily roll their own custom second order optimisers and provide a minimal example of how to do this that minimises maintenance" rather than "what this could potentially be given many months of work and added maintenance". These are potential options in order of increasing complexity/maintenance. Each change to the codebase could be done incrementally in a separate digestible PR.
Personally, I'd probably lean towards 2 being a sensible stopping point but I'd still be very content with 1. Either of these would mean I could theoretically add TruncatedCG to lineaxtra and then add another example to optimistix on how to implement SteighaugCGDescent (basically 3 lines of algorithm, all the rest boilerplate/wiring). And of course if you don't think it's valuable to add such an example at all and keep things as are I am also fully sympathetic to that viewpoint and would not be overly disappointed. Eisenstat-Walker Note how TruncatedCG allows rtol and atol to be updated in the |
I like this characterisation a lot! On your hierarchy of options: as a nit, I confess I don't love the name 'AbstractNewtonBase`, what with 'Base' actually just being a synonym for 'Abstract'. :p More seriously, the current iterate-over-steps+accept/reject is something we kind of have a bit copy-paste between three different solvers at the moment (AbstractQuasiNewton, AbstractBFGS, AbstractGradientDescent). I'd lean towards either copy-pasting this again for Newton, or finding a way to unify these and only having this code appear once. If possible, that is... I might be missing something in why I did it this was originally. (NB, I'm also conscious of one other mistake I made here — our 'latest' iterate is actually held in On rtol/atol, I lean towards not making these dynamic, as the dominant use case is actually the opposite way around I think: see for example how Diffrax allows you to put rtol/atol just on the step size controller, and letting everything else inherit them from there. Fiddling with these is kind of annoying. So if dynamic stuff is needed I'd lean towards using |
|
Oh wow! Indeed a fair chunk of new solvers. From a numerics perspective, have you tried different starting points for the non-quadratic bowl problems? Methods with exact Hessians mostly perform well on globally convex functions, and break quite easily on other problems. Himmelblau, depending on where you start, is fairly bowl-like, but it does have stationary points (local maxima). The reason I'm talking about stationary points is that using the exact Hessian in Newton-type solvers introduces a strong tendency for convergence to stationary points. (In the early days of Optimistix, making the existing Newton root-finder a minimiser had actually been a consideration, but it was relegated to being a root-finder only for this particular reason. Jason remarked this somewhere in our issue tracker.) If you'd like to benchmark your solvers on a more extensive collection of minimisation problems, you may want to try the selection we are currently using in our own benchmarking, e.g. here optimistix/benchmarks/test_benchmarks.py Line 77 in 8dc7f30 or just directly add your new solvers to the benchmarking script on your branch. |
|
Hi @johannahaffner I tried running the solvers on your benchmark problem and the iterative second order solvers (NewtonCG and TrustNCG) are remarkably robust, outperforming all quasi-Newton solvers in terms of success rates, iteration counts and objective value (I haven't benchmarked actual runtime though). However, while the exact solvers were the best in terms of iteration counts for easy problems they struggled with hard problems as you suggested with a failure of 15%/43% for LineSearchExact/TrustExact (only problems with dim <= 600). The reasons for this are not insurmountable but as we're not planning to offer this in optimistix proper for the foreseeable future I'm not going to try address just yet.
Summary of results Dolan-Moré plots (objective function is not a standard metric but I thought I might as well throw it in there), solver that wins the most often has highest y-intercept, solver that fails the least often has the highest asymptote as tau->infinity.
CDF of accepted step counts
|


Another (overrunning) weekend project, inspired by realising lineax doesn't really need a HessianLinearOperator. This is a giant PR so I fully appreciate it may take quite some time to review, iterate and merge this (assuming this is something you want to support in optimistix, understand if not). Consider this mostly a proof of concept to show what's possible, I am very open to feedback and opinions on API and design. Note this was written with a lot of LLM support and I didn't want to overly finesse the docs until final architecture is agreed upon and there might be some boilerplate/wiring code (especially in TruncatedCG) that needs further simplification/refactoring.
Note this was motivated by trying to come up with good examples for a Hessian linear operator on a lineax issue discussion. The key behind everything is the
_make_hessian_f_infofunction everything else is just managing negative curvature, trust regions, performance, adding linear solvers to make sure this isn't a damp squib etc. Note that I never useJacobianLinearOperatorbutFunctionLinearOperatoras I need to usejax.linearizeto access the gradient value. Therefore ahessian_linear_operatorwould not actually be used here, the main use case for Hessians. We could of course just have the below function as helper in lineaxgrad_and_hessian_opor something but might be tricky to document and explain clearly.I believe this would make optimistix the first jax library with true jittable second-order methods, (pretty sure optax doesn't have any and jaxopt only has a non-accelerated wrapper of scipy.optimize.minimize).
API
This PR offers two main points of entry:
LineSearchNewton
Uses NewtonDescent and BacktrackingArmijo linesearch with exact Hessian operator, uses steepest descent in regions of negative curvature.
TrustNewton
Uses ClassicalTrustRegion with either the new
SteihaugCGDescent(to replicate scipy's trust-ncg) orIndirectDampedNewtonDescent(akin to scipy's trust-exact)We also offer
AbstractNewtonMinimiserto allow users to thread custom descents and searches (e.g. allowing a DoglegNewton like DoglegBFGS).It also introduces a new linear solver TruncatedCG that detects negative curvature and is aware of trust regions. It can be used directly in
LineSearchNewton(withlinear_solver=...coming close to scipy's newton-cg) or indirectly byTrustNewtonwithuse_steihaug=True.Summary of scipy mapping
Newton-CG->LineSearchNewton(linear_solver=TruncatedCG())(Gaps: scipy uses Wolfe linesearch we use backtracking Armijo)trust-ncg->TrustNewton(use_steihaug=True)(pretty faithful implementation works even for non-SPD)‘trust-exact’ -> Needs further development to properly support More-Sorensen but for SPD cases
solver=TrustNewton(linear_solver=lx.Cholesky()), tags=frozenset({lx.positive_semidefinite_tag})should work pretty well‘trust-krylov’ – Not supported yet requires even more complex linear solvers and wiring I believe
Design
Created a new _AbstractNewtonBase which acts as a parent for _AbstractQuasiNewton and the new _AbstractNewtonMinimiser, all the child classes need to do is define
initand_prepare_step.Performance
I have not bechmarked runtime/compile time but iteration wise these solvers beat all existing minimisers but typically lose against least square solvers. For a quadratic bowl all these solvers except for those that use TruncatedCG complete in 3 iterations (should be 1 but optimistix requires two steps for confirming Cauchy convergence but then only checks this at the beginning of the third step, we can experiment with pure gradient-based termination to improve if you're interested) where the best quasi-Newton solver (LBFGS) takes 48. TruncatedCG takes 5 iterations. For Beale/Himmelblau they converge in 6/8 iterations (although Steihaug takes 10 for Himmelblau) against 14/13 with the best quasi-Newton solver.
For completeness here are the *minimiser comparisons:
And the least square comparisons:
Questions
Assuming we want to support second-order methods, these questions come to mind first as things to iron out to get this pull ready.
a) Make _AbstractNewtonMinimiser the only public API (with a better name) users can just plug and play with descents and searches.
b) (This PR) Concretise on search type only: LineSearchNewton and TrustNewton
c) Concretise very specific solvers (e.g. NewtonCG, TrustNCG, TrustExact) and maybe add even more Abstract classes
If we go with b) there is currently a bit of a gotcha,
use_steihaug=TrueALWAYS usesTruncatedCGand requireslinear_solverto be sentinel_UNSETand errors if it isn't. As such having thelinear_solverargument could be a bit confusing and separating Steihaug out as TrustNCG (option c) might be sensible.optimistix.compat.minimize(including those you already support)Tags
_AbstractNewtonBaseminimisers to be is_symmetric (should be generally true), or just silently thread it in as I am doing currently in `_grad_hessianFuture Linear Solvers (out of scope for this PR)
LineSearchNewton would arguably benefit from a More-Sorensen approach such as ModifiedCholesky or ModifiedLDL, let me know how keen you are in having those in optimistix/lineax. I am currently working on exposing sytrf in jax to support LDL, but it performs pretty poorly on GPU compared to LU.