Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added cli/__init__.py
Empty file.
3 changes: 2 additions & 1 deletion runtime/triton_trtllm/client_grpc_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ def tts(
if i == 0:
new_audio = audio[:-cross_fade_samples]
else:
cross_faded_overlap = audio[:cross_fade_samples] * fade_in + audios[i - 1][-cross_fade_samples:] * fade_out
cross_faded_overlap = audios[i - 1][-cross_fade_samples:] * fade_out + audio[:cross_fade_samples] * fade_in
new_audio = np.concatenate([new_audio, cross_faded_overlap, audio[cross_fade_samples:-cross_fade_samples]])
new_audio = np.concatenate([new_audio, audio[-cross_fade_samples:]])

#new_audio = np.hstack(audios)
sf.write(output_audio, new_audio, 16000, "PCM_16")
print(f"save audio to {output_audio}")

Expand Down
12 changes: 9 additions & 3 deletions sparktts/models/bicodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,16 @@ def _remove_weight_norm(m):
model = BiCodec.load_from_checkpoint(
model_dir="pretrained_models/SparkTTS-0.5B/BiCodec",
)
device = "cpu" if not torch.cuda.is_available() else "cuda"
print(model)
model_million_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"{model_million_params}M parameters")
model.to(device)

# Generate random inputs for testing
duration = 0.96
x = torch.randn(20, 1, int(duration * 16000))
feat = torch.randn(20, int(duration * 50), 1024)
x = torch.randn(20, 1, int(duration * 16000)).to(device)
feat = torch.randn(20, int(duration * 50), 1024).to(device)
inputs = {"feat": feat, "wav": x, "ref_wav": x}

# Forward pass
Expand All @@ -241,7 +246,8 @@ def _remove_weight_norm(m):
wav_recon = model.detokenize(semantic_tokens, global_tokens)

# Verify if the reconstruction matches
if torch.allclose(outputs["recons"].detach(), wav_recon):
if torch.allclose(outputs["recons"].detach(), wav_recon, rtol=1e-3, atol=1e-5):
#if torch.allclose(outputs["recons"].detach(), wav_recon):
print("Test successful")
else:
print("Test failed")
9 changes: 6 additions & 3 deletions sparktts/utils/token_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,11 @@ def emotion(emotion: str):

# test
if __name__ == "__main__":
import os
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
"/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer"
os.getenv("TOKENIZER_PATH", "/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer")
)

tasks = ["tts", "tts", "understand", "controllable_tts", "prompt_tts"]
Expand All @@ -183,5 +184,7 @@ def emotion(emotion: str):
inputs = [task, age, gender, mel, mel_level, loudness, loudness_level, emotion]
inputs = "".join(inputs)
ids = tokenizer.encode(inputs, add_special_tokens=False)
print(ids)
print("decode", tokenizer.decode(ids))
print("tokenized ids",ids)
tokens = tokenizer.decode(ids)
print("decoded tokens", tokens)
assert tokens == inputs