From 9cff5ebf0a7bfe88531852d5648249d85a204403 Mon Sep 17 00:00:00 2001 From: weedge Date: Wed, 12 Mar 2025 16:27:59 +0800 Subject: [PATCH 1/5] fix: test bicodec to print model and params Signed-off-by: weedge --- sparktts/models/bicodec.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sparktts/models/bicodec.py b/sparktts/models/bicodec.py index 8cab2f0..e4a4202 100644 --- a/sparktts/models/bicodec.py +++ b/sparktts/models/bicodec.py @@ -228,11 +228,16 @@ def _remove_weight_norm(m): model = BiCodec.load_from_checkpoint( model_dir="pretrained_models/SparkTTS-0.5B/BiCodec", ) + device = "cpu" if not torch.cuda.is_available() else "cuda" + print(model) + model_million_params = sum(p.numel() for p in model.parameters()) / 1e6 + print(f"{model_million_params}M parameters") + model.to(device) # Generate random inputs for testing duration = 0.96 - x = torch.randn(20, 1, int(duration * 16000)) - feat = torch.randn(20, int(duration * 50), 1024) + x = torch.randn(20, 1, int(duration * 16000)).to(device) + feat = torch.randn(20, int(duration * 50), 1024).to(device) inputs = {"feat": feat, "wav": x, "ref_wav": x} # Forward pass @@ -241,7 +246,8 @@ def _remove_weight_norm(m): wav_recon = model.detokenize(semantic_tokens, global_tokens) # Verify if the reconstruction matches - if torch.allclose(outputs["recons"].detach(), wav_recon): + if torch.allclose(outputs["recons"].detach(), wav_recon, rtol=1e-3, atol=1e-5): + #if torch.allclose(outputs["recons"].detach(), wav_recon): print("Test successful") else: print("Test failed") From f127b853797e57a57dfa241b341e9d789ad7bdf5 Mon Sep 17 00:00:00 2001 From: weedge Date: Wed, 12 Mar 2025 17:04:18 +0800 Subject: [PATCH 2/5] fix: add cli/__init__.py Signed-off-by: weedge --- cli/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 cli/__init__.py diff --git a/cli/__init__.py b/cli/__init__.py new file mode 100644 index 0000000..e69de29 From 58c8d105573b7a049267aa0f0d228416cd9204d5 Mon Sep 17 00:00:00 2001 From: weedge Date: Wed, 12 Mar 2025 17:18:54 +0800 Subject: [PATCH 3/5] add TOKENIZER_PATH env for tokenizer paser Signed-off-by: weedge --- sparktts/utils/token_parser.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sparktts/utils/token_parser.py b/sparktts/utils/token_parser.py index cc43782..cc9cad9 100644 --- a/sparktts/utils/token_parser.py +++ b/sparktts/utils/token_parser.py @@ -156,10 +156,11 @@ def emotion(emotion: str): # test if __name__ == "__main__": + import os from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( - "/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer" + os.getenv("TOKENIZER_PATH", "/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer") ) tasks = ["tts", "tts", "understand", "controllable_tts", "prompt_tts"] From 27e0648d0d275a9df8c8ae27571517ca4f588b4a Mon Sep 17 00:00:00 2001 From: weedge Date: Wed, 12 Mar 2025 17:25:33 +0800 Subject: [PATCH 4/5] assert tokens==inputs Signed-off-by: weedge --- sparktts/utils/token_parser.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sparktts/utils/token_parser.py b/sparktts/utils/token_parser.py index cc9cad9..9dbd13c 100644 --- a/sparktts/utils/token_parser.py +++ b/sparktts/utils/token_parser.py @@ -184,5 +184,7 @@ def emotion(emotion: str): inputs = [task, age, gender, mel, mel_level, loudness, loudness_level, emotion] inputs = "".join(inputs) ids = tokenizer.encode(inputs, add_special_tokens=False) - print(ids) - print("decode", tokenizer.decode(ids)) + print("tokenized ids",ids) + tokens = tokenizer.decode(ids) + print("decoded tokens", tokens) + assert tokens == inputs From b18a76c52621b7fef739c828a2c5d9d30e4b81cd Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 30 Mar 2025 20:52:00 +0800 Subject: [PATCH 5/5] fix Signed-off-by: weedge --- runtime/triton_trtllm/client_grpc_streaming.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/runtime/triton_trtllm/client_grpc_streaming.py b/runtime/triton_trtllm/client_grpc_streaming.py index 60c2510..36a9ed9 100644 --- a/runtime/triton_trtllm/client_grpc_streaming.py +++ b/runtime/triton_trtllm/client_grpc_streaming.py @@ -85,10 +85,11 @@ def tts( if i == 0: new_audio = audio[:-cross_fade_samples] else: - cross_faded_overlap = audio[:cross_fade_samples] * fade_in + audios[i - 1][-cross_fade_samples:] * fade_out + cross_faded_overlap = audios[i - 1][-cross_fade_samples:] * fade_out + audio[:cross_fade_samples] * fade_in new_audio = np.concatenate([new_audio, cross_faded_overlap, audio[cross_fade_samples:-cross_fade_samples]]) new_audio = np.concatenate([new_audio, audio[-cross_fade_samples:]]) + #new_audio = np.hstack(audios) sf.write(output_audio, new_audio, 16000, "PCM_16") print(f"save audio to {output_audio}")