diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index 35c032ef0..79a54698e 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -53,6 +53,7 @@ from optax.contrib._muon import muon from optax.contrib._muon import MuonDimensionNumbers from optax.contrib._muon import MuonState +from optax.contrib._muon import polar_express_coeffs from optax.contrib._muon import scale_by_muon from optax.contrib._privacy import differentially_private_aggregate from optax.contrib._privacy import DifferentiallyPrivateAggregateState diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index abaa26362..4bbee5773 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -26,6 +26,7 @@ import jax import jax.numpy as jnp +import numpy as np from optax._src import alias from optax._src import base @@ -47,6 +48,140 @@ (2.8769, -3.1427, 1.2046), (2.8366, -3.0525, 1.2012), ] + + +def _optimal_quintic(l, u): + r"""Optimal quintic coefficients for the Newton-Schulz iteration. + + Uses a simplified Remez algorithm to find coefficients (a, b, c) for the + odd quintic :math:`p(x) = ax + bx^3 + cx^5` that minimizes the Chebyshev + (minimax) approximation error :math:`\max_{x \in [\ell, u]} |1 - p(x)|`. + + Args: + l: Lower bound on singular values. Must satisfy ``0 < l <= u``. + u: Upper bound on singular values. Must be positive. + + Returns: + A tuple ``(a, b, c)`` of quintic iteration coefficients. + + Raises: + ValueError: If ``l <= 0``, ``u <= 0``, or ``l > u``. + + References: + Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their + Application to the Muon Algorithm + `_, 2025, Section 4.2. + """ + if not 0 < l <= u: + raise ValueError(f'l must satisfy 0 < l <= u, got l={l}, u={u}.') + if 1 - 1e-5 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + max_iter = 1000 + e, old_e, n_iter = np.inf, None, 0 + while old_e is None or abs(old_e - e) > 1e-15: + if n_iter >= max_iter: + break + n_iter += 1 + old_e = e + lhs = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, e = np.linalg.solve(lhs, np.ones(4)) + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * np.sqrt(9 * b**2 - 20 * a * c)) + / (10 * c) + ) + return float(a), float(b), float(c) + + +def polar_express_coeffs( + l=1e-3, + num_iters=8, + *, + safety_factor_eps=2e-2, + cushion=0.02407327424182761, +): + r"""Compute PolarExpress optimal Newton-Schulz coefficients. + + Computes per-iteration optimal quintic coefficients for the Newton-Schulz + matrix sign iteration. + + The ``l`` parameter controls the assumed lower bound on normalized singular + values after preconditioning. The default ``l=1e-3`` works well for both + bfloat16 and float32 training. + + Example:: + + # Custom lower bound (e.g. for float32 with tighter convergence): + coeffs = polar_express_coeffs(l=1e-7, num_iters=12) + + # Use with muon: + optimizer = optax.contrib.muon( + learning_rate=0.02, + ns_coeffs=coeffs, + ) + + Args: + l: Lower bound on normalized singular values. Must satisfy + ``0 < l <= 1``. Defaults to ``1e-3``. + num_iters: Number of Newton-Schulz iterations to compute coefficients + for. Defaults to ``8``. + safety_factor_eps: Epsilon for the safety factor ``1 + eps`` applied to + all iterations except the last. Contracts the polynomial slightly to + ensure convergence under floating-point round-off errors. See + Section 4.4 of Amsel et al., 2025. Defaults to ``1e-2``. + cushion: Minimum fraction of ``u`` used as the lower bound when + computing each iteration's optimal polynomial. When + ``cushion * u > l``, a rescaler is applied to maintain the correct + mapping. Helps with numerical stability of the Remez linear system + in early iterations. Defaults to ``0.024`` (from the reference + implementation). + + Returns: + A list of ``num_iters`` tuples ``(a, b, c)``, where each tuple contains + the quintic Newton-Schulz coefficients for that iteration. + + Raises: + ValueError: If ``l <= 0``, ``l > 1``, or ``num_iters < 1``. + + References: + Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their + Application to the Muon Algorithm + `_, 2025 + """ + u = 1.0 + if not 0 < l <= u: + raise ValueError(f'l must satisfy 0 < l <= 1, got {l}.') + if num_iters < 1: + raise ValueError(f'num_iters must be >= 1, got {num_iters}.') + safety_factor = 1 + safety_factor_eps + coefficients = [] + for i in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if i < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Note: 'polar_express' is not in this dict because it depends on ns_steps +# and is handled separately in muon() via polar_express_coeffs(). _NS_COEFFS_PRESET_DICT = { 'standard': _DEFAULT_NS_COEFFS, 'dion': _DION_NS_COEFFS, @@ -419,6 +554,7 @@ def scale_by_muon( - 'spectral' : Use Spectral norm rescaling before NS. - 'aol': Use AOL rescaling to improve orthogonality. - 'schatten': Use the Schatten-4 norm for rescaling. + Recommended when using PolarExpress coefficients. weight_dimension_numbers: An optional tree with the same structure as the params of `MuonDimensionNumbers`s, specifying how to reshape the parameters before and after the orthogonalization OR a callable returning @@ -449,8 +585,8 @@ def scale_by_muon( `_, 2025 Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their - Application to the Muon Algorithm`, - `, 2025 + Application to the Muon Algorithm + `_, 2025 """ mu_dtype = utils.canonicalize_dtype(mu_dtype) @@ -463,9 +599,9 @@ def init_fn(params): f'ns_coeffs must have shape (3,) or (n, 3), got {ns_coeffs_.shape}' ) if ns_coeffs_.ndim == 2: - if not ns_coeffs_.shape[0] <= ns_steps: + if ns_coeffs_.shape[0] < ns_steps: raise ValueError(f'Not enough coeffs to perform {ns_steps} steps') - ns_coeffs_ = ns_coeffs_[-ns_steps:] + ns_coeffs_ = ns_coeffs_[:ns_steps] return MuonState( count=jnp.zeros([], jnp.int32), @@ -567,7 +703,11 @@ def muon( learning_rate: A global scaling factor, either fixed or evolving along iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. ns_coeffs: Coefficients for the Newton-schulz method (can be a string - indicator for a preset). Existing presets: `muon`, `dion`. + indicator for a preset). Existing presets: ``standard``, ``dion``, + ``polar_express``. The ``polar_express`` preset uses ``l=1e-3`` and + the default cushion; for custom ``l``, ``cushion``, or + ``num_iters``, call :func:`polar_express_coeffs` directly and pass + the result as ``ns_coeffs``. ns_steps: Number of Newton-schulz iterations. Ignored if `ns_coeffs` is a tuple of tuples. beta: Decay rate for the exponentially weighted average of grads. @@ -598,6 +738,7 @@ def muon( See . - 'schatten': Use the Schatten-4 norm for rescaling, allows for better performance with little to no extra cost. + Recommended when using ``ns_coeffs='polar_express'``. See . adam_b1: Exponential decay rate for Adam's first moment estimates. adam_b2: Exponential decay rate for Adam's second moment estimates. @@ -643,17 +784,21 @@ def muon( `_, 2025 Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their - Application to the Muon Algorithm`, - `, 2025 + Application to the Muon Algorithm + `_, 2025 """ if adam_learning_rate is None: adam_learning_rate = learning_rate if isinstance(ns_coeffs, str): - if ns_coeffs not in _NS_COEFFS_PRESET_DICT: + if ns_coeffs == 'polar_express': + ns_coeffs_ = polar_express_coeffs(num_iters=ns_steps) + elif ns_coeffs in _NS_COEFFS_PRESET_DICT: + ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs] + else: raise ValueError(f'Unknown ns_coeff preset string: {ns_coeffs}') - ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs] + else: ns_coeffs_ = ns_coeffs diff --git a/optax/contrib/_muon_test.py b/optax/contrib/_muon_test.py index 02c02f78d..4f8172f07 100644 --- a/optax/contrib/_muon_test.py +++ b/optax/contrib/_muon_test.py @@ -65,6 +65,16 @@ def obj_fn(params): return initial_params, final_params, obj_fn +def _random_matrix_with_svs(key, shape, singular_values): + """Matrix with given singular values and random singular vectors.""" + m, n = shape + k = min(m, n) + key1, key2 = jax.random.split(key) + u, _ = jnp.linalg.qr(jax.random.normal(key1, (m, k))) + v, _ = jnp.linalg.qr(jax.random.normal(key2, (n, k))) + return u @ jnp.diag(jnp.asarray(singular_values)) @ v.T + + class MuonTest(parameterized.TestCase): @parameterized.named_parameters( @@ -355,6 +365,9 @@ def f(params_state, _): ('aol_square', 'aol', (100, 100)), ('aol_tall', 'aol', (100, 50)), ('aol_wide', 'aol', (50, 100)), + ('schatten_square', 'schatten', (100, 100)), + ('schatten_tall', 'schatten', (100, 50)), + ('schatten_wide', 'schatten', (50, 100)), ) def test_muon_orthogonalization_modes(self, preconditioning, shape): """Tests that Muon runs and produces near-orthogonal updates.""" @@ -420,6 +433,114 @@ def test_orthogonality(self, preconditioning): self.assertLess(ortho_error, 1e-3, f'Orthogonality error too high: {ortho_error}') + def test_polar_express(self): + """Tests PolarExpress ns_coeffs with frobenius preconditioning.""" + params = {'w': jnp.eye(8) * 2.0} + opt = _muon.muon( + learning_rate=0.1, + ns_coeffs='polar_express', + ns_steps=8, + ) + updates, _ = opt.update(params, opt.init(params), params) + w_update = updates['w'] + + for leaf in jax.tree_util.tree_leaves(updates): + self.assertFalse(jnp.isnan(leaf).any(), + 'Found NaN values in polar_express updates') + + # Check orthogonality. + gram = jnp.dot(w_update.T, w_update) + gram = gram / jnp.max(gram) + ortho_error = jnp.linalg.norm(gram - jnp.eye(gram.shape[0])) + self.assertLess(ortho_error, 1e-3, + f'Orthogonality error too high: {ortho_error}') + + def test_polar_express_numerical_difference(self): + """Ensures PolarExpress produces different updates than standard Muon.""" + params = {'w': jnp.eye(8) * 2.0} + + opt_std = _muon.muon(learning_rate=0.1, preconditioning='frobenius') + u_std, _ = opt_std.update(params, opt_std.init(params), params) + + opt_pe = _muon.muon( + learning_rate=0.1, + ns_coeffs='polar_express', + ns_steps=5, + preconditioning='frobenius' + ) + u_pe, _ = opt_pe.update(params, opt_pe.init(params), params) + + with self.assertRaises(AssertionError): + test_utils.assert_trees_all_close(u_std, u_pe) + + def test_polar_express_coeffs_match_reference(self): + """Computed coefficients match the paper's hard-coded values.""" + # Hard-coded coefficients from Amsel et al., 2025, Algorithm 1 + # (before safety factor application). + expected = [ + (8.28721201814563, -23.595886519098837, 17.300387312530933), + (4.107059111542203, -2.9478499167379106, 0.5448431082926601), + (3.9486908534822946, -2.908902115962949, 0.5518191394370137), + (3.3184196573706015, -2.488488024314874, 0.51004894012372), + (2.300652019954817, -1.6689039845747493, 0.4188073119525673), + (1.891301407787398, -1.2679958271945868, 0.37680408948524835), + (1.8750014808534479, -1.2500016453999487, 0.3750001645474248), + (1.875, -1.25, 0.375), + ] + computed = _muon.polar_express_coeffs( + l=1e-3, num_iters=8, + safety_factor_eps=0.0, cushion=0.02407327424182761, + ) + for i, (exp, got) in enumerate(zip(expected, computed)): + np.testing.assert_allclose( + got, exp, rtol=1e-8, + err_msg=f'Coefficient mismatch at iteration {i}', + ) + + @parameterized.named_parameters( + ('frobenius_low_rank', 'frobenius', 'low_rank'), + ('spectral_low_rank', 'spectral', 'low_rank'), + ('schatten_low_rank', 'schatten', 'low_rank'), + ('frobenius_binary', 'frobenius', 'binary'), + ('spectral_binary', 'spectral', 'binary'), + ('schatten_binary', 'schatten', 'binary'), + ) + def test_polar_express_hard_matrices(self, preconditioning, matrix_type): + """PolarExpress coefficients on hard matrices with random singular vectors. + + Tests two cases: + - low_rank: exponentially decaying singular values + - binary: singular values all 0 or 2 + """ + key = jax.random.key(42) + shape = (50, 100) + k = min(shape) + + if matrix_type == 'low_rank': + svs = 2.0 * np.exp(-0.5 * np.arange(k)) + else: + svs = np.array([2.0] * (k // 2) + [0.0] * (k - k // 2)) + mat = _random_matrix_with_svs(key, shape, svs) + + ns_coeffs = jnp.array(_muon.polar_express_coeffs( + l=1e-3, num_iters=10, + safety_factor_eps=1e-2, cushion=0.02)) + + result = _muon.orthogonalize_via_newton_schulz( + mat, ns_coeffs, ns_steps=10, + preconditioning=preconditioning, + dimension_numbers=_muon.MuonDimensionNumbers(0, 1)) + + self.assertFalse(jnp.any(jnp.isnan(result)).item()) + self.assertFalse(jnp.any(jnp.isinf(result)).item()) + self.assertEqual(result.shape, shape) + # SVs above l should converge to 1 (orthogonality). + n_above_l = int(np.sum(svs > 1e-3)) + out_svs = jnp.linalg.svd(result, compute_uv=False)[:n_above_l] + np.testing.assert_allclose( + out_svs, jnp.ones(n_above_l), atol=1e-3, + err_msg=f'SVs above l did not converge to 1 ({preconditioning})') + if __name__ == '__main__': absltest.main()