Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 123 additions & 32 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,28 @@
import logging
import argparse
import gradio as gr
import platform

from datetime import datetime
from cli.SparkTTS import SparkTTS
from sparktts.utils.token_parser import LEVELS_MAP_UI
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
import tempfile
import shutil
from pathlib import Path


def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", device=0):
def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", device="auto"):
"""Load the model once at the beginning."""
logging.info(f"Loading model from: {model_dir}")

# Determine appropriate device based on platform and availability
if platform.system() == "Darwin":
# macOS with MPS support (Apple Silicon)
device = torch.device(f"mps:{device}")
logging.info(f"Using MPS device: {device}")
elif torch.cuda.is_available():
# System with CUDA support
device = torch.device(f"cuda:{device}")
logging.info(f"Using CUDA device: {device}")
else:
# Fall back to CPU
device = torch.device("cpu")
logging.info("GPU acceleration not available, using CPU")


# Auto-detect device if set to "auto"
if device == "auto":
device = "cuda:0" if torch.cuda.is_available() else "cpu"
logging.info(f"Using device: {device}")

# Handle device string input to support both CPU and GPU
device = torch.device(device)
model = SparkTTS(model_dir, device)
return model

Expand Down Expand Up @@ -91,8 +88,8 @@ def run_tts(
return save_path


def build_ui(model_dir, device=0):

def build_ui(model_dir, device="cuda:0"):
# Initialize model
model = initialize_model(model_dir, device=device)

Expand Down Expand Up @@ -134,7 +131,101 @@ def voice_creation(text, gender, pitch, speed):
)
return audio_output_path

with gr.Blocks() as demo:
# Create a FastAPI app
app = FastAPI()

# Create results directory if it doesn't exist
results_dir = Path("example/results")
results_dir.mkdir(parents=True, exist_ok=True)

# Mount static file directory for serving audio files
app.mount("/audio", StaticFiles(directory=str(results_dir)), name="audio")

# API endpoint for voice cloning
@app.post("/api/voice-clone")
async def api_voice_clone(
text: str = Form(...),
prompt_text: str = Form(None),
prompt_audio: UploadFile = File(None)
):
# Save uploaded audio to a temp file if provided
prompt_speech = None
if prompt_audio:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
try:
shutil.copyfileobj(prompt_audio.file, temp_file)
temp_file.close()
prompt_speech = temp_file.name
finally:
prompt_audio.file.close()

# Run TTS
prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text
audio_output_path = run_tts(
text,
model,
prompt_text=prompt_text_clean,
prompt_speech=prompt_speech
)

# Clean up temp file
if prompt_speech:
try:
os.unlink(prompt_speech)
except:
pass

# Return URL to the audio file
filename = os.path.basename(audio_output_path)
audio_url = f"/audio/{filename}"
return JSONResponse({
"audio_url": audio_url,
"filename": filename,
"text": text
})

# API endpoint for voice creation
@app.post("/api/voice-creation")
async def api_voice_creation(
text: str = Form(...),
gender: str = Form("male"),
pitch: str = Form(3),
speed: str = Form(3)
):
pitch_val = LEVELS_MAP_UI[int(pitch)]
speed_val = LEVELS_MAP_UI[int(speed)]

audio_output_path = run_tts(
text,
model,
gender=gender,
pitch=pitch_val,
speed=speed_val
)

# Return URL to the audio file
filename = os.path.basename(audio_output_path)
audio_url = f"/audio/{filename}"
return JSONResponse({
"audio_url": audio_url,
"filename": filename,
"text": text,
"gender": gender,
"pitch": pitch,
"speed": speed
})

# Create a direct route to get audio by filename
@app.get("/audio/{filename}")
async def get_audio(filename: str):
audio_path = os.path.join("example/results", filename)
if os.path.exists(audio_path):
return FileResponse(audio_path, media_type="audio/wav")
return JSONResponse({"error": "File not found"}, status_code=404)

# Create Gradio interface
demo = gr.Blocks()
with demo:
# Use HTML for centered title
gr.HTML('<h1 style="text-align: center;">Spark-TTS by SparkAudio</h1>')
with gr.Tabs():
Expand Down Expand Up @@ -218,7 +309,8 @@ def voice_creation(text, gender, pitch, speed):
outputs=[audio_output],
)

return demo
# Return both the FastAPI app and Gradio interface
return app, demo


def parse_arguments():
Expand All @@ -234,9 +326,9 @@ def parse_arguments():
)
parser.add_argument(
"--device",
type=int,
default=0,
help="ID of the GPU device to use (e.g., 0 for cuda:0)."
type=str,
default="auto",
help="Device to run inference on: 'auto' (default, uses GPU if available), 'cpu', or 'cuda:x' where x is the GPU ID."
)
parser.add_argument(
"--server_name",
Expand All @@ -256,14 +348,13 @@ def parse_arguments():
# Parse command-line arguments
args = parse_arguments()

# Build the Gradio demo by specifying the model directory and GPU device
demo = build_ui(
# Build the Gradio demo and FastAPI app
app, demo = build_ui(
model_dir=args.model_dir,
device=args.device
)

# Launch Gradio with the specified server name and port
demo.launch(
server_name=args.server_name,
server_port=args.server_port
)
# Launch Gradio with FastAPI backend
gr.mount_gradio_app(app, demo, path="/")
import uvicorn
uvicorn.run(app, host=args.server_name, port=args.server_port)