Skip to content

feat: add Mel-Band-RoFormer architecture for vocal source separation#654

Open
xocialize wants to merge 13 commits intoBlaizzy:mainfrom
xocialize:feat/mel-band-roformer
Open

feat: add Mel-Band-RoFormer architecture for vocal source separation#654
xocialize wants to merge 13 commits intoBlaizzy:mainfrom
xocialize:feat/mel-band-roformer

Conversation

@xocialize
Copy link
Copy Markdown

@xocialize xocialize commented Apr 16, 2026

Summary

Adds the Mel-Band-RoFormer architecture to mlx_audio.sts for music source separation — particularly vocal isolation. Parity-verified against the PyTorch reference at two scales: 66.08 dB SDR for the 228M-parameter Kim Vocal 2 (bf16) and 44.19 dB SDR for the 33.7M-parameter ZFTurbo vocals_v1 (fp16), both above the 40 dB target the parity test enforces.

Architecture paper: Lu et al., "Mel-Band RoFormer for Music Source Separation" (2023) — https://arxiv.org/abs/2310.01809

Reference implementations:

The PR ships the architecture, conversion tooling, and parity-test infrastructure. Two MIT-compatible checkpoints have been published as separate artifacts on mlx-community so from_pretrained() works out of the box once this lands; users wanting other checkpoints (viperx, anvuew, custom-trained) can run convert.py locally on their own copies.

Architecture

Component Purpose
BandSplit Mel-scale band splitting of the CaC-interleaved spectrogram into 60 bands
Transformer (×N layers × 2 axes) Dual-axis RoFormer — alternating time-axis and frequency-axis attention
RoFormerAttention Multi-head attention with interleaved-pair RoPE and per-head sigmoid gates
MaskEstimator Per-band MLP with GLU activation producing complex masks
stft / istft Center-padded STFT with Hann window and overlap-add reconstruction

Pipeline: [B, 2, samples] → STFT → CaC interleave → BandSplit → N× DualAxisTransformer → MaskEstimate → complex multiply → iSTFT → [B, 2, samples] (stereo 44.1 kHz in, separated stereo out).

Config presets

Preset Depth Dim hop_length mask_depth Target checkpoints
MelRoFormerConfig.kim_vocal_2() 6 384 441 2 KimberleyJSN/melbandroformer (MIT — relicensed Apr 2026)
MelRoFormerConfig.viperx_vocals() 12 384 441 2 TRvlvr/model_repo viperx vocals (undeclared license)
MelRoFormerConfig.zfturbo_bs_roformer() 12 384 441 2 ZFTurbo MSS-Training release-asset BS-RoFormer (MIT inherited)
MelRoFormerConfig.zfturbo_vocals_v1() 8 192 512 1 model_vocals_mel_band_roformer_sdr_8.42.ckpt (MIT inherited)
MelRoFormerConfig.custom(...) user user user user Escape hatch for other community variants

No default() constructor — callers must name their checkpoint family explicitly to avoid accidentally pulling weights with restrictive licenses via a tutorial copy-paste.

Parity verification

Single-chunk SDR comparison against PyTorch reference, run at two checkpoint scales to verify the architecture holds across the most-used presets:

Checkpoint Architecture MLX dtype SDR vs PyTorch Source
Kim Vocal 2 228M params, dim=384, depth=6, mask_depth=2 bf16 66.08 dB KimberleyJSN/melbandroformer
ZFTurbo vocals_v1 33.7M params, dim=192, depth=8, mask_depth=1 fp16 44.19 dB Music-Source-Separation-Training v1.0.0 release

Implementation parity numbers are SDR between MLX output and PyTorch reference output on the same input chunk — not vs ground truth. Both are above the 40 dB target the parity test treats as "effectively bit-exact up to floating-point precision."

The parity test lives at tests/sts/test_mel_roformer_parity.py, gated behind @pytest.mark.requires_torch so CI skips it. Local runs need PyTorch + the original .ckpt + a reference WAV configured via MEL_ROFORMER_* env vars; the tests/sts/torch_infer.py fixture wraps bs_roformer.MelBandRoformer (the lucidrains PyTorch reference that ZFTurbo's MSS-Training imports).

A note on dtype: the smaller ZFTurbo architecture (dim=192) is bf16-sensitive — bf16 quantization drops parity to 21.96 dB on the same harness, while fp16 (10-bit mantissa vs bf16's 7) recovers 44.19 dB at the same file size. The new --dtype {float32,float16,bfloat16} flag on the converter lets callers pick precision per checkpoint family. The default is float32 to preserve existing behavior.

A note on bs_roformer version: parity testing pins bs_roformer==0.3.10. Newer lucidrains releases (0.4+) reorder the layers ModuleList nesting and add nGPT-style normalization params, breaking compatibility with the classic-architecture checkpoints (Kim, viperx, ZFTurbo v1.0.0) that the world is using today. The _build_torch_model helper in torch_infer.py filters YAML hyperparameters against the installed init signature, so future drift fails visibly with a printed list of dropped keys rather than silently.

Bugs found and fixed during parity wiring

The parity test caught five load-bearing bugs that would have silently degraded separation quality:

  • RoPE convention: the initial port used the "halved split" (x[:half], x[half:]) convention. rotary_embedding_torch (what ZFTurbo imports) uses interleaved pairs (x[2i], x[2i+1]) with a doubled-frequency layout. Same math family, different rotation plane — attention output was uncorrelated with the reference until this was fixed. Biggest single SDR jump.
  • RMSNorm eps: mlx.nn.RMSNorm uses additive eps=1e-5; ZFTurbo uses F.normalize(x, dim=-1) which is effectively max(||x||, 1e-12). The two agree for typical magnitudes but diverge by up to 70% at the small-magnitude high-frequency STFT bins the band-split's per-band RMSNorm ingests. Replaced with a local RMSNorm matching F.normalize semantics.
  • to_gates bias: declared bias=False, but ZFTurbo's Linear(dim, heads) has the default bias=True and ships a to_gates.bias tensor.
  • MaskEstimator: was hardcoded to 3 linears (mask_estimator_depth=2) — made configurable from the config so the zfturbo_vocals_v1 depth=1 variant loads correctly.
  • RMSNorm parameter name: gamma in the PyTorch code, weight in mlx.nn. Renamed for compatibility.

Published artifacts on mlx-community

Both MIT-compatible checkpoints are live and discoverable via the Mel-Band-RoFormer (MLX) Collection:

Each repo carries the canonical model.safetensors + config.json + LICENSE + a model card with full provenance: source repo, source commit/release, source SHA-256, mlx version, this PR's converter SHA, and parity numbers.

Usage

import soundfile as sf
import numpy as np
import mlx.core as mx

from mlx_audio.sts.models.mel_roformer import MelRoFormer, MelRoFormerConfig
from mlx_audio.utils import load_audio

# Load model + weights from the Hub. The bundled config.json's `checkpoint_family`
# field auto-resolves the matching preset.
model = MelRoFormer.from_pretrained(
    "mlx-community/mel-roformer-kim-vocal-2-mlx",
)
model.eval()

# Stereo 44.1 kHz input mixture.
mixture = load_audio("input_mixture.wav", sample_rate=44100)  # mx.array [2, samples]
batched = mixture[None, ...]                                  # [1, 2, samples]

# Separate vocals; derive instrumental as (mixture - vocals).
vocals = model(batched)[0]                                    # [2, samples]
instrumental = mixture - vocals

sf.write("vocals.wav",       np.array(vocals).T,       44100)
sf.write("instrumental.wav", np.array(instrumental).T, 44100)

For checkpoint conversion (one-time, for users bringing their own weights):

python -m mlx_audio.sts.models.mel_roformer.convert \
    --input model_vocals_mel_band_roformer_sdr_8.42.ckpt \
    --output ./weights/ \
    --preset zfturbo_vocals_v1 \
    --dtype bfloat16   # or float16 / float32

The conversion is idempotent (content-addressed by SHA-256) and writes a companion config.json so from_pretrained() can resolve the architecture without the user re-selecting a preset. Output filenames include the dtype (<basename>.<sha>.<dtype>.safetensors) so multi-precision conversions of the same source don't collide.

License posture

Architecture: MIT (inherited from lucidrains/BS-RoFormer and ZFTurbo/Music-Source-Separation-Training).

Checkpoints: independently licensed — see mlx_audio/sts/models/mel_roformer/README.md for the per-checkpoint table. Both shipped mlx-community checkpoints are MIT, with caveats worth disclosing for clean redistribution:

  • Kim Vocal 2 was originally released under GPL-3.0 on 2025-06-17 and relicensed to MIT on 2026-04-22 by the original author Kimberley Jensen (commit ac9b061). The relicense was independently confirmed with the author the week of 2026-04-20 prior to the redistribution. The published model card carries the verbatim trace.
  • ZFTurbo vocals_v1 inherits MIT from the MSS-Training repository, but the repo's LICENSE file was committed on 2024-11-04 — about a year after the v1.0.0 release that produced this checkpoint asset. The published model card discloses the timeline transparently and notes the inheritance reasoning. (Repo is MIT today, asset is still distributed by the same MIT-licensed repo, author is an active OSS maintainer who has continued publishing under MIT.)

convert.py emits license-aware warnings when known-source checkpoint paths are passed (KimberleyJSN, ZFTurbo, TRvlvr/viperx, anvuew), printing the recognized license and any redistribution caveat before conversion proceeds.

What this PR does NOT do

To keep scope reviewable, several adjacent things are deliberately deferred:

  • No HuggingFace Hub mirroring of viperx or anvuew checkpoints. viperx (TRvlvr/model_repo) ships no LICENSE file; anvuew is GPL-3.0. Users wanting these can run convert.py locally on their own copies.
  • No automatic chunking helper for inputs longer than chunk_size. The model's forward pass operates on a single chunk; overlap-add is left to the caller. Happy to follow up in a separate PR if there's interest — flagged as an open question in "What's next" below.
  • No MelRoFormerProcessor class. The model takes raw mx.array [B, 2, samples] audio and returns mx.array [B, 2, samples] separated vocals. Audio loading happens via mlx_audio.utils.load_audio. Adding a Sam-Audio-style processor wrapper is straightforward but felt like scope creep.

Tests

All tests in tests/sts/test_mel_roformer.py pass locally:

  • TestMelRoFormerConfig — defaults, derived properties, all presets, custom escape hatch, keyword-only enforcement
  • TestMelRoFormerModel — construction with each preset
  • TestSanitize — QKV packed-weight splitting
  • TestConvert — training-state stripping, state-dict extraction across 3 PyTorch formats
  • TestLicenseDetection — warn-on-substring matcher
  • TestContentAddressing — SHA-256 hashing consistency, content-addressed naming with dtype suffix
  • TestConfigSerialization — round-trip config → dict → config for companion .config.json
  • TestPresetResolution — CLI preset lookup

Parity tests in tests/sts/test_mel_roformer_parity.py are @pytest.mark.requires_torch-gated and skipped by default. Local parity runs require PyTorch + bs_roformer==0.3.10 + the original .ckpt + a reference WAV configured via MEL_ROFORMER_* env vars. The tests/sts/torch_infer.py fixture wraps the lucidrains MelBandRoformer reference; it filters YAML hyperparameters against the installed init signature so future bs_roformer version drift surfaces visibly rather than silently.

What's next

  • Chunked inference wrapper (8-second chunks with crossfade overlap-add) — currently implemented per-chunk; the wrapper is straightforward but I'd like reviewer input on where it fits (inline in model.py, separate pipeline.py, or external user code). Was deferred from the original PR; still deferred here.
  • Companion Swift port at xocialize/mel-roformer-mlx-swift — independent SPM package, MLX-Swift native, ships a MelRoFormer.fromPretrained(_:) matching the Python convention. Not part of this PR; mentioned because the published mlx-community model cards reference it under "Usage" → Swift.
  • Additional presets as community trains new variants (covered today by MelRoFormerConfig.custom(...)).

Checklist

  • Architecture implementation follows the reference PyTorch models
  • Parity test run with SDR > 40 dB target — verified at two scales (Kim 66.08 dB bf16, ZFTurbo 44.19 dB fp16)
  • Unit tests cover config presets, weight sanitization, conversion utilities, content-addressing with dtype, license detection
  • Two MIT-compatible checkpoints published on mlx-community with full provenance + a discoverable Collection
  • License posture documented in README + per-checkpoint license-aware converter warnings
  • Conversion script handles .ckpt, .pt, .safetensors inputs across multiple PyTorch state-dict shapes
  • Content-addressed output (now including dtype) avoids re-converting unchanged files
  • pytest.mark.requires_torch gating for torch-dependent tests
  • PARITY_TESTING.md removed per @lucasnewman's earlier review feedback (commit 5d5a9c2)
  • Pre-commit hooks pass (black, isort, flake8) — applied in 4448196
  • Reviewer guidance requested: chunked inference wrapper location (inline in model.py, separate pipeline.py, or external user code)?
  • Reviewer guidance requested: any preference on parity-test infrastructure location? tests/sts/torch_infer.py is the test fixture; happy to move it under tests/sts/fixtures/ or similar if there's a preferred convention.

Tagging @Blaizzy and @lucasnewman per earlier review engagement. Two new commits since the last touch (rebased onto current main):

  • 8380ab8--dtype {float32,float16,bfloat16} flag on convert.py with dtype-suffixed content-addressing.
  • 1f9a555tests/sts/torch_infer.py (the parity-test fixture, previously external) now lives in-tree.

Both checkpoints have been parity-validated end-to-end and published to mlx-community since the last review pass — happy to walk through any specific section if helpful.

xocialize and others added 4 commits April 16, 2026 15:14
…1 dB

The parity test was permanently skipped, so several load-bearing bugs in
the MLX port went undetected. Wiring up the reference inference surfaced
them; this commit fixes them and verifies numerical parity.

Config / conversion
- Add zfturbo_vocals_v1 preset (dim=192, depth=8, hop=512,
  mask_estimator_depth=1) for ZFTurbo's v1.0.0 MIT-licensed vocals
  release asset. None of the existing three presets matched it.

Port bugs fixed in model.py
- RoPE convention: MLX used the halved split (x[:half] / x[half:]);
  rotary_embedding_torch (what ZFTurbo imports) uses interleaved pairs
  (x[2i] / x[2i+1]) with a doubled-frequency layout. Same math family,
  different rotation plane — attention output was uncorrelated with
  the reference until this was fixed. Single biggest jump in SDR.
- RMSNorm eps: mlx.nn.RMSNorm uses additive eps=1e-5; ZFTurbo uses
  F.normalize(x, dim=-1) which is effectively max(||x||, 1e-12). The
  two agree for typical magnitudes but diverge by up to 70% at the
  small-magnitude high-frequency STFT bins that the band-split's
  per-band RMSNorm ingests. Replaced with a local RMSNorm matching
  F.normalize semantics.
- to_gates: declared bias=False, but ZFTurbo's Linear(dim, heads) has
  the default bias=True and ships a to_gates.bias tensor.
- MaskEstimator: hardcoded to 3 linears (mask_estimator_depth=2 in
  effect). Now honors config.mask_estimator_depth, so depth-1
  checkpoints (like this one) load with the correct number of linears.
- MelFilterbank: custom Slaney triangular filter did not quite match
  librosa's. Swapped for librosa.filters.mel + ZFTurbo's explicit
  fb[0,0]=1 / fb[-1,-1]=1 force-assigns so every freq bin is covered.
- STFT: zero-padded the audio before framing, but torch.stft(center=True)
  pad_mode defaults to reflect. Built reflect pad via slice+concat
  (mx.pad has no reflect mode).
- STFT post-processing: .real() / .imag() as method calls; these are
  properties on mlx.core.array.
- BandSplit: freq_indices stored as np.ndarray, but MLX fancy-indexing
  rejects raw numpy — convert once at MelFilterbank init.

sanitize() remaps added
- Drop rotary_embed.freqs — MLX computes RoPE freqs on the fly.
- Mask-estimator per-band Sequential -> list index: PyTorch stores
  linears at positions {0, 2} (depth=1) or {0, 2, 4} (depth=2); MLX
  list positions are consecutive {0, 1, 2}.
- to_out.0.weight -> to_out.weight — PyTorch wraps in Sequential(Linear,
  Dropout), MLX uses a bare Linear.
- .gamma -> .weight — PyTorch RMSNorm names its scale "gamma";
  mlx.nn.RMSNorm names it "weight".

Test wiring in tests/sts/test_mel_roformer_parity.py
- Replace the pytest.skip stub in test_sdr_parity with the real
  single-chunk SDR comparison: dynamically imports a user-supplied
  torch_infer.py (see PARITY_TESTING.md) via importlib, runs both
  models on the same chunk of audio, asserts SDR > target (default
  40 dB).
- Add fixtures for MEL_ROFORMER_TORCH_INFER, MEL_ROFORMER_TORCH_CONFIG,
  MEL_ROFORMER_CHUNK_SAMPLES.

Verified
- All 25 existing unit tests in tests/sts/test_mel_roformer.py still pass.
- test_qkv_split_preserves_weights: pass.
- test_sdr_parity on the v1.0.0 ZFTurbo MIT vocals ckpt +
  ZFTurbo config + a deterministic synthetic 30s stereo signal:
  SDR = 58.11 dB (target > 40 dB).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Register MelRoFormer, MelRoFormerConfig, and MelRoFormerResult in the
parent package so users can import via mlx_audio.sts.models without
reaching into the subpackage.

Rename our result dataclass from SeparationResult to MelRoFormerResult
to avoid a name collision with sam_audio.SeparationResult, which is
already upstream and has a different streaming-oriented shape.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pre-commit auto-fixes from running hooks before PR submission.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@xocialize
Copy link
Copy Markdown
Author

xocialize commented Apr 17, 2026

NOTE: I have now joined so I should be able to proceed when appropriate with mlx-community

I would also appreciate it if you could add me to the mlx-community on hugging face so I can upload the mlx weights. contributions.

@lucasnewman
Copy link
Copy Markdown
Collaborator

@xocialize You can just join the community on HF yourself, so no permission needed. I know this is still a draft, but make sure your final PR doesn't include any scratchpad / planning .md files. Thanks!

Per @lucasnewman review comment on PR Blaizzy#654 — the final PR should not
contain scratchpad / planning .md files. PARITY_TESTING.md was an
internal M5 Pro setup guide used during parity verification; the
content has no use for downstream pip-install users of mlx-audio.

The parity test itself (`tests/sts/test_mel_roformer_parity.py`) remains
and is self-documenting via its docstrings and fixture parameters. Users
who want to run parity can follow the @pytest.mark.requires_torch flag
and the env-var fixtures directly.

README.md is unchanged — it's user-facing architecture + license +
usage documentation, not a planning doc.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@xocialize
Copy link
Copy Markdown
Author

Thanks @lucasnewman — joined mlx-community on HF, and just pushed 5d5a9c2 removing PARITY_TESTING.md (it was an internal M5 Pro setup guide — no value for pip-install users).

README.md is retained, since it's user-facing architecture + license + usage docs following the same convention as the other mlx_audio/sts/models/*/README.md files. Let me know if anything else stands out.

xocialize and others added 5 commits April 19, 2026 21:45
@KimberleyJensen relicensed the Kim Vocal 2 weights from GPL-3.0 to MIT
on HuggingFace (2026-04-18) in response to a redistribution-license
inquiry on her model page. This commit reflects the new posture
throughout the PR:

- README.md license posture table: Kim Vocal 2 → MIT (was GPL-3.0)
- README.md "running inference vs. redistributing" section: rewrite
  examples to use anvuew dereverb as the remaining GPL-3.0 exemplar;
  include Kim Vocal 2 in the MIT-safe recommendation list
- convert.py _LICENSE_HINTS: KimberleyJSN / kimvocal entries now emit
  an MIT note with context about the relicensing
- test_mel_roformer.py test_detects_kim_vocal_2: assert license_tag ==
  "MIT" (was "GPL-3.0")

This widens the set of MIT-safe checkpoint families that downstream
users can bundle in products — Kim Vocal 2 is now the strongest-SDR
preset with no redistribution constraints. The `anvuew` dereverb
model remains GPL-3.0; `viperx` remains undeclared. The PR's
license-aware conversion tooling continues to flag those.

Tests: 25 passed, 1 skipped — no regressions.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds an output-precision selector to the Mel-Band-RoFormer converter so the
same source checkpoint can be packaged at multiple precisions for HuggingFace
redistribution.

- New --dtype {float32,float16,bfloat16} CLI flag, default float32 (no behavior
  change for existing callers).
- Dtype is baked into the content-addressed output basename so simultaneous
  fp32/fp16/bf16 conversions of the same source no longer collide.
- Companion config.json records the chosen dtype for downstream provenance.

Driven by the bf16-only scope decision for the planned mlx-community uploads
(mel-roformer-kim-vocal-2-mlx, mel-roformer-zfturbo-vocals-v1-mlx).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Companion to tests/sts/test_mel_roformer_parity.py — wraps lucidrains
bs_roformer.MelBandRoformer (the canonical PyTorch reference used by
ZFTurbo's MSS-Training) so the parity harness can compare MLX outputs
against the upstream implementation rather than a custom reimplementation.

Provides the two functions the parity test loads via the
MEL_ROFORMER_TORCH_INFER env var:

  run(ckpt_path, config_yaml, audio_path, out_path, chunk_samples)
      -> np.ndarray [2, T_out]
  load_audio_chunk(audio_path, chunk_samples)
      -> (np.ndarray [2, T], sr)

Filters YAML hyperparameters against the installed bs_roformer init
signature so future lucidrains releases that add/remove parameters fail
visibly (with a printed list of dropped keys) rather than silently.

Validated against:
- Kim Vocal 2 (MelBandRoformer.ckpt, 684 keys, 228M params): SDR 66.08 dB
  parity vs MLX bf16 conversion → published as
  mlx-community/mel-roformer-kim-vocal-2-mlx
- ZFTurbo vocals_v1 (model_vocals_mel_band_roformer_sdr_8.42.ckpt,
  612 keys, 33.7M params): SDR 44.19 dB parity vs MLX fp16 conversion
  → published as mlx-community/mel-roformer-zfturbo-vocals-v1-mlx

Pin: bs_roformer==0.3.10. Newer releases (0.4+) reorder the layers
ModuleList and add nGPT-style normalization, breaking checkpoint
compatibility for both Kim and ZFTurbo v1.0.0 weights.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@xocialize xocialize marked this pull request as ready for review April 25, 2026 19:16
| ZFTurbo MSS-Training release assets | [ZFTurbo MSS-Training releases](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases) | MIT (inherited from repo) |
| Community variants (dereverb, denoise, crowd) | various | Varies — review each |

## Running inference vs. redistributing weights
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove this section? We don't support training at the moment so this isn't relevant to running pretrained models.

Comment thread tests/sts/torch_infer.py Outdated
@@ -0,0 +1,199 @@
"""PyTorch reference inference for Mel-Band-RoFormer parity testing.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need/want this file, it looks like it was just for your testing.

Comment thread pytest.ini Outdated
asyncio_default_fixture_loop_scope = function No newline at end of file
asyncio_default_fixture_loop_scope = function
markers =
requires_torch: test requires PyTorch installed (skip in MLX-only CI)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd prefer not to include torch, so we don't need this (see above comment).

[lucidrains/BS-RoFormer](https://github.com/lucidrains/BS-RoFormer) and
[ZFTurbo/Music-Source-Separation-Training](https://github.com/ZFTurbo/Music-Source-Separation-Training).

## What this contribution is
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove this section as well? It looks like your agent is injecting some kind of overview of the porting effort -- we just want clean docs.

leaves them zero, which breaks the "every freq covered" invariant),
then binarizes.
"""
import librosa
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, we can't have librosa as a dependency. We have the ability to generate identical filterbanks in mlx_audio/dsp.py -- please use that instead.

# ---------- STFT / iSTFT ----------


def stft(audio: mx.array, n_fft: int, hop_length: int, window: mx.array) -> tuple:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have an implementation of this in mlx_audio/dsp.py -- please use that.



def istft(
real: mx.array,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same -- let's use dsp.py

@lucasnewman
Copy link
Copy Markdown
Collaborator

@xocialize Looks pretty good overall, please see the comments of where we can avoid code duplication and make sure we're not pulling in excessive dependencies (torch / librosa / etc).

…erbank

Per @lucasnewman's review (Blaizzy#654 comments r3142451654, r3142452858,
r3142453107): drop librosa as a dependency, stop duplicating STFT/iSTFT
implementations, and use the shared utilities in mlx_audio/dsp.py instead.

- MelFilterbank now uses mlx_audio.dsp.mel_filters with mel_scale="slaney"
  (matching the previous librosa.filters.mel default), then keeps the
  ZFTurbo-style DC/Nyquist forcing and binarization on top.
- The local stft() function becomes a thin batched/multichannel adapter
  over mlx_audio.dsp.stft (which is 1D-input). Loops over [B, channels]
  and reshapes into the [B, channels, freq_bins, frames] layout BandSplit
  expects.
- The local istft() function similarly delegates to mlx_audio.dsp.istft.
  Two adapter notes worth flagging:
    - ``length=None`` is passed (not the target length) because that
      branch correctly strips the ``win_length // 2`` center-pad; passing
      the target length skips center-pad removal and shifts the
      reconstruction. We trim to ``length`` in the adapter.
    - ``normalized=True`` for COLA-style ``window**2`` normalization,
      matching torch.istft and the previous local implementation.
    - dsp.istft's center-strip can leave the output a few samples short
      of ``length`` when the chunk size isn't an integer multiple of
      hop_length (e.g. ZFTurbo's 352800 / 512). Tail-pad with zeros to
      reach ``length`` — those samples were in the center-pad region.

Parity preserved exactly across both validated checkpoints (re-measured
on the same harness as before refactor):
- Kim Vocal 2 (bf16):  66.08 dB SDR vs PyTorch reference (unchanged)
- ZFTurbo vocals_v1 (fp16): 44.19 dB SDR vs PyTorch reference (unchanged)
- ZFTurbo vocals_v1 (bf16): 21.96 dB SDR vs PyTorch reference (unchanged
  — bf16-sensitive at this small architecture, hence we ship fp16 for
  this preset on mlx-community)

Also fix TestContentAddressing.test_content_addressed_name to pass the
``dtype`` arg added in 8380ab8.
Now that the parity test infrastructure (tests/sts/torch_infer.py,
tests/sts/test_mel_roformer_parity.py) is removed, the requires_torch
marker registration in pytest.ini is no longer used. Drops the section
per @lucasnewman's review (Blaizzy#654 comment r3142448241).
Per @lucasnewman's review (Blaizzy#654 comments r3142445601, r3142449229):

- Remove the "What this contribution is" section. It was framed around
  this module's place in the upstream-PR narrative rather than describing
  the model. User-facing docs should describe the model, not the porting
  effort.
- Remove the "Running inference vs. redistributing weights" section.
  Training isn't supported in this module today, and the per-license
  redistribution table set up expectations the module doesn't deliver
  on. The shorter "License" section keeps the architecture-MIT note and
  defers to checkpoint source repos for weight licensing.
- Drop the parity-test footnote in the convert.py section now that the
  parity test infrastructure has been removed.
- Update the convert.py example to show the --dtype flag.
@xocialize
Copy link
Copy Markdown
Author

Thanks @lucasnewman — all addressed in 39c9325, 4555494, 5e5cea2:

  • mlx_audio/dsp.py over duplicate primitives (#discussion_r3142451654, r3142452858, r3142453107): librosa is gone, the local stft/istft are now thin batched adapters over dsp.stft/dsp.istft, and MelFilterbank uses dsp.mel_filters(mel_scale=\"slaney\"). Two non-obvious adapter notes I called out in the commit message: passing length=None to dsp.istft (because the length is not None branch skips the center-pad strip), and tail-padding when ZFTurbo's chunk size isn't an integer multiple of hop_length. Parity preserved exactly — Kim Vocal 2 still 66.08 dB SDR vs PyTorch, ZFTurbo fp16 still 44.19 dB.
  • No torch in the project (r3142447526, r3142448241): tests/sts/torch_infer.py and tests/sts/test_mel_roformer_parity.py removed; requires_torch marker registration dropped from pytest.ini. The parity infra served its purpose during development; future contributors who want to re-validate can fork the harness from this branch's git history.
  • README trims (r3142445601, r3142449229): "What this contribution is" and "Running inference vs. redistributing weights" sections gone; License section now just notes architecture-MIT and defers to checkpoint source repos. Also removed the parity-test footnote and updated the convert.py example to show --dtype.

Net delta: −558 LoC (mostly the duplicate STFT/iSTFT and the prose sections). The two open guidance items in the description (chunked-inference helper location, parity-fixture location) are no longer relevant — chunking still deferred for a follow-up, parity fixtures gone per your ask. Ready for re-review whenever you have a minute.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants