From 244144322c738a64d98ce51247d8de75a74e9449 Mon Sep 17 00:00:00 2001 From: weedge Date: Thu, 3 Apr 2025 16:16:53 +0800 Subject: [PATCH] feat: audio tokenizer add batch encode Signed-off-by: weedge --- .gitignore | 4 + sparktts/models/audio_tokenizer.py | 141 ++++++++++++++++++++++------- sparktts/models/bicodec.py | 29 +++--- 3 files changed, 128 insertions(+), 46 deletions(-) diff --git a/.gitignore b/.gitignore index 146a43c..8bd9469 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,7 @@ cython_debug/ # PyPI configuration file .pypirc + +# runtime +triton_python_backend_utils.py +*.wav \ No newline at end of file diff --git a/sparktts/models/audio_tokenizer.py b/sparktts/models/audio_tokenizer.py index d7065eb..55c57ad 100644 --- a/sparktts/models/audio_tokenizer.py +++ b/sparktts/models/audio_tokenizer.py @@ -18,7 +18,7 @@ import numpy as np from pathlib import Path -from typing import Any, Dict, Tuple +from typing import Any, Dict, Literal, Optional, Tuple, List, Union from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model from sparktts.utils.file import load_config @@ -39,19 +39,25 @@ def __init__(self, model_dir: Path, device: torch.device = None, **kwargs): self.device = device self.model_dir = model_dir self.config = load_config(f"{model_dir}/config.yaml") - self._initialize_model() + self._initialize_model(**kwargs) - def _initialize_model(self): + def _initialize_model( + self, + attn_implementation: Optional[Literal["sdpa", "flash_attention_2", "eager"]] = None, + ): """Load and initialize the BiCodec model and Wav2Vec2 feature extractor.""" - self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to( - self.device - ) + self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to(self.device) self.processor = Wav2Vec2FeatureExtractor.from_pretrained( f"{self.model_dir}/wav2vec2-large-xlsr-53" ) - self.feature_extractor = Wav2Vec2Model.from_pretrained( - f"{self.model_dir}/wav2vec2-large-xlsr-53" - ).to(self.device) + self.feature_extractor = ( + Wav2Vec2Model.from_pretrained( + f"{self.model_dir}/wav2vec2-large-xlsr-53", + attn_implementation=attn_implementation, + ) + .to(self.device) + .eval() + ) self.feature_extractor.config.output_hidden_states = True def get_ref_clip(self, wav: np.ndarray) -> np.ndarray: @@ -69,8 +75,11 @@ def get_ref_clip(self, wav: np.ndarray) -> np.ndarray: return wav[:ref_segment_length] - def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Tensor]: - """load auido and get reference audio from wav path""" + def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, np.ndarray]: + """ + load auido and get reference audio from wav path + return (wav, wav_ref) # (shape:(seq_len), shape:(seq_len)) + """ wav = load_audio( wav_path, sampling_rate=self.config["sample_rate"], @@ -79,24 +88,23 @@ def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Tensor]: wav_ref = self.get_ref_clip(wav) - wav_ref = torch.from_numpy(wav_ref).unsqueeze(0).float() return wav, wav_ref - def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor: - """extract wav2vec2 features""" + def extract_wav2vec2_features(self, wavs: np.ndarray | List[np.ndarray]) -> torch.Tensor: + """extract wav2vec2 features + return: torch.Tensor shape:(batch_size, features_seq_len, feature_dim) + """ inputs = self.processor( wavs, sampling_rate=16000, return_tensors="pt", padding=True, - output_hidden_states=True, - ).input_values - feat = self.feature_extractor(inputs.to(self.feature_extractor.device)) - feats_mix = ( - feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16] - ) / 3 + # output_hidden_states=True, + ).to(self.feature_extractor.device) + feat = self.feature_extractor(**inputs) + feats_mix = (feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]) / 3 - return feats_mix + return feats_mix.detach() def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor: """tokenize the batch of audio @@ -116,22 +124,66 @@ def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor: return global_tokens, semantic_tokens + def batch_tokenize( + self, audio_paths: Union[str | List[str]] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + return (global_tokens, semantic_tokens): + - semantic_tokens: semantic tokens. shape: (batch_size, latent_dim) + - global_tokens: global tokens. shape: (batch_size, channel, global_dim) + """ + if isinstance(audio_paths, str): + audio_paths = [audio_paths] + wav_list = [] + audio_clip = [] + for audio_path in audio_paths: + wav, wav_ref = self.process_audio(audio_path) + wav_list.append(wav) + audio_clip.append(torch.from_numpy(wav_ref)) + + audio_clip = torch.stack(audio_clip).to(self.device) + audio_features = self.extract_wav2vec2_features(wav_list) + + batch = { + "ref_wav": audio_clip.float().to(self.device), # [batch_size,seq_len] + "feat": audio_features.to(self.device), # [batch_size,features_seq_len,feature_dim] + } + semantic_tokens, global_tokens = self.model.tokenize(batch) # [batch_size,seq_len] + + if self.device.type == "cuda": + torch.cuda.empty_cache() + + return global_tokens, semantic_tokens + + def batch_detokenize( + self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor + ) -> np.array: + wav_rec = self.model.detokenize(semantic_tokens, global_tokens) + if self.device.type == "cuda": + torch.cuda.empty_cache() + return wav_rec.squeeze().cpu().numpy() + def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]: - """tokenize the audio""" + """tokenize the audio + return (global_tokens, semantic_tokens): + - semantic_tokens: semantic tokens. shape: (batch_size, latent_dim) + - global_tokens: global tokens. shape: (batch_size, channel, global_dim) + """ wav, ref_wav = self.process_audio(audio_path) feat = self.extract_wav2vec2_features(wav) batch = { - "wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device), - "ref_wav": ref_wav.to(self.device), + # "wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device), + "ref_wav": torch.from_numpy(ref_wav).unsqueeze(0).float().to(self.device), "feat": feat.to(self.device), } semantic_tokens, global_tokens = self.model.tokenize(batch) + if self.device.type == "cuda": + torch.cuda.empty_cache() + return global_tokens, semantic_tokens - def detokenize( - self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor - ) -> np.array: + def detokenize(self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor) -> np.array: """detokenize the tokens to waveform Args: @@ -143,21 +195,42 @@ def detokenize( """ global_tokens = global_tokens.unsqueeze(1) wav_rec = self.model.detokenize(semantic_tokens, global_tokens) - return wav_rec.detach().squeeze().cpu().numpy() + if self.device.type == "cuda": + torch.cuda.empty_cache() + return wav_rec.squeeze().cpu().numpy() # test if __name__ == "__main__": import soundfile as sf + import os + from time import perf_counter device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = BiCodecTokenizer( - model_dir="pretrained_models/Spark-TTS-0.5B", + model_dir=os.getenv("MODEL_DIR", "pretrained_models/Spark-TTS-0.5B"), device=device, ) - wav_path = "example/prompt_audio.wav" - - global_tokens, semantic_tokens = tokenizer.tokenize(wav_path) - wav_rec = tokenizer.detokenize(global_tokens.squeeze(0), semantic_tokens) - sf.write("example/prompt_recon.wav", wav_rec, 16000) + wav_cases = { + "single": "example/prompt_audio.wav", + "multi": ["example/prompt_audio.wav", "example/prompt_audio.wav"], + } + for case, wav_path in wav_cases.items(): + start_time = perf_counter() + if isinstance(wav_path, list): + global_tokens, semantic_tokens = tokenizer.batch_tokenize(wav_path) + else: + global_tokens, semantic_tokens = tokenizer.tokenize(wav_path) + print(f"""{case} encode elapsed time: {perf_counter()-start_time:.4f} seconds""") + print(semantic_tokens.shape, global_tokens.shape) + + start_time = perf_counter() + wav_rec = tokenizer.detokenize(global_tokens.squeeze(1), semantic_tokens) + print(f"""{case} decode elapsed time: {perf_counter()-start_time:.4f} seconds""") + print(wav_rec.shape) + if len(wav_rec.shape) > 1: + for i, wav in enumerate(wav_rec): + sf.write(f"example/prompt_recon_{i}.wav", wav, 16000) + if len(wav_rec.shape) == 1: + sf.write("example/prompt_recon.wav", wav_rec, 16000) diff --git a/sparktts/models/bicodec.py b/sparktts/models/bicodec.py index 8cab2f0..2555ab3 100644 --- a/sparktts/models/bicodec.py +++ b/sparktts/models/bicodec.py @@ -17,7 +17,6 @@ import torch.nn as nn from pathlib import Path from typing import Dict, Any -from omegaconf import DictConfig from safetensors.torch import load_file from sparktts.utils.file import load_config @@ -43,7 +42,7 @@ def __init__( speaker_encoder: nn.Module, prenet: nn.Module, postnet: nn.Module, - **kwargs + **kwargs, ) -> None: """ Initializes the BiCodec model with the required components. @@ -73,12 +72,12 @@ def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec": Args: model_dir (Path): Path to the model directory containing checkpoint and config. - + Returns: BiCodec: The initialized BiCodec model. """ - ckpt_path = f'{model_dir}/model.safetensors' - config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer'] + ckpt_path = f"{model_dir}/model.safetensors" + config = load_config(f"{model_dir}/config.yaml")["audio_tokenizer"] mel_params = config["mel_params"] encoder = Encoder(**config["encoder"]) quantizer = FactorizedVectorQuantize(**config["quantizer"]) @@ -116,7 +115,7 @@ def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: Args: batch (dict): A dictionary containing features, reference waveform, and target waveform. - + Returns: dict: A dictionary containing the reconstruction, features, and other metrics. """ @@ -166,7 +165,7 @@ def tokenize(self, batch: Dict[str, Any]): semantic_tokens = self.quantizer.tokenize(z) global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2)) - return semantic_tokens, global_tokens + return semantic_tokens.detach(), global_tokens.detach() @torch.no_grad() def detokenize(self, semantic_tokens, global_tokens): @@ -186,7 +185,7 @@ def detokenize(self, semantic_tokens, global_tokens): x = x + d_vector.unsqueeze(-1) wav_recon = self.decoder(x) - return wav_recon + return wav_recon.detach() def init_mel_transformer(self, config: Dict[str, Any]): """ @@ -212,6 +211,7 @@ def init_mel_transformer(self, config: Dict[str, Any]): def remove_weight_norm(self): """Removes weight normalization from all layers.""" + def _remove_weight_norm(m): try: torch.nn.utils.remove_weight_norm(m) @@ -223,16 +223,20 @@ def _remove_weight_norm(m): # Test the model if __name__ == "__main__": - config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml") model = BiCodec.load_from_checkpoint( model_dir="pretrained_models/SparkTTS-0.5B/BiCodec", ) + device = "cpu" if not torch.cuda.is_available() else "cuda" + print(model) + model_million_params = sum(p.numel() for p in model.parameters()) / 1e6 + print(f"{model_million_params}M parameters") + model.to(device) # Generate random inputs for testing duration = 0.96 - x = torch.randn(20, 1, int(duration * 16000)) - feat = torch.randn(20, int(duration * 50), 1024) + x = torch.randn(20, 1, int(duration * 16000)).to(device) + feat = torch.randn(20, int(duration * 50), 1024).to(device) inputs = {"feat": feat, "wav": x, "ref_wav": x} # Forward pass @@ -241,7 +245,8 @@ def _remove_weight_norm(m): wav_recon = model.detokenize(semantic_tokens, global_tokens) # Verify if the reconstruction matches - if torch.allclose(outputs["recons"].detach(), wav_recon): + if torch.allclose(outputs["recons"].detach(), wav_recon, rtol=1e-3, atol=1e-5): + # if torch.allclose(outputs["recons"].detach(), wav_recon): print("Test successful") else: print("Test failed")