From 89657bf38b54c4f47c8711a76f7ff4f734b24d3c Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 16 Feb 2026 19:41:27 -0800 Subject: [PATCH 01/15] added spectron + test Signed-off-by: mikail --- .../orthogonalized_optimizers/spectron.py | 303 ++++++++++++++++ tests/test_spectron.py | 343 ++++++++++++++++++ 2 files changed, 646 insertions(+) create mode 100644 emerging_optimizers/orthogonalized_optimizers/spectron.py create mode 100644 tests/test_spectron.py diff --git a/emerging_optimizers/orthogonalized_optimizers/spectron.py b/emerging_optimizers/orthogonalized_optimizers/spectron.py new file mode 100644 index 00000000..4df1d97f --- /dev/null +++ b/emerging_optimizers/orthogonalized_optimizers/spectron.py @@ -0,0 +1,303 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, overload, override + +import torch +import torch.optim as optim +from absl import logging +from torch.optim.optimizer import ParamsT + +from emerging_optimizers import mixin as opt_mixin +from emerging_optimizers import registry, utils +from emerging_optimizers.orthogonalized_optimizers import muon_utils +from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT +from emerging_optimizers.utils import FP32MatmulPrecT + + +__all__ = ["Spectron", "power_iteration"] + + +def power_iteration( + W: torch.Tensor, + u: torch.Tensor, + k: int = 1, + eps: float = 1e-8, +) -> tuple[torch.Tensor, torch.Tensor]: + """Approximate largest singular value and left singular vector using power iteration. + + Implements Algorithm 3 from the Spectron paper. This method iteratively refines + estimates of the dominant singular value and corresponding left singular vector + of a matrix W. + + Args: + W: Matrix of shape (p, q) to analyze + u: Initial left singular vector of shape (p,), should be normalized + k: Number of power iteration steps. Default: 1 + eps: Small constant for numerical stability. Default: 1e-8 + + Returns: + Tuple of (sigma, u) where: + - sigma: Approximation of the largest singular value (scalar tensor) + - u: Updated left singular vector of shape (p,) + """ + # Ensure initial normalization + u = u / u.norm(p=2).clamp_min(eps) + + # Power iteration loop + for _ in range(k): + # v ← W^T u (right vector) + v = W.mT @ u + + # v ← v / ||v||_2 (normalize right vector) + v = v / v.norm(p=2).clamp_min(eps) + + # u ← W v (left vector) + u = W @ v + + # u ← u / ||u||_2 (normalize left vector) + u = u / u.norm(p=2).clamp_min(eps) + + # σ ← u^T W v (Rayleigh quotient approximation) + v = W.mT @ u + v = v / v.norm(p=2).clamp_min(eps) + sigma = u @ (W @ v) + + # Return σ and u + return sigma, u + + +@registry.register_optimizer("spectron") +class Spectron(opt_mixin.WeightDecayMixin, optim.Optimizer): + """Spectron: Low-rank spectral optimizer with orthogonalized momentum. + + Spectron maintains each 2D weight matrix W as a low-rank factorization W = A @ B^T, + where A ∈ R^(m×r) and B ∈ R^(n×r). It applies momentum, orthogonalizes the updates + using Newton-Schulz iteration, and scales the learning rate by the spectral radii + of both factors. + + The algorithm: + 1. Compute gradients with respect to A and B from parameter gradients + 2. Apply momentum to both factors + 3. Orthogonalize momentum buffers using Newton-Schulz iteration + 4. Estimate spectral radius of A and B using power iteration + 5. Update with scaled learning rate: η / (σ_A + σ_B + 1) + 6. Reconstruct full weight matrix W = A @ B^T + + References: + - Algorithm 1 (Spectron) and Algorithm 3 (PowerIter) from the Spectron paper. + Low-rank spectral optimization with orthogonalized momentum. + + Warning: + - This optimizer requires that all parameters passed in are 2D. + - Low-rank factorization may not be suitable for all parameter types. + + Args: + params: Iterable of parameters to optimize or dicts defining parameter groups + lr: The learning rate (η in the algorithm). Default: 3e-4 + rank: The rank of the low-rank factorization. Default: 64 + momentum_beta: The momentum decay coefficient (β). Default: 0.9 + weight_decay: The weight decay coefficient. Default: 0.01 + weight_decay_method: Method to apply weight decay. Default: "decoupled" + fp32_matmul_prec: Precision of matmul operations. Default: "medium" + num_ns_steps: Number of Newton-Schulz iteration steps. Default: 5 + num_power_iter: Number of power iteration steps for spectral radius. Default: 1 + coefficient_type: Type of coefficient set for Newton-Schulz. Default: "quintic" + """ + + def __init__( + self, + params: ParamsT, + lr: float = 3e-4, + rank: int = 64, + momentum_beta: float = 0.9, + weight_decay: float = 0.01, + *, + weight_decay_method: opt_mixin.WeightDecayT = "decoupled", + fp32_matmul_prec: FP32MatmulPrecT = "medium", + num_ns_steps: int = 5, + num_power_iter: int = 1, + coefficient_type: NSCoeffT = "quintic", + ) -> None: + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if rank < 1: + raise ValueError(f"Invalid rank: {rank}") + if not 0.0 <= momentum_beta < 1.0: + raise ValueError(f"Invalid momentum_beta: {momentum_beta}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay: {weight_decay}") + if num_ns_steps < 1: + raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") + if num_power_iter < 1: + raise ValueError(f"num_power_iter must be at least 1, got {num_power_iter}") + + self.fp32_matmul_prec = fp32_matmul_prec + self.weight_decay_method = weight_decay_method + self.rank = rank + self.num_power_iter = num_power_iter + + # Create orthogonalization function following OrthogonalizedOptimizer pattern + def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: + logging.debug( + f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient" + ) + return muon_utils.newton_schulz( + grad, + steps=num_ns_steps, + coefficient_type=coefficient_type, + ) + + self.scaled_orthogonalize_fn = scaled_orthogonalize_fn + + defaults = dict( + lr=lr, + momentum_beta=momentum_beta, + weight_decay=weight_decay, + ) + + super().__init__(params, defaults) + + @overload + def step(self, closure: None = ...) -> None: ... + + @overload + def step(self, closure: Callable[[], float]) -> float: ... + + @torch.no_grad() # type: ignore[misc] + @override + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """Performs a single optimization step. + + Args: + closure: A closure that reevaluates the model and returns the loss. + """ + if closure is None: + loss = None + else: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + if p.ndim != 2: + raise ValueError(f"Spectron only supports 2D parameters, got shape {p.shape}") + + grad = p.grad + state = self.state[p] + + # Initialize low-rank factors and momentum buffers + if "factor_A" not in state: + self._initialize_state(p, state) + + # Get state variables + factor_A = state["factor_A"] + factor_B = state["factor_B"] + momentum_A = state["momentum_A"] + momentum_B = state["momentum_B"] + u_A = state["u_A"] + u_B = state["u_B"] + + # Compute gradients for A and B from parameter gradient + # Using chain rule: ∂L/∂A = ∂L/∂W @ B, ∂L/∂B = ∂L/∂W^T @ A + grad_A = grad @ factor_B # shape: (m, r) + grad_B = grad.mT @ factor_A # shape: (n, r) + + # Apply weight decay + self._apply_weight_decay_inplace( + factor_A, grad_A, group["lr"], group["weight_decay"] + ) + self._apply_weight_decay_inplace( + factor_B, grad_B, group["lr"], group["weight_decay"] + ) + + # Update momentum buffers (EMA of gradients) + momentum_A.lerp_(grad_A, 1 - group["momentum_beta"]) + momentum_B.lerp_(grad_B, 1 - group["momentum_beta"]) + + # Orthogonalize momentum using Newton-Schulz + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + orth_momentum_A = self.scaled_orthogonalize_fn(momentum_A) + orth_momentum_B = self.scaled_orthogonalize_fn(momentum_B) + + # Estimate spectral radius using power iteration (Algorithm 3) + sigma_A, u_A = self._power_iteration(factor_A, u_A, self.num_power_iter) + sigma_B, u_B = self._power_iteration(factor_B, u_B, self.num_power_iter) + + # Update power iteration vectors + state["u_A"] = u_A + state["u_B"] = u_B + + # Compute scaled learning rate + scaled_lr = group["lr"] / (sigma_A + sigma_B + 1.0) + + # Update low-rank factors + factor_A.add_(orth_momentum_A, alpha=-scaled_lr) + factor_B.add_(orth_momentum_B, alpha=-scaled_lr) + + # Reconstruct full weight matrix: W = A @ B^T + p.copy_(factor_A @ factor_B.mT) + + return loss + + def _initialize_state(self, p: torch.Tensor, state: dict[str, torch.Tensor]) -> None: + """Initialize low-rank factors and state for a parameter. + + Args: + p: The parameter tensor (shape: m × n) + state: The state dictionary for this parameter + """ + m, n = p.shape + r = min(self.rank, m, n) # Ensure rank doesn't exceed dimensions + + # Initialize A and B using SVD of the parameter + # This provides a good initialization close to the original weights + with torch.no_grad(): + U, S, Vh = torch.linalg.svd(p.float(), full_matrices=False) + # Keep only top r singular values/vectors + sqrt_S = torch.sqrt(S[:r]) + factor_A = (U[:, :r] * sqrt_S).to(p.dtype) + factor_B = (Vh[:r, :].mT * sqrt_S).to(p.dtype) + + state["factor_A"] = factor_A.clone() + state["factor_B"] = factor_B.clone() + state["momentum_A"] = torch.zeros_like(factor_A) + state["momentum_B"] = torch.zeros_like(factor_B) + + # Initialize power iteration vectors (normalized random vectors) + u_A = torch.randn(m, dtype=p.dtype, device=p.device) + u_A = u_A / u_A.norm() + u_B = torch.randn(n, dtype=p.dtype, device=p.device) + u_B = u_B / u_B.norm() + + state["u_A"] = u_A + state["u_B"] = u_B + + def _power_iteration( + self, matrix: torch.Tensor, u: torch.Tensor, num_iters: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """Estimate the largest singular value using power iteration. + + Args: + matrix: The matrix to estimate largest singular value for (shape: p × q) + u: The current approximation of the dominant left singular vector + num_iters: Number of power iteration steps + + Returns: + Tuple of (largest singular value, updated_u) + """ + return power_iteration(matrix, u, k=num_iters) \ No newline at end of file diff --git a/tests/test_spectron.py b/tests/test_spectron.py new file mode 100644 index 00000000..5f823e3a --- /dev/null +++ b/tests/test_spectron.py @@ -0,0 +1,343 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl import flags +import torch +import torch.nn as nn +from absl.testing import absltest, parameterized + +from emerging_optimizers.orthogonalized_optimizers import spectron +# Define command line flags +flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") + +FLAGS = flags.FLAGS + +class PowerIterationTest(parameterized.TestCase): + @parameterized.parameters( + {"shape": (10, 8), "k": 1}, + {"shape": (32, 16), "k": 5}, + {"shape": (64, 32), "k": 10}, + {"shape": (100, 50), "k": 20}, + ) + def test_power_iteration_converges_to_largest_singular_value(self, shape, k) -> None: + """Test that power iteration approximates the largest singular value.""" + # Create a random matrix with known singular values + W = torch.randn(shape, dtype=torch.float32, device=FLAGS.device) + + # Get ground truth largest singular value using SVD + _, S, _ = torch.linalg.svd(W, full_matrices=False) + true_sigma_max = S[0].item() + + # Initialize random left singular vector + u = torch.randn(shape[0], dtype=torch.float32, device=FLAGS.device) + u = u / u.norm() + + # Run power iteration + sigma_approx, u_out = spectron.power_iteration(W, u, k=k) + + # Check that approximation is close to true value + # More iterations should give better approximation + rel_error = abs(sigma_approx.item() - true_sigma_max) / true_sigma_max + + # With more iterations, error should be smaller + if k >= 10: + self.assertLess(rel_error, 0.01, f"Relative error {rel_error} too large with {k} iterations") + else: + self.assertLess(rel_error, 0.1, f"Relative error {rel_error} too large with {k} iterations") + + def test_power_iteration_output_normalized(self) -> None: + """Test that power iteration returns normalized left singular vector.""" + W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) + u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) + + _, u_out = spectron.power_iteration(W, u, k=5) + + # Check that output is normalized + torch.testing.assert_close( + u_out.norm(), + torch.tensor(1.0, device=FLAGS.device), + atol=1e-6, + rtol=1e-6, + ) + + def test_power_iteration_handles_unnormalized_input(self) -> None: + """Test that power iteration works even with unnormalized input.""" + W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) + u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) * 100 # Unnormalized + + # Should not raise error and should normalize internally + sigma, u_out = spectron.power_iteration(W, u, k=5) + + self.assertIsInstance(sigma.item(), float) + torch.testing.assert_close( + u_out.norm(), + torch.tensor(1.0, device=FLAGS.device), + atol=1e-6, + rtol=1e-6, + ) + + def test_power_iteration_deterministic(self) -> None: + """Test that power iteration is deterministic given same inputs.""" + W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) + u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) + + sigma1, u1 = spectron.power_iteration(W, u.clone(), k=5) + sigma2, u2 = spectron.power_iteration(W, u.clone(), k=5) + + torch.testing.assert_close(sigma1, sigma2, atol=0, rtol=0) + torch.testing.assert_close(u1, u2, atol=0, rtol=0) + + +class SpectronTest(parameterized.TestCase): + @parameterized.product( + shape=[(10, 8), (32, 16), (64, 32)], + rank=[4, 8, 16], + weight_decay_method=["decoupled", "independent", "l2"], + fp32_matmul_prec=["highest", "medium"], + ) + def test_smoke(self, shape, rank, weight_decay_method, fp32_matmul_prec) -> None: + """Smoke test Spectron optimizer with various configurations.""" + test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) + test_param.grad = torch.randn_like(test_param) + + spectron_opt = spectron.Spectron( + [test_param], + lr=0.01, + rank=rank, + weight_decay=0.01, + weight_decay_method=weight_decay_method, + fp32_matmul_prec=fp32_matmul_prec, + ) + spectron_opt.step() + + # Check that parameter was updated + self.assertIsNotNone(test_param.data) + self.assertEqual(test_param.shape, shape) + + @parameterized.parameters( + {"shape": (32, 16), "rank": 8}, + {"shape": (64, 32), "rank": 16}, + {"shape": (100, 50), "rank": 20}, + ) + def test_low_rank_reconstruction_quality(self, shape, rank) -> None: + """Test that low-rank factorization preserves parameter reasonably after initialization.""" + # Create parameter with known structure + test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) + original_param = test_param.data.clone() + + spectron_opt = spectron.Spectron( + [test_param], + lr=0.0, # No update, just check initialization + rank=rank, + momentum_beta=0.0, + weight_decay=0.0, + ) + + # Initialize state + test_param.grad = torch.randn_like(test_param) + spectron_opt.step() + + # Get state + state = spectron_opt.state[test_param] + factor_A = state["factor_A"] + factor_B = state["factor_B"] + + # Reconstruct should give back the parameter (since lr=0) + reconstructed = factor_A @ factor_B.mT + + # Check reconstruction quality (won't be perfect due to low-rank) + rel_error = (reconstructed - original_param).norm() / original_param.norm() + + # Error should decrease with higher rank + self.assertLess(rel_error.item(), 0.5, f"Reconstruction error {rel_error.item()} too large") + + def test_momentum_accumulation(self) -> None: + """Test that momentum is properly accumulated over multiple steps.""" + shape = (32, 16) + test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) + + momentum_beta = 0.9 + spectron_opt = spectron.Spectron( + [test_param], + lr=0.01, + rank=8, + momentum_beta=momentum_beta, + weight_decay=0.0, + ) + + # First step + test_param.grad = torch.ones_like(test_param) + spectron_opt.step() + + state = spectron_opt.state[test_param] + momentum_A_step1 = state["momentum_A"].clone() + momentum_B_step1 = state["momentum_B"].clone() + + # Second step with same gradient + test_param.grad = torch.ones_like(test_param) + spectron_opt.step() + + momentum_A_step2 = state["momentum_A"] + momentum_B_step2 = state["momentum_B"] + + # Momentum should have changed (accumulated) + self.assertFalse(torch.allclose(momentum_A_step1, momentum_A_step2)) + self.assertFalse(torch.allclose(momentum_B_step1, momentum_B_step2)) + + def test_spectral_scaling_reduces_lr_for_large_sigma(self) -> None: + """Test that learning rate is scaled down when spectral radius is large.""" + shape = (32, 16) + + # Create parameter with large norm (will have large spectral radius) + test_param_large = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device) * 10) + test_param_small = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device) * 0.1) + + test_param_large.grad = torch.ones_like(test_param_large) * 0.01 + test_param_small.grad = torch.ones_like(test_param_small) * 0.01 + + lr = 0.1 + + opt_large = spectron.Spectron([test_param_large], lr=lr, rank=8, momentum_beta=0.0) + opt_small = spectron.Spectron([test_param_small], lr=lr, rank=8, momentum_beta=0.0) + + param_large_before = test_param_large.data.clone() + param_small_before = test_param_small.data.clone() + + opt_large.step() + opt_small.step() + + # Get effective learning rates from spectral scaling + state_large = opt_large.state[test_param_large] + state_small = opt_small.state[test_param_small] + + # Compute sigma values after step + sigma_A_large, _ = spectron.power_iteration(state_large["factor_A"], state_large["u_A"], k=1) + sigma_B_large, _ = spectron.power_iteration(state_large["factor_B"], state_large["u_B"], k=1) + + sigma_A_small, _ = spectron.power_iteration(state_small["factor_A"], state_small["u_A"], k=1) + sigma_B_small, _ = spectron.power_iteration(state_small["factor_B"], state_small["u_B"], k=1) + + scaled_lr_large = lr / (sigma_A_large + sigma_B_large + 1.0) + scaled_lr_small = lr / (sigma_A_small + sigma_B_small + 1.0) + + # Larger spectral radius should result in smaller effective learning rate + self.assertLess(scaled_lr_large.item(), scaled_lr_small.item()) + + def test_rank_capped_by_dimensions(self) -> None: + """Test that rank is automatically capped by matrix dimensions.""" + shape = (10, 8) # Small matrix + test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) + test_param.grad = torch.randn_like(test_param) + + # Request rank larger than min dimension + spectron_opt = spectron.Spectron( + [test_param], + lr=0.01, + rank=100, # Larger than both dimensions + ) + spectron_opt.step() + + state = spectron_opt.state[test_param] + factor_A = state["factor_A"] + factor_B = state["factor_B"] + + # Rank should be capped at min(m, n) = 8 + self.assertEqual(factor_A.shape[1], 8) + self.assertEqual(factor_B.shape[1], 8) + + def test_raises_error_for_1d_params(self) -> None: + """Test that Spectron raises error for 1D parameters.""" + test_param = nn.Parameter(torch.randn(10, dtype=torch.float32, device=FLAGS.device)) + test_param.grad = torch.randn_like(test_param) + + spectron_opt = spectron.Spectron([test_param], lr=0.01, rank=4) + + with self.assertRaises(ValueError): + spectron_opt.step() + + @parameterized.parameters( + {"num_ns_steps": 1}, + {"num_ns_steps": 3}, + {"num_ns_steps": 5}, + {"num_ns_steps": 10}, + ) + def test_different_ns_steps(self, num_ns_steps) -> None: + """Test that different numbers of Newton-Schulz steps work.""" + shape = (32, 16) + test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) + test_param.grad = torch.randn_like(test_param) + + spectron_opt = spectron.Spectron( + [test_param], + lr=0.01, + rank=8, + num_ns_steps=num_ns_steps, + ) + + # Should not raise error + spectron_opt.step() + + @parameterized.parameters( + {"num_power_iter": 1}, + {"num_power_iter": 5}, + {"num_power_iter": 10}, + ) + def test_different_power_iter_steps(self, num_power_iter) -> None: + """Test that different numbers of power iteration steps work.""" + shape = (32, 16) + test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) + test_param.grad = torch.randn_like(test_param) + + spectron_opt = spectron.Spectron( + [test_param], + lr=0.01, + rank=8, + num_power_iter=num_power_iter, + ) + + # Should not raise error + spectron_opt.step() + + def test_state_persistence_across_steps(self) -> None: + """Test that optimizer state (A, B, momentum, u) persists correctly across steps.""" + shape = (32, 16) + test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) + + spectron_opt = spectron.Spectron([test_param], lr=0.01, rank=8) + + # First step + test_param.grad = torch.randn_like(test_param) + spectron_opt.step() + + state = spectron_opt.state[test_param] + factor_A_step1 = state["factor_A"].clone() + u_A_step1 = state["u_A"].clone() + + # Second step + test_param.grad = torch.randn_like(test_param) + spectron_opt.step() + + # State should still exist and be updated + self.assertIn("factor_A", state) + self.assertIn("u_A", state) + + # Values should have changed + self.assertFalse(torch.allclose(state["factor_A"], factor_A_step1)) + # u vector should be updated (but might be similar due to slow changes) + self.assertEqual(state["u_A"].shape, u_A_step1.shape) + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file From 8be51a85d498d2df8b71d2821590b6e244fc1fa7 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 16 Feb 2026 19:47:42 -0800 Subject: [PATCH 02/15] moved power iteration to eig Signed-off-by: mikail --- .../orthogonalized_optimizers/spectron.py | 52 +------------------ emerging_optimizers/utils/eig.py | 50 ++++++++++++++++++ tests/test_spectron.py | 19 +++---- 3 files changed, 62 insertions(+), 59 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/spectron.py b/emerging_optimizers/orthogonalized_optimizers/spectron.py index 4df1d97f..8ecc17a8 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spectron.py +++ b/emerging_optimizers/orthogonalized_optimizers/spectron.py @@ -25,58 +25,10 @@ from emerging_optimizers.orthogonalized_optimizers import muon_utils from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT from emerging_optimizers.utils import FP32MatmulPrecT +from emerging_optimizers.utils.eig import power_iteration -__all__ = ["Spectron", "power_iteration"] - - -def power_iteration( - W: torch.Tensor, - u: torch.Tensor, - k: int = 1, - eps: float = 1e-8, -) -> tuple[torch.Tensor, torch.Tensor]: - """Approximate largest singular value and left singular vector using power iteration. - - Implements Algorithm 3 from the Spectron paper. This method iteratively refines - estimates of the dominant singular value and corresponding left singular vector - of a matrix W. - - Args: - W: Matrix of shape (p, q) to analyze - u: Initial left singular vector of shape (p,), should be normalized - k: Number of power iteration steps. Default: 1 - eps: Small constant for numerical stability. Default: 1e-8 - - Returns: - Tuple of (sigma, u) where: - - sigma: Approximation of the largest singular value (scalar tensor) - - u: Updated left singular vector of shape (p,) - """ - # Ensure initial normalization - u = u / u.norm(p=2).clamp_min(eps) - - # Power iteration loop - for _ in range(k): - # v ← W^T u (right vector) - v = W.mT @ u - - # v ← v / ||v||_2 (normalize right vector) - v = v / v.norm(p=2).clamp_min(eps) - - # u ← W v (left vector) - u = W @ v - - # u ← u / ||u||_2 (normalize left vector) - u = u / u.norm(p=2).clamp_min(eps) - - # σ ← u^T W v (Rayleigh quotient approximation) - v = W.mT @ u - v = v / v.norm(p=2).clamp_min(eps) - sigma = u @ (W @ v) - - # Return σ and u - return sigma, u +__all__ = ["Spectron"] @registry.register_optimizer("spectron") diff --git a/emerging_optimizers/utils/eig.py b/emerging_optimizers/utils/eig.py index b139ae39..b2d1b789 100644 --- a/emerging_optimizers/utils/eig.py +++ b/emerging_optimizers/utils/eig.py @@ -22,9 +22,59 @@ "met_approx_eigvals_criteria", "conjugate", "orthogonal_iteration", + "power_iteration", ] +def power_iteration( + W: torch.Tensor, + u: torch.Tensor, + k: int = 1, + eps: float = 1e-8, +) -> tuple[torch.Tensor, torch.Tensor]: + """Approximate largest singular value and left singular vector using power iteration. + + Implements Algorithm 3 from the Spectron paper (https://arxiv.org/abs/2602.12429). This method iteratively refines + estimates of the dominant singular value and corresponding left singular vector + of a matrix W. + + Args: + W: Matrix of shape (p, q) to analyze + u: Initial left singular vector of shape (p,), should be normalized + k: Number of power iteration steps. Default: 1 + eps: Small constant for numerical stability. Default: 1e-8 + + Returns: + Tuple of (sigma, u) where: + - sigma: Approximation of the largest singular value (scalar tensor) + - u: Updated left singular vector of shape (p,) + """ + # Ensure initial normalization + u = u / u.norm(p=2).clamp_min(eps) + + # Power iteration loop + for _ in range(k): + # v ← W^T u (right vector) + v = W.mT @ u + + # v ← v / ||v||_2 (normalize right vector) + v = v / v.norm(p=2).clamp_min(eps) + + # u ← W v (left vector) + u = W @ v + + # u ← u / ||u||_2 (normalize left vector) + u = u / u.norm(p=2).clamp_min(eps) + + # σ ← u^T W v (Rayleigh quotient approximation) + v = W.mT @ u + v = v / v.norm(p=2).clamp_min(eps) + sigma = u @ (W @ v) + + # Return σ and u + return sigma, u + + def eigh_with_fallback( x: Tensor, force_double: bool = False, diff --git a/tests/test_spectron.py b/tests/test_spectron.py index 5f823e3a..d862f3eb 100644 --- a/tests/test_spectron.py +++ b/tests/test_spectron.py @@ -19,6 +19,7 @@ from absl.testing import absltest, parameterized from emerging_optimizers.orthogonalized_optimizers import spectron +from emerging_optimizers.utils.eig import power_iteration # Define command line flags flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") @@ -45,7 +46,7 @@ def test_power_iteration_converges_to_largest_singular_value(self, shape, k) -> u = u / u.norm() # Run power iteration - sigma_approx, u_out = spectron.power_iteration(W, u, k=k) + sigma_approx, u_out = power_iteration(W, u, k=k) # Check that approximation is close to true value # More iterations should give better approximation @@ -62,7 +63,7 @@ def test_power_iteration_output_normalized(self) -> None: W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) - _, u_out = spectron.power_iteration(W, u, k=5) + _, u_out = power_iteration(W, u, k=5) # Check that output is normalized torch.testing.assert_close( @@ -78,7 +79,7 @@ def test_power_iteration_handles_unnormalized_input(self) -> None: u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) * 100 # Unnormalized # Should not raise error and should normalize internally - sigma, u_out = spectron.power_iteration(W, u, k=5) + sigma, u_out = power_iteration(W, u, k=5) self.assertIsInstance(sigma.item(), float) torch.testing.assert_close( @@ -93,8 +94,8 @@ def test_power_iteration_deterministic(self) -> None: W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) - sigma1, u1 = spectron.power_iteration(W, u.clone(), k=5) - sigma2, u2 = spectron.power_iteration(W, u.clone(), k=5) + sigma1, u1 = power_iteration(W, u.clone(), k=5) + sigma2, u2 = power_iteration(W, u.clone(), k=5) torch.testing.assert_close(sigma1, sigma2, atol=0, rtol=0) torch.testing.assert_close(u1, u2, atol=0, rtol=0) @@ -223,11 +224,11 @@ def test_spectral_scaling_reduces_lr_for_large_sigma(self) -> None: state_small = opt_small.state[test_param_small] # Compute sigma values after step - sigma_A_large, _ = spectron.power_iteration(state_large["factor_A"], state_large["u_A"], k=1) - sigma_B_large, _ = spectron.power_iteration(state_large["factor_B"], state_large["u_B"], k=1) + sigma_A_large, _ = power_iteration(state_large["factor_A"], state_large["u_A"], k=1) + sigma_B_large, _ = power_iteration(state_large["factor_B"], state_large["u_B"], k=1) - sigma_A_small, _ = spectron.power_iteration(state_small["factor_A"], state_small["u_A"], k=1) - sigma_B_small, _ = spectron.power_iteration(state_small["factor_B"], state_small["u_B"], k=1) + sigma_A_small, _ = power_iteration(state_small["factor_A"], state_small["u_A"], k=1) + sigma_B_small, _ = power_iteration(state_small["factor_B"], state_small["u_B"], k=1) scaled_lr_large = lr / (sigma_A_large + sigma_B_large + 1.0) scaled_lr_small = lr / (sigma_A_small + sigma_B_small + 1.0) From ce307989270a0e2f7f9e5fa7d90ae72c337e6530 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 16 Feb 2026 19:49:38 -0800 Subject: [PATCH 03/15] generalized power iteration to return both left and right singular vectors Signed-off-by: mikail --- .../orthogonalized_optimizers/spectron.py | 4 +++- emerging_optimizers/utils/eig.py | 13 +++++++------ tests/test_spectron.py | 18 +++++++++--------- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/spectron.py b/emerging_optimizers/orthogonalized_optimizers/spectron.py index 8ecc17a8..453e3cae 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spectron.py +++ b/emerging_optimizers/orthogonalized_optimizers/spectron.py @@ -252,4 +252,6 @@ def _power_iteration( Returns: Tuple of (largest singular value, updated_u) """ - return power_iteration(matrix, u, k=num_iters) \ No newline at end of file + # power_iteration returns (sigma, u, v) but Spectron only needs sigma and u (left singular vector) + sigma, u, _v = power_iteration(matrix, u, k=num_iters) + return sigma, u \ No newline at end of file diff --git a/emerging_optimizers/utils/eig.py b/emerging_optimizers/utils/eig.py index b2d1b789..7a121ea7 100644 --- a/emerging_optimizers/utils/eig.py +++ b/emerging_optimizers/utils/eig.py @@ -31,11 +31,11 @@ def power_iteration( u: torch.Tensor, k: int = 1, eps: float = 1e-8, -) -> tuple[torch.Tensor, torch.Tensor]: - """Approximate largest singular value and left singular vector using power iteration. +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Approximate largest singular value and left/right singular vectors using power iteration. Implements Algorithm 3 from the Spectron paper (https://arxiv.org/abs/2602.12429). This method iteratively refines - estimates of the dominant singular value and corresponding left singular vector + estimates of the dominant singular value and corresponding left and right singular vectors of a matrix W. Args: @@ -45,9 +45,10 @@ def power_iteration( eps: Small constant for numerical stability. Default: 1e-8 Returns: - Tuple of (sigma, u) where: + Tuple of (sigma, u, v) where: - sigma: Approximation of the largest singular value (scalar tensor) - u: Updated left singular vector of shape (p,) + - v: Updated right singular vector of shape (q,) """ # Ensure initial normalization u = u / u.norm(p=2).clamp_min(eps) @@ -71,8 +72,8 @@ def power_iteration( v = v / v.norm(p=2).clamp_min(eps) sigma = u @ (W @ v) - # Return σ and u - return sigma, u + # Return σ, u, and v + return sigma, u, v def eigh_with_fallback( diff --git a/tests/test_spectron.py b/tests/test_spectron.py index d862f3eb..ce06ed95 100644 --- a/tests/test_spectron.py +++ b/tests/test_spectron.py @@ -46,7 +46,7 @@ def test_power_iteration_converges_to_largest_singular_value(self, shape, k) -> u = u / u.norm() # Run power iteration - sigma_approx, u_out = power_iteration(W, u, k=k) + sigma_approx, u_out, _v_out = power_iteration(W, u, k=k) # Check that approximation is close to true value # More iterations should give better approximation @@ -63,7 +63,7 @@ def test_power_iteration_output_normalized(self) -> None: W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) - _, u_out = power_iteration(W, u, k=5) + _, u_out, _v_out = power_iteration(W, u, k=5) # Check that output is normalized torch.testing.assert_close( @@ -79,7 +79,7 @@ def test_power_iteration_handles_unnormalized_input(self) -> None: u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) * 100 # Unnormalized # Should not raise error and should normalize internally - sigma, u_out = power_iteration(W, u, k=5) + sigma, u_out, _v_out = power_iteration(W, u, k=5) self.assertIsInstance(sigma.item(), float) torch.testing.assert_close( @@ -94,8 +94,8 @@ def test_power_iteration_deterministic(self) -> None: W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) - sigma1, u1 = power_iteration(W, u.clone(), k=5) - sigma2, u2 = power_iteration(W, u.clone(), k=5) + sigma1, u1, _v1 = power_iteration(W, u.clone(), k=5) + sigma2, u2, _v2 = power_iteration(W, u.clone(), k=5) torch.testing.assert_close(sigma1, sigma2, atol=0, rtol=0) torch.testing.assert_close(u1, u2, atol=0, rtol=0) @@ -224,11 +224,11 @@ def test_spectral_scaling_reduces_lr_for_large_sigma(self) -> None: state_small = opt_small.state[test_param_small] # Compute sigma values after step - sigma_A_large, _ = power_iteration(state_large["factor_A"], state_large["u_A"], k=1) - sigma_B_large, _ = power_iteration(state_large["factor_B"], state_large["u_B"], k=1) + sigma_A_large, _, _ = power_iteration(state_large["factor_A"], state_large["u_A"], k=1) + sigma_B_large, _, _ = power_iteration(state_large["factor_B"], state_large["u_B"], k=1) - sigma_A_small, _ = power_iteration(state_small["factor_A"], state_small["u_A"], k=1) - sigma_B_small, _ = power_iteration(state_small["factor_B"], state_small["u_B"], k=1) + sigma_A_small, _, _ = power_iteration(state_small["factor_A"], state_small["u_A"], k=1) + sigma_B_small, _, _ = power_iteration(state_small["factor_B"], state_small["u_B"], k=1) scaled_lr_large = lr / (sigma_A_large + sigma_B_large + 1.0) scaled_lr_small = lr / (sigma_A_small + sigma_B_small + 1.0) From d83bc106e0466e1060e756561c24bfc8da764ae0 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 16 Feb 2026 19:50:03 -0800 Subject: [PATCH 04/15] added a unit test for both singular vectors Signed-off-by: mikail --- tests/test_spectron.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_spectron.py b/tests/test_spectron.py index ce06ed95..99bed1cc 100644 --- a/tests/test_spectron.py +++ b/tests/test_spectron.py @@ -99,6 +99,37 @@ def test_power_iteration_deterministic(self) -> None: torch.testing.assert_close(sigma1, sigma2, atol=0, rtol=0) torch.testing.assert_close(u1, u2, atol=0, rtol=0) + + def test_power_iteration_returns_both_singular_vectors(self) -> None: + """Test that power iteration returns both left and right singular vectors normalized.""" + W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) + u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) + + sigma, u_out, v_out = power_iteration(W, u, k=10) + + # Both singular vectors should be normalized + torch.testing.assert_close( + u_out.norm(), + torch.tensor(1.0, device=FLAGS.device), + atol=1e-6, + rtol=1e-6, + ) + torch.testing.assert_close( + v_out.norm(), + torch.tensor(1.0, device=FLAGS.device), + atol=1e-6, + rtol=1e-6, + ) + + # Check that W @ v ≈ sigma * u (definition of singular vectors) + Wv = W @ v_out + sigma_u = sigma * u_out + torch.testing.assert_close(Wv, sigma_u, atol=1e-4, rtol=1e-4) + + # Check that W^T @ u ≈ sigma * v + WTu = W.mT @ u_out + sigma_v = sigma * v_out + torch.testing.assert_close(WTu, sigma_v, atol=1e-4, rtol=1e-4) class SpectronTest(parameterized.TestCase): From 9d7c6fabbc922352f4a2b3aff5807396685a4c19 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 16 Feb 2026 19:52:15 -0800 Subject: [PATCH 05/15] linting errors fix Signed-off-by: mikail --- .../orthogonalized_optimizers/spectron.py | 14 +- tests/test_spectron.py | 142 +++++++++--------- 2 files changed, 75 insertions(+), 81 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/spectron.py b/emerging_optimizers/orthogonalized_optimizers/spectron.py index 453e3cae..f105873f 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spectron.py +++ b/emerging_optimizers/orthogonalized_optimizers/spectron.py @@ -103,9 +103,7 @@ def __init__( # Create orthogonalization function following OrthogonalizedOptimizer pattern def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: - logging.debug( - f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient" - ) + logging.debug(f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient") return muon_utils.newton_schulz( grad, steps=num_ns_steps, @@ -170,12 +168,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: grad_B = grad.mT @ factor_A # shape: (n, r) # Apply weight decay - self._apply_weight_decay_inplace( - factor_A, grad_A, group["lr"], group["weight_decay"] - ) - self._apply_weight_decay_inplace( - factor_B, grad_B, group["lr"], group["weight_decay"] - ) + self._apply_weight_decay_inplace(factor_A, grad_A, group["lr"], group["weight_decay"]) + self._apply_weight_decay_inplace(factor_B, grad_B, group["lr"], group["weight_decay"]) # Update momentum buffers (EMA of gradients) momentum_A.lerp_(grad_A, 1 - group["momentum_beta"]) @@ -254,4 +248,4 @@ def _power_iteration( """ # power_iteration returns (sigma, u, v) but Spectron only needs sigma and u (left singular vector) sigma, u, _v = power_iteration(matrix, u, k=num_iters) - return sigma, u \ No newline at end of file + return sigma, u diff --git a/tests/test_spectron.py b/tests/test_spectron.py index 99bed1cc..f5c3d3c3 100644 --- a/tests/test_spectron.py +++ b/tests/test_spectron.py @@ -13,18 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from absl import flags import torch import torch.nn as nn +from absl import flags from absl.testing import absltest, parameterized from emerging_optimizers.orthogonalized_optimizers import spectron from emerging_optimizers.utils.eig import power_iteration + + # Define command line flags flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") FLAGS = flags.FLAGS + class PowerIterationTest(parameterized.TestCase): @parameterized.parameters( {"shape": (10, 8), "k": 1}, @@ -36,35 +39,35 @@ def test_power_iteration_converges_to_largest_singular_value(self, shape, k) -> """Test that power iteration approximates the largest singular value.""" # Create a random matrix with known singular values W = torch.randn(shape, dtype=torch.float32, device=FLAGS.device) - + # Get ground truth largest singular value using SVD _, S, _ = torch.linalg.svd(W, full_matrices=False) true_sigma_max = S[0].item() - + # Initialize random left singular vector u = torch.randn(shape[0], dtype=torch.float32, device=FLAGS.device) u = u / u.norm() - + # Run power iteration sigma_approx, u_out, _v_out = power_iteration(W, u, k=k) - + # Check that approximation is close to true value # More iterations should give better approximation rel_error = abs(sigma_approx.item() - true_sigma_max) / true_sigma_max - + # With more iterations, error should be smaller if k >= 10: self.assertLess(rel_error, 0.01, f"Relative error {rel_error} too large with {k} iterations") else: self.assertLess(rel_error, 0.1, f"Relative error {rel_error} too large with {k} iterations") - + def test_power_iteration_output_normalized(self) -> None: """Test that power iteration returns normalized left singular vector.""" W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) - + _, u_out, _v_out = power_iteration(W, u, k=5) - + # Check that output is normalized torch.testing.assert_close( u_out.norm(), @@ -72,15 +75,15 @@ def test_power_iteration_output_normalized(self) -> None: atol=1e-6, rtol=1e-6, ) - + def test_power_iteration_handles_unnormalized_input(self) -> None: """Test that power iteration works even with unnormalized input.""" W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) * 100 # Unnormalized - + # Should not raise error and should normalize internally sigma, u_out, _v_out = power_iteration(W, u, k=5) - + self.assertIsInstance(sigma.item(), float) torch.testing.assert_close( u_out.norm(), @@ -88,25 +91,25 @@ def test_power_iteration_handles_unnormalized_input(self) -> None: atol=1e-6, rtol=1e-6, ) - + def test_power_iteration_deterministic(self) -> None: """Test that power iteration is deterministic given same inputs.""" W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) - + sigma1, u1, _v1 = power_iteration(W, u.clone(), k=5) sigma2, u2, _v2 = power_iteration(W, u.clone(), k=5) - + torch.testing.assert_close(sigma1, sigma2, atol=0, rtol=0) torch.testing.assert_close(u1, u2, atol=0, rtol=0) - + def test_power_iteration_returns_both_singular_vectors(self) -> None: """Test that power iteration returns both left and right singular vectors normalized.""" W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) - + sigma, u_out, v_out = power_iteration(W, u, k=10) - + # Both singular vectors should be normalized torch.testing.assert_close( u_out.norm(), @@ -120,12 +123,12 @@ def test_power_iteration_returns_both_singular_vectors(self) -> None: atol=1e-6, rtol=1e-6, ) - + # Check that W @ v ≈ sigma * u (definition of singular vectors) Wv = W @ v_out sigma_u = sigma * u_out torch.testing.assert_close(Wv, sigma_u, atol=1e-4, rtol=1e-4) - + # Check that W^T @ u ≈ sigma * v WTu = W.mT @ u_out sigma_v = sigma * v_out @@ -143,7 +146,7 @@ def test_smoke(self, shape, rank, weight_decay_method, fp32_matmul_prec) -> None """Smoke test Spectron optimizer with various configurations.""" test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) test_param.grad = torch.randn_like(test_param) - + spectron_opt = spectron.Spectron( [test_param], lr=0.01, @@ -153,11 +156,11 @@ def test_smoke(self, shape, rank, weight_decay_method, fp32_matmul_prec) -> None fp32_matmul_prec=fp32_matmul_prec, ) spectron_opt.step() - + # Check that parameter was updated self.assertIsNotNone(test_param.data) self.assertEqual(test_param.shape, shape) - + @parameterized.parameters( {"shape": (32, 16), "rank": 8}, {"shape": (64, 32), "rank": 16}, @@ -168,7 +171,7 @@ def test_low_rank_reconstruction_quality(self, shape, rank) -> None: # Create parameter with known structure test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) original_param = test_param.data.clone() - + spectron_opt = spectron.Spectron( [test_param], lr=0.0, # No update, just check initialization @@ -176,30 +179,30 @@ def test_low_rank_reconstruction_quality(self, shape, rank) -> None: momentum_beta=0.0, weight_decay=0.0, ) - + # Initialize state test_param.grad = torch.randn_like(test_param) spectron_opt.step() - + # Get state state = spectron_opt.state[test_param] factor_A = state["factor_A"] factor_B = state["factor_B"] - + # Reconstruct should give back the parameter (since lr=0) reconstructed = factor_A @ factor_B.mT - + # Check reconstruction quality (won't be perfect due to low-rank) rel_error = (reconstructed - original_param).norm() / original_param.norm() - + # Error should decrease with higher rank self.assertLess(rel_error.item(), 0.5, f"Reconstruction error {rel_error.item()} too large") - + def test_momentum_accumulation(self) -> None: """Test that momentum is properly accumulated over multiple steps.""" shape = (32, 16) test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) - + momentum_beta = 0.9 spectron_opt = spectron.Spectron( [test_param], @@ -208,71 +211,68 @@ def test_momentum_accumulation(self) -> None: momentum_beta=momentum_beta, weight_decay=0.0, ) - + # First step test_param.grad = torch.ones_like(test_param) spectron_opt.step() - + state = spectron_opt.state[test_param] momentum_A_step1 = state["momentum_A"].clone() momentum_B_step1 = state["momentum_B"].clone() - + # Second step with same gradient test_param.grad = torch.ones_like(test_param) spectron_opt.step() - + momentum_A_step2 = state["momentum_A"] momentum_B_step2 = state["momentum_B"] - + # Momentum should have changed (accumulated) self.assertFalse(torch.allclose(momentum_A_step1, momentum_A_step2)) self.assertFalse(torch.allclose(momentum_B_step1, momentum_B_step2)) - + def test_spectral_scaling_reduces_lr_for_large_sigma(self) -> None: """Test that learning rate is scaled down when spectral radius is large.""" shape = (32, 16) - + # Create parameter with large norm (will have large spectral radius) test_param_large = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device) * 10) test_param_small = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device) * 0.1) - + test_param_large.grad = torch.ones_like(test_param_large) * 0.01 test_param_small.grad = torch.ones_like(test_param_small) * 0.01 - + lr = 0.1 - + opt_large = spectron.Spectron([test_param_large], lr=lr, rank=8, momentum_beta=0.0) opt_small = spectron.Spectron([test_param_small], lr=lr, rank=8, momentum_beta=0.0) - - param_large_before = test_param_large.data.clone() - param_small_before = test_param_small.data.clone() - + opt_large.step() opt_small.step() - + # Get effective learning rates from spectral scaling state_large = opt_large.state[test_param_large] state_small = opt_small.state[test_param_small] - + # Compute sigma values after step sigma_A_large, _, _ = power_iteration(state_large["factor_A"], state_large["u_A"], k=1) sigma_B_large, _, _ = power_iteration(state_large["factor_B"], state_large["u_B"], k=1) - + sigma_A_small, _, _ = power_iteration(state_small["factor_A"], state_small["u_A"], k=1) sigma_B_small, _, _ = power_iteration(state_small["factor_B"], state_small["u_B"], k=1) - + scaled_lr_large = lr / (sigma_A_large + sigma_B_large + 1.0) scaled_lr_small = lr / (sigma_A_small + sigma_B_small + 1.0) - + # Larger spectral radius should result in smaller effective learning rate self.assertLess(scaled_lr_large.item(), scaled_lr_small.item()) - + def test_rank_capped_by_dimensions(self) -> None: """Test that rank is automatically capped by matrix dimensions.""" shape = (10, 8) # Small matrix test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) test_param.grad = torch.randn_like(test_param) - + # Request rank larger than min dimension spectron_opt = spectron.Spectron( [test_param], @@ -280,25 +280,25 @@ def test_rank_capped_by_dimensions(self) -> None: rank=100, # Larger than both dimensions ) spectron_opt.step() - + state = spectron_opt.state[test_param] factor_A = state["factor_A"] factor_B = state["factor_B"] - + # Rank should be capped at min(m, n) = 8 self.assertEqual(factor_A.shape[1], 8) self.assertEqual(factor_B.shape[1], 8) - + def test_raises_error_for_1d_params(self) -> None: """Test that Spectron raises error for 1D parameters.""" test_param = nn.Parameter(torch.randn(10, dtype=torch.float32, device=FLAGS.device)) test_param.grad = torch.randn_like(test_param) - + spectron_opt = spectron.Spectron([test_param], lr=0.01, rank=4) - + with self.assertRaises(ValueError): spectron_opt.step() - + @parameterized.parameters( {"num_ns_steps": 1}, {"num_ns_steps": 3}, @@ -310,17 +310,17 @@ def test_different_ns_steps(self, num_ns_steps) -> None: shape = (32, 16) test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) test_param.grad = torch.randn_like(test_param) - + spectron_opt = spectron.Spectron( [test_param], lr=0.01, rank=8, num_ns_steps=num_ns_steps, ) - + # Should not raise error spectron_opt.step() - + @parameterized.parameters( {"num_power_iter": 1}, {"num_power_iter": 5}, @@ -331,40 +331,40 @@ def test_different_power_iter_steps(self, num_power_iter) -> None: shape = (32, 16) test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) test_param.grad = torch.randn_like(test_param) - + spectron_opt = spectron.Spectron( [test_param], lr=0.01, rank=8, num_power_iter=num_power_iter, ) - + # Should not raise error spectron_opt.step() - + def test_state_persistence_across_steps(self) -> None: """Test that optimizer state (A, B, momentum, u) persists correctly across steps.""" shape = (32, 16) test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) - + spectron_opt = spectron.Spectron([test_param], lr=0.01, rank=8) - + # First step test_param.grad = torch.randn_like(test_param) spectron_opt.step() - + state = spectron_opt.state[test_param] factor_A_step1 = state["factor_A"].clone() u_A_step1 = state["u_A"].clone() - + # Second step test_param.grad = torch.randn_like(test_param) spectron_opt.step() - + # State should still exist and be updated self.assertIn("factor_A", state) self.assertIn("u_A", state) - + # Values should have changed self.assertFalse(torch.allclose(state["factor_A"], factor_A_step1)) # u vector should be updated (but might be similar due to slow changes) @@ -372,4 +372,4 @@ def test_state_persistence_across_steps(self) -> None: if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From 0f01d191d2060f5340a5b6c7c41b857d71f81acf Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 16 Feb 2026 19:53:42 -0800 Subject: [PATCH 06/15] added ref to paper in algorithm Signed-off-by: mikail --- emerging_optimizers/orthogonalized_optimizers/spectron.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/spectron.py b/emerging_optimizers/orthogonalized_optimizers/spectron.py index f105873f..fd308f8f 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spectron.py +++ b/emerging_optimizers/orthogonalized_optimizers/spectron.py @@ -49,7 +49,7 @@ class Spectron(opt_mixin.WeightDecayMixin, optim.Optimizer): 6. Reconstruct full weight matrix W = A @ B^T References: - - Algorithm 1 (Spectron) and Algorithm 3 (PowerIter) from the Spectron paper. + - Algorithm 1 (Spectron) and Algorithm 3 (PowerIter) from the Spectron paper (https://arxiv.org/abs/2602.12429). Low-rank spectral optimization with orthogonalized momentum. Warning: From e8b59eae8233a1a3cb965d90c59aa3b72a4e8212 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 16 Feb 2026 19:55:55 -0800 Subject: [PATCH 07/15] updated documentation and ci script with spectron Signed-off-by: mikail --- docs/apidocs/orthogonalized-optimizers.md | 6 ++++++ emerging_optimizers/orthogonalized_optimizers/__init__.py | 1 + tests/ci/L0_Tests_CPU.sh | 1 + tests/ci/L0_Tests_GPU.sh | 1 + 4 files changed, 9 insertions(+) diff --git a/docs/apidocs/orthogonalized-optimizers.md b/docs/apidocs/orthogonalized-optimizers.md index 051bcfe7..afbf5294 100644 --- a/docs/apidocs/orthogonalized-optimizers.md +++ b/docs/apidocs/orthogonalized-optimizers.md @@ -39,6 +39,12 @@ emerging_optimizers.orthogonalized_optimizers .. autoclass:: MuonHyperball :members: +:hidden:`Spectron` +~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: Spectron + :members: + :hidden:`Newton-Schulz` ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/emerging_optimizers/orthogonalized_optimizers/__init__.py b/emerging_optimizers/orthogonalized_optimizers/__init__.py index 7e8ddc4d..0a42869d 100644 --- a/emerging_optimizers/orthogonalized_optimizers/__init__.py +++ b/emerging_optimizers/orthogonalized_optimizers/__init__.py @@ -19,3 +19,4 @@ from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import * from emerging_optimizers.orthogonalized_optimizers.scion import * from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import * +from emerging_optimizers.orthogonalized_optimizers.spectron import * \ No newline at end of file diff --git a/tests/ci/L0_Tests_CPU.sh b/tests/ci/L0_Tests_CPU.sh index ca6a1a9b..8f6964bf 100644 --- a/tests/ci/L0_Tests_CPU.sh +++ b/tests/ci/L0_Tests_CPU.sh @@ -17,6 +17,7 @@ error=0 torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1 torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_spectron.py --device=cpu -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_procrustes_step.py --device=cpu -v -2 || error=1 exit "${error}" diff --git a/tests/ci/L0_Tests_GPU.sh b/tests/ci/L0_Tests_GPU.sh index 25866cea..02d29a92 100644 --- a/tests/ci/L0_Tests_GPU.sh +++ b/tests/ci/L0_Tests_GPU.sh @@ -19,6 +19,7 @@ error=0 coverage run -p --source=emerging_optimizers tests/test_muon_utils.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_adaptive_muon.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_spectron.py --device=cuda -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_soap_utils.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_soap.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py -v -2 || error=1 From d1c11799b66d818c63533a6a78301c2937f4086b Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 16 Feb 2026 20:05:46 -0800 Subject: [PATCH 08/15] handled float32 miss Signed-off-by: mikail --- .../orthogonalized_optimizers/spectron.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/spectron.py b/emerging_optimizers/orthogonalized_optimizers/spectron.py index fd308f8f..8efe45c2 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spectron.py +++ b/emerging_optimizers/orthogonalized_optimizers/spectron.py @@ -164,6 +164,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # Compute gradients for A and B from parameter gradient # Using chain rule: ∂L/∂A = ∂L/∂W @ B, ∂L/∂B = ∂L/∂W^T @ A + # grad is always fp32, so no cast needed grad_A = grad @ factor_B # shape: (m, r) grad_B = grad.mT @ factor_A # shape: (n, r) @@ -196,7 +197,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: factor_B.add_(orth_momentum_B, alpha=-scaled_lr) # Reconstruct full weight matrix: W = A @ B^T - p.copy_(factor_A @ factor_B.mT) + p.copy_((factor_A @ factor_B.mT).to(p.dtype)) return loss @@ -212,22 +213,24 @@ def _initialize_state(self, p: torch.Tensor, state: dict[str, torch.Tensor]) -> # Initialize A and B using SVD of the parameter # This provides a good initialization close to the original weights + # Low-rank factors are stored in fp32 for numerical stability with torch.no_grad(): U, S, Vh = torch.linalg.svd(p.float(), full_matrices=False) # Keep only top r singular values/vectors sqrt_S = torch.sqrt(S[:r]) - factor_A = (U[:, :r] * sqrt_S).to(p.dtype) - factor_B = (Vh[:r, :].mT * sqrt_S).to(p.dtype) + factor_A = U[:, :r] * sqrt_S + factor_B = Vh[:r, :].mT * sqrt_S state["factor_A"] = factor_A.clone() state["factor_B"] = factor_B.clone() - state["momentum_A"] = torch.zeros_like(factor_A) - state["momentum_B"] = torch.zeros_like(factor_B) + # Momentum buffers are always stored in fp32 for numerical stability + state["momentum_A"] = torch.zeros_like(factor_A, dtype=torch.float32) + state["momentum_B"] = torch.zeros_like(factor_B, dtype=torch.float32) - # Initialize power iteration vectors (normalized random vectors) - u_A = torch.randn(m, dtype=p.dtype, device=p.device) + # Initialize power iteration vectors (normalized random vectors in fp32) + u_A = torch.randn(m, dtype=torch.float32, device=p.device) u_A = u_A / u_A.norm() - u_B = torch.randn(n, dtype=p.dtype, device=p.device) + u_B = torch.randn(n, dtype=torch.float32, device=p.device) u_B = u_B / u_B.norm() state["u_A"] = u_A From a8096d0fa86afbc31804bc4415d800dc6015cae2 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 16 Feb 2026 20:06:57 -0800 Subject: [PATCH 09/15] linting Signed-off-by: mikail --- emerging_optimizers/orthogonalized_optimizers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/__init__.py b/emerging_optimizers/orthogonalized_optimizers/__init__.py index 0a42869d..3ae4ac5d 100644 --- a/emerging_optimizers/orthogonalized_optimizers/__init__.py +++ b/emerging_optimizers/orthogonalized_optimizers/__init__.py @@ -19,4 +19,4 @@ from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import * from emerging_optimizers.orthogonalized_optimizers.scion import * from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import * -from emerging_optimizers.orthogonalized_optimizers.spectron import * \ No newline at end of file +from emerging_optimizers.orthogonalized_optimizers.spectron import * From cd985ad65e6f491400c311f85574f6e69e79f7b6 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 16 Feb 2026 20:42:38 -0800 Subject: [PATCH 10/15] cleanup some bad practices, fix state initialization handling Signed-off-by: mikail --- .../orthogonalized_optimizers/spectron.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/spectron.py b/emerging_optimizers/orthogonalized_optimizers/spectron.py index 8efe45c2..bc907b7f 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spectron.py +++ b/emerging_optimizers/orthogonalized_optimizers/spectron.py @@ -150,10 +150,22 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: grad = p.grad state = self.state[p] - # Initialize low-rank factors and momentum buffers - if "factor_A" not in state: + # State initialization + if len(state) == 0: + state["step"] = 0 + + if state["step"] == 0: + assert all( + key not in state + for key in ["factor_A", "factor_B", "momentum_A", "momentum_B", "u_A", "u_B"] + ), ( + "factor_A, factor_B, momentum_A, momentum_B, u_A, u_B should not be initialized at step 0. " + "Some mismatch has been created likely in checkpointing" + ) self._initialize_state(p, state) + state["step"] += 1 + # Get state variables factor_A = state["factor_A"] factor_B = state["factor_B"] @@ -197,7 +209,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: factor_B.add_(orth_momentum_B, alpha=-scaled_lr) # Reconstruct full weight matrix: W = A @ B^T - p.copy_((factor_A @ factor_B.mT).to(p.dtype)) + p.copy_(factor_A @ factor_B.mT) return loss @@ -221,8 +233,8 @@ def _initialize_state(self, p: torch.Tensor, state: dict[str, torch.Tensor]) -> factor_A = U[:, :r] * sqrt_S factor_B = Vh[:r, :].mT * sqrt_S - state["factor_A"] = factor_A.clone() - state["factor_B"] = factor_B.clone() + state["factor_A"] = factor_A + state["factor_B"] = factor_B # Momentum buffers are always stored in fp32 for numerical stability state["momentum_A"] = torch.zeros_like(factor_A, dtype=torch.float32) state["momentum_B"] = torch.zeros_like(factor_B, dtype=torch.float32) @@ -237,12 +249,12 @@ def _initialize_state(self, p: torch.Tensor, state: dict[str, torch.Tensor]) -> state["u_B"] = u_B def _power_iteration( - self, matrix: torch.Tensor, u: torch.Tensor, num_iters: int + self, X: torch.Tensor, u: torch.Tensor, num_iters: int ) -> tuple[torch.Tensor, torch.Tensor]: """Estimate the largest singular value using power iteration. Args: - matrix: The matrix to estimate largest singular value for (shape: p × q) + X: The matrix to estimate largest singular value for u: The current approximation of the dominant left singular vector num_iters: Number of power iteration steps @@ -250,5 +262,5 @@ def _power_iteration( Tuple of (largest singular value, updated_u) """ # power_iteration returns (sigma, u, v) but Spectron only needs sigma and u (left singular vector) - sigma, u, _v = power_iteration(matrix, u, k=num_iters) + sigma, u, _v = power_iteration(X, u, k=num_iters) return sigma, u From 86f11260f3532723e38310945f183ca8c4bc4582 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 16 Feb 2026 20:46:56 -0800 Subject: [PATCH 11/15] removed stale trailing comment Signed-off-by: mikail --- emerging_optimizers/orthogonalized_optimizers/spectron.py | 1 - 1 file changed, 1 deletion(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/spectron.py b/emerging_optimizers/orthogonalized_optimizers/spectron.py index bc907b7f..0b154681 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spectron.py +++ b/emerging_optimizers/orthogonalized_optimizers/spectron.py @@ -176,7 +176,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # Compute gradients for A and B from parameter gradient # Using chain rule: ∂L/∂A = ∂L/∂W @ B, ∂L/∂B = ∂L/∂W^T @ A - # grad is always fp32, so no cast needed grad_A = grad @ factor_B # shape: (m, r) grad_B = grad.mT @ factor_A # shape: (n, r) From 326f3f6d0a6ab1a87e9d68b2fd4d50059ef06eda Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 16 Feb 2026 20:56:50 -0800 Subject: [PATCH 12/15] linting Signed-off-by: mikail --- emerging_optimizers/orthogonalized_optimizers/spectron.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/spectron.py b/emerging_optimizers/orthogonalized_optimizers/spectron.py index 0b154681..487a9ccc 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spectron.py +++ b/emerging_optimizers/orthogonalized_optimizers/spectron.py @@ -156,8 +156,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: if state["step"] == 0: assert all( - key not in state - for key in ["factor_A", "factor_B", "momentum_A", "momentum_B", "u_A", "u_B"] + key not in state for key in ["factor_A", "factor_B", "momentum_A", "momentum_B", "u_A", "u_B"] ), ( "factor_A, factor_B, momentum_A, momentum_B, u_A, u_B should not be initialized at step 0. " "Some mismatch has been created likely in checkpointing" @@ -247,9 +246,7 @@ def _initialize_state(self, p: torch.Tensor, state: dict[str, torch.Tensor]) -> state["u_A"] = u_A state["u_B"] = u_B - def _power_iteration( - self, X: torch.Tensor, u: torch.Tensor, num_iters: int - ) -> tuple[torch.Tensor, torch.Tensor]: + def _power_iteration(self, X: torch.Tensor, u: torch.Tensor, num_iters: int) -> tuple[torch.Tensor, torch.Tensor]: """Estimate the largest singular value using power iteration. Args: From b7eb9be6396bd6f886ce74bb6ff13f7056448652 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 17 Feb 2026 11:28:33 -0800 Subject: [PATCH 13/15] added fp32 matmul decorators for safety Signed-off-by: mikail --- .../orthogonalized_optimizers/spectron.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/spectron.py b/emerging_optimizers/orthogonalized_optimizers/spectron.py index 487a9ccc..6a169601 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spectron.py +++ b/emerging_optimizers/orthogonalized_optimizers/spectron.py @@ -175,8 +175,9 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # Compute gradients for A and B from parameter gradient # Using chain rule: ∂L/∂A = ∂L/∂W @ B, ∂L/∂B = ∂L/∂W^T @ A - grad_A = grad @ factor_B # shape: (m, r) - grad_B = grad.mT @ factor_A # shape: (n, r) + with utils.fp32_matmul_precision("highest"): + grad_A = grad @ factor_B # shape: (m, r) + grad_B = grad.mT @ factor_A # shape: (n, r) # Apply weight decay self._apply_weight_decay_inplace(factor_A, grad_A, group["lr"], group["weight_decay"]) @@ -187,13 +188,13 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: momentum_B.lerp_(grad_B, 1 - group["momentum_beta"]) # Orthogonalize momentum using Newton-Schulz - with utils.fp32_matmul_precision(self.fp32_matmul_prec): + with utils.fp32_matmul_precision("highest"): orth_momentum_A = self.scaled_orthogonalize_fn(momentum_A) orth_momentum_B = self.scaled_orthogonalize_fn(momentum_B) - # Estimate spectral radius using power iteration (Algorithm 3) - sigma_A, u_A = self._power_iteration(factor_A, u_A, self.num_power_iter) - sigma_B, u_B = self._power_iteration(factor_B, u_B, self.num_power_iter) + # Estimate spectral radius using power iteration + sigma_A, u_A = self._power_iteration(factor_A, u_A, self.num_power_iter) + sigma_B, u_B = self._power_iteration(factor_B, u_B, self.num_power_iter) # Update power iteration vectors state["u_A"] = u_A @@ -207,7 +208,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: factor_B.add_(orth_momentum_B, alpha=-scaled_lr) # Reconstruct full weight matrix: W = A @ B^T - p.copy_(factor_A @ factor_B.mT) + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + p.copy_(factor_A @ factor_B.mT) return loss From 74506932b5a2dde6a3afda9496be9290a2c260f5 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 17 Feb 2026 12:07:35 -0800 Subject: [PATCH 14/15] separate newton-schulz and other fp32 gemm decorators Signed-off-by: mikail --- emerging_optimizers/orthogonalized_optimizers/spectron.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/spectron.py b/emerging_optimizers/orthogonalized_optimizers/spectron.py index 6a169601..a369a830 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spectron.py +++ b/emerging_optimizers/orthogonalized_optimizers/spectron.py @@ -188,10 +188,11 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: momentum_B.lerp_(grad_B, 1 - group["momentum_beta"]) # Orthogonalize momentum using Newton-Schulz - with utils.fp32_matmul_precision("highest"): + with utils.fp32_matmul_precision(self.fp32_matmul_prec): orth_momentum_A = self.scaled_orthogonalize_fn(momentum_A) orth_momentum_B = self.scaled_orthogonalize_fn(momentum_B) + with utils.fp32_matmul_precision("highest"): # Estimate spectral radius using power iteration sigma_A, u_A = self._power_iteration(factor_A, u_A, self.num_power_iter) sigma_B, u_B = self._power_iteration(factor_B, u_B, self.num_power_iter) From d2686bbd70beb9143dea27e90d4ecc0c531e1bb4 Mon Sep 17 00:00:00 2001 From: mikail Date: Fri, 20 Feb 2026 10:51:09 -0800 Subject: [PATCH 15/15] linting Signed-off-by: mikail --- emerging_optimizers/orthogonalized_optimizers/spectron.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/spectron.py b/emerging_optimizers/orthogonalized_optimizers/spectron.py index a369a830..c4915060 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spectron.py +++ b/emerging_optimizers/orthogonalized_optimizers/spectron.py @@ -192,7 +192,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: orth_momentum_A = self.scaled_orthogonalize_fn(momentum_A) orth_momentum_B = self.scaled_orthogonalize_fn(momentum_B) - with utils.fp32_matmul_precision("highest"): + with utils.fp32_matmul_precision("highest"): # Estimate spectral radius using power iteration sigma_A, u_A = self._power_iteration(factor_A, u_A, self.num_power_iter) sigma_B, u_B = self._power_iteration(factor_B, u_B, self.num_power_iter)