diff --git a/cli/__init__.py b/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/runtime/triton_trtllm/client_grpc_streaming.py b/runtime/triton_trtllm/client_grpc_streaming.py index 60c2510..36a9ed9 100644 --- a/runtime/triton_trtllm/client_grpc_streaming.py +++ b/runtime/triton_trtllm/client_grpc_streaming.py @@ -85,10 +85,11 @@ def tts( if i == 0: new_audio = audio[:-cross_fade_samples] else: - cross_faded_overlap = audio[:cross_fade_samples] * fade_in + audios[i - 1][-cross_fade_samples:] * fade_out + cross_faded_overlap = audios[i - 1][-cross_fade_samples:] * fade_out + audio[:cross_fade_samples] * fade_in new_audio = np.concatenate([new_audio, cross_faded_overlap, audio[cross_fade_samples:-cross_fade_samples]]) new_audio = np.concatenate([new_audio, audio[-cross_fade_samples:]]) + #new_audio = np.hstack(audios) sf.write(output_audio, new_audio, 16000, "PCM_16") print(f"save audio to {output_audio}") diff --git a/sparktts/models/bicodec.py b/sparktts/models/bicodec.py index 8cab2f0..e4a4202 100644 --- a/sparktts/models/bicodec.py +++ b/sparktts/models/bicodec.py @@ -228,11 +228,16 @@ def _remove_weight_norm(m): model = BiCodec.load_from_checkpoint( model_dir="pretrained_models/SparkTTS-0.5B/BiCodec", ) + device = "cpu" if not torch.cuda.is_available() else "cuda" + print(model) + model_million_params = sum(p.numel() for p in model.parameters()) / 1e6 + print(f"{model_million_params}M parameters") + model.to(device) # Generate random inputs for testing duration = 0.96 - x = torch.randn(20, 1, int(duration * 16000)) - feat = torch.randn(20, int(duration * 50), 1024) + x = torch.randn(20, 1, int(duration * 16000)).to(device) + feat = torch.randn(20, int(duration * 50), 1024).to(device) inputs = {"feat": feat, "wav": x, "ref_wav": x} # Forward pass @@ -241,7 +246,8 @@ def _remove_weight_norm(m): wav_recon = model.detokenize(semantic_tokens, global_tokens) # Verify if the reconstruction matches - if torch.allclose(outputs["recons"].detach(), wav_recon): + if torch.allclose(outputs["recons"].detach(), wav_recon, rtol=1e-3, atol=1e-5): + #if torch.allclose(outputs["recons"].detach(), wav_recon): print("Test successful") else: print("Test failed") diff --git a/sparktts/utils/token_parser.py b/sparktts/utils/token_parser.py index cc43782..9dbd13c 100644 --- a/sparktts/utils/token_parser.py +++ b/sparktts/utils/token_parser.py @@ -156,10 +156,11 @@ def emotion(emotion: str): # test if __name__ == "__main__": + import os from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( - "/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer" + os.getenv("TOKENIZER_PATH", "/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer") ) tasks = ["tts", "tts", "understand", "controllable_tts", "prompt_tts"] @@ -183,5 +184,7 @@ def emotion(emotion: str): inputs = [task, age, gender, mel, mel_level, loudness, loudness_level, emotion] inputs = "".join(inputs) ids = tokenizer.encode(inputs, add_special_tokens=False) - print(ids) - print("decode", tokenizer.decode(ids)) + print("tokenized ids",ids) + tokens = tokenizer.decode(ids) + print("decoded tokens", tokens) + assert tokens == inputs