基于 Stochastic Regeneration 框架的因果音色修复模型,用于补偿 DeepFilterNet 降噪造成的音色损失。
- ✅ 完全因果:适合实时流式处理,无前瞻延迟
- ✅ GAN 训练:生成模型补偿预测模型的 over-suppression
- ✅ 轻量级:~3.5M 参数,单帧处理 < 1ms
- ✅ ONNX 导出:便于 Rust/C++ 集成
pip install -r requirements.txt修改 configs/default.yaml 中的数据路径:
data:
vctk_path: "/workspace/vctk/wav48_silence_trimmed" # 英文
aishell3_path: "/workspace/Aishell3/train/wav" # 中文
ir_path: "/data/impulse/datasets_fullband/impulse_responses" # IR 递归加载多级子目录
noise_path: "/data/freesound/datasets_fullband/noise_fullband" # 噪声目录,递归随机抽样
output_dir: "/workspace/timbre_restore_data"
# 中/英比例(权重),默认 0.3 : 0.7
dataset_ratio:
vctk: 0.3
aishell3: 0.7运行数据准备(支持分片并行 + 断点续跑):
# 小规模调试(随机抽样 200 条,跳过 DF 加速)
python data/prepare_dataset.py \
--config configs/default.yaml \
--max_files 200 \
--skip_df \
--skip_existing
# 全量单进程(默认 24 worker,在进程内并行切片)
python data/prepare_dataset.py \
--config configs/default.yaml \
--skip_existing
# 全量分片并行(4 进程/4 卡示例)
for i in 0 1 2 3; do
CUDA_VISIBLE_DEVICES=$i \
python data/prepare_dataset.py \
--config configs/default.yaml \
--shard-idx $i --shard-count 4 \
--skip_existing \
--num_workers 12 &
done
wait关键参数:
--skip_existing已存在的 clean/degraded 直接跳过,便于断点续跑/多机多进程并行。--shard-idx/--shard-count分片跑,不同进程/机器用不同的 shard-idx。--max_files仅用于抽样调试。
单卡:
python train.py --config configs/default.yaml多卡 DDP:
torchrun --nproc_per_node=4 train.py --config configs/default.yaml或使用脚本:
chmod +x run.sh
./run.sh all # 完整流程:准备 → 训练 → 导出python export_onnx.py \
--checkpoint logs/xxx/checkpoints/checkpoint_final.pt \
--output timbre_restore.onnx \
--verify \
--benchmark-
ONNX 非流式(一次性):
python tools/inference.py \ -m timbre_restore_ep74.onnx \ -i nc2.wav
-
ONNX 流式(需传入隐藏状态规格):
python tools/inference.py \ -m timbre_restore_ep74.onnx \ -i nc2.wav \ --frame-size 480 \ --hidden-size 384 \ --num-layers 2
- 如只有 CPU,可加
OMP_NUM_THREADS=1 OMP_WAIT_POLICY=PASSIVE。 - 模型有
h_in/h_out即为流式导出。
- 如只有 CPU,可加
-
PyTorch
.pt:python tools/inference.py \ -m logs1209/20251208_063820/checkpoints/checkpoint_epoch_0074.pt \ -i nc2.wav \ --config configs/lightweight.yaml
流式导出(带隐藏状态):
python tools/export_onnx.py \
--checkpoint logs1209/20251208_063820/checkpoints/checkpoint_epoch_0074.pt \
--config configs/lightweight.yaml \
--output timbre_restore_ep74.onnx \
--streaming \
--verify \
--benchmark
# CPU 基准可结合线程控制
# OMP_NUM_THREADS=4 OMP_WAIT_POLICY=PASSIVE ...python - <<'PY'
import random, numpy as np, pathlib, sys
sys.path.insert(0,'.')
from data.dataset import TimbreRestoreDataset
txt = "/data/train_data_lite/train.txt" # 换成你的 train.txt
ds = TimbreRestoreDataset(
txt,
segment_length=48000,
sample_rate=48000,
align_df_delay=True,
align_max_shift=4000, # 与配置一致
augment=False, # 关闭增强,便于测量
)
print("样本数:", len(ds))
# 1) 存在性与大小检查(直接扫文件列表)
missing = []
pairs = [l.split('|',1) for l in pathlib.Path(txt).read_text().splitlines() if '|' in l]
for d,c in random.sample(pairs, min(200, len(pairs))):
d = pathlib.Path(d); c = pathlib.Path(c)
if (not d.exists()) or (not c.exists()) or d.stat().st_size < 1024 or c.stat().st_size < 1024:
missing.append((d,c))
print("存在性可疑条目:", len(missing))
# 2) 对齐偏移 & 长度差(经过 Dataset 对齐后)
shifts = []
len_diff = []
for _ in range(min(100, len(ds))):
d,c = ds[random.randint(0,len(ds)-1)]
x,y = d.numpy().flatten(), c.numpy().flatten()
n = min(len(x), len(y), 48000)
x,y = x[:n], y[:n]
corr = np.correlate(y, x, mode='full')
lag = np.argmax(corr) - (n-1)
shifts.append(lag)
len_diff.append(len(x)-len(y))
print("shift_samples: mean=%.1f, min=%d, max=%d" % (np.mean(shifts), np.min(shifts), np.max(shifts)))
print("len_diff: mean=%.1f, min=%d, max=%d" % (np.mean(len_diff), np.min(len_diff), np.max(len_diff)))
# 3) SNR 粗测(对齐后)
snrs = []
for _ in range(min(100, len(ds))):
d,c = ds[random.randint(0,len(ds)-1)]
x,y = d.numpy().flatten(), c.numpy().flatten()
n = min(len(x), len(y))
x,y = x[:n], y[:n]
nse = y - x
snr = 10*np.log10(np.mean(y**2)/(np.mean(nse**2)+1e-9))
snrs.append(snr)
print("SNR(dB): mean=%.2f, min=%.2f, max=%.2f" % (np.mean(snrs), np.min(snrs), np.max(snrs)))
PY判定参考:
- 存在性:缺失/零字节为 0
- 对齐:shift 绝对值应靠近 0(<~5-10 样本),len_diff 应为 0
- SNR:均值建议在 15~25 dB(明显低于 10 dB 需检查对齐/损坏样本)
| 配置 | 数据量 | 50 Epochs |
|---|---|---|
| 单卡 4090 | 10万样本 | ~6-8 小时 |
| 2× 4090 | 10万样本 | ~3.5-4.5 小时 |
| 4× 4090 | 10万样本 | ~2-2.5 小时 |
输入 (DF输出) [B, 1, T]
│
▼
┌─────────────────┐
│ Causal Encoder │ Conv1d ↓ [64→128→256→512]
└─────────────────┘
│
▼
┌─────────────────┐
│ GRU Bottleneck │ 因果时序建模
└─────────────────┘
│
▼
┌─────────────────┐
│ Causal Decoder │ Conv1d ↑ + Skip [512→256→128→64]
└─────────────────┘
│
▼
输出 (修复) [B, 1, T]
参见 rust_integration/timbre_restore.rs 示例。
核心集成点(在 capture.rs):
// DF 处理后
lsnr = df.process(inframe.view(), outframe.view_mut())?;
// 音色修复
if let Some(buf) = outframe.as_slice_mut() {
timbre_model.process_frame(buf)?;
}
// 后续 EQ/AGCtimbre_restore/
├── configs/
│ └── default.yaml # 训练配置
├── data/
│ ├── prepare_dataset.py # 数据准备
│ └── dataset.py # PyTorch Dataset
├── model/
│ ├── generator.py # 因果 U-Net Generator
│ ├── discriminator.py # 多尺度判别器
│ └── losses.py # 损失函数
├── rust_integration/
│ └── timbre_restore.rs # Rust 集成示例
├── train.py # 训练入口
├── export_onnx.py # ONNX 导出
├── run.sh # 运行脚本
└── requirements.txt # Python 依赖
- 数据抽样:干净语音按
dataset_ratio权重随机抽取(含洗牌),IR/噪声目录递归加载,噪声/IR 每次调用都会随机选取,确保样本多样。 - 批次顺序:数据列表与任务列表均打乱,训练 DataLoader 默认
shuffle=True。 - 如需复现,可在外部设置随机种子(
PYTHONHASHSEED、torch.manual_seed等);默认不固定种子以增加多样性,帮助泛化。
MIT