Skip to content

Add NorMuon optimizer#1612

Open
tinker495 wants to merge 1 commit intogoogle-deepmind:mainfrom
tinker495:feature/add-normuon
Open

Add NorMuon optimizer#1612
tinker495 wants to merge 1 commit intogoogle-deepmind:mainfrom
tinker495:feature/add-normuon

Conversation

@tinker495
Copy link
Contributor

Summary

  • Add normuon and scale_by_normuon to optax.contrib, implementing NorMuon (Neuron-wise Normalized Muon) (Li et al., 2025)
  • NorMuon augments Muon with a per-neuron second-moment statistic computed from orthogonalized updates, normalizing update magnitudes across neurons for improved scalability
  • Non-NorMuon parameters (non-2D) fall back to AdamW via combine.partition

Changes

File Description
optax/contrib/_normuon.py scale_by_normuon (core transform) and normuon (full optimizer with AdamW fallback)
optax/contrib/_normuon_test.py Unit tests (2D default, custom dimension numbers, partition with AdamW)
optax/contrib/__init__.py Export normuon, scale_by_normuon, NorMuonState
optax/contrib/_common_test.py Add normuon to shared contrib optimizer tests
docs/api/contrib.rst Add NorMuon to API docs

Design

  • Reuses Muon's orthogonalize_via_newton_schulz and MuonDimensionNumbers for consistency
  • Per-neuron second moment v is stored as (batch, output) shape, matching the paper's Algorithm 1
  • Supports custom weight_dimension_numbers for non-2D tensors (same interface as Muon)

Test plan

  • pytest optax/contrib/_normuon_test.py — dedicated unit tests
  • pytest optax/contrib/_common_test.py — shared contrib optimizer tests
  • ruff lint check passes

This commit introduces the NorMuon optimizer, which enhances the Muon algorithm by incorporating neuron-wise normalization of update magnitudes. The implementation includes the core functionality in `_normuon.py`, along with a dedicated test suite in `_normuon_test.py` to validate its behavior. Additionally, references to the new optimizer have been added in the documentation and the main `__init__.py` file.

Key changes:
- New optimizer: NorMuon with associated scaling and state management.
- Tests for NorMuon functionality and behavior.
- Documentation updates to include NorMuon in the API reference.

PiperOrigin-RevId: [insert-rev-id-here]
@rdyro
Copy link
Collaborator

rdyro commented Mar 5, 2026

Hey, we really appreciate the contribution! This is a lot of code which might be difficult to maintain with our small team of maintainers and so this might take us a while to merge.

I'll try to get to this as soon as possible, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants