diff --git a/cli/SparkTTS.py b/cli/SparkTTS.py index bc86ce3..9f26733 100644 --- a/cli/SparkTTS.py +++ b/cli/SparkTTS.py @@ -91,6 +91,7 @@ def process_prompt( "<|end_global_token|>", "<|start_semantic_token|>", semantic_tokens, + "<|end_semantic_token|>", ] else: inputs = [ @@ -233,4 +234,4 @@ def inference( pred_semantic_ids.to(self.device), ) - return wav \ No newline at end of file + return wav