Adahessian optimizer and refactor Hutchinson estimator as shared utility to be used by Adahessian, Sophia optimizer#1604
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
…ity to be used by Adahessian, Sophia optimizer - Share Hutchinson estimator between Sophia and AdaHessian - Add AdaHessian implementation/tests and re-export contrib symbols - Tune AdaHessian test hyperparams for Rosenbrock convergence - Improve comments in AdaHessian/Hutchinson utilities
0f375d7 to
62aea73
Compare
rdyro
left a comment
There was a problem hiding this comment.
Thanks, I left a couple of comments!
|
FYI @rdyro Responded to all the comments |
rdyro
left a comment
There was a problem hiding this comment.
Thanks, this looks great!
|
Hey, I'm rereading the implementation to merge this internally and I noticed you only included spatial averaging that the paper mentions for Conv2D, but the adahessian repo hints at other cases for spatial averaging as well. Also, depending on which version of the implementation I'm reading, the update formula appears slightly different. hv = hv
hv = torch.abs(hv)
hv = torch.abs(hv * vi)Which reference implementation were you following? |
|
Thats an excellent catch and I think the code diverges from pytorch implementation for dim = 0/1/2. Note - even though implementation is slightly different both are accurate. I show my thoughts below. However I feel its best to stay consistent with reference pytorch implementation so I will make the changes. Let a_i = (H v)_i, v_i ∈ {−1, +1} Jax code’s sample per element: s_i = a_i. v_i PyTorch code (that branch) uses: t_i = |a_i| Compare carefully: s_i and t_i are NOT equal in general (the sign differs). But after taking absolute value: |s_i| = |a_i v_i| because |v_i| = 1. So the equivalence claim only holds under magnitude-based downstream use. In AdaHessian, downstream uses a second-moment / square-like accumulation, so the sign disappears: s_i² = (a_i v_i)² = a_i² t_i² = |a_i|² = a_i² FYI @rdyro - made the code changes |
Summary
This PR implements the AdaHessian optimizer in contrib. It also refactors the Hutchinson estimator into a shared utility and wires both Sophia and AdaHessian to use it, while keeping the existing public APIs intact. In addition, it adds AdaHessian tests to contrib.
Motivation
Sophia and AdaHessian rely on the same Hutchinson estimator, but the implementation was duplicated. This made maintenance more error-prone and led to subtle drift (e.g., multi-sample support only existing in AdaHessian). Centralizing the estimator makes the behavior consistent, reduces duplication, and provides a single place to evolve the estimator logic.
Algorithm details
AdaHessian approximates second‑order scaling using only the diagonal of the Hessian. At each step:
g_t = ∇_θ L(θ_t)m_t = β1 m_{t-1} + (1-β1) g_tm̂_t = m_t / (1 - β1^t)v(entries in {−1, +1})Hvd_t ≈ v ⊙ (H v)n_samplesprobesupdate_intervalsteps; otherwise the cached value is reused.ν_t = β2 ν_{t-1} + (1-β2) (d_t)^2ν̂_t = ν_t / (1 - β2^t)denom = (ν̂_t)^(hessian_power / 2) + εΔθ_t = m̂_t / denomθ_{t+1} = θ_t - lr * Δθ_tThis PR does not change AdaHessian defaults; it only centralizes the Hutchinson estimator and tunes test‑only hyperparameters for convergence in the common test suite.
What changed
1) Shared Hutchinson estimator
optax/contrib/_hutchinson.pywith:HutchinsonStatehutchinson_estimator_diag_hessian(random_seed=None, n_samples=1)grad(obj_fn)to compute Hessian-vector productsv ⊙ (H v)and averages across samples for variance reduction2) Sophia / AdaHessian use the shared utility
hutchinson_estimator_diag_hessiannow forwards to the shared utility withn_samples=1n_samplespassthrough3) contrib re-exports
optax.contrib.hutchinson_estimator_diag_hessianandHutchinsonStatenow re-exported from the shared moduleadahessian,scale_by_adahessian,AdaHessianState) added tooptax.contrib.__init__so tests and users can import them fromoptax.contrib4) Tests / convergence
_common_test.py) and usesobj_fnsimilarly to Sophiahessian_power=0.25learning_rate=1e-2,update_interval=10)5) Documentation / comments
Behavioral considerations
n_samplessupport remains available (now shared)Testing
pytest optax/contrib -vvNotes / follow‑ups
hessian_power(currently unchanged in implementation; only the test suite uses0.25).