diff --git a/webui.py b/webui.py index 6a4a653..157869b 100644 --- a/webui.py +++ b/webui.py @@ -19,31 +19,28 @@ import logging import argparse import gradio as gr -import platform - from datetime import datetime from cli.SparkTTS import SparkTTS from sparktts.utils.token_parser import LEVELS_MAP_UI +from fastapi import FastAPI, File, UploadFile, Form +from fastapi.responses import FileResponse, JSONResponse +from fastapi.staticfiles import StaticFiles +import tempfile +import shutil +from pathlib import Path -def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", device=0): +def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", device="auto"): """Load the model once at the beginning.""" logging.info(f"Loading model from: {model_dir}") - - # Determine appropriate device based on platform and availability - if platform.system() == "Darwin": - # macOS with MPS support (Apple Silicon) - device = torch.device(f"mps:{device}") - logging.info(f"Using MPS device: {device}") - elif torch.cuda.is_available(): - # System with CUDA support - device = torch.device(f"cuda:{device}") - logging.info(f"Using CUDA device: {device}") - else: - # Fall back to CPU - device = torch.device("cpu") - logging.info("GPU acceleration not available, using CPU") - + + # Auto-detect device if set to "auto" + if device == "auto": + device = "cuda:0" if torch.cuda.is_available() else "cpu" + logging.info(f"Using device: {device}") + + # Handle device string input to support both CPU and GPU + device = torch.device(device) model = SparkTTS(model_dir, device) return model @@ -91,8 +88,8 @@ def run_tts( return save_path -def build_ui(model_dir, device=0): - +def build_ui(model_dir, device="cuda:0"): + # Initialize model model = initialize_model(model_dir, device=device) @@ -134,7 +131,101 @@ def voice_creation(text, gender, pitch, speed): ) return audio_output_path - with gr.Blocks() as demo: + # Create a FastAPI app + app = FastAPI() + + # Create results directory if it doesn't exist + results_dir = Path("example/results") + results_dir.mkdir(parents=True, exist_ok=True) + + # Mount static file directory for serving audio files + app.mount("/audio", StaticFiles(directory=str(results_dir)), name="audio") + + # API endpoint for voice cloning + @app.post("/api/voice-clone") + async def api_voice_clone( + text: str = Form(...), + prompt_text: str = Form(None), + prompt_audio: UploadFile = File(None) + ): + # Save uploaded audio to a temp file if provided + prompt_speech = None + if prompt_audio: + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') + try: + shutil.copyfileobj(prompt_audio.file, temp_file) + temp_file.close() + prompt_speech = temp_file.name + finally: + prompt_audio.file.close() + + # Run TTS + prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text + audio_output_path = run_tts( + text, + model, + prompt_text=prompt_text_clean, + prompt_speech=prompt_speech + ) + + # Clean up temp file + if prompt_speech: + try: + os.unlink(prompt_speech) + except: + pass + + # Return URL to the audio file + filename = os.path.basename(audio_output_path) + audio_url = f"/audio/{filename}" + return JSONResponse({ + "audio_url": audio_url, + "filename": filename, + "text": text + }) + + # API endpoint for voice creation + @app.post("/api/voice-creation") + async def api_voice_creation( + text: str = Form(...), + gender: str = Form("male"), + pitch: str = Form(3), + speed: str = Form(3) + ): + pitch_val = LEVELS_MAP_UI[int(pitch)] + speed_val = LEVELS_MAP_UI[int(speed)] + + audio_output_path = run_tts( + text, + model, + gender=gender, + pitch=pitch_val, + speed=speed_val + ) + + # Return URL to the audio file + filename = os.path.basename(audio_output_path) + audio_url = f"/audio/{filename}" + return JSONResponse({ + "audio_url": audio_url, + "filename": filename, + "text": text, + "gender": gender, + "pitch": pitch, + "speed": speed + }) + + # Create a direct route to get audio by filename + @app.get("/audio/{filename}") + async def get_audio(filename: str): + audio_path = os.path.join("example/results", filename) + if os.path.exists(audio_path): + return FileResponse(audio_path, media_type="audio/wav") + return JSONResponse({"error": "File not found"}, status_code=404) + + # Create Gradio interface + demo = gr.Blocks() + with demo: # Use HTML for centered title gr.HTML('

Spark-TTS by SparkAudio

') with gr.Tabs(): @@ -218,7 +309,8 @@ def voice_creation(text, gender, pitch, speed): outputs=[audio_output], ) - return demo + # Return both the FastAPI app and Gradio interface + return app, demo def parse_arguments(): @@ -234,9 +326,9 @@ def parse_arguments(): ) parser.add_argument( "--device", - type=int, - default=0, - help="ID of the GPU device to use (e.g., 0 for cuda:0)." + type=str, + default="auto", + help="Device to run inference on: 'auto' (default, uses GPU if available), 'cpu', or 'cuda:x' where x is the GPU ID." ) parser.add_argument( "--server_name", @@ -256,14 +348,13 @@ def parse_arguments(): # Parse command-line arguments args = parse_arguments() - # Build the Gradio demo by specifying the model directory and GPU device - demo = build_ui( + # Build the Gradio demo and FastAPI app + app, demo = build_ui( model_dir=args.model_dir, device=args.device ) - # Launch Gradio with the specified server name and port - demo.launch( - server_name=args.server_name, - server_port=args.server_port - ) \ No newline at end of file + # Launch Gradio with FastAPI backend + gr.mount_gradio_app(app, demo, path="/") + import uvicorn + uvicorn.run(app, host=args.server_name, port=args.server_port) \ No newline at end of file