feat: add Mel-Band-RoFormer architecture for vocal source separation#654
feat: add Mel-Band-RoFormer architecture for vocal source separation#654xocialize wants to merge 13 commits intoBlaizzy:mainfrom
Conversation
…1 dB
The parity test was permanently skipped, so several load-bearing bugs in
the MLX port went undetected. Wiring up the reference inference surfaced
them; this commit fixes them and verifies numerical parity.
Config / conversion
- Add zfturbo_vocals_v1 preset (dim=192, depth=8, hop=512,
mask_estimator_depth=1) for ZFTurbo's v1.0.0 MIT-licensed vocals
release asset. None of the existing three presets matched it.
Port bugs fixed in model.py
- RoPE convention: MLX used the halved split (x[:half] / x[half:]);
rotary_embedding_torch (what ZFTurbo imports) uses interleaved pairs
(x[2i] / x[2i+1]) with a doubled-frequency layout. Same math family,
different rotation plane — attention output was uncorrelated with
the reference until this was fixed. Single biggest jump in SDR.
- RMSNorm eps: mlx.nn.RMSNorm uses additive eps=1e-5; ZFTurbo uses
F.normalize(x, dim=-1) which is effectively max(||x||, 1e-12). The
two agree for typical magnitudes but diverge by up to 70% at the
small-magnitude high-frequency STFT bins that the band-split's
per-band RMSNorm ingests. Replaced with a local RMSNorm matching
F.normalize semantics.
- to_gates: declared bias=False, but ZFTurbo's Linear(dim, heads) has
the default bias=True and ships a to_gates.bias tensor.
- MaskEstimator: hardcoded to 3 linears (mask_estimator_depth=2 in
effect). Now honors config.mask_estimator_depth, so depth-1
checkpoints (like this one) load with the correct number of linears.
- MelFilterbank: custom Slaney triangular filter did not quite match
librosa's. Swapped for librosa.filters.mel + ZFTurbo's explicit
fb[0,0]=1 / fb[-1,-1]=1 force-assigns so every freq bin is covered.
- STFT: zero-padded the audio before framing, but torch.stft(center=True)
pad_mode defaults to reflect. Built reflect pad via slice+concat
(mx.pad has no reflect mode).
- STFT post-processing: .real() / .imag() as method calls; these are
properties on mlx.core.array.
- BandSplit: freq_indices stored as np.ndarray, but MLX fancy-indexing
rejects raw numpy — convert once at MelFilterbank init.
sanitize() remaps added
- Drop rotary_embed.freqs — MLX computes RoPE freqs on the fly.
- Mask-estimator per-band Sequential -> list index: PyTorch stores
linears at positions {0, 2} (depth=1) or {0, 2, 4} (depth=2); MLX
list positions are consecutive {0, 1, 2}.
- to_out.0.weight -> to_out.weight — PyTorch wraps in Sequential(Linear,
Dropout), MLX uses a bare Linear.
- .gamma -> .weight — PyTorch RMSNorm names its scale "gamma";
mlx.nn.RMSNorm names it "weight".
Test wiring in tests/sts/test_mel_roformer_parity.py
- Replace the pytest.skip stub in test_sdr_parity with the real
single-chunk SDR comparison: dynamically imports a user-supplied
torch_infer.py (see PARITY_TESTING.md) via importlib, runs both
models on the same chunk of audio, asserts SDR > target (default
40 dB).
- Add fixtures for MEL_ROFORMER_TORCH_INFER, MEL_ROFORMER_TORCH_CONFIG,
MEL_ROFORMER_CHUNK_SAMPLES.
Verified
- All 25 existing unit tests in tests/sts/test_mel_roformer.py still pass.
- test_qkv_split_preserves_weights: pass.
- test_sdr_parity on the v1.0.0 ZFTurbo MIT vocals ckpt +
ZFTurbo config + a deterministic synthetic 30s stereo signal:
SDR = 58.11 dB (target > 40 dB).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Register MelRoFormer, MelRoFormerConfig, and MelRoFormerResult in the parent package so users can import via mlx_audio.sts.models without reaching into the subpackage. Rename our result dataclass from SeparationResult to MelRoFormerResult to avoid a name collision with sam_audio.SeparationResult, which is already upstream and has a different streaming-oriented shape. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pre-commit auto-fixes from running hooks before PR submission. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
NOTE: I have now joined so I should be able to proceed when appropriate with mlx-community I would also appreciate it if you could add me to the mlx-community on hugging face so I can upload the mlx weights. contributions. |
|
@xocialize You can just join the community on HF yourself, so no permission needed. I know this is still a draft, but make sure your final PR doesn't include any scratchpad / planning .md files. Thanks! |
Per @lucasnewman review comment on PR Blaizzy#654 — the final PR should not contain scratchpad / planning .md files. PARITY_TESTING.md was an internal M5 Pro setup guide used during parity verification; the content has no use for downstream pip-install users of mlx-audio. The parity test itself (`tests/sts/test_mel_roformer_parity.py`) remains and is self-documenting via its docstrings and fixture parameters. Users who want to run parity can follow the @pytest.mark.requires_torch flag and the env-var fixtures directly. README.md is unchanged — it's user-facing architecture + license + usage documentation, not a planning doc. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Thanks @lucasnewman — joined mlx-community on HF, and just pushed
|
@KimberleyJensen relicensed the Kim Vocal 2 weights from GPL-3.0 to MIT on HuggingFace (2026-04-18) in response to a redistribution-license inquiry on her model page. This commit reflects the new posture throughout the PR: - README.md license posture table: Kim Vocal 2 → MIT (was GPL-3.0) - README.md "running inference vs. redistributing" section: rewrite examples to use anvuew dereverb as the remaining GPL-3.0 exemplar; include Kim Vocal 2 in the MIT-safe recommendation list - convert.py _LICENSE_HINTS: KimberleyJSN / kimvocal entries now emit an MIT note with context about the relicensing - test_mel_roformer.py test_detects_kim_vocal_2: assert license_tag == "MIT" (was "GPL-3.0") This widens the set of MIT-safe checkpoint families that downstream users can bundle in products — Kim Vocal 2 is now the strongest-SDR preset with no redistribution constraints. The `anvuew` dereverb model remains GPL-3.0; `viperx` remains undeclared. The PR's license-aware conversion tooling continues to flag those. Tests: 25 passed, 1 skipped — no regressions. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds an output-precision selector to the Mel-Band-RoFormer converter so the
same source checkpoint can be packaged at multiple precisions for HuggingFace
redistribution.
- New --dtype {float32,float16,bfloat16} CLI flag, default float32 (no behavior
change for existing callers).
- Dtype is baked into the content-addressed output basename so simultaneous
fp32/fp16/bf16 conversions of the same source no longer collide.
- Companion config.json records the chosen dtype for downstream provenance.
Driven by the bf16-only scope decision for the planned mlx-community uploads
(mel-roformer-kim-vocal-2-mlx, mel-roformer-zfturbo-vocals-v1-mlx).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Companion to tests/sts/test_mel_roformer_parity.py — wraps lucidrains
bs_roformer.MelBandRoformer (the canonical PyTorch reference used by
ZFTurbo's MSS-Training) so the parity harness can compare MLX outputs
against the upstream implementation rather than a custom reimplementation.
Provides the two functions the parity test loads via the
MEL_ROFORMER_TORCH_INFER env var:
run(ckpt_path, config_yaml, audio_path, out_path, chunk_samples)
-> np.ndarray [2, T_out]
load_audio_chunk(audio_path, chunk_samples)
-> (np.ndarray [2, T], sr)
Filters YAML hyperparameters against the installed bs_roformer init
signature so future lucidrains releases that add/remove parameters fail
visibly (with a printed list of dropped keys) rather than silently.
Validated against:
- Kim Vocal 2 (MelBandRoformer.ckpt, 684 keys, 228M params): SDR 66.08 dB
parity vs MLX bf16 conversion → published as
mlx-community/mel-roformer-kim-vocal-2-mlx
- ZFTurbo vocals_v1 (model_vocals_mel_band_roformer_sdr_8.42.ckpt,
612 keys, 33.7M params): SDR 44.19 dB parity vs MLX fp16 conversion
→ published as mlx-community/mel-roformer-zfturbo-vocals-v1-mlx
Pin: bs_roformer==0.3.10. Newer releases (0.4+) reorder the layers
ModuleList and add nGPT-style normalization, breaking checkpoint
compatibility for both Kim and ZFTurbo v1.0.0 weights.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
| | ZFTurbo MSS-Training release assets | [ZFTurbo MSS-Training releases](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases) | MIT (inherited from repo) | | ||
| | Community variants (dereverb, denoise, crowd) | various | Varies — review each | | ||
|
|
||
| ## Running inference vs. redistributing weights |
There was a problem hiding this comment.
Can you remove this section? We don't support training at the moment so this isn't relevant to running pretrained models.
| @@ -0,0 +1,199 @@ | |||
| """PyTorch reference inference for Mel-Band-RoFormer parity testing. | |||
There was a problem hiding this comment.
We don't need/want this file, it looks like it was just for your testing.
| asyncio_default_fixture_loop_scope = function No newline at end of file | ||
| asyncio_default_fixture_loop_scope = function | ||
| markers = | ||
| requires_torch: test requires PyTorch installed (skip in MLX-only CI) |
There was a problem hiding this comment.
We'd prefer not to include torch, so we don't need this (see above comment).
| [lucidrains/BS-RoFormer](https://github.com/lucidrains/BS-RoFormer) and | ||
| [ZFTurbo/Music-Source-Separation-Training](https://github.com/ZFTurbo/Music-Source-Separation-Training). | ||
|
|
||
| ## What this contribution is |
There was a problem hiding this comment.
Can you remove this section as well? It looks like your agent is injecting some kind of overview of the porting effort -- we just want clean docs.
| leaves them zero, which breaks the "every freq covered" invariant), | ||
| then binarizes. | ||
| """ | ||
| import librosa |
There was a problem hiding this comment.
Sorry, we can't have librosa as a dependency. We have the ability to generate identical filterbanks in mlx_audio/dsp.py -- please use that instead.
| # ---------- STFT / iSTFT ---------- | ||
|
|
||
|
|
||
| def stft(audio: mx.array, n_fft: int, hop_length: int, window: mx.array) -> tuple: |
There was a problem hiding this comment.
We already have an implementation of this in mlx_audio/dsp.py -- please use that.
|
|
||
|
|
||
| def istft( | ||
| real: mx.array, |
There was a problem hiding this comment.
Same -- let's use dsp.py
|
@xocialize Looks pretty good overall, please see the comments of where we can avoid code duplication and make sure we're not pulling in excessive dependencies (torch / librosa / etc). |
…erbank Per @lucasnewman's review (Blaizzy#654 comments r3142451654, r3142452858, r3142453107): drop librosa as a dependency, stop duplicating STFT/iSTFT implementations, and use the shared utilities in mlx_audio/dsp.py instead. - MelFilterbank now uses mlx_audio.dsp.mel_filters with mel_scale="slaney" (matching the previous librosa.filters.mel default), then keeps the ZFTurbo-style DC/Nyquist forcing and binarization on top. - The local stft() function becomes a thin batched/multichannel adapter over mlx_audio.dsp.stft (which is 1D-input). Loops over [B, channels] and reshapes into the [B, channels, freq_bins, frames] layout BandSplit expects. - The local istft() function similarly delegates to mlx_audio.dsp.istft. Two adapter notes worth flagging: - ``length=None`` is passed (not the target length) because that branch correctly strips the ``win_length // 2`` center-pad; passing the target length skips center-pad removal and shifts the reconstruction. We trim to ``length`` in the adapter. - ``normalized=True`` for COLA-style ``window**2`` normalization, matching torch.istft and the previous local implementation. - dsp.istft's center-strip can leave the output a few samples short of ``length`` when the chunk size isn't an integer multiple of hop_length (e.g. ZFTurbo's 352800 / 512). Tail-pad with zeros to reach ``length`` — those samples were in the center-pad region. Parity preserved exactly across both validated checkpoints (re-measured on the same harness as before refactor): - Kim Vocal 2 (bf16): 66.08 dB SDR vs PyTorch reference (unchanged) - ZFTurbo vocals_v1 (fp16): 44.19 dB SDR vs PyTorch reference (unchanged) - ZFTurbo vocals_v1 (bf16): 21.96 dB SDR vs PyTorch reference (unchanged — bf16-sensitive at this small architecture, hence we ship fp16 for this preset on mlx-community) Also fix TestContentAddressing.test_content_addressed_name to pass the ``dtype`` arg added in 8380ab8.
Now that the parity test infrastructure (tests/sts/torch_infer.py, tests/sts/test_mel_roformer_parity.py) is removed, the requires_torch marker registration in pytest.ini is no longer used. Drops the section per @lucasnewman's review (Blaizzy#654 comment r3142448241).
Per @lucasnewman's review (Blaizzy#654 comments r3142445601, r3142449229): - Remove the "What this contribution is" section. It was framed around this module's place in the upstream-PR narrative rather than describing the model. User-facing docs should describe the model, not the porting effort. - Remove the "Running inference vs. redistributing weights" section. Training isn't supported in this module today, and the per-license redistribution table set up expectations the module doesn't deliver on. The shorter "License" section keeps the architecture-MIT note and defers to checkpoint source repos for weight licensing. - Drop the parity-test footnote in the convert.py section now that the parity test infrastructure has been removed. - Update the convert.py example to show the --dtype flag.
|
Thanks @lucasnewman — all addressed in 39c9325, 4555494, 5e5cea2:
Net delta: −558 LoC (mostly the duplicate STFT/iSTFT and the prose sections). The two open guidance items in the description (chunked-inference helper location, parity-fixture location) are no longer relevant — chunking still deferred for a follow-up, parity fixtures gone per your ask. Ready for re-review whenever you have a minute. |
Summary
Adds the Mel-Band-RoFormer architecture to
mlx_audio.stsfor music source separation — particularly vocal isolation. Parity-verified against the PyTorch reference at two scales: 66.08 dB SDR for the 228M-parameter Kim Vocal 2 (bf16) and 44.19 dB SDR for the 33.7M-parameter ZFTurbovocals_v1(fp16), both above the 40 dB target the parity test enforces.Architecture paper: Lu et al., "Mel-Band RoFormer for Music Source Separation" (2023) — https://arxiv.org/abs/2310.01809
Reference implementations:
lucidrains/BS-RoFormerhttps://github.com/lucidrains/BS-RoFormer (MIT)ZFTurbo/Music-Source-Separation-Traininghttps://github.com/ZFTurbo/Music-Source-Separation-Training (MIT)The PR ships the architecture, conversion tooling, and parity-test infrastructure. Two MIT-compatible checkpoints have been published as separate artifacts on
mlx-communitysofrom_pretrained()works out of the box once this lands; users wanting other checkpoints (viperx, anvuew, custom-trained) can runconvert.pylocally on their own copies.Architecture
BandSplitTransformer(×N layers × 2 axes)RoFormerAttentionMaskEstimatorstft/istftPipeline:
[B, 2, samples]→ STFT → CaC interleave → BandSplit → N× DualAxisTransformer → MaskEstimate → complex multiply → iSTFT →[B, 2, samples](stereo 44.1 kHz in, separated stereo out).Config presets
MelRoFormerConfig.kim_vocal_2()KimberleyJSN/melbandroformer(MIT — relicensed Apr 2026)MelRoFormerConfig.viperx_vocals()TRvlvr/model_repoviperx vocals (undeclared license)MelRoFormerConfig.zfturbo_bs_roformer()MelRoFormerConfig.zfturbo_vocals_v1()model_vocals_mel_band_roformer_sdr_8.42.ckpt(MIT inherited)MelRoFormerConfig.custom(...)No
default()constructor — callers must name their checkpoint family explicitly to avoid accidentally pulling weights with restrictive licenses via a tutorial copy-paste.Parity verification
Single-chunk SDR comparison against PyTorch reference, run at two checkpoint scales to verify the architecture holds across the most-used presets:
dim=384,depth=6,mask_depth=2vocals_v1dim=192,depth=8,mask_depth=1Implementation parity numbers are SDR between MLX output and PyTorch reference output on the same input chunk — not vs ground truth. Both are above the 40 dB target the parity test treats as "effectively bit-exact up to floating-point precision."
The parity test lives at
tests/sts/test_mel_roformer_parity.py, gated behind@pytest.mark.requires_torchso CI skips it. Local runs need PyTorch + the original.ckpt+ a reference WAV configured viaMEL_ROFORMER_*env vars; thetests/sts/torch_infer.pyfixture wrapsbs_roformer.MelBandRoformer(the lucidrains PyTorch reference that ZFTurbo's MSS-Training imports).A note on dtype: the smaller ZFTurbo architecture (
dim=192) is bf16-sensitive — bf16 quantization drops parity to 21.96 dB on the same harness, while fp16 (10-bit mantissa vs bf16's 7) recovers 44.19 dB at the same file size. The new--dtype {float32,float16,bfloat16}flag on the converter lets callers pick precision per checkpoint family. The default isfloat32to preserve existing behavior.A note on
bs_roformerversion: parity testing pinsbs_roformer==0.3.10. Newer lucidrains releases (0.4+) reorder thelayersModuleList nesting and add nGPT-style normalization params, breaking compatibility with the classic-architecture checkpoints (Kim, viperx, ZFTurbo v1.0.0) that the world is using today. The_build_torch_modelhelper intorch_infer.pyfilters YAML hyperparameters against the installed init signature, so future drift fails visibly with a printed list of dropped keys rather than silently.Bugs found and fixed during parity wiring
The parity test caught five load-bearing bugs that would have silently degraded separation quality:
(x[:half], x[half:])convention.rotary_embedding_torch(what ZFTurbo imports) uses interleaved pairs(x[2i], x[2i+1])with a doubled-frequency layout. Same math family, different rotation plane — attention output was uncorrelated with the reference until this was fixed. Biggest single SDR jump.eps:mlx.nn.RMSNormuses additiveeps=1e-5; ZFTurbo usesF.normalize(x, dim=-1)which is effectivelymax(||x||, 1e-12). The two agree for typical magnitudes but diverge by up to 70% at the small-magnitude high-frequency STFT bins the band-split's per-band RMSNorm ingests. Replaced with a localRMSNormmatchingF.normalizesemantics.to_gatesbias: declaredbias=False, but ZFTurbo'sLinear(dim, heads)has the defaultbias=Trueand ships ato_gates.biastensor.MaskEstimator: was hardcoded to 3 linears (mask_estimator_depth=2) — made configurable from the config so thezfturbo_vocals_v1depth=1variant loads correctly.gammain the PyTorch code,weightinmlx.nn. Renamed for compatibility.Published artifacts on
mlx-communityBoth MIT-compatible checkpoints are live and discoverable via the Mel-Band-RoFormer (MLX) Collection:
mlx-community/mel-roformer-kim-vocal-2-mlx(bf16, 435 MB)mlx-community/mel-roformer-zfturbo-vocals-v1-mlx(fp16, 64 MB)Each repo carries the canonical
model.safetensors+config.json+LICENSE+ a model card with full provenance: source repo, source commit/release, source SHA-256, mlx version, this PR's converter SHA, and parity numbers.Usage
For checkpoint conversion (one-time, for users bringing their own weights):
python -m mlx_audio.sts.models.mel_roformer.convert \ --input model_vocals_mel_band_roformer_sdr_8.42.ckpt \ --output ./weights/ \ --preset zfturbo_vocals_v1 \ --dtype bfloat16 # or float16 / float32The conversion is idempotent (content-addressed by SHA-256) and writes a companion
config.jsonsofrom_pretrained()can resolve the architecture without the user re-selecting a preset. Output filenames include the dtype (<basename>.<sha>.<dtype>.safetensors) so multi-precision conversions of the same source don't collide.License posture
Architecture: MIT (inherited from
lucidrains/BS-RoFormerandZFTurbo/Music-Source-Separation-Training).Checkpoints: independently licensed — see
mlx_audio/sts/models/mel_roformer/README.mdfor the per-checkpoint table. Both shippedmlx-communitycheckpoints are MIT, with caveats worth disclosing for clean redistribution:ac9b061). The relicense was independently confirmed with the author the week of 2026-04-20 prior to the redistribution. The published model card carries the verbatim trace.vocals_v1inherits MIT from the MSS-Training repository, but the repo'sLICENSEfile was committed on 2024-11-04 — about a year after the v1.0.0 release that produced this checkpoint asset. The published model card discloses the timeline transparently and notes the inheritance reasoning. (Repo is MIT today, asset is still distributed by the same MIT-licensed repo, author is an active OSS maintainer who has continued publishing under MIT.)convert.pyemits license-aware warnings when known-source checkpoint paths are passed (KimberleyJSN,ZFTurbo,TRvlvr/viperx,anvuew), printing the recognized license and any redistribution caveat before conversion proceeds.What this PR does NOT do
To keep scope reviewable, several adjacent things are deliberately deferred:
TRvlvr/model_repo) ships no LICENSE file; anvuew is GPL-3.0. Users wanting these can runconvert.pylocally on their own copies.chunk_size. The model's forward pass operates on a single chunk; overlap-add is left to the caller. Happy to follow up in a separate PR if there's interest — flagged as an open question in "What's next" below.MelRoFormerProcessorclass. The model takes rawmx.array [B, 2, samples]audio and returnsmx.array [B, 2, samples]separated vocals. Audio loading happens viamlx_audio.utils.load_audio. Adding a Sam-Audio-style processor wrapper is straightforward but felt like scope creep.Tests
All tests in
tests/sts/test_mel_roformer.pypass locally:TestMelRoFormerConfig— defaults, derived properties, all presets, custom escape hatch, keyword-only enforcementTestMelRoFormerModel— construction with each presetTestSanitize— QKV packed-weight splittingTestConvert— training-state stripping, state-dict extraction across 3 PyTorch formatsTestLicenseDetection— warn-on-substring matcherTestContentAddressing— SHA-256 hashing consistency, content-addressed naming with dtype suffixTestConfigSerialization— round-trip config → dict → config for companion.config.jsonTestPresetResolution— CLI preset lookupParity tests in
tests/sts/test_mel_roformer_parity.pyare@pytest.mark.requires_torch-gated and skipped by default. Local parity runs require PyTorch +bs_roformer==0.3.10+ the original.ckpt+ a reference WAV configured viaMEL_ROFORMER_*env vars. Thetests/sts/torch_infer.pyfixture wraps the lucidrainsMelBandRoformerreference; it filters YAML hyperparameters against the installed init signature so futurebs_roformerversion drift surfaces visibly rather than silently.What's next
model.py, separatepipeline.py, or external user code). Was deferred from the original PR; still deferred here.xocialize/mel-roformer-mlx-swift— independent SPM package, MLX-Swift native, ships aMelRoFormer.fromPretrained(_:)matching the Python convention. Not part of this PR; mentioned because the publishedmlx-communitymodel cards reference it under "Usage" → Swift.MelRoFormerConfig.custom(...)).Checklist
SDR > 40 dBtarget — verified at two scales (Kim 66.08 dB bf16, ZFTurbo 44.19 dB fp16)mlx-communitywith full provenance + a discoverable Collection.ckpt,.pt,.safetensorsinputs across multiple PyTorch state-dict shapespytest.mark.requires_torchgating for torch-dependent testsPARITY_TESTING.mdremoved per @lucasnewman's earlier review feedback (commit5d5a9c2)4448196model.py, separatepipeline.py, or external user code)?tests/sts/torch_infer.pyis the test fixture; happy to move it undertests/sts/fixtures/or similar if there's a preferred convention.Tagging @Blaizzy and @lucasnewman per earlier review engagement. Two new commits since the last touch (rebased onto current main):
8380ab8—--dtype {float32,float16,bfloat16}flag onconvert.pywith dtype-suffixed content-addressing.1f9a555—tests/sts/torch_infer.py(the parity-test fixture, previously external) now lives in-tree.Both checkpoints have been parity-validated end-to-end and published to
mlx-communitysince the last review pass — happy to walk through any specific section if helpful.