Add PolarExpress variant to Muon optimizer#1613
Add PolarExpress variant to Muon optimizer#1613MarcMachaczek wants to merge 20 commits intogoogle-deepmind:mainfrom
Conversation
…number of predefined coeffs. fixed ValueError when number of steps exceeds number of ns_coeffs in scale_by_muon
…number of ns_coeffs
|
Happy to get your feedback @NoahAmsel, just want to make sure everything is implemented faithfully. |
… remez algorithm is skipped from 1 - 5e-6 to 1 - 5e-5. for the previous cutoff, the linear system became ill-conditioned (k~1e11), causing platform dependent deviations on the order 1e-5
|
Thanks for implementing @MarcMachaczek! Mostly looks good. Three comments
|
… (binary and low-rank spectrum)
…nal work, this required a different cushion value
|
Thanks a lot for the feedback!
I will set
I think I need a bit more feedback here: Would you suggest we use Frobenius or Schatten-4 preconditioning as the default? (We could keep 'polar_express' as a preconditioning mode, and make it an alias for one of the two). Do the optimal coefficients for the ns iterations depend on the preconditioning?
Good point! I implemented an additional test that tests the low-rank and binary spectrum case. I'll push it later with the rest. |
… this again allows for tighter remez algorithm cutoff (now again the same as in original work)
…tion schatten as recommended; remove associated warnings ang logging import
No, they just assume that the input matrix has singular values <= 1. In exact arithmetic, all the initialization methods here accomplish that goal. So you can set the default however you like. (As an aside, I like Schatten-4 but for this PR, it's probably better to leave the default alone.) Polar Express may be slightly more sensitive to numerical issues than other coefficients. I.e., if you have a singular value at 1.01, polar express might blow up when the baselines didn't. So just to be safe, let's make the default safety factor 2e-2 instead of 1e-2. That will ensure we get the same stability benefits of what you were calling "polar express preconditioning", because (sigma / (frobenius norm * 1.01)) / 1.01 is about the same as (sigma / (frobenius norm)) / 1.02
I think there is a very small difference. If you apply it in the loop, it will affect the ell and u used in the next iteration. in turn, this will slightly change the coefficients of the polynomial used in the next iteration. So it's better to do it in the loop, but the difference is really small. Actually I think there may be an even better simpler option. replace this with this I feel this might be easier to understand because it highlights the problem we're trying to solve: while the singular values are supposed to lie in [ell, u], they might have spilled out into the interval [ell, u * safety_factor] due to numerical roundoff. But basically any way you prefer is fine. |
|
Thanks for the feedback. I increased the default ell to 2e-2. I think everything is good to go now. |
Add PolarExpress variant to Muon optimizer
Addresses #1621.
Integrates the PolarExpress method (Amsel et al., 2025) into optax's Muon optimizer.
This has already been suggested here #1602. Currently, 'polar_express' is an option that can be set separately for both ns_coeffs and preconditioning. A warning is raised in case only one of them is set to 'polar_express'. The original implementation uses eps=1e-7 (current default is 1e-8). To match the original implementation, it must be set explicitly, but it likely will not make a big difference in practice; hence, I decided against raising any warnings if eps!=1e-7 when polar_express is specified.
Changes
New features:
preconditioning='polar_express': Frobenius norm normalization with a* 1.01safety factor for numerical stabilityns_coeffs='polar_express': coefficients from Amsel et al., 2025ns_steps > 8, the last coefficient(1.875, -1.25, 0.375)is repeated — matching the reference implementationabsl.loggingwhenns_coeffs='polar_express'andpreconditioningdon't agree (or vice versa)Bug fixes:
ns_coeffsslicing direction: changedns_coeffs_[-ns_steps:]tons_coeffs_[:ns_steps]inscale_by_muon.init_fn. The old behavior took the last N coefficients. The new behavior takes the first N.ns_coeffsvalidation: the original conditionnot ns_coeffs_.shape[0] <= ns_stepsraised an error when there were more coefficients than steps (inverted logic). Fixed tons_coeffs_.shape[0] < ns_steps, which correctly errors only when there are fewer coefficients than needed.Usage
Update:
Replaced hardcoded PolarExpress coefficients with the actual computation functions (_optimal_quintic, polar_express_coeffs) translated from the https://github.com/NoahAmsel/PolarExpress.
Now, optimal coefficients are computed based on the provided number of iterations
ns_steps, and from the (optional) provided mu_dtype, which determines the lower boundl(bfloat16 → 1e-3, float32 → 1e-7, etc.).Additionally, users can call polar_express_coeffs() directly with custom parameters.
Update 2:
The safety factor is applied after all optimal polynomial factors have been computed, improving numerical stability.
Revert to using mu_dtype independent lower bound of
l=1e-3according to Noah's feedback. Users who want more fine-grained control over these parameters can callpolar_express_coeffsand feed the coeffs directly to muon. The 'polar_express' preconditioner was removed. The default Frobenius is sufficient (the only difference was the additional safety factor, which, according to Noah, is unnecessary). It is mentioned in the documentation that theschattenpreconditioner is recommended. Constants have been removed in favor of defaults in thepolar_express_coeffs. Finally, a more complex test has been added for the polar express variant.