Skip to content

Add PolarExpress variant to Muon optimizer#1613

Open
MarcMachaczek wants to merge 20 commits intogoogle-deepmind:mainfrom
MarcMachaczek:feat/polar-express
Open

Add PolarExpress variant to Muon optimizer#1613
MarcMachaczek wants to merge 20 commits intogoogle-deepmind:mainfrom
MarcMachaczek:feat/polar-express

Conversation

@MarcMachaczek
Copy link

@MarcMachaczek MarcMachaczek commented Mar 4, 2026

Add PolarExpress variant to Muon optimizer

Addresses #1621.

Integrates the PolarExpress method (Amsel et al., 2025) into optax's Muon optimizer.
This has already been suggested here #1602. Currently, 'polar_express' is an option that can be set separately for both ns_coeffs and preconditioning. A warning is raised in case only one of them is set to 'polar_express'. The original implementation uses eps=1e-7 (current default is 1e-8). To match the original implementation, it must be set explicitly, but it likely will not make a big difference in practice; hence, I decided against raising any warnings if eps!=1e-7 when polar_express is specified.

Changes

New features:

  • preconditioning='polar_express': Frobenius norm normalization with a * 1.01 safety factor for numerical stability
  • ns_coeffs='polar_express': coefficients from Amsel et al., 2025
  • Coefficient padding: when ns_steps > 8, the last coefficient (1.875, -1.25, 0.375) is repeated — matching the reference implementation
  • Mismatch warnings via absl.logging when ns_coeffs='polar_express' and preconditioning don't agree (or vice versa)

Bug fixes:

  • ns_coeffs slicing direction: changed ns_coeffs_[-ns_steps:] to ns_coeffs_[:ns_steps] in scale_by_muon.init_fn. The old behavior took the last N coefficients. The new behavior takes the first N.
  • ns_coeffs validation: the original condition not ns_coeffs_.shape[0] <= ns_steps raised an error when there were more coefficients than steps (inverted logic). Fixed to ns_coeffs_.shape[0] < ns_steps, which correctly errors only when there are fewer coefficients than needed.

Usage

import optax

optimizer = optax.contrib.muon(
    learning_rate=0.02,
    ns_coeffs='polar_express',
    preconditioning='polar_express',
    ns_steps=8,
    eps=1e-7

)

Update:

Replaced hardcoded PolarExpress coefficients with the actual computation functions (_optimal_quintic, polar_express_coeffs) translated from the https://github.com/NoahAmsel/PolarExpress.

Now, optimal coefficients are computed based on the provided number of iterations ns_steps, and from the (optional) provided mu_dtype, which determines the lower bound l (bfloat16 → 1e-3, float32 → 1e-7, etc.).
Additionally, users can call polar_express_coeffs() directly with custom parameters.

Update 2:

The safety factor is applied after all optimal polynomial factors have been computed, improving numerical stability.
Revert to using mu_dtype independent lower bound of l=1e-3 according to Noah's feedback. Users who want more fine-grained control over these parameters can call polar_express_coeffs and feed the coeffs directly to muon. The 'polar_express' preconditioner was removed. The default Frobenius is sufficient (the only difference was the additional safety factor, which, according to Noah, is unnecessary). It is mentioned in the documentation that the schattenpreconditioner is recommended. Constants have been removed in favor of defaults in the polar_express_coeffs. Finally, a more complex test has been added for the polar express variant.

@MarcMachaczek
Copy link
Author

Happy to get your feedback @NoahAmsel, just want to make sure everything is implemented faithfully.

… remez algorithm is skipped from 1 - 5e-6 to 1 - 5e-5. for the previous cutoff, the linear system became ill-conditioned (k~1e11), causing platform dependent deviations on the order 1e-5
@NoahAmsel
Copy link

Thanks for implementing @MarcMachaczek! Mostly looks good. Three comments

  • I can see why you set the default ell to be 1e-7 for float32, but even in this case, it's probably better to use something much closer to 1e-3. We didn't try training in float32, but for the Muon application, tiny singular values are probably just noise anyway. Take a look at appendix H.1 of our latest revision. Better to do a good job on the range [1e-3, 1] than a subpar job on the range [1e-7, 1].
  • I'm not sure that we need to use what you've called the polar express "preconditioning" with the polar express polynomials. @gowerrobert correct me if i'm wrong, but if we use a safety factor -- replacing each Polar Express polynomial p(x) with p(x/1.01) -- then we don't also need to divide by ||X||_F * 1.01. We did it in our code because there's no harm, but if you have a configurable "preconditioning" method then it's probably overkill. If you're worried, you can always just raise the safety factor to 1.02 when computing the polar express coefficients and get the same effect. I'd guess we should be using Schatten-4.
  • In your tests, you probably want to use a numerically more difficult matrix than 2*Id. Diagonal matrices are unrealistically easy. You should also try a matrix with random singular vectors, perhaps one that's numerically low rank (exponentially decaying singular values). to test for blowups, i'd suggest a matrix with all singular values equal either to 0 or to 2, random singular vectors, and the "spectral norm preconditioning", which is numerically the closest to being unstable.

@MarcMachaczek
Copy link
Author

Thanks a lot for the feedback!

  • I can see why you set the default ell to be 1e-7 for float32, but even in this case, it's probably better to use something much closer to 1e-3. We didn't try training in float32, but for the Muon application, tiny singular values are probably just noise anyway. Take a look at appendix H.1 of our latest revision. Better to do a good job on the range [1e-3, 1] than a subpar job on the range [1e-7, 1].

I will set l to 1e-3 with a note in the documentation. If a user wants something different or more fine-grained, they can just call the polar express coefficient function directly with custom parameters and feed it into muon.

  • I'm not sure that we need to use what you've called the polar express "preconditioning" with the polar express polynomials. @gowerrobert correct me if i'm wrong, but if we use a safety factor -- replacing each Polar Express polynomial p(x) with p(x/1.01) -- then we don't also need to divide by ||X||_F * 1.01. We did it in our code because there's no harm, but if you have a configurable "preconditioning" method then it's probably overkill. If you're worried, you can always just raise the safety factor to 1.02 when computing the polar express coefficients and get the same effect. I'd guess we should be using Schatten-4.

I think I need a bit more feedback here: Would you suggest we use Frobenius or Schatten-4 preconditioning as the default? (We could keep 'polar_express' as a preconditioning mode, and make it an alias for one of the two). Do the optimal coefficients for the ns iterations depend on the preconditioning?
Moreover, when applying the safety factor to the coefficients, does it make a relevant difference whether this is done inside the loop (the way it is done in your repo) or afterwards (the way you do it in Implementation 1)?
I found that the latter is more stable (lower conditioning numbers in the linear system of equations to be solved inside _optimal_quintic)

  • In your tests, you probably want to use a numerically more difficult matrix than 2*Id. Diagonal matrices are unrealistically easy. You should also try a matrix with random singular vectors, perhaps one that's numerically low rank (exponentially decaying singular values). to test for blowups, i'd suggest a matrix with all singular values equal either to 0 or to 2, random singular vectors, and the "spectral norm preconditioning", which is numerically the closest to being unstable.

Good point! I implemented an additional test that tests the low-rank and binary spectrum case. I'll push it later with the rest.

@MarcMachaczek MarcMachaczek changed the title Feat/polar express Add PolarExpress variant to Muon optimizer Mar 5, 2026
@NoahAmsel
Copy link

NoahAmsel commented Mar 7, 2026

Do the optimal coefficients for the ns iterations depend on the preconditioning?

No, they just assume that the input matrix has singular values <= 1. In exact arithmetic, all the initialization methods here accomplish that goal. So you can set the default however you like. (As an aside, I like Schatten-4 but for this PR, it's probably better to leave the default alone.) Polar Express may be slightly more sensitive to numerical issues than other coefficients. I.e., if you have a singular value at 1.01, polar express might blow up when the baselines didn't. So just to be safe, let's make the default safety factor 2e-2 instead of 1e-2. That will ensure we get the same stability benefits of what you were calling "polar express preconditioning", because (sigma / (frobenius norm * 1.01)) / 1.01 is about the same as (sigma / (frobenius norm)) / 1.02

Moreover, when applying the safety factor to the coefficients, does it make a relevant difference whether this is done inside the loop (the way it is done in your repo) or afterwards (the way you do it in Implementation 1)?

I think there is a very small difference. If you apply it in the loop, it will affect the ell and u used in the next iteration. in turn, this will slightly change the coefficients of the polynomial used in the next iteration. So it's better to do it in the loop, but the difference is really small. Actually I think there may be an even better simpler option. replace this

  for _ in range(num_iters):
    a, b, c = _optimal_quintic(max(l, cushion * u), u)

with this

  for iter in range(num_iters):
    if iter < num_iters - 1: u *= safety_factor
    a, b, c = _optimal_quintic(max(l, cushion * u), u)

I feel this might be easier to understand because it highlights the problem we're trying to solve: while the singular values are supposed to lie in [ell, u], they might have spilled out into the interval [ell, u * safety_factor] due to numerical roundoff. But basically any way you prefer is fine.

@MarcMachaczek
Copy link
Author

MarcMachaczek commented Mar 8, 2026

Thanks for the feedback. I increased the default ell to 2e-2.
I reverted to including the safety factor inside the loop. I chose to stick to the original variant, however. This should give more consistency with your reference torch implementations. The alternative option you mentioned definitely makes the problem the safety factor is mitigating more obvious, but it's functionally different.
Finally, I slightly bumped the threshold from 1 - 5e-6 to 1 - 1e-5 at which the Remez iterations are cutoff, and the exact result is used. I found that this helps make results more stable and consistent across different platforms (some earlier GitHub tests failed because of this). When using el=1e-3, for example, the condition number at the 7th iteration reaches up to 4.5e11. It seems this slight bump is enough to make things stable enough, while the approximation error is essentially still around double precision.

I think everything is good to go now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants