Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
e687744
implementing polar express preconditioning and ns_coeffs
MarcMachaczek Mar 4, 2026
e5473f9
padding ns_coeffs for polar_express when number of steps exceeds the …
MarcMachaczek Mar 4, 2026
6b8834e
fixed ns_coeffs selection logic when number of steps is smaller than …
MarcMachaczek Mar 4, 2026
ab9656c
spacing
MarcMachaczek Mar 4, 2026
b91e978
documentation and proper warnings
MarcMachaczek Mar 4, 2026
e61b4a5
formatting
MarcMachaczek Mar 4, 2026
8c94ae2
tests (pass)
MarcMachaczek Mar 4, 2026
9be1e2e
formatting
MarcMachaczek Mar 4, 2026
595e090
computing optimal polar express coefficients for specified ns_iter an…
MarcMachaczek Mar 4, 2026
48c5ce2
relax tolerance for test
MarcMachaczek Mar 5, 2026
591e39e
test mismatch reason solved: increased cutoff for l/u ratio for which…
MarcMachaczek Mar 5, 2026
b50ca12
testing polar express with different preconditioners on hard matrices…
MarcMachaczek Mar 5, 2026
84a5f3a
compare coefficients to hard coded coefficientw as found in the origi…
MarcMachaczek Mar 5, 2026
78d06d6
changing how safety factor is applied to the polynomial coefficients.…
MarcMachaczek Mar 5, 2026
d0a72f7
use default dtype independent lower bound for polar express of 1e-3
MarcMachaczek Mar 5, 2026
b522e88
remove polar express preconditioning; use default (frobenius) and men…
MarcMachaczek Mar 5, 2026
4af8ab1
removed polar express constants in favor of default arguments
MarcMachaczek Mar 5, 2026
560a424
final formatting changes to minimize diff with main
MarcMachaczek Mar 5, 2026
f973db1
increase default safety factor to 2e-2, include safety factor inside …
MarcMachaczek Mar 8, 2026
e0dbb0b
slightly bump threshold for remez iteration cutoff for robustness acr…
MarcMachaczek Mar 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions optax/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from optax.contrib._muon import muon
from optax.contrib._muon import MuonDimensionNumbers
from optax.contrib._muon import MuonState
from optax.contrib._muon import polar_express_coeffs
from optax.contrib._muon import scale_by_muon
from optax.contrib._privacy import differentially_private_aggregate
from optax.contrib._privacy import DifferentiallyPrivateAggregateState
Expand Down
163 changes: 154 additions & 9 deletions optax/contrib/_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import jax
import jax.numpy as jnp
import numpy as np

from optax._src import alias
from optax._src import base
Expand All @@ -47,6 +48,140 @@
(2.8769, -3.1427, 1.2046),
(2.8366, -3.0525, 1.2012),
]


def _optimal_quintic(l, u):
r"""Optimal quintic coefficients for the Newton-Schulz iteration.

Uses a simplified Remez algorithm to find coefficients (a, b, c) for the
odd quintic :math:`p(x) = ax + bx^3 + cx^5` that minimizes the Chebyshev
(minimax) approximation error :math:`\max_{x \in [\ell, u]} |1 - p(x)|`.

Args:
l: Lower bound on singular values. Must satisfy ``0 < l <= u``.
u: Upper bound on singular values. Must be positive.

Returns:
A tuple ``(a, b, c)`` of quintic iteration coefficients.

Raises:
ValueError: If ``l <= 0``, ``u <= 0``, or ``l > u``.

References:
Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their
Application to the Muon Algorithm
<https://arxiv.org/abs/2505.16932>`_, 2025, Section 4.2.
"""
if not 0 < l <= u:
raise ValueError(f'l must satisfy 0 < l <= u, got l={l}, u={u}.')
if 1 - 1e-5 <= l / u:
return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
q = (3 * l + u) / 4
r = (l + 3 * u) / 4
max_iter = 1000
e, old_e, n_iter = np.inf, None, 0
while old_e is None or abs(old_e - e) > 1e-15:
if n_iter >= max_iter:
break
n_iter += 1
old_e = e
lhs = np.array([
[l, l**3, l**5, 1],
[q, q**3, q**5, -1],
[r, r**3, r**5, 1],
[u, u**3, u**5, -1],
])
a, b, c, e = np.linalg.solve(lhs, np.ones(4))
q, r = np.sqrt(
(-3 * b + np.array([-1, 1]) * np.sqrt(9 * b**2 - 20 * a * c))
/ (10 * c)
)
return float(a), float(b), float(c)


def polar_express_coeffs(
l=1e-3,
num_iters=8,
*,
safety_factor_eps=2e-2,
cushion=0.02407327424182761,
):
r"""Compute PolarExpress optimal Newton-Schulz coefficients.

Computes per-iteration optimal quintic coefficients for the Newton-Schulz
matrix sign iteration.

The ``l`` parameter controls the assumed lower bound on normalized singular
values after preconditioning. The default ``l=1e-3`` works well for both
bfloat16 and float32 training.

Example::

# Custom lower bound (e.g. for float32 with tighter convergence):
coeffs = polar_express_coeffs(l=1e-7, num_iters=12)

# Use with muon:
optimizer = optax.contrib.muon(
learning_rate=0.02,
ns_coeffs=coeffs,
)

Args:
l: Lower bound on normalized singular values. Must satisfy
``0 < l <= 1``. Defaults to ``1e-3``.
num_iters: Number of Newton-Schulz iterations to compute coefficients
for. Defaults to ``8``.
safety_factor_eps: Epsilon for the safety factor ``1 + eps`` applied to
all iterations except the last. Contracts the polynomial slightly to
ensure convergence under floating-point round-off errors. See
Section 4.4 of Amsel et al., 2025. Defaults to ``1e-2``.
cushion: Minimum fraction of ``u`` used as the lower bound when
computing each iteration's optimal polynomial. When
``cushion * u > l``, a rescaler is applied to maintain the correct
mapping. Helps with numerical stability of the Remez linear system
in early iterations. Defaults to ``0.024`` (from the reference
implementation).

Returns:
A list of ``num_iters`` tuples ``(a, b, c)``, where each tuple contains
the quintic Newton-Schulz coefficients for that iteration.

Raises:
ValueError: If ``l <= 0``, ``l > 1``, or ``num_iters < 1``.

References:
Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their
Application to the Muon Algorithm
<https://arxiv.org/abs/2505.16932>`_, 2025
"""
u = 1.0
if not 0 < l <= u:
raise ValueError(f'l must satisfy 0 < l <= 1, got {l}.')
if num_iters < 1:
raise ValueError(f'num_iters must be >= 1, got {num_iters}.')
safety_factor = 1 + safety_factor_eps
coefficients = []
for i in range(num_iters):
a, b, c = _optimal_quintic(max(l, cushion * u), u)
if cushion * u > l:
pl = a * l + b * l**3 + c * l**5
pu = a * u + b * u**3 + c * u**5
rescaler = 2 / (pl + pu)
a *= rescaler
b *= rescaler
c *= rescaler
if i < num_iters - 1:
a /= safety_factor
b /= safety_factor**3
c /= safety_factor**5
coefficients.append((a, b, c))
l = a * l + b * l**3 + c * l**5
u = 2 - l
return coefficients


# Note: 'polar_express' is not in this dict because it depends on ns_steps
# and is handled separately in muon() via polar_express_coeffs().
_NS_COEFFS_PRESET_DICT = {
'standard': _DEFAULT_NS_COEFFS,
'dion': _DION_NS_COEFFS,
Expand Down Expand Up @@ -419,6 +554,7 @@ def scale_by_muon(
- 'spectral' : Use Spectral norm rescaling before NS.
- 'aol': Use AOL rescaling to improve orthogonality.
- 'schatten': Use the Schatten-4 norm for rescaling.
Recommended when using PolarExpress coefficients.
weight_dimension_numbers: An optional tree with the same structure as the
params of `MuonDimensionNumbers`s, specifying how to reshape the
parameters before and after the orthogonalization OR a callable returning
Expand Down Expand Up @@ -449,8 +585,8 @@ def scale_by_muon(
<https://arxiv.org/abs/2506.10935>`_, 2025

Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their
Application to the Muon Algorithm`,
<https://arxiv.org/pdf/2505.16932>`, 2025
Application to the Muon Algorithm
<https://arxiv.org/abs/2505.16932>`_, 2025
"""
mu_dtype = utils.canonicalize_dtype(mu_dtype)

Expand All @@ -463,9 +599,9 @@ def init_fn(params):
f'ns_coeffs must have shape (3,) or (n, 3), got {ns_coeffs_.shape}'
)
if ns_coeffs_.ndim == 2:
if not ns_coeffs_.shape[0] <= ns_steps:
if ns_coeffs_.shape[0] < ns_steps:
raise ValueError(f'Not enough coeffs to perform {ns_steps} steps')
ns_coeffs_ = ns_coeffs_[-ns_steps:]
ns_coeffs_ = ns_coeffs_[:ns_steps]

return MuonState(
count=jnp.zeros([], jnp.int32),
Expand Down Expand Up @@ -567,7 +703,11 @@ def muon(
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
ns_coeffs: Coefficients for the Newton-schulz method (can be a string
indicator for a preset). Existing presets: `muon`, `dion`.
indicator for a preset). Existing presets: ``standard``, ``dion``,
``polar_express``. The ``polar_express`` preset uses ``l=1e-3`` and
the default cushion; for custom ``l``, ``cushion``, or
``num_iters``, call :func:`polar_express_coeffs` directly and pass
the result as ``ns_coeffs``.
ns_steps: Number of Newton-schulz iterations.
Ignored if `ns_coeffs` is a tuple of tuples.
beta: Decay rate for the exponentially weighted average of grads.
Expand Down Expand Up @@ -598,6 +738,7 @@ def muon(
See <https://arxiv.org/abs/2512.04632>.
- 'schatten': Use the Schatten-4 norm for rescaling,
allows for better performance with little to no extra cost.
Recommended when using ``ns_coeffs='polar_express'``.
See <https://arxiv.org/abs/2506.10935>.
adam_b1: Exponential decay rate for Adam's first moment estimates.
adam_b2: Exponential decay rate for Adam's second moment estimates.
Expand Down Expand Up @@ -643,17 +784,21 @@ def muon(
<https://arxiv.org/abs/2506.10935>`_, 2025

Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their
Application to the Muon Algorithm`,
<https://arxiv.org/pdf/2505.16932>`, 2025
Application to the Muon Algorithm
<https://arxiv.org/abs/2505.16932>`_, 2025
"""

if adam_learning_rate is None:
adam_learning_rate = learning_rate

if isinstance(ns_coeffs, str):
if ns_coeffs not in _NS_COEFFS_PRESET_DICT:
if ns_coeffs == 'polar_express':
ns_coeffs_ = polar_express_coeffs(num_iters=ns_steps)
elif ns_coeffs in _NS_COEFFS_PRESET_DICT:
ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs]
else:
raise ValueError(f'Unknown ns_coeff preset string: {ns_coeffs}')
ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs]

else:
ns_coeffs_ = ns_coeffs

Expand Down
121 changes: 121 additions & 0 deletions optax/contrib/_muon_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ def obj_fn(params):
return initial_params, final_params, obj_fn


def _random_matrix_with_svs(key, shape, singular_values):
"""Matrix with given singular values and random singular vectors."""
m, n = shape
k = min(m, n)
key1, key2 = jax.random.split(key)
u, _ = jnp.linalg.qr(jax.random.normal(key1, (m, k)))
v, _ = jnp.linalg.qr(jax.random.normal(key2, (n, k)))
return u @ jnp.diag(jnp.asarray(singular_values)) @ v.T


class MuonTest(parameterized.TestCase):

@parameterized.named_parameters(
Expand Down Expand Up @@ -355,6 +365,9 @@ def f(params_state, _):
('aol_square', 'aol', (100, 100)),
('aol_tall', 'aol', (100, 50)),
('aol_wide', 'aol', (50, 100)),
('schatten_square', 'schatten', (100, 100)),
('schatten_tall', 'schatten', (100, 50)),
('schatten_wide', 'schatten', (50, 100)),
)
def test_muon_orthogonalization_modes(self, preconditioning, shape):
"""Tests that Muon runs and produces near-orthogonal updates."""
Expand Down Expand Up @@ -420,6 +433,114 @@ def test_orthogonality(self, preconditioning):
self.assertLess(ortho_error, 1e-3,
f'Orthogonality error too high: {ortho_error}')

def test_polar_express(self):
"""Tests PolarExpress ns_coeffs with frobenius preconditioning."""
params = {'w': jnp.eye(8) * 2.0}
opt = _muon.muon(
learning_rate=0.1,
ns_coeffs='polar_express',
ns_steps=8,
)
updates, _ = opt.update(params, opt.init(params), params)
w_update = updates['w']

for leaf in jax.tree_util.tree_leaves(updates):
self.assertFalse(jnp.isnan(leaf).any(),
'Found NaN values in polar_express updates')

# Check orthogonality.
gram = jnp.dot(w_update.T, w_update)
gram = gram / jnp.max(gram)
ortho_error = jnp.linalg.norm(gram - jnp.eye(gram.shape[0]))
self.assertLess(ortho_error, 1e-3,
f'Orthogonality error too high: {ortho_error}')

def test_polar_express_numerical_difference(self):
"""Ensures PolarExpress produces different updates than standard Muon."""
params = {'w': jnp.eye(8) * 2.0}

opt_std = _muon.muon(learning_rate=0.1, preconditioning='frobenius')
u_std, _ = opt_std.update(params, opt_std.init(params), params)

opt_pe = _muon.muon(
learning_rate=0.1,
ns_coeffs='polar_express',
ns_steps=5,
preconditioning='frobenius'
)
u_pe, _ = opt_pe.update(params, opt_pe.init(params), params)

with self.assertRaises(AssertionError):
test_utils.assert_trees_all_close(u_std, u_pe)

def test_polar_express_coeffs_match_reference(self):
"""Computed coefficients match the paper's hard-coded values."""
# Hard-coded coefficients from Amsel et al., 2025, Algorithm 1
# (before safety factor application).
expected = [
(8.28721201814563, -23.595886519098837, 17.300387312530933),
(4.107059111542203, -2.9478499167379106, 0.5448431082926601),
(3.9486908534822946, -2.908902115962949, 0.5518191394370137),
(3.3184196573706015, -2.488488024314874, 0.51004894012372),
(2.300652019954817, -1.6689039845747493, 0.4188073119525673),
(1.891301407787398, -1.2679958271945868, 0.37680408948524835),
(1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
(1.875, -1.25, 0.375),
]
computed = _muon.polar_express_coeffs(
l=1e-3, num_iters=8,
safety_factor_eps=0.0, cushion=0.02407327424182761,
)
for i, (exp, got) in enumerate(zip(expected, computed)):
np.testing.assert_allclose(
got, exp, rtol=1e-8,
err_msg=f'Coefficient mismatch at iteration {i}',
)

@parameterized.named_parameters(
('frobenius_low_rank', 'frobenius', 'low_rank'),
('spectral_low_rank', 'spectral', 'low_rank'),
('schatten_low_rank', 'schatten', 'low_rank'),
('frobenius_binary', 'frobenius', 'binary'),
('spectral_binary', 'spectral', 'binary'),
('schatten_binary', 'schatten', 'binary'),
)
def test_polar_express_hard_matrices(self, preconditioning, matrix_type):
"""PolarExpress coefficients on hard matrices with random singular vectors.

Tests two cases:
- low_rank: exponentially decaying singular values
- binary: singular values all 0 or 2
"""
key = jax.random.key(42)
shape = (50, 100)
k = min(shape)

if matrix_type == 'low_rank':
svs = 2.0 * np.exp(-0.5 * np.arange(k))
else:
svs = np.array([2.0] * (k // 2) + [0.0] * (k - k // 2))
mat = _random_matrix_with_svs(key, shape, svs)

ns_coeffs = jnp.array(_muon.polar_express_coeffs(
l=1e-3, num_iters=10,
safety_factor_eps=1e-2, cushion=0.02))

result = _muon.orthogonalize_via_newton_schulz(
mat, ns_coeffs, ns_steps=10,
preconditioning=preconditioning,
dimension_numbers=_muon.MuonDimensionNumbers(0, 1))

self.assertFalse(jnp.any(jnp.isnan(result)).item())
self.assertFalse(jnp.any(jnp.isinf(result)).item())
self.assertEqual(result.shape, shape)
# SVs above l should converge to 1 (orthogonality).
n_above_l = int(np.sum(svs > 1e-3))
out_svs = jnp.linalg.svd(result, compute_uv=False)[:n_above_l]
np.testing.assert_allclose(
out_svs, jnp.ones(n_above_l), atol=1e-3,
err_msg=f'SVs above l did not converge to 1 ({preconditioning})')


if __name__ == '__main__':
absltest.main()
Loading