From cecd130a068a5b01c40609968f0f651ee6117df3 Mon Sep 17 00:00:00 2001 From: mihara-bot <1147220090@qq.com> Date: Fri, 20 Mar 2026 18:20:32 +0800 Subject: [PATCH 1/7] Support CANS orthogonalization in Muon. This adds `coefficient_type=\"cans\"` Newton-Schulz coefficients (and tests) so the optimizer can match CANS-based Muon implementations. Made-with: Cursor Signed-off-by: mihara-bot <1147220090@qq.com> --- .../orthogonalized_optimizers/muon.py | 2 +- .../orthogonalized_optimizers/muon_utils.py | 17 ++++++++-- .../orthogonalized_optimizers/polargrad.py | 2 +- .../orthogonalized_optimizers/scion.py | 2 +- tests/test_muon_utils.py | 34 +++++++++++++++++++ 5 files changed, 52 insertions(+), 5 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon.py b/emerging_optimizers/orthogonalized_optimizers/muon.py index b3613a2b..7cee64f5 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon.py @@ -63,7 +63,7 @@ class Muon(OrthogonalizedOptimizer): Args: {_args_doc} coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. Can be one of - ["simple", "quintic", "polar_express"]. + ["simple", "quintic", "polar_express", "cans"]. num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration. scale_mode: The type of scale factor to use for the update. Defaults to "spectral" style scaling. extra_scale_factor: The additional scale factor to use for the update. Setting it to 0.2 can closely match diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py index 4eeb4500..1c421337 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py @@ -25,7 +25,7 @@ CoeffIterMode = Literal["cycle", "repeat_last"] -NSCoeffT = Literal["simple", "quintic", "polar_express", "aol", "custom"] +NSCoeffT = Literal["simple", "quintic", "polar_express", "aol", "cans", "custom"] _COEFFICIENT_SETS = { # Values are rounded to closest representable in single precision. @@ -55,6 +55,15 @@ (1.8564, -1.2132, 0.3568), (1.8750, -1.2500, 0.3750), ], + "cans": [ + # CANS iteration (Remez + adaptive interval) based coefficients. + # Source (generation): accelerating_orthogonalization/polynomials.py + (8.4703, -25.1081, 18.6293), + (4.1828, -3.1087, 0.5806), + (3.9619, -2.9541, 0.5630), + (3.2866, -2.4647, 0.5074), + (2.2737, -1.6447, 0.4162), + ], "aol": [ # from https://github.com/thib-s/flash-newton-schulz/blob/main/newton_schulz_triton.py#L511 (4.0098, -7.0585, 2.4635), @@ -136,6 +145,8 @@ def newton_schulz( - "simple": Default coefficient set. - "quintic": Quintic iteration with optimized coefficients. - "polar_express": Polar Express iteration with optimized coefficients. + - "cans": CANS iteration with Remez + adaptive interval coefficients. + - "aol": AOL coefficient set. - "custom": Custom coefficient sets. Arguments: @@ -179,7 +190,9 @@ def newton_schulz( else: raise ValueError(f"Invalid coefficient type: {coefficient_type}") - iter_mode: CoeffIterMode = "cycle" if coefficient_type != "polar_express" else "repeat_last" + iter_mode: CoeffIterMode = ( + "repeat_last" if coefficient_type in ("polar_express", "cans") else "cycle" + ) coeff_iter = get_coefficient_iterator(steps, coefficient_sets, mode=iter_mode) ns_step_fn = newton_schulz_step diff --git a/emerging_optimizers/orthogonalized_optimizers/polargrad.py b/emerging_optimizers/orthogonalized_optimizers/polargrad.py index 4b8104aa..af435653 100644 --- a/emerging_optimizers/orthogonalized_optimizers/polargrad.py +++ b/emerging_optimizers/orthogonalized_optimizers/polargrad.py @@ -55,7 +55,7 @@ class PolarGrad(OrthogonalizedOptimizer): Args: {_args_doc} coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. Can be one of - ["simple", "quintic", "polar_express"]. + ["simple", "quintic", "polar_express", "cans"]. num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration. extra_scale_factor: The additional scale factor to use for the update. Setting it to 0.2 can closely match the update RMS norm of AdamW as suggested by https://arxiv.org/abs/2502.16982. diff --git a/emerging_optimizers/orthogonalized_optimizers/scion.py b/emerging_optimizers/orthogonalized_optimizers/scion.py index d5982aa9..f3efc127 100644 --- a/emerging_optimizers/orthogonalized_optimizers/scion.py +++ b/emerging_optimizers/orthogonalized_optimizers/scion.py @@ -57,7 +57,7 @@ class Scion(OrthogonalizedOptimizer): momentum: The momentum used by the internal SGD. fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations. coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. Can be one of - ["simple", "quintic", "polar_express"]. + ["simple", "quintic", "polar_express", "cans"]. num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration. spectral_radius: The spectral radius to use for the update, we are scaling the LMO by this spectral radius. """ diff --git a/tests/test_muon_utils.py b/tests/test_muon_utils.py index c8d25c7b..95d2f69e 100644 --- a/tests/test_muon_utils.py +++ b/tests/test_muon_utils.py @@ -220,6 +220,40 @@ def test_get_polar_express_9steps_close_to_reference(self, dim1, dim2): out_ref = newton_schulz_ref(x, coefficient_sets=coeff) torch.testing.assert_close(out_pe9, out_ref, atol=2e-6, rtol=1e-7) + @parameterized.parameters( + (512, 512), + (512, 256), + (256, 512), + ) + def test_cans_close_to_reference(self, dim1, dim2): + x = torch.randn(dim1, dim2, device=self.device, dtype=torch.float32) + out_cans_test = muon_utils.newton_schulz(x, steps=5, coefficient_type="cans") + out_cans_ref = newton_schulz_ref(x, coefficient_sets=muon_utils._COEFFICIENT_SETS["cans"]) + + torch.testing.assert_close( + out_cans_test, + out_cans_ref, + atol=1e-6, + rtol=1e-7, + ) + + @parameterized.parameters( + (511, 513), + (511, 257), + (257, 513), + ) + def test_get_cans_9steps_close_to_reference(self, dim1, dim2): + x = torch.randn(dim1, dim2, device=self.device, dtype=torch.float32) + out_cans9 = muon_utils.newton_schulz(x, steps=9, coefficient_type="cans") + + coeff = deepcopy(muon_utils._COEFFICIENT_SETS["cans"]) + # CANS uses repeat_last, so repeat the last tuple for remaining steps. + coeff.append(coeff[-1]) + coeff.append(coeff[-1]) + coeff.append(coeff[-1]) + out_ref = newton_schulz_ref(x, coefficient_sets=coeff) + torch.testing.assert_close(out_cans9, out_ref, atol=2e-6, rtol=1e-7) + @absltest.skipIf( _SM_VERSION not in ((8, 0), (9, 0), (10, 0), (10, 3)), From 5ae47fa71b0e18292440f06681f50f8f84bf295d Mon Sep 17 00:00:00 2001 From: Xinlin Zhuang <1147220090@qq.com> Date: Fri, 20 Mar 2026 18:37:17 +0800 Subject: [PATCH 2/7] Update tests/test_muon_utils.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Xinlin Zhuang <1147220090@qq.com> Signed-off-by: mihara-bot <1147220090@qq.com> --- tests/test_muon_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_muon_utils.py b/tests/test_muon_utils.py index 95d2f69e..d4694bb8 100644 --- a/tests/test_muon_utils.py +++ b/tests/test_muon_utils.py @@ -242,17 +242,13 @@ def test_cans_close_to_reference(self, dim1, dim2): (511, 257), (257, 513), ) - def test_get_cans_9steps_close_to_reference(self, dim1, dim2): - x = torch.randn(dim1, dim2, device=self.device, dtype=torch.float32) - out_cans9 = muon_utils.newton_schulz(x, steps=9, coefficient_type="cans") - coeff = deepcopy(muon_utils._COEFFICIENT_SETS["cans"]) # CANS uses repeat_last, so repeat the last tuple for remaining steps. coeff.append(coeff[-1]) coeff.append(coeff[-1]) coeff.append(coeff[-1]) + coeff.append(coeff[-1]) out_ref = newton_schulz_ref(x, coefficient_sets=coeff) - torch.testing.assert_close(out_cans9, out_ref, atol=2e-6, rtol=1e-7) @absltest.skipIf( From 5d5cb6193d862515b819554b204a1061b2e596b2 Mon Sep 17 00:00:00 2001 From: Xinlin Zhuang <1147220090@qq.com> Date: Fri, 20 Mar 2026 18:39:38 +0800 Subject: [PATCH 3/7] Update tests/test_muon_utils.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Xinlin Zhuang <1147220090@qq.com> Signed-off-by: mihara-bot <1147220090@qq.com> --- tests/test_muon_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_muon_utils.py b/tests/test_muon_utils.py index d4694bb8..f494502e 100644 --- a/tests/test_muon_utils.py +++ b/tests/test_muon_utils.py @@ -242,6 +242,9 @@ def test_cans_close_to_reference(self, dim1, dim2): (511, 257), (257, 513), ) + def test_get_cans_9steps_close_to_reference(self, dim1, dim2): + x = torch.randn(dim1, dim2, device=self.device, dtype=torch.float32) + out_cans9 = muon_utils.newton_schulz(x, steps=9, coefficient_type="cans") coeff = deepcopy(muon_utils._COEFFICIENT_SETS["cans"]) # CANS uses repeat_last, so repeat the last tuple for remaining steps. coeff.append(coeff[-1]) @@ -249,6 +252,7 @@ def test_cans_close_to_reference(self, dim1, dim2): coeff.append(coeff[-1]) coeff.append(coeff[-1]) out_ref = newton_schulz_ref(x, coefficient_sets=coeff) + torch.testing.assert_close(out_cans9, out_ref, atol=2e-6, rtol=1e-7) @absltest.skipIf( From a731cea867fe4ce3f19b74cc7f5dcdf0d1cfe4ca Mon Sep 17 00:00:00 2001 From: Xinlin Zhuang <1147220090@qq.com> Date: Fri, 20 Mar 2026 18:44:33 +0800 Subject: [PATCH 4/7] Update emerging_optimizers/orthogonalized_optimizers/muon_utils.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Xinlin Zhuang <1147220090@qq.com> Signed-off-by: mihara-bot <1147220090@qq.com> --- emerging_optimizers/orthogonalized_optimizers/muon_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py index 1c421337..098a0205 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py @@ -25,7 +25,7 @@ CoeffIterMode = Literal["cycle", "repeat_last"] -NSCoeffT = Literal["simple", "quintic", "polar_express", "aol", "cans", "custom"] +NSCoeffT = Literal["simple", "quintic", "polar_express", "cans", "aol", "custom"] _COEFFICIENT_SETS = { # Values are rounded to closest representable in single precision. From e225bb244c0a92f8631219a4af5b3eec24251e97 Mon Sep 17 00:00:00 2001 From: mihara-bot <1147220090@qq.com> Date: Sat, 21 Mar 2026 00:36:54 +0800 Subject: [PATCH 5/7] clarify source and fix verbose Signed-off-by: mihara-bot <1147220090@qq.com> --- emerging_optimizers/orthogonalized_optimizers/muon_utils.py | 3 ++- tests/test_muon_utils.py | 5 +---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py index 098a0205..5d7ee983 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py @@ -56,8 +56,9 @@ (1.8750, -1.2500, 0.3750), ], "cans": [ + # CANS from: http://arxiv.org/abs/2506.10935 # CANS iteration (Remez + adaptive interval) based coefficients. - # Source (generation): accelerating_orthogonalization/polynomials.py + # Source (for generating CANS coefficients): https://github.com/GrishKate/accelerating_orthogonalization/blob/main/polynomials.py (8.4703, -25.1081, 18.6293), (4.1828, -3.1087, 0.5806), (3.9619, -2.9541, 0.5630), diff --git a/tests/test_muon_utils.py b/tests/test_muon_utils.py index f494502e..6557f8e5 100644 --- a/tests/test_muon_utils.py +++ b/tests/test_muon_utils.py @@ -247,10 +247,7 @@ def test_get_cans_9steps_close_to_reference(self, dim1, dim2): out_cans9 = muon_utils.newton_schulz(x, steps=9, coefficient_type="cans") coeff = deepcopy(muon_utils._COEFFICIENT_SETS["cans"]) # CANS uses repeat_last, so repeat the last tuple for remaining steps. - coeff.append(coeff[-1]) - coeff.append(coeff[-1]) - coeff.append(coeff[-1]) - coeff.append(coeff[-1]) + coeff.extend([coeff[-1]] * 4) out_ref = newton_schulz_ref(x, coefficient_sets=coeff) torch.testing.assert_close(out_cans9, out_ref, atol=2e-6, rtol=1e-7) From 5f405d29f0fbdd19d97488fe534183f798034f52 Mon Sep 17 00:00:00 2001 From: mihara-bot <1147220090@qq.com> Date: Thu, 26 Mar 2026 11:23:15 +0800 Subject: [PATCH 6/7] fix linting error Signed-off-by: mihara-bot <1147220090@qq.com> --- emerging_optimizers/orthogonalized_optimizers/muon_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py index 5d7ee983..925afa76 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py @@ -191,9 +191,7 @@ def newton_schulz( else: raise ValueError(f"Invalid coefficient type: {coefficient_type}") - iter_mode: CoeffIterMode = ( - "repeat_last" if coefficient_type in ("polar_express", "cans") else "cycle" - ) + iter_mode: CoeffIterMode = "repeat_last" if coefficient_type in ("polar_express", "cans") else "cycle" coeff_iter = get_coefficient_iterator(steps, coefficient_sets, mode=iter_mode) ns_step_fn = newton_schulz_step From 1b81d8d3eeedadabfb8a97e3b48dbd85e0be8e90 Mon Sep 17 00:00:00 2001 From: mihara-bot <1147220090@qq.com> Date: Thu, 26 Mar 2026 17:03:45 +0800 Subject: [PATCH 7/7] loosen strictness in test Signed-off-by: mihara-bot <1147220090@qq.com> --- tests/test_muon_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_muon_utils.py b/tests/test_muon_utils.py index 6557f8e5..22f04ec2 100644 --- a/tests/test_muon_utils.py +++ b/tests/test_muon_utils.py @@ -233,7 +233,7 @@ def test_cans_close_to_reference(self, dim1, dim2): torch.testing.assert_close( out_cans_test, out_cans_ref, - atol=1e-6, + atol=1e-5, rtol=1e-7, )