From e687744997890222ddc2bf2b2f6a7234a8af14d4 Mon Sep 17 00:00:00 2001 From: Marc Date: Tue, 3 Mar 2026 20:56:33 -0500 Subject: [PATCH 01/20] implementing polar express preconditioning and ns_coeffs --- optax/contrib/_muon.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index abaa26362..2ba765cc7 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -38,7 +38,7 @@ ReshapeFn = Callable[[jax.Array], jax.Array] -_PRECONDITIONINGS = ['frobenius', 'spectral', 'aol', 'schatten'] +_PRECONDITIONINGS = ['frobenius', 'spectral', 'aol', 'schatten', 'polar_express'] _DEFAULT_NS_COEFFS = (3.4445, -4.7750, 2.0315) _DION_NS_COEFFS = [ (4.0848, -6.8946, 2.9270), @@ -47,9 +47,30 @@ (2.8769, -3.1427, 1.2046), (2.8366, -3.0525, 1.2012), ] +_POLAR_EXPRESS_RAW_COEFFS = [ + (8.28721201814563, -23.595886519098837, 17.300387312530933), + (4.107059111542203, -2.9478499167379106, 0.5448431082926601), + (3.9486908534822946, -2.908902115962949, 0.5518191394370137), + (3.3184196573706015, -2.488488024314874, 0.51004894012372), + (2.300652019954817, -1.6689039845747493, 0.4188073119525673), + (1.891301407787398, -1.2679958271945868, 0.37680408948524835), + (1.8750014808534479, -1.2500016453999487, 0.3750001645474248), + (1.875, -1.25, 0.375), # subsequent coeffs equal this numerically +] +# Safety factor adjustment for numerical stability +# See Section 4.4 of Amsel et al., 2025. +_POLAR_EXPRESS_SAFETY = 1.01 +_POLAR_EXPRESS_NS_COEFFS = [ + (a / _POLAR_EXPRESS_SAFETY, + b / _POLAR_EXPRESS_SAFETY**3, + c / _POLAR_EXPRESS_SAFETY**5) + for a, b, c in _POLAR_EXPRESS_RAW_COEFFS[:-1] +] + [_POLAR_EXPRESS_RAW_COEFFS[-1]] + _NS_COEFFS_PRESET_DICT = { 'standard': _DEFAULT_NS_COEFFS, 'dion': _DION_NS_COEFFS, + 'polar_express': _POLAR_EXPRESS_NS_COEFFS, } @@ -280,7 +301,7 @@ def orthogonalize_via_newton_schulz( ns_coeffs: jax.Array, ns_steps: jax.typing.ArrayLike = 5, preconditioning: Literal[ - 'frobenius', 'spectral', 'aol', 'schatten' + 'frobenius', 'spectral', 'aol', 'schatten', 'polar_express' ] = 'frobenius', eps: jax.typing.ArrayLike = 1e-8, dimension_numbers: MuonDimensionNumbers | None = None, @@ -331,6 +352,7 @@ def _orthogonalize(x): 'spectral': _base_ns_iterator, 'aol': _aol_ns_iterator, 'schatten': _schatten_ns_iterator, + 'polar_express': _base_ns_iterator, } if preconditioning not in _PRECONDITIONINGS: raise ValueError(f'Unknown preconditioning {preconditioning}') @@ -340,6 +362,8 @@ def _orthogonalize(x): x /= jnp.linalg.norm(x, ord='fro') + eps elif preconditioning == 'spectral': x /= jnp.linalg.norm(x, ord=2) + eps + elif preconditioning == 'polar_express': + x /= jnp.linalg.norm(x, ord='fro') * _POLAR_EXPRESS_SAFETY + eps else: pass @@ -391,7 +415,7 @@ def scale_by_muon( nesterov: bool = True, adaptive: bool = False, preconditioning: Literal[ - 'frobenius', 'spectral', 'aol', 'schatten' + 'frobenius', 'spectral', 'aol', 'schatten', 'polar_express' ] = 'frobenius', weight_dimension_numbers: WeightDimNumOrFn | None = None, ) -> base.GradientTransformation: @@ -540,7 +564,7 @@ def muon( nesterov: bool = True, adaptive: bool = False, preconditioning: Literal[ - 'frobenius', 'spectral', 'aol', 'schatten' + 'frobenius', 'spectral', 'aol', 'schatten', 'polar_express' ] = 'frobenius', adam_b1: jax.typing.ArrayLike = 0.9, adam_b2: jax.typing.ArrayLike = 0.999, @@ -653,6 +677,9 @@ def muon( if isinstance(ns_coeffs, str): if ns_coeffs not in _NS_COEFFS_PRESET_DICT: raise ValueError(f'Unknown ns_coeff preset string: {ns_coeffs}') + if ns_coeffs == 'polar_express' and preconditioning != 'polar_express': + print("Warning: Using 'polar_express' ns_coeffs without 'polar_express' preconditioning" + "is suboptimal.") ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs] else: ns_coeffs_ = ns_coeffs From e5473f9046e7fa1cb62189f7e0cceeff0e1794ae Mon Sep 17 00:00:00 2001 From: Marc Date: Tue, 3 Mar 2026 21:07:13 -0500 Subject: [PATCH 02/20] padding ns_coeffs for polar_express when number of steps exceeds the number of predefined coeffs. fixed ValueError when number of steps exceeds number of ns_coeffs in scale_by_muon --- optax/contrib/_muon.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index 2ba765cc7..b6aa52fe4 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -487,7 +487,7 @@ def init_fn(params): f'ns_coeffs must have shape (3,) or (n, 3), got {ns_coeffs_.shape}' ) if ns_coeffs_.ndim == 2: - if not ns_coeffs_.shape[0] <= ns_steps: + if ns_coeffs_.shape[0] < ns_steps: raise ValueError(f'Not enough coeffs to perform {ns_steps} steps') ns_coeffs_ = ns_coeffs_[-ns_steps:] @@ -681,6 +681,9 @@ def muon( print("Warning: Using 'polar_express' ns_coeffs without 'polar_express' preconditioning" "is suboptimal.") ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs] + if ns_coeffs == 'polar_express' and ns_steps > len(ns_coeffs_): + n_pad = ns_steps - len(ns_coeffs_) + ns_coeffs_ = list(ns_coeffs_) + [ns_coeffs_[-1]] * n_pad else: ns_coeffs_ = ns_coeffs From 6b8834e1e8dbe4c16d7092fd85f155823a49dbd1 Mon Sep 17 00:00:00 2001 From: Marc Date: Tue, 3 Mar 2026 21:24:45 -0500 Subject: [PATCH 03/20] fixed ns_coeffs selection logic when number of steps is smaller than number of ns_coeffs --- optax/contrib/_muon.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index b6aa52fe4..2a87afbac 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -489,7 +489,7 @@ def init_fn(params): if ns_coeffs_.ndim == 2: if ns_coeffs_.shape[0] < ns_steps: raise ValueError(f'Not enough coeffs to perform {ns_steps} steps') - ns_coeffs_ = ns_coeffs_[-ns_steps:] + ns_coeffs_ = ns_coeffs_[:ns_steps] return MuonState( count=jnp.zeros([], jnp.int32), @@ -755,3 +755,14 @@ def muon_weight_dim_nums_fn(params): }, param_labels=param_labels, ) + +# %% +ns_coeffs = 'polar_express' +ns_steps = 10 +ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs] +print(len(ns_coeffs_), ns_coeffs_) +if ns_coeffs == 'polar_express' and ns_steps > len(ns_coeffs_): + n_pad = ns_steps - len(ns_coeffs_) + ns_coeffs_ = list(ns_coeffs_) + [ns_coeffs_[-1]] * n_pad + +print(len(ns_coeffs_), ns_coeffs_) \ No newline at end of file From ab9656c9dbb98d700cf186cc8d1a5a9ef8faa5ce Mon Sep 17 00:00:00 2001 From: Marc Date: Tue, 3 Mar 2026 21:26:07 -0500 Subject: [PATCH 04/20] spacing --- optax/contrib/_muon.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index 2a87afbac..c5ef882e7 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -678,9 +678,10 @@ def muon( if ns_coeffs not in _NS_COEFFS_PRESET_DICT: raise ValueError(f'Unknown ns_coeff preset string: {ns_coeffs}') if ns_coeffs == 'polar_express' and preconditioning != 'polar_express': - print("Warning: Using 'polar_express' ns_coeffs without 'polar_express' preconditioning" - "is suboptimal.") + print("Warning: Using 'polar_express' ns_coeffs without 'polar_express' preconditioning" + "is suboptimal.") ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs] + if ns_coeffs == 'polar_express' and ns_steps > len(ns_coeffs_): n_pad = ns_steps - len(ns_coeffs_) ns_coeffs_ = list(ns_coeffs_) + [ns_coeffs_[-1]] * n_pad From b91e97861d5ed7038e1e81fc0a9766daa917f0bf Mon Sep 17 00:00:00 2001 From: Marc Date: Tue, 3 Mar 2026 21:44:49 -0500 Subject: [PATCH 05/20] documentation and proper warnings --- optax/contrib/_muon.py | 39 ++++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index c5ef882e7..a44c00684 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -24,6 +24,8 @@ import math from typing import Any, Callable, NamedTuple, Optional, Union, Sequence, Literal +from absl import logging + import jax import jax.numpy as jnp @@ -443,6 +445,9 @@ def scale_by_muon( - 'spectral' : Use Spectral norm rescaling before NS. - 'aol': Use AOL rescaling to improve orthogonality. - 'schatten': Use the Schatten-4 norm for rescaling. + - 'polar_express': Use Frobenius norm with a 1.01 safety factor, + designed for use with ``ns_coeffs='polar_express'``. + See . weight_dimension_numbers: An optional tree with the same structure as the params of `MuonDimensionNumbers`s, specifying how to reshape the parameters before and after the orthogonalization OR a callable returning @@ -591,7 +596,8 @@ def muon( learning_rate: A global scaling factor, either fixed or evolving along iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. ns_coeffs: Coefficients for the Newton-schulz method (can be a string - indicator for a preset). Existing presets: `muon`, `dion`. + indicator for a preset). Existing presets: `standard`, `dion`, + `polar_express`. ns_steps: Number of Newton-schulz iterations. Ignored if `ns_coeffs` is a tuple of tuples. beta: Decay rate for the exponentially weighted average of grads. @@ -623,6 +629,10 @@ def muon( - 'schatten': Use the Schatten-4 norm for rescaling, allows for better performance with little to no extra cost. See . + - 'polar_express': Use Frobenius norm with a 1.01 safety factor + for bfloat16 stability, designed for use with + ``ns_coeffs='polar_express'``. + See . adam_b1: Exponential decay rate for Adam's first moment estimates. adam_b2: Exponential decay rate for Adam's second moment estimates. adam_eps_root: Epsilon to stabilize division in Adam, square root version. @@ -677,11 +687,21 @@ def muon( if isinstance(ns_coeffs, str): if ns_coeffs not in _NS_COEFFS_PRESET_DICT: raise ValueError(f'Unknown ns_coeff preset string: {ns_coeffs}') - if ns_coeffs == 'polar_express' and preconditioning != 'polar_express': - print("Warning: Using 'polar_express' ns_coeffs without 'polar_express' preconditioning" - "is suboptimal.") + ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs] + if ns_coeffs == 'polar_express' and preconditioning != 'polar_express': + logging.warning( + 'Using polar_express ns_coeffs without polar_express' + ' preconditioning is suboptimal and might lead to instability.' + ) + if (preconditioning == 'polar_express' + and ns_coeffs != 'polar_express'): + logging.warning( + 'Using polar_express preconditioning without polar_express' + ' ns_coeffs is not recommended.' + ) + if ns_coeffs == 'polar_express' and ns_steps > len(ns_coeffs_): n_pad = ns_steps - len(ns_coeffs_) ns_coeffs_ = list(ns_coeffs_) + [ns_coeffs_[-1]] * n_pad @@ -756,14 +776,3 @@ def muon_weight_dim_nums_fn(params): }, param_labels=param_labels, ) - -# %% -ns_coeffs = 'polar_express' -ns_steps = 10 -ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs] -print(len(ns_coeffs_), ns_coeffs_) -if ns_coeffs == 'polar_express' and ns_steps > len(ns_coeffs_): - n_pad = ns_steps - len(ns_coeffs_) - ns_coeffs_ = list(ns_coeffs_) + [ns_coeffs_[-1]] * n_pad - -print(len(ns_coeffs_), ns_coeffs_) \ No newline at end of file From e61b4a588503c3219aca32e5fe983055f92352df Mon Sep 17 00:00:00 2001 From: Marc Date: Tue, 3 Mar 2026 21:57:53 -0500 Subject: [PATCH 06/20] formatting --- optax/contrib/_muon.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index a44c00684..e3437d724 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -695,8 +695,7 @@ def muon( 'Using polar_express ns_coeffs without polar_express' ' preconditioning is suboptimal and might lead to instability.' ) - if (preconditioning == 'polar_express' - and ns_coeffs != 'polar_express'): + if preconditioning == 'polar_express' and ns_coeffs != 'polar_express': logging.warning( 'Using polar_express preconditioning without polar_express' ' ns_coeffs is not recommended.' From 8c94ae2dc06b234cb5f92b6a51a0246a8253e1e2 Mon Sep 17 00:00:00 2001 From: Marc Date: Tue, 3 Mar 2026 21:57:59 -0500 Subject: [PATCH 07/20] tests (pass) --- optax/contrib/_muon_test.py | 59 +++++++++++++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 5 deletions(-) diff --git a/optax/contrib/_muon_test.py b/optax/contrib/_muon_test.py index 02c02f78d..37d939009 100644 --- a/optax/contrib/_muon_test.py +++ b/optax/contrib/_muon_test.py @@ -118,7 +118,8 @@ def test_reshape_inverse(self, input_shape, dim_nums, expected_flat_shape): test_utils.assert_trees_all_close(reconstructed_x, x) @parameterized.named_parameters( - ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten') + ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten'), + ('polar_express', 'polar_express'), ) def test_callable_weight_dim_nums(self, preconditioning): # Case 1: a dim nums for all weights, no matter if they're muon. @@ -144,7 +145,8 @@ def weight_dim_nums_fn(params): # pylint: disable=function-redefined _, _ = opt.update(params, state, params=params) @parameterized.named_parameters( - ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten') + ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten'), + ('polar_express', 'polar_express'), ) def test_reshape_update_for_square_parameter_matches_muon_without_dim_nums( self, preconditioning @@ -164,7 +166,8 @@ def test_reshape_update_for_square_parameter_matches_muon_without_dim_nums( ) @parameterized.named_parameters( - ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten') + ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten'), + ('polar_express', 'polar_express'), ) def test_reshape_and_update_single_param(self, preconditioning): # Use 2D parameter (10, 12) with no dimension numbers as groundtruth @@ -211,7 +214,8 @@ def test_reshape_and_update_single_param(self, preconditioning): atol=1e-5) @parameterized.named_parameters( - ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten') + ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten'), + ('polar_express', 'polar_express'), ) def test_dim_nums_combinations(self, preconditioning): get_muon_mu = lambda state: state[0]['muon'][0][0][1] @@ -355,6 +359,9 @@ def f(params_state, _): ('aol_square', 'aol', (100, 100)), ('aol_tall', 'aol', (100, 50)), ('aol_wide', 'aol', (50, 100)), + ('polar_express_square', 'polar_express', (100, 100)), + ('polar_express_tall', 'polar_express', (100, 50)), + ('polar_express_wide', 'polar_express', (50, 100)), ) def test_muon_orthogonalization_modes(self, preconditioning, shape): """Tests that Muon runs and produces near-orthogonal updates.""" @@ -401,7 +408,8 @@ def _get_updates(preconditioning, **kwargs): test_utils.assert_trees_all_close(u_schatten, u_aol) @parameterized.named_parameters( - ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten') + ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten'), + ('polar_express', 'polar_express'), ) def test_orthogonality(self, preconditioning): """Ensures that updates satisfy approximate orthogonality (U^T U ≈ I).""" @@ -420,6 +428,47 @@ def test_orthogonality(self, preconditioning): self.assertLess(ortho_error, 1e-3, f'Orthogonality error too high: {ortho_error}') + def test_polar_express(self): + """Tests PolarExpress ns_coeffs with polar_express preconditioning.""" + params = {'w': jnp.eye(8) * 2.0} + opt = _muon.muon( + learning_rate=0.1, + ns_coeffs='polar_express', + preconditioning='polar_express', + ns_steps=8, + ) + updates, _ = opt.update(params, opt.init(params), params) + w_update = updates['w'] + + for leaf in jax.tree_util.tree_leaves(updates): + self.assertFalse(jnp.isnan(leaf).any(), + 'Found NaN values in polar_express updates') + + # Check orthogonality. + gram = jnp.dot(w_update.T, w_update) + gram = gram / jnp.max(gram) + ortho_error = jnp.linalg.norm(gram - jnp.eye(gram.shape[0])) + self.assertLess(ortho_error, 1e-3, + f'Orthogonality error too high: {ortho_error}') + + def test_polar_express_numerical_difference(self): + """Ensures PolarExpress produces different updates than standard Muon.""" + params = {'w': jnp.eye(8) * 2.0} + + opt_std = _muon.muon(learning_rate=0.1, preconditioning='frobenius') + u_std, _ = opt_std.update(params, opt_std.init(params), params) + + opt_pe = _muon.muon( + learning_rate=0.1, + ns_coeffs='polar_express', + preconditioning='polar_express', + ns_steps=8, + ) + u_pe, _ = opt_pe.update(params, opt_pe.init(params), params) + + with self.assertRaises(AssertionError): + test_utils.assert_trees_all_close(u_std, u_pe) + if __name__ == '__main__': absltest.main() From 9be1e2e5d641605e2c965a7d93e27d7679f69a99 Mon Sep 17 00:00:00 2001 From: Marc Date: Tue, 3 Mar 2026 23:31:48 -0500 Subject: [PATCH 08/20] formatting --- optax/contrib/_muon.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index e3437d724..70aa22d2e 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -40,7 +40,9 @@ ReshapeFn = Callable[[jax.Array], jax.Array] -_PRECONDITIONINGS = ['frobenius', 'spectral', 'aol', 'schatten', 'polar_express'] +_PRECONDITIONINGS = [ + 'frobenius', 'spectral', 'aol', 'schatten', 'polar_express', +] _DEFAULT_NS_COEFFS = (3.4445, -4.7750, 2.0315) _DION_NS_COEFFS = [ (4.0848, -6.8946, 2.9270), From 595e0900ac0c6a431d17e2fa120e92f2d2c7a496 Mon Sep 17 00:00:00 2001 From: Marc Date: Wed, 4 Mar 2026 17:14:08 -0500 Subject: [PATCH 09/20] computing optimal polar express coefficients for specified ns_iter and muon_dtype --- optax/contrib/__init__.py | 1 + optax/contrib/_muon.py | 204 +++++++++++++++++++++++++++++------- optax/contrib/_muon_test.py | 25 +++++ 3 files changed, 195 insertions(+), 35 deletions(-) diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index 35c032ef0..79a54698e 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -53,6 +53,7 @@ from optax.contrib._muon import muon from optax.contrib._muon import MuonDimensionNumbers from optax.contrib._muon import MuonState +from optax.contrib._muon import polar_express_coeffs from optax.contrib._muon import scale_by_muon from optax.contrib._privacy import differentially_private_aggregate from optax.contrib._privacy import DifferentiallyPrivateAggregateState diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index 70aa22d2e..3167ee441 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -28,6 +28,7 @@ import jax import jax.numpy as jnp +import numpy as np from optax._src import alias from optax._src import base @@ -51,30 +52,152 @@ (2.8769, -3.1427, 1.2046), (2.8366, -3.0525, 1.2012), ] -_POLAR_EXPRESS_RAW_COEFFS = [ - (8.28721201814563, -23.595886519098837, 17.300387312530933), - (4.107059111542203, -2.9478499167379106, 0.5448431082926601), - (3.9486908534822946, -2.908902115962949, 0.5518191394370137), - (3.3184196573706015, -2.488488024314874, 0.51004894012372), - (2.300652019954817, -1.6689039845747493, 0.4188073119525673), - (1.891301407787398, -1.2679958271945868, 0.37680408948524835), - (1.8750014808534479, -1.2500016453999487, 0.3750001645474248), - (1.875, -1.25, 0.375), # subsequent coeffs equal this numerically -] -# Safety factor adjustment for numerical stability -# See Section 4.4 of Amsel et al., 2025. -_POLAR_EXPRESS_SAFETY = 1.01 -_POLAR_EXPRESS_NS_COEFFS = [ - (a / _POLAR_EXPRESS_SAFETY, - b / _POLAR_EXPRESS_SAFETY**3, - c / _POLAR_EXPRESS_SAFETY**5) - for a, b, c in _POLAR_EXPRESS_RAW_COEFFS[:-1] -] + [_POLAR_EXPRESS_RAW_COEFFS[-1]] + + +# Polar Express defaults from Amsel et al., 2025 (Section 4.4) +# and reference implementation (github.com/NoahAmsel/PolarExpress). +_POLAR_EXPRESS_SAFETY_EPS = 1e-2 +_POLAR_EXPRESS_SAFETY = 1 + _POLAR_EXPRESS_SAFETY_EPS +_POLAR_EXPRESS_CUSHION = 0.02 + + +def _optimal_quintic(l, u): + r"""Optimal quintic coefficients for the Newton-Schulz iteration. + + Uses a simplified Remez algorithm to find coefficients (a, b, c) for the + odd quintic :math:`p(x) = ax + bx^3 + cx^5` that minimizes the Chebyshev + (minimax) approximation error :math:`\max_{x \in [\ell, u]} |1 - p(x)|`. + + Args: + l: Lower bound on singular values. Must satisfy ``0 <= l <= u``. + u: Upper bound on singular values. + + Returns: + A tuple ``(a, b, c)`` of quintic iteration coefficients. + + Raises: + ValueError: If ``l < 0`` or ``l > u``. + + References: + Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their + Application to the Muon Algorithm + `_, 2025, Section 4.2. + """ + if not 0 <= l <= u: + raise ValueError(f'l must be between 0 and u, got {l}.') + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + max_iter = 100 + e, old_e, n_iter = np.inf, None, 0 + while old_e is None or abs(old_e - e) > 1e-15: + if n_iter >= max_iter: + break + n_iter += 1 + old_e = e + lhs = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, e = np.linalg.solve(lhs, np.ones(4)) + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * np.sqrt(9 * b**2 - 20 * a * c)) + / (10 * c) + ) + return float(a), float(b), float(c) + + +def polar_express_coeffs(l, num_iters, safety_factor_eps, cushion): + r"""Compute PolarExpress optimal Newton-Schulz coefficients. + + Computes per-iteration optimal quintic coefficients for the Newton-Schulz + matrix sign iteration. Each iteration refines the singular value interval + :math:`[\ell, u]`, producing coefficients that are optimal in the Chebyshev + sense for that iteration's interval. + + The ``l`` parameter controls the assumed lower bound on normalized singular + values after preconditioning. A good default is the machine epsilon of the + training dtype (e.g. ``1e-3`` for bfloat16, ``1e-7`` for float32). Smaller + values are more conservative but may require more iterations. + + Example:: + + # Default coefficients (bfloat16, same as 'polar_express' preset): + coeffs = polar_express_coeffs(l=1e-3, num_iters=10, + safety_factor_eps=1e-2, cushion=0.02) + + # Use with muon: + optimizer = optax.contrib.muon( + learning_rate=0.02, + ns_coeffs=coeffs, + preconditioning='polar_express', + ) + + Args: + l: Lower bound on normalized singular values. Must satisfy + ``0 <= l <= 1``. + num_iters: Number of Newton-Schulz iterations to compute coefficients + for. + safety_factor_eps: Epsilon for the safety factor ``1 + eps`` applied to + all iterations except the last. Contracts the polynomial slightly to + ensure convergence under floating-point round-off errors. See + Section 4.4 of Amsel et al., 2025. + cushion: Minimum fraction of ``u`` used as the lower bound when + computing each iteration's optimal polynomial. When + ``cushion * u > l``, a rescaler is applied to maintain the correct + mapping. Helps with numerical stability in early iterations. + + .. note:: + When using ``preconditioning='polar_express'``, the preconditioning + normalization divides by ``||X||_F * (1 + safety_factor_eps)``. This + factor is hardcoded to match the default ``safety_factor_eps=1e-2``. + If you use a custom ``safety_factor_eps``, you should use + ``preconditioning='frobenius'`` and handle the safety factor scaling + in your own normalization, or pass the coefficients directly to + :func:`scale_by_muon` with a custom preconditioning setup. + + Returns: + A list of ``num_iters`` tuples ``(a, b, c)``, where each tuple contains + the quintic Newton-Schulz coefficients for that iteration. + + Raises: + ValueError: If ``l < 0`` or ``l > 1``. + + References: + Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their + Application to the Muon Algorithm + `_, 2025 + """ + u = 1.0 + if not 0 <= l <= u: + raise ValueError(f'l must be between 0 and 1, got {l}.') + safety_factor = 1 + safety_factor_eps + coefficients = [] + for i in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if i < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + _NS_COEFFS_PRESET_DICT = { 'standard': _DEFAULT_NS_COEFFS, 'dion': _DION_NS_COEFFS, - 'polar_express': _POLAR_EXPRESS_NS_COEFFS, } @@ -447,8 +570,9 @@ def scale_by_muon( - 'spectral' : Use Spectral norm rescaling before NS. - 'aol': Use AOL rescaling to improve orthogonality. - 'schatten': Use the Schatten-4 norm for rescaling. - - 'polar_express': Use Frobenius norm with a 1.01 safety factor, - designed for use with ``ns_coeffs='polar_express'``. + - 'polar_express': Use Frobenius norm with a safety factor, + designed for use with coefficients from + :func:`polar_express_coeffs`. See . weight_dimension_numbers: An optional tree with the same structure as the params of `MuonDimensionNumbers`s, specifying how to reshape the @@ -687,25 +811,35 @@ def muon( adam_learning_rate = learning_rate if isinstance(ns_coeffs, str): - if ns_coeffs not in _NS_COEFFS_PRESET_DICT: - raise ValueError(f'Unknown ns_coeff preset string: {ns_coeffs}') - - ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs] - - if ns_coeffs == 'polar_express' and preconditioning != 'polar_express': - logging.warning( - 'Using polar_express ns_coeffs without polar_express' - ' preconditioning is suboptimal and might lead to instability.' - ) if preconditioning == 'polar_express' and ns_coeffs != 'polar_express': logging.warning( 'Using polar_express preconditioning without polar_express' ' ns_coeffs is not recommended.' ) - if ns_coeffs == 'polar_express' and ns_steps > len(ns_coeffs_): - n_pad = ns_steps - len(ns_coeffs_) - ns_coeffs_ = list(ns_coeffs_) + [ns_coeffs_[-1]] * n_pad + if ns_coeffs == 'polar_express': + if preconditioning != 'polar_express': + logging.warning( + 'Using polar_express ns_coeffs without polar_express' + ' preconditioning is suboptimal and might lead to' + ' instability.' + ) + if mu_dtype is not None: + l = 10 ** math.floor( + math.log10(float(jnp.finfo(mu_dtype).eps)) + ) + else: + l = 1e-3 # default for bfloat16 + ns_coeffs_ = polar_express_coeffs( + l, ns_steps, + safety_factor_eps=_POLAR_EXPRESS_SAFETY_EPS, + cushion=_POLAR_EXPRESS_CUSHION, + ) + elif ns_coeffs in _NS_COEFFS_PRESET_DICT: + ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs] + else: + raise ValueError(f'Unknown ns_coeff preset string: {ns_coeffs}') + else: ns_coeffs_ = ns_coeffs diff --git a/optax/contrib/_muon_test.py b/optax/contrib/_muon_test.py index 37d939009..687b8cb9f 100644 --- a/optax/contrib/_muon_test.py +++ b/optax/contrib/_muon_test.py @@ -469,6 +469,31 @@ def test_polar_express_numerical_difference(self): with self.assertRaises(AssertionError): test_utils.assert_trees_all_close(u_std, u_pe) + def test_polar_express_coeffs_match_reference(self): + """Computed coefficients match the reference implementation.""" + # Values from github.com/NoahAmsel/PolarExpress: + # optimal_composition(l=1e-3, num_iters=8, safety_factor_eps=1e-2, + # cushion=0.02) + expected = [ + (8.237312490495558, -23.157747414558205, 16.68056841144592), + (4.082441999064834, -2.893047735332586, 0.5252849256975647), + (3.926347992254655, -2.85474680347653, 0.531802242289499), + (3.2982187133085143, -2.424541981026706, 0.48632008358844075), + (2.297036943455258, -1.6366255812590327, 0.4002628455953635), + (1.8763805351440446, -1.234789657772233, 0.3589188750166889), + (1.8564423485588517, -1.2132449880877845, 0.35680034877976435), + (1.8750013458595656, -1.2500026917060685, 0.3750013458465025), + ] + computed = _muon.polar_express_coeffs( + l=1e-3, num_iters=8, + safety_factor_eps=1e-2, cushion=0.02, + ) + for i, (exp, got) in enumerate(zip(expected, computed)): + np.testing.assert_allclose( + got, exp, rtol=1e-10, + err_msg=f'Coefficient mismatch at iteration {i}', + ) + if __name__ == '__main__': absltest.main() From 48c5ce2254871ca3f6e7b081532abb8c81fe07aa Mon Sep 17 00:00:00 2001 From: Marc Date: Wed, 4 Mar 2026 19:47:47 -0500 Subject: [PATCH 10/20] relax tolerance for test --- optax/contrib/_muon.py | 2 +- optax/contrib/_muon_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index 3167ee441..3ad632447 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -89,7 +89,7 @@ def _optimal_quintic(l, u): return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) q = (3 * l + u) / 4 r = (l + 3 * u) / 4 - max_iter = 100 + max_iter = 1000 e, old_e, n_iter = np.inf, None, 0 while old_e is None or abs(old_e - e) > 1e-15: if n_iter >= max_iter: diff --git a/optax/contrib/_muon_test.py b/optax/contrib/_muon_test.py index 687b8cb9f..2fa620dd2 100644 --- a/optax/contrib/_muon_test.py +++ b/optax/contrib/_muon_test.py @@ -490,7 +490,7 @@ def test_polar_express_coeffs_match_reference(self): ) for i, (exp, got) in enumerate(zip(expected, computed)): np.testing.assert_allclose( - got, exp, rtol=1e-10, + got, exp, rtol=1e-6, err_msg=f'Coefficient mismatch at iteration {i}', ) From 591e39edeefae286caa8be9928598704d44754e8 Mon Sep 17 00:00:00 2001 From: Marc Date: Wed, 4 Mar 2026 20:35:13 -0500 Subject: [PATCH 11/20] test mismatch reason solved: increased cutoff for l/u ratio for which remez algorithm is skipped from 1 - 5e-6 to 1 - 5e-5. for the previous cutoff, the linear system became ill-conditioned (k~1e11), causing platform dependent deviations on the order 1e-5 --- optax/contrib/_muon.py | 2 +- optax/contrib/_muon_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index 3ad632447..bd6b4a05a 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -85,7 +85,7 @@ def _optimal_quintic(l, u): """ if not 0 <= l <= u: raise ValueError(f'l must be between 0 and u, got {l}.') - if 1 - 5e-6 <= l / u: + if 1 - 5e-5 <= l / u: return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) q = (3 * l + u) / 4 r = (l + 3 * u) / 4 diff --git a/optax/contrib/_muon_test.py b/optax/contrib/_muon_test.py index 2fa620dd2..ff35bfce0 100644 --- a/optax/contrib/_muon_test.py +++ b/optax/contrib/_muon_test.py @@ -482,7 +482,7 @@ def test_polar_express_coeffs_match_reference(self): (2.297036943455258, -1.6366255812590327, 0.4002628455953635), (1.8763805351440446, -1.234789657772233, 0.3589188750166889), (1.8564423485588517, -1.2132449880877845, 0.35680034877976435), - (1.8750013458595656, -1.2500026917060685, 0.3750013458465025), + (1.8749914004324066, -1.2499828009436962, 0.3749914005112891), ] computed = _muon.polar_express_coeffs( l=1e-3, num_iters=8, @@ -490,7 +490,7 @@ def test_polar_express_coeffs_match_reference(self): ) for i, (exp, got) in enumerate(zip(expected, computed)): np.testing.assert_allclose( - got, exp, rtol=1e-6, + got, exp, rtol=1e-10, err_msg=f'Coefficient mismatch at iteration {i}', ) From b50ca126dd5e00f8a8a3d4d4c237cf39260585e1 Mon Sep 17 00:00:00 2001 From: Marc Date: Thu, 5 Mar 2026 10:43:45 -0500 Subject: [PATCH 12/20] testing polar express with different preconditioners on hard matrices (binary and low-rank spectrum) --- optax/contrib/_muon_test.py | 57 +++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/optax/contrib/_muon_test.py b/optax/contrib/_muon_test.py index ff35bfce0..b4a2ae8f9 100644 --- a/optax/contrib/_muon_test.py +++ b/optax/contrib/_muon_test.py @@ -65,6 +65,16 @@ def obj_fn(params): return initial_params, final_params, obj_fn +def _random_matrix_with_svs(key, shape, singular_values): + """Matrix with given singular values and random singular vectors.""" + m, n = shape + k = min(m, n) + key1, key2 = jax.random.split(key) + u, _ = jnp.linalg.qr(jax.random.normal(key1, (m, k))) + v, _ = jnp.linalg.qr(jax.random.normal(key2, (n, k))) + return u @ jnp.diag(jnp.asarray(singular_values)) @ v.T + + class MuonTest(parameterized.TestCase): @parameterized.named_parameters( @@ -494,6 +504,53 @@ def test_polar_express_coeffs_match_reference(self): err_msg=f'Coefficient mismatch at iteration {i}', ) + @parameterized.named_parameters( + ('frobenius_low_rank', 'frobenius', 'low_rank'), + ('spectral_low_rank', 'spectral', 'low_rank'), + ('schatten_low_rank', 'schatten', 'low_rank'), + ('polar_express_low_rank', 'polar_express', 'low_rank'), + ('frobenius_binary', 'frobenius', 'binary'), + ('spectral_binary', 'spectral', 'binary'), + ('schatten_binary', 'schatten', 'binary'), + ('polar_express_binary', 'polar_express', 'binary'), + ) + def test_polar_express_hard_matrices(self, preconditioning, matrix_type): + """PolarExpress coefficients on hard matrices with random singular vectors. + + Tests two cases suggested by Amsel (private communication): + - low_rank: exponentially decaying singular values + - binary: singular values all 0 or 2 (blowup test; spectral + preconditioning is closest to being unstable) + """ + key = jax.random.key(42) + shape = (50, 100) + k = min(shape) + + if matrix_type == 'low_rank': + svs = 2.0 * np.exp(-0.5 * np.arange(k)) + else: + svs = np.array([2.0] * (k // 2) + [0.0] * (k - k // 2)) + mat = _random_matrix_with_svs(key, shape, svs) + + ns_coeffs = jnp.array(_muon.polar_express_coeffs( + l=1e-3, num_iters=10, + safety_factor_eps=1e-2, cushion=0.02)) + + result = _muon.orthogonalize_via_newton_schulz( + mat, ns_coeffs, ns_steps=10, + preconditioning=preconditioning, + dimension_numbers=_muon.MuonDimensionNumbers(0, 1)) + + self.assertFalse(jnp.any(jnp.isnan(result)).item()) + self.assertFalse(jnp.any(jnp.isinf(result)).item()) + self.assertEqual(result.shape, shape) + # SVs above l should converge to 1 (orthogonality). + n_above_l = int(np.sum(svs > 1e-3)) + out_svs = jnp.linalg.svd(result, compute_uv=False)[:n_above_l] + np.testing.assert_allclose( + out_svs, jnp.ones(n_above_l), atol=1e-3, + err_msg=f'SVs above l did not converge to 1 ({preconditioning})') + if __name__ == '__main__': absltest.main() From 84a5f3af032bbb4901aa6bb25a0c4a092580947e Mon Sep 17 00:00:00 2001 From: Marc Date: Thu, 5 Mar 2026 11:56:45 -0500 Subject: [PATCH 13/20] compare coefficients to hard coded coefficientw as found in the original work, this required a different cushion value --- optax/contrib/_muon.py | 2 +- optax/contrib/_muon_test.py | 27 +++++++++++++-------------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index bd6b4a05a..baaf28ce0 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -58,7 +58,7 @@ # and reference implementation (github.com/NoahAmsel/PolarExpress). _POLAR_EXPRESS_SAFETY_EPS = 1e-2 _POLAR_EXPRESS_SAFETY = 1 + _POLAR_EXPRESS_SAFETY_EPS -_POLAR_EXPRESS_CUSHION = 0.02 +_POLAR_EXPRESS_CUSHION = 0.02407327424182761 def _optimal_quintic(l, u): diff --git a/optax/contrib/_muon_test.py b/optax/contrib/_muon_test.py index b4a2ae8f9..887b2fd69 100644 --- a/optax/contrib/_muon_test.py +++ b/optax/contrib/_muon_test.py @@ -480,27 +480,26 @@ def test_polar_express_numerical_difference(self): test_utils.assert_trees_all_close(u_std, u_pe) def test_polar_express_coeffs_match_reference(self): - """Computed coefficients match the reference implementation.""" - # Values from github.com/NoahAmsel/PolarExpress: - # optimal_composition(l=1e-3, num_iters=8, safety_factor_eps=1e-2, - # cushion=0.02) + """Computed coefficients match the paper's hard-coded values.""" + # Hard-coded coefficients from Amsel et al., 2025, Algorithm 1 + # (before safety factor application). expected = [ - (8.237312490495558, -23.157747414558205, 16.68056841144592), - (4.082441999064834, -2.893047735332586, 0.5252849256975647), - (3.926347992254655, -2.85474680347653, 0.531802242289499), - (3.2982187133085143, -2.424541981026706, 0.48632008358844075), - (2.297036943455258, -1.6366255812590327, 0.4002628455953635), - (1.8763805351440446, -1.234789657772233, 0.3589188750166889), - (1.8564423485588517, -1.2132449880877845, 0.35680034877976435), - (1.8749914004324066, -1.2499828009436962, 0.3749914005112891), + (8.28721201814563, -23.595886519098837, 17.300387312530933), + (4.107059111542203, -2.9478499167379106, 0.5448431082926601), + (3.9486908534822946, -2.908902115962949, 0.5518191394370137), + (3.3184196573706015, -2.488488024314874, 0.51004894012372), + (2.300652019954817, -1.6689039845747493, 0.4188073119525673), + (1.891301407787398, -1.2679958271945868, 0.37680408948524835), + (1.8750014808534479, -1.2500016453999487, 0.3750001645474248), + (1.875, -1.25, 0.375), ] computed = _muon.polar_express_coeffs( l=1e-3, num_iters=8, - safety_factor_eps=1e-2, cushion=0.02, + safety_factor_eps=0.0, cushion=0.02407327424182761, ) for i, (exp, got) in enumerate(zip(expected, computed)): np.testing.assert_allclose( - got, exp, rtol=1e-10, + got, exp, rtol=1e-8, err_msg=f'Coefficient mismatch at iteration {i}', ) From 78d06d616623c0fda143b176d33031f965105d6d Mon Sep 17 00:00:00 2001 From: Marc Date: Thu, 5 Mar 2026 12:08:25 -0500 Subject: [PATCH 14/20] changing how safety factor is applied to the polynomial coefficients. this again allows for tighter remez algorithm cutoff (now again the same as in original work) --- optax/contrib/_muon.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index baaf28ce0..18827122f 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -85,7 +85,7 @@ def _optimal_quintic(l, u): """ if not 0 <= l <= u: raise ValueError(f'l must be between 0 and u, got {l}.') - if 1 - 5e-5 <= l / u: + if 1 - 5e-6 <= l / u: return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) q = (3 * l + u) / 4 r = (l + 3 * u) / 4 @@ -174,9 +174,11 @@ def polar_express_coeffs(l, num_iters, safety_factor_eps, cushion): u = 1.0 if not 0 <= l <= u: raise ValueError(f'l must be between 0 and 1, got {l}.') - safety_factor = 1 + safety_factor_eps + # Compute raw optimal coefficients without safety factor (matches the + # paper's approach: safety factor is applied after all coefficients are + # computed, so it does not affect the interval evolution). coefficients = [] - for i in range(num_iters): + for _ in range(num_iters): a, b, c = _optimal_quintic(max(l, cushion * u), u) if cushion * u > l: pl = a * l + b * l**3 + c * l**5 @@ -185,13 +187,18 @@ def polar_express_coeffs(l, num_iters, safety_factor_eps, cushion): a *= rescaler b *= rescaler c *= rescaler - if i < num_iters - 1: - a /= safety_factor - b /= safety_factor**3 - c /= safety_factor**5 coefficients.append((a, b, c)) l = a * l + b * l**3 + c * l**5 u = 2 - l + # Apply safety factor to all but the last iteration. + safety_factor = 1 + safety_factor_eps + for i in range(num_iters - 1): + a, b, c = coefficients[i] + coefficients[i] = ( + a / safety_factor, + b / safety_factor**3, + c / safety_factor**5, + ) return coefficients From d0a72f7ec02ef89c1897df333f76b62d5d289d54 Mon Sep 17 00:00:00 2001 From: Marc Date: Thu, 5 Mar 2026 12:21:08 -0500 Subject: [PATCH 15/20] use default dtype independent lower bound for polar express of 1e-3 --- optax/contrib/_muon.py | 45 ++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index 18827122f..9de185b7b 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -59,6 +59,7 @@ _POLAR_EXPRESS_SAFETY_EPS = 1e-2 _POLAR_EXPRESS_SAFETY = 1 + _POLAR_EXPRESS_SAFETY_EPS _POLAR_EXPRESS_CUSHION = 0.02407327424182761 +_POLAR_EXPRESS_LOWER_BOUND = 1e-3 def _optimal_quintic(l, u): @@ -119,15 +120,19 @@ def polar_express_coeffs(l, num_iters, safety_factor_eps, cushion): sense for that iteration's interval. The ``l`` parameter controls the assumed lower bound on normalized singular - values after preconditioning. A good default is the machine epsilon of the - training dtype (e.g. ``1e-3`` for bfloat16, ``1e-7`` for float32). Smaller - values are more conservative but may require more iterations. + values after preconditioning. The default ``ns_coeffs='polar_express'`` + preset uses ``l=1e-3`` (see ``_POLAR_EXPRESS_LOWER_BOUND``), which works + well for both bfloat16 and float32 training. Call this function directly + to use a different ``l``. Example:: - # Default coefficients (bfloat16, same as 'polar_express' preset): - coeffs = polar_express_coeffs(l=1e-3, num_iters=10, - safety_factor_eps=1e-2, cushion=0.02) + # Custom lower bound (e.g. for float32 with tighter convergence): + coeffs = polar_express_coeffs( + l=1e-7, num_iters=12, + safety_factor_eps=1e-2, + cushion=0.02407327424182761, + ) # Use with muon: optimizer = optax.contrib.muon( @@ -611,8 +616,8 @@ def scale_by_muon( `_, 2025 Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their - Application to the Muon Algorithm`, - `, 2025 + Application to the Muon Algorithm + `_, 2025 """ mu_dtype = utils.canonicalize_dtype(mu_dtype) @@ -729,8 +734,11 @@ def muon( learning_rate: A global scaling factor, either fixed or evolving along iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. ns_coeffs: Coefficients for the Newton-schulz method (can be a string - indicator for a preset). Existing presets: `standard`, `dion`, - `polar_express`. + indicator for a preset). Existing presets: ``standard``, ``dion``, + ``polar_express``. The ``polar_express`` preset uses ``l=1e-3`` and + the default cushion; for custom ``l``, ``cushion``, or + ``num_iters``, call :func:`polar_express_coeffs` directly and pass + the result as ``ns_coeffs``. ns_steps: Number of Newton-schulz iterations. Ignored if `ns_coeffs` is a tuple of tuples. beta: Decay rate for the exponentially weighted average of grads. @@ -762,8 +770,8 @@ def muon( - 'schatten': Use the Schatten-4 norm for rescaling, allows for better performance with little to no extra cost. See . - - 'polar_express': Use Frobenius norm with a 1.01 safety factor - for bfloat16 stability, designed for use with + - 'polar_express': Use Frobenius norm with a safety factor for + floating-point stability, designed for use with ``ns_coeffs='polar_express'``. See . adam_b1: Exponential decay rate for Adam's first moment estimates. @@ -810,8 +818,8 @@ def muon( `_, 2025 Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their - Application to the Muon Algorithm`, - `, 2025 + Application to the Muon Algorithm + `_, 2025 """ if adam_learning_rate is None: @@ -831,14 +839,9 @@ def muon( ' preconditioning is suboptimal and might lead to' ' instability.' ) - if mu_dtype is not None: - l = 10 ** math.floor( - math.log10(float(jnp.finfo(mu_dtype).eps)) - ) - else: - l = 1e-3 # default for bfloat16 ns_coeffs_ = polar_express_coeffs( - l, ns_steps, + l=_POLAR_EXPRESS_LOWER_BOUND, + num_iters=ns_steps, safety_factor_eps=_POLAR_EXPRESS_SAFETY_EPS, cushion=_POLAR_EXPRESS_CUSHION, ) From b522e88a11954582057fd7babc2574772a0022e1 Mon Sep 17 00:00:00 2001 From: Marc Date: Thu, 5 Mar 2026 12:37:58 -0500 Subject: [PATCH 16/20] remove polar express preconditioning; use default (frobenius) and mention schatten as recommended; remove associated warnings ang logging import --- optax/contrib/_muon.py | 48 ++++++------------------------------- optax/contrib/_muon_test.py | 17 ++++--------- 2 files changed, 11 insertions(+), 54 deletions(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index 9de185b7b..99b66ef17 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -24,8 +24,6 @@ import math from typing import Any, Callable, NamedTuple, Optional, Union, Sequence, Literal -from absl import logging - import jax import jax.numpy as jnp import numpy as np @@ -42,7 +40,7 @@ ReshapeFn = Callable[[jax.Array], jax.Array] _PRECONDITIONINGS = [ - 'frobenius', 'spectral', 'aol', 'schatten', 'polar_express', + 'frobenius', 'spectral', 'aol', 'schatten', ] _DEFAULT_NS_COEFFS = (3.4445, -4.7750, 2.0315) _DION_NS_COEFFS = [ @@ -57,7 +55,6 @@ # Polar Express defaults from Amsel et al., 2025 (Section 4.4) # and reference implementation (github.com/NoahAmsel/PolarExpress). _POLAR_EXPRESS_SAFETY_EPS = 1e-2 -_POLAR_EXPRESS_SAFETY = 1 + _POLAR_EXPRESS_SAFETY_EPS _POLAR_EXPRESS_CUSHION = 0.02407327424182761 _POLAR_EXPRESS_LOWER_BOUND = 1e-3 @@ -134,11 +131,11 @@ def polar_express_coeffs(l, num_iters, safety_factor_eps, cushion): cushion=0.02407327424182761, ) - # Use with muon: + # Use with muon (Schatten-4 preconditioning recommended): optimizer = optax.contrib.muon( learning_rate=0.02, ns_coeffs=coeffs, - preconditioning='polar_express', + preconditioning='schatten', ) Args: @@ -155,15 +152,6 @@ def polar_express_coeffs(l, num_iters, safety_factor_eps, cushion): ``cushion * u > l``, a rescaler is applied to maintain the correct mapping. Helps with numerical stability in early iterations. - .. note:: - When using ``preconditioning='polar_express'``, the preconditioning - normalization divides by ``||X||_F * (1 + safety_factor_eps)``. This - factor is hardcoded to match the default ``safety_factor_eps=1e-2``. - If you use a custom ``safety_factor_eps``, you should use - ``preconditioning='frobenius'`` and handle the safety factor scaling - in your own normalization, or pass the coefficients directly to - :func:`scale_by_muon` with a custom preconditioning setup. - Returns: A list of ``num_iters`` tuples ``(a, b, c)``, where each tuple contains the quintic Newton-Schulz coefficients for that iteration. @@ -440,7 +428,7 @@ def orthogonalize_via_newton_schulz( ns_coeffs: jax.Array, ns_steps: jax.typing.ArrayLike = 5, preconditioning: Literal[ - 'frobenius', 'spectral', 'aol', 'schatten', 'polar_express' + 'frobenius', 'spectral', 'aol', 'schatten', ] = 'frobenius', eps: jax.typing.ArrayLike = 1e-8, dimension_numbers: MuonDimensionNumbers | None = None, @@ -491,7 +479,6 @@ def _orthogonalize(x): 'spectral': _base_ns_iterator, 'aol': _aol_ns_iterator, 'schatten': _schatten_ns_iterator, - 'polar_express': _base_ns_iterator, } if preconditioning not in _PRECONDITIONINGS: raise ValueError(f'Unknown preconditioning {preconditioning}') @@ -501,8 +488,6 @@ def _orthogonalize(x): x /= jnp.linalg.norm(x, ord='fro') + eps elif preconditioning == 'spectral': x /= jnp.linalg.norm(x, ord=2) + eps - elif preconditioning == 'polar_express': - x /= jnp.linalg.norm(x, ord='fro') * _POLAR_EXPRESS_SAFETY + eps else: pass @@ -554,7 +539,7 @@ def scale_by_muon( nesterov: bool = True, adaptive: bool = False, preconditioning: Literal[ - 'frobenius', 'spectral', 'aol', 'schatten', 'polar_express' + 'frobenius', 'spectral', 'aol', 'schatten', ] = 'frobenius', weight_dimension_numbers: WeightDimNumOrFn | None = None, ) -> base.GradientTransformation: @@ -582,10 +567,6 @@ def scale_by_muon( - 'spectral' : Use Spectral norm rescaling before NS. - 'aol': Use AOL rescaling to improve orthogonality. - 'schatten': Use the Schatten-4 norm for rescaling. - - 'polar_express': Use Frobenius norm with a safety factor, - designed for use with coefficients from - :func:`polar_express_coeffs`. - See . weight_dimension_numbers: An optional tree with the same structure as the params of `MuonDimensionNumbers`s, specifying how to reshape the parameters before and after the orthogonalization OR a callable returning @@ -707,7 +688,7 @@ def muon( nesterov: bool = True, adaptive: bool = False, preconditioning: Literal[ - 'frobenius', 'spectral', 'aol', 'schatten', 'polar_express' + 'frobenius', 'spectral', 'aol', 'schatten', ] = 'frobenius', adam_b1: jax.typing.ArrayLike = 0.9, adam_b2: jax.typing.ArrayLike = 0.999, @@ -769,11 +750,8 @@ def muon( See . - 'schatten': Use the Schatten-4 norm for rescaling, allows for better performance with little to no extra cost. + Recommended when using ``ns_coeffs='polar_express'``. See . - - 'polar_express': Use Frobenius norm with a safety factor for - floating-point stability, designed for use with - ``ns_coeffs='polar_express'``. - See . adam_b1: Exponential decay rate for Adam's first moment estimates. adam_b2: Exponential decay rate for Adam's second moment estimates. adam_eps_root: Epsilon to stabilize division in Adam, square root version. @@ -826,19 +804,7 @@ def muon( adam_learning_rate = learning_rate if isinstance(ns_coeffs, str): - if preconditioning == 'polar_express' and ns_coeffs != 'polar_express': - logging.warning( - 'Using polar_express preconditioning without polar_express' - ' ns_coeffs is not recommended.' - ) - if ns_coeffs == 'polar_express': - if preconditioning != 'polar_express': - logging.warning( - 'Using polar_express ns_coeffs without polar_express' - ' preconditioning is suboptimal and might lead to' - ' instability.' - ) ns_coeffs_ = polar_express_coeffs( l=_POLAR_EXPRESS_LOWER_BOUND, num_iters=ns_steps, diff --git a/optax/contrib/_muon_test.py b/optax/contrib/_muon_test.py index 887b2fd69..4d9ac0a08 100644 --- a/optax/contrib/_muon_test.py +++ b/optax/contrib/_muon_test.py @@ -129,7 +129,6 @@ def test_reshape_inverse(self, input_shape, dim_nums, expected_flat_shape): @parameterized.named_parameters( ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten'), - ('polar_express', 'polar_express'), ) def test_callable_weight_dim_nums(self, preconditioning): # Case 1: a dim nums for all weights, no matter if they're muon. @@ -156,7 +155,6 @@ def weight_dim_nums_fn(params): # pylint: disable=function-redefined @parameterized.named_parameters( ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten'), - ('polar_express', 'polar_express'), ) def test_reshape_update_for_square_parameter_matches_muon_without_dim_nums( self, preconditioning @@ -177,7 +175,6 @@ def test_reshape_update_for_square_parameter_matches_muon_without_dim_nums( @parameterized.named_parameters( ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten'), - ('polar_express', 'polar_express'), ) def test_reshape_and_update_single_param(self, preconditioning): # Use 2D parameter (10, 12) with no dimension numbers as groundtruth @@ -225,7 +222,6 @@ def test_reshape_and_update_single_param(self, preconditioning): @parameterized.named_parameters( ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten'), - ('polar_express', 'polar_express'), ) def test_dim_nums_combinations(self, preconditioning): get_muon_mu = lambda state: state[0]['muon'][0][0][1] @@ -369,9 +365,9 @@ def f(params_state, _): ('aol_square', 'aol', (100, 100)), ('aol_tall', 'aol', (100, 50)), ('aol_wide', 'aol', (50, 100)), - ('polar_express_square', 'polar_express', (100, 100)), - ('polar_express_tall', 'polar_express', (100, 50)), - ('polar_express_wide', 'polar_express', (50, 100)), + ('schatten_square', 'schatten', (100, 100)), + ('schatten_tall', 'schatten', (100, 50)), + ('schatten_wide', 'schatten', (50, 100)), ) def test_muon_orthogonalization_modes(self, preconditioning, shape): """Tests that Muon runs and produces near-orthogonal updates.""" @@ -419,7 +415,6 @@ def _get_updates(preconditioning, **kwargs): @parameterized.named_parameters( ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten'), - ('polar_express', 'polar_express'), ) def test_orthogonality(self, preconditioning): """Ensures that updates satisfy approximate orthogonality (U^T U ≈ I).""" @@ -439,12 +434,11 @@ def test_orthogonality(self, preconditioning): f'Orthogonality error too high: {ortho_error}') def test_polar_express(self): - """Tests PolarExpress ns_coeffs with polar_express preconditioning.""" + """Tests PolarExpress ns_coeffs with frobenius preconditioning.""" params = {'w': jnp.eye(8) * 2.0} opt = _muon.muon( learning_rate=0.1, ns_coeffs='polar_express', - preconditioning='polar_express', ns_steps=8, ) updates, _ = opt.update(params, opt.init(params), params) @@ -471,7 +465,6 @@ def test_polar_express_numerical_difference(self): opt_pe = _muon.muon( learning_rate=0.1, ns_coeffs='polar_express', - preconditioning='polar_express', ns_steps=8, ) u_pe, _ = opt_pe.update(params, opt_pe.init(params), params) @@ -507,11 +500,9 @@ def test_polar_express_coeffs_match_reference(self): ('frobenius_low_rank', 'frobenius', 'low_rank'), ('spectral_low_rank', 'spectral', 'low_rank'), ('schatten_low_rank', 'schatten', 'low_rank'), - ('polar_express_low_rank', 'polar_express', 'low_rank'), ('frobenius_binary', 'frobenius', 'binary'), ('spectral_binary', 'spectral', 'binary'), ('schatten_binary', 'schatten', 'binary'), - ('polar_express_binary', 'polar_express', 'binary'), ) def test_polar_express_hard_matrices(self, preconditioning, matrix_type): """PolarExpress coefficients on hard matrices with random singular vectors. From 4af8ab16c68751e023ed3246912cece2b9787652 Mon Sep 17 00:00:00 2001 From: Marc Date: Thu, 5 Mar 2026 12:52:33 -0500 Subject: [PATCH 17/20] removed polar express constants in favor of default arguments --- optax/contrib/_muon.py | 68 +++++++++++++++++++----------------------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index 99b66ef17..0edd8764a 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -52,13 +52,6 @@ ] -# Polar Express defaults from Amsel et al., 2025 (Section 4.4) -# and reference implementation (github.com/NoahAmsel/PolarExpress). -_POLAR_EXPRESS_SAFETY_EPS = 1e-2 -_POLAR_EXPRESS_CUSHION = 0.02407327424182761 -_POLAR_EXPRESS_LOWER_BOUND = 1e-3 - - def _optimal_quintic(l, u): r"""Optimal quintic coefficients for the Newton-Schulz iteration. @@ -67,22 +60,22 @@ def _optimal_quintic(l, u): (minimax) approximation error :math:`\max_{x \in [\ell, u]} |1 - p(x)|`. Args: - l: Lower bound on singular values. Must satisfy ``0 <= l <= u``. - u: Upper bound on singular values. + l: Lower bound on singular values. Must satisfy ``0 < l <= u``. + u: Upper bound on singular values. Must be positive. Returns: A tuple ``(a, b, c)`` of quintic iteration coefficients. Raises: - ValueError: If ``l < 0`` or ``l > u``. + ValueError: If ``l <= 0``, ``u <= 0``, or ``l > u``. References: Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm `_, 2025, Section 4.2. """ - if not 0 <= l <= u: - raise ValueError(f'l must be between 0 and u, got {l}.') + if not 0 < l <= u: + raise ValueError(f'l must satisfy 0 < l <= u, got l={l}, u={u}.') if 1 - 5e-6 <= l / u: return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) q = (3 * l + u) / 4 @@ -108,7 +101,13 @@ def _optimal_quintic(l, u): return float(a), float(b), float(c) -def polar_express_coeffs(l, num_iters, safety_factor_eps, cushion): +def polar_express_coeffs( + l=1e-3, + num_iters=8, + *, + safety_factor_eps=1e-2, + cushion=0.02407327424182761, +): r"""Compute PolarExpress optimal Newton-Schulz coefficients. Computes per-iteration optimal quintic coefficients for the Newton-Schulz @@ -117,47 +116,42 @@ def polar_express_coeffs(l, num_iters, safety_factor_eps, cushion): sense for that iteration's interval. The ``l`` parameter controls the assumed lower bound on normalized singular - values after preconditioning. The default ``ns_coeffs='polar_express'`` - preset uses ``l=1e-3`` (see ``_POLAR_EXPRESS_LOWER_BOUND``), which works - well for both bfloat16 and float32 training. Call this function directly - to use a different ``l``. + values after preconditioning. The default ``l=1e-3`` works well for both + bfloat16 and float32 training. Example:: # Custom lower bound (e.g. for float32 with tighter convergence): - coeffs = polar_express_coeffs( - l=1e-7, num_iters=12, - safety_factor_eps=1e-2, - cushion=0.02407327424182761, - ) + coeffs = polar_express_coeffs(l=1e-7, num_iters=12) - # Use with muon (Schatten-4 preconditioning recommended): + # Use with muon: optimizer = optax.contrib.muon( learning_rate=0.02, ns_coeffs=coeffs, - preconditioning='schatten', ) Args: l: Lower bound on normalized singular values. Must satisfy - ``0 <= l <= 1``. + ``0 < l <= 1``. Defaults to ``1e-3``. num_iters: Number of Newton-Schulz iterations to compute coefficients - for. + for. Defaults to ``8``. safety_factor_eps: Epsilon for the safety factor ``1 + eps`` applied to all iterations except the last. Contracts the polynomial slightly to ensure convergence under floating-point round-off errors. See - Section 4.4 of Amsel et al., 2025. + Section 4.4 of Amsel et al., 2025. Defaults to ``1e-2``. cushion: Minimum fraction of ``u`` used as the lower bound when computing each iteration's optimal polynomial. When ``cushion * u > l``, a rescaler is applied to maintain the correct - mapping. Helps with numerical stability in early iterations. + mapping. Helps with numerical stability of the Remez linear system + in early iterations. Defaults to ``0.024`` (from the reference + implementation). Returns: A list of ``num_iters`` tuples ``(a, b, c)``, where each tuple contains the quintic Newton-Schulz coefficients for that iteration. Raises: - ValueError: If ``l < 0`` or ``l > 1``. + ValueError: If ``l <= 0``, ``l > 1``, or ``num_iters < 1``. References: Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their @@ -165,8 +159,10 @@ def polar_express_coeffs(l, num_iters, safety_factor_eps, cushion): `_, 2025 """ u = 1.0 - if not 0 <= l <= u: - raise ValueError(f'l must be between 0 and 1, got {l}.') + if not 0 < l <= u: + raise ValueError(f'l must satisfy 0 < l <= 1, got {l}.') + if num_iters < 1: + raise ValueError(f'num_iters must be >= 1, got {num_iters}.') # Compute raw optimal coefficients without safety factor (matches the # paper's approach: safety factor is applied after all coefficients are # computed, so it does not affect the interval evolution). @@ -195,6 +191,8 @@ def polar_express_coeffs(l, num_iters, safety_factor_eps, cushion): return coefficients +# Note: 'polar_express' is not in this dict because it depends on ns_steps +# and is handled separately in muon() via polar_express_coeffs(). _NS_COEFFS_PRESET_DICT = { 'standard': _DEFAULT_NS_COEFFS, 'dion': _DION_NS_COEFFS, @@ -567,6 +565,7 @@ def scale_by_muon( - 'spectral' : Use Spectral norm rescaling before NS. - 'aol': Use AOL rescaling to improve orthogonality. - 'schatten': Use the Schatten-4 norm for rescaling. + Recommended when using PolarExpress coefficients. weight_dimension_numbers: An optional tree with the same structure as the params of `MuonDimensionNumbers`s, specifying how to reshape the parameters before and after the orthogonalization OR a callable returning @@ -805,12 +804,7 @@ def muon( if isinstance(ns_coeffs, str): if ns_coeffs == 'polar_express': - ns_coeffs_ = polar_express_coeffs( - l=_POLAR_EXPRESS_LOWER_BOUND, - num_iters=ns_steps, - safety_factor_eps=_POLAR_EXPRESS_SAFETY_EPS, - cushion=_POLAR_EXPRESS_CUSHION, - ) + ns_coeffs_ = polar_express_coeffs(num_iters=ns_steps) elif ns_coeffs in _NS_COEFFS_PRESET_DICT: ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs] else: From 560a4245030a39abf25eb4d82aaba2ac1b38ed57 Mon Sep 17 00:00:00 2001 From: Marc Date: Thu, 5 Mar 2026 12:59:18 -0500 Subject: [PATCH 18/20] final formatting changes to minimize diff with main --- optax/contrib/_muon.py | 10 ++++------ optax/contrib/_muon_test.py | 10 +++++----- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index 0edd8764a..6e5123689 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -39,9 +39,7 @@ ReshapeFn = Callable[[jax.Array], jax.Array] -_PRECONDITIONINGS = [ - 'frobenius', 'spectral', 'aol', 'schatten', -] +_PRECONDITIONINGS = ['frobenius', 'spectral', 'aol', 'schatten'] _DEFAULT_NS_COEFFS = (3.4445, -4.7750, 2.0315) _DION_NS_COEFFS = [ (4.0848, -6.8946, 2.9270), @@ -426,7 +424,7 @@ def orthogonalize_via_newton_schulz( ns_coeffs: jax.Array, ns_steps: jax.typing.ArrayLike = 5, preconditioning: Literal[ - 'frobenius', 'spectral', 'aol', 'schatten', + 'frobenius', 'spectral', 'aol', 'schatten' ] = 'frobenius', eps: jax.typing.ArrayLike = 1e-8, dimension_numbers: MuonDimensionNumbers | None = None, @@ -537,7 +535,7 @@ def scale_by_muon( nesterov: bool = True, adaptive: bool = False, preconditioning: Literal[ - 'frobenius', 'spectral', 'aol', 'schatten', + 'frobenius', 'spectral', 'aol', 'schatten' ] = 'frobenius', weight_dimension_numbers: WeightDimNumOrFn | None = None, ) -> base.GradientTransformation: @@ -687,7 +685,7 @@ def muon( nesterov: bool = True, adaptive: bool = False, preconditioning: Literal[ - 'frobenius', 'spectral', 'aol', 'schatten', + 'frobenius', 'spectral', 'aol', 'schatten' ] = 'frobenius', adam_b1: jax.typing.ArrayLike = 0.9, adam_b2: jax.typing.ArrayLike = 0.999, diff --git a/optax/contrib/_muon_test.py b/optax/contrib/_muon_test.py index 4d9ac0a08..5cda6e205 100644 --- a/optax/contrib/_muon_test.py +++ b/optax/contrib/_muon_test.py @@ -128,7 +128,7 @@ def test_reshape_inverse(self, input_shape, dim_nums, expected_flat_shape): test_utils.assert_trees_all_close(reconstructed_x, x) @parameterized.named_parameters( - ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten'), + ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten') ) def test_callable_weight_dim_nums(self, preconditioning): # Case 1: a dim nums for all weights, no matter if they're muon. @@ -154,7 +154,7 @@ def weight_dim_nums_fn(params): # pylint: disable=function-redefined _, _ = opt.update(params, state, params=params) @parameterized.named_parameters( - ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten'), + ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten') ) def test_reshape_update_for_square_parameter_matches_muon_without_dim_nums( self, preconditioning @@ -174,7 +174,7 @@ def test_reshape_update_for_square_parameter_matches_muon_without_dim_nums( ) @parameterized.named_parameters( - ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten'), + ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten') ) def test_reshape_and_update_single_param(self, preconditioning): # Use 2D parameter (10, 12) with no dimension numbers as groundtruth @@ -221,7 +221,7 @@ def test_reshape_and_update_single_param(self, preconditioning): atol=1e-5) @parameterized.named_parameters( - ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten'), + ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten') ) def test_dim_nums_combinations(self, preconditioning): get_muon_mu = lambda state: state[0]['muon'][0][0][1] @@ -414,7 +414,7 @@ def _get_updates(preconditioning, **kwargs): test_utils.assert_trees_all_close(u_schatten, u_aol) @parameterized.named_parameters( - ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten'), + ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten') ) def test_orthogonality(self, preconditioning): """Ensures that updates satisfy approximate orthogonality (U^T U ≈ I).""" From f973db1b29132cb5fde627edae1ab19c5aa8172c Mon Sep 17 00:00:00 2001 From: Marc Date: Sun, 8 Mar 2026 00:09:03 -0500 Subject: [PATCH 19/20] increase default safety factor to 2e-2, include safety factor inside the loop --- optax/contrib/_muon.py | 25 ++++++++----------------- optax/contrib/_muon_test.py | 8 ++++---- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index 6e5123689..2c30d9bcf 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -103,15 +103,13 @@ def polar_express_coeffs( l=1e-3, num_iters=8, *, - safety_factor_eps=1e-2, + safety_factor_eps=2e-2, cushion=0.02407327424182761, ): r"""Compute PolarExpress optimal Newton-Schulz coefficients. Computes per-iteration optimal quintic coefficients for the Newton-Schulz - matrix sign iteration. Each iteration refines the singular value interval - :math:`[\ell, u]`, producing coefficients that are optimal in the Chebyshev - sense for that iteration's interval. + matrix sign iteration. The ``l`` parameter controls the assumed lower bound on normalized singular values after preconditioning. The default ``l=1e-3`` works well for both @@ -161,11 +159,9 @@ def polar_express_coeffs( raise ValueError(f'l must satisfy 0 < l <= 1, got {l}.') if num_iters < 1: raise ValueError(f'num_iters must be >= 1, got {num_iters}.') - # Compute raw optimal coefficients without safety factor (matches the - # paper's approach: safety factor is applied after all coefficients are - # computed, so it does not affect the interval evolution). + safety_factor = 1 + safety_factor_eps coefficients = [] - for _ in range(num_iters): + for i in range(num_iters): a, b, c = _optimal_quintic(max(l, cushion * u), u) if cushion * u > l: pl = a * l + b * l**3 + c * l**5 @@ -174,18 +170,13 @@ def polar_express_coeffs( a *= rescaler b *= rescaler c *= rescaler + if i < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 coefficients.append((a, b, c)) l = a * l + b * l**3 + c * l**5 u = 2 - l - # Apply safety factor to all but the last iteration. - safety_factor = 1 + safety_factor_eps - for i in range(num_iters - 1): - a, b, c = coefficients[i] - coefficients[i] = ( - a / safety_factor, - b / safety_factor**3, - c / safety_factor**5, - ) return coefficients diff --git a/optax/contrib/_muon_test.py b/optax/contrib/_muon_test.py index 5cda6e205..4f8172f07 100644 --- a/optax/contrib/_muon_test.py +++ b/optax/contrib/_muon_test.py @@ -465,7 +465,8 @@ def test_polar_express_numerical_difference(self): opt_pe = _muon.muon( learning_rate=0.1, ns_coeffs='polar_express', - ns_steps=8, + ns_steps=5, + preconditioning='frobenius' ) u_pe, _ = opt_pe.update(params, opt_pe.init(params), params) @@ -507,10 +508,9 @@ def test_polar_express_coeffs_match_reference(self): def test_polar_express_hard_matrices(self, preconditioning, matrix_type): """PolarExpress coefficients on hard matrices with random singular vectors. - Tests two cases suggested by Amsel (private communication): + Tests two cases: - low_rank: exponentially decaying singular values - - binary: singular values all 0 or 2 (blowup test; spectral - preconditioning is closest to being unstable) + - binary: singular values all 0 or 2 """ key = jax.random.key(42) shape = (50, 100) From e0dbb0b9338589a0cfea124a1ac4eec4e73215bd Mon Sep 17 00:00:00 2001 From: Marc Date: Sun, 8 Mar 2026 00:41:37 -0500 Subject: [PATCH 20/20] slightly bump threshold for remez iteration cutoff for robustness across different platforms --- optax/contrib/_muon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index 2c30d9bcf..4bbee5773 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -74,7 +74,7 @@ def _optimal_quintic(l, u): """ if not 0 < l <= u: raise ValueError(f'l must satisfy 0 < l <= u, got l={l}, u={u}.') - if 1 - 5e-6 <= l / u: + if 1 - 1e-5 <= l / u: return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) q = (3 * l + u) / 4 r = (l + 3 * u) / 4