Skip to content

Adahessian optimizer and refactor Hutchinson estimator as shared utility to be used by Adahessian, Sophia optimizer#1604

Open
viralvgupta wants to merge 7 commits intogoogle-deepmind:mainfrom
viralvgupta:codex/hutchinson-refactor
Open

Adahessian optimizer and refactor Hutchinson estimator as shared utility to be used by Adahessian, Sophia optimizer#1604
viralvgupta wants to merge 7 commits intogoogle-deepmind:mainfrom
viralvgupta:codex/hutchinson-refactor

Conversation

@viralvgupta
Copy link

@viralvgupta viralvgupta commented Feb 19, 2026

Summary

This PR implements the AdaHessian optimizer in contrib. It also refactors the Hutchinson estimator into a shared utility and wires both Sophia and AdaHessian to use it, while keeping the existing public APIs intact. In addition, it adds AdaHessian tests to contrib.

Motivation

Sophia and AdaHessian rely on the same Hutchinson estimator, but the implementation was duplicated. This made maintenance more error-prone and led to subtle drift (e.g., multi-sample support only existing in AdaHessian). Centralizing the estimator makes the behavior consistent, reduces duplication, and provides a single place to evolve the estimator logic.

Algorithm details

AdaHessian approximates second‑order scaling using only the diagonal of the Hessian. At each step:

  1. Gradient + momentum
  • Compute gradient: g_t = ∇_θ L(θ_t)
  • First moment (momentum): m_t = β1 m_{t-1} + (1-β1) g_t
  • Bias‑corrected moment: m̂_t = m_t / (1 - β1^t)
  1. Hessian diagonal (Hutchinson)
  • Sample Rademacher probes v (entries in {−1, +1})
  • Compute Hessian‑vector product via JVP: Hv
  • Estimate diagonal: d_t ≈ v ⊙ (H v)
  • Optionally average over n_samples probes
  • The Hessian diagonal is recomputed every update_interval steps; otherwise the cached value is reused.
  1. Second‑order EMA
  • Track squared diagonal with EMA:
    • ν_t = β2 ν_{t-1} + (1-β2) (d_t)^2
    • ν̂_t = ν_t / (1 - β2^t)
  1. Scaling and update
  • Per‑parameter denominator:
    • denom = (ν̂_t)^(hessian_power / 2) + ε
  • Parameter update:
    • Δθ_t = m̂_t / denom
    • θ_{t+1} = θ_t - lr * Δθ_t

This PR does not change AdaHessian defaults; it only centralizes the Hutchinson estimator and tunes test‑only hyperparameters for convergence in the common test suite.

What changed

1) Shared Hutchinson estimator

  • Added optax/contrib/_hutchinson.py with:
    • HutchinsonState
    • hutchinson_estimator_diag_hessian(random_seed=None, n_samples=1)
  • The estimator:
    • Draws Rademacher probe vectors
    • Uses JVP of grad(obj_fn) to compute Hessian-vector products
    • Computes v ⊙ (H v) and averages across samples for variance reduction

2) Sophia / AdaHessian use the shared utility

  • Sophia’s hutchinson_estimator_diag_hessian now forwards to the shared utility with n_samples=1
  • AdaHessian’s version forwards to the shared utility with n_samples passthrough
  • Both preserve existing signatures and behavior

3) contrib re-exports

  • optax.contrib.hutchinson_estimator_diag_hessian and HutchinsonState now re-exported from the shared module
  • AdaHessian symbols (adahessian, scale_by_adahessian, AdaHessianState) added to optax.contrib.__init__ so tests and users can import them from optax.contrib

4) Tests / convergence

  • AdaHessian added to contrib test suite (_common_test.py) and uses obj_fn similarly to Sophia
  • Tuned AdaHessian test-only hyperparameters:
    • hessian_power=0.25
    • (keeps learning_rate=1e-2, update_interval=10)
  • This makes Rosenbrock convergence pass in the common test suite without any skip

5) Documentation / comments

  • Added clearer, top‑level comments in AdaHessian update logic and Hutchinson utility to explain estimator flow and stability steps
  • Added comments in AdaHessian tests to clarify expectations

Behavioral considerations

  • No change to public Sophia/AdaHessian API signatures
  • Hutchinson estimator is now defined in a shared module; behavior is equivalent to the previous implementations
  • n_samples support remains available (now shared)

Testing

  • pytest optax/contrib -vv

Notes / follow‑ups

  • If desired, we can later revisit AdaHessian default hessian_power (currently unchanged in implementation; only the test suite uses 0.25).

@google-cla
Copy link

google-cla bot commented Feb 19, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@viralvgupta viralvgupta changed the title refactor(contrib): share Hutchinson estimator and tune AdaHessian tests Adahessian optimizer and refactor Hutchinson estimator as shared utility to be used by Adahessian, Sophia optimizer Feb 19, 2026
…ity to be used by Adahessian, Sophia optimizer

- Share Hutchinson estimator between Sophia and AdaHessian

- Add AdaHessian implementation/tests and re-export contrib symbols

- Tune AdaHessian test hyperparams for Rosenbrock convergence

- Improve comments in AdaHessian/Hutchinson utilities
@viralvgupta viralvgupta force-pushed the codex/hutchinson-refactor branch from 0f375d7 to 62aea73 Compare February 19, 2026 03:24
Copy link
Collaborator

@rdyro rdyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I left a couple of comments!

@viralvgupta
Copy link
Author

FYI @rdyro Responded to all the comments

Copy link
Collaborator

@rdyro rdyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks great!

@rdyro
Copy link
Collaborator

rdyro commented Feb 23, 2026

Hey, I'm rereading the implementation to merge this internally and I noticed you only included spatial averaging that the paper mentions for Conv2D, but the adahessian repo hints at other cases for spatial averaging as well.

Also, depending on which version of the implementation I'm reading, the update formula appears slightly different.

hv = hv
hv = torch.abs(hv)
hv = torch.abs(hv * vi)

Which reference implementation were you following?

@viralvgupta

@viralvgupta
Copy link
Author

viralvgupta commented Feb 25, 2026

Thats an excellent catch and I think the code diverges from pytorch implementation for dim = 0/1/2.

Note - even though implementation is slightly different both are accurate. I show my thoughts below. However I feel its best to stay consistent with reference pytorch implementation so I will make the changes.

Let

a_i = (H v)_i, v_i ∈ {−1, +1}

Jax code’s sample per element:

s_i = a_i. v_i

PyTorch code (that branch) uses:

t_i = |a_i|

Compare carefully:

s_i and t_i are NOT equal in general (the sign differs).

But after taking absolute value:

|s_i| = |a_i v_i|
= |a_i| |v_i|
= |a_i|
= t_i

because |v_i| = 1.

So the equivalence claim only holds under magnitude-based downstream use.

In AdaHessian, downstream uses a second-moment / square-like accumulation, so the sign disappears:

s_i² = (a_i v_i)² = a_i²

t_i² = |a_i|² = a_i²

FYI @rdyro - made the code changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants