diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..08c34e3 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,75 @@ +# Git +.git +.github +.gitignore +.gitattributes + +# Docker +.dockerignore +Dockerfile +docker-compose.yml +docker_builder.sh + +# CI/CD and development files +.circleci/ +.travis.yml +.env +*.md +!README.md +!LICENSE +docs/ +tests/ + +# Virtual environments +venv/ +env/ +.venv/ +.env/ +.python-version + +# Python cache +__pycache__/ +*.py[cod] +*$py.class +*.so +.pytest_cache/ +.coverage +htmlcov/ +.tox/ +.nox/ +.hypothesis/ +.eggs/ +*.egg-info/ +*.egg + +# IDE specific files +.idea/ +.vscode/ +*.swp +*.swo +.DS_Store + +# Temporary files +temp/ +tmp/ +*.tmp +*.log + +# API outputs (these should be created at runtime) +api/outputs/ + +# Local model directories (only include if specified) +# Uncomment if you never want to include models +# pretrained_models/ + +# Jupyter Notebooks +.ipynb_checkpoints +*.ipynb + +# Large unnecessary files +*.wav +*.wav.zip +*.mp3 +*.mp4 +*.tar.gz +output/ \ No newline at end of file diff --git a/.gitignore b/.gitignore index 146a43c..4d5f950 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,7 @@ cython_debug/ # PyPI configuration file .pypirc + + +api/.env +api/outputs/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..2128a7c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,119 @@ +# Usage Instructions +# 1. Recommended way to build all images at once: +# ./docker_builder.sh +# This creates: spark-tts:latest-lite, spark-tts:latest (alias of latest-lite), and spark-tts:latest-full +# +# 2. Manual build without models: +# docker build -t spark-tts:latest-lite . +# docker tag spark-tts:latest-lite spark-tts:latest +# +# 3. Manual build with models: +# docker build --build-arg INCLUDE_MODELS=true -t spark-tts:latest-full . +# +# 4. Run container without models (needs to mount models): +# docker run -p 7860:7860 --gpus all -v /local/path/pretrained_models:/app/pretrained_models spark-tts:latest-lite +# +# 5. Run container with models: +# docker run -p 7860:7860 --gpus all spark-tts:latest-full +# +# 6. Run with API (default): +# docker run -p 7860:7860 --gpus all -e SERVICE_TYPE=api spark-tts:latest-full +# +# 7. Run with WebUI: +# docker run -p 7860:7860 --gpus all -e SERVICE_TYPE=webui spark-tts:latest-full +# +# 8. Use docker-compose for more advanced configurations: +# docker-compose up api # Run API service +# docker-compose up webui # Run WebUI service +# +# Note: +# - NVIDIA Container Toolkit must be installed on the host to support GPU +# - If using an image without models, you can provide models in the following ways: +# a) Mount the model directory from the host: docker run -p 7860:7860 --gpus all -v /local/path/pretrained_models:/app/pretrained_models spark-tts:latest-lite +# b) Download models inside the container: python -c "from huggingface_hub import snapshot_download; snapshot_download('SparkAudio/Spark-TTS-0.5B', local_dir='pretrained_models/Spark-TTS-0.5B')" + +FROM python:3.12-slim + +# Build argument to determine whether to include models +ARG INCLUDE_MODELS=false + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + git \ + git-lfs \ + ffmpeg \ + libsndfile1 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# Initialize git-lfs +RUN git lfs install + +# Copy dependency file +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Create model directory +RUN mkdir -p pretrained_models + +# Copy project files (layered copying to optimize caching) +COPY cli/ ./cli/ +COPY sparktts/ ./sparktts/ +COPY src/ ./src/ +COPY example/ ./example/ +COPY api/ ./api/ +COPY webui.py . +COPY LICENSE README.md ./ + +# Display build argument value +RUN echo "INCLUDE_MODELS=${INCLUDE_MODELS}" + +# Copy context +COPY . /tmp/context/ + +# Check if model directory exists +RUN if [ -d "/tmp/context/pretrained_models" ]; then \ + echo "Found pretrained_models directory"; \ +else \ + echo "pretrained_models directory not found"; \ +fi + +# Decide whether to copy model files based on INCLUDE_MODELS parameter +RUN if [ "${INCLUDE_MODELS}" = "true" ]; then \ + echo "Including models in the image"; \ + if [ -d "/tmp/context/pretrained_models" ]; then \ + cp -r /tmp/context/pretrained_models/* /app/pretrained_models/ || echo "No model files to copy"; \ + else \ + echo "Warning: pretrained_models directory not found in build context"; \ + fi; \ +else \ + echo "Models will need to be mounted at runtime"; \ +fi + +# Clean up temporary directory +RUN rm -rf /tmp/context + +# Create outputs directory for API +RUN mkdir -p /app/api/outputs && chmod 777 /app/api/outputs + +# Set environment variables +ENV PYTHONPATH=/app +ENV SERVICE_TYPE=api + +# Expose single port for both WebUI and API +EXPOSE 7860 + +# Make run_api.sh executable +RUN chmod +x /app/api/run_api.sh + +# Set container startup command +CMD if [ "$SERVICE_TYPE" = "webui" ]; then \ + python webui.py --device 0; \ +else \ + ./api/run_api.sh; \ +fi \ No newline at end of file diff --git a/README.md b/README.md index 74f792d..b39f7ba 100644 --- a/README.md +++ b/README.md @@ -142,6 +142,69 @@ For additional CLI and Web UI methods, including alternative implementations and - [CLI and UI by AcTePuKc](https://github.com/SparkAudio/Spark-TTS/issues/10) +**API Service** + +Spark-TTS provides a FastAPI-based web API service for seamless integration with other applications. + +1. **Running the API service in conda environment**: + ```sh + # Make sure you're in the Spark-TTS conda environment + conda activate sparktts + + # Execute from the project root directory + ./api/run_api.sh + ``` + The API will be available at http://localhost:7860 by default. + +2. **Docker support**: + You can build and run the Spark-TTS API using the provided build script: + ```sh + # Build Docker images (both full and lite versions) + chmod +x docker_builder.sh + ./docker_builder.sh + + # Run the API service in the background + docker compose up -d api + # OR for the lite version with mounted models + docker compose up -d api-lite + + # Run the WebUI service in the background + docker compose up -d webui + # OR for the lite version with mounted models + docker compose up -d webui-lite + + # To check running containers + docker compose ps + + # To stop services + docker compose down + ``` + + > **Note**: If you encounter YAML errors like `mapping key "<<" already defined`, it might be due to compatibility issues with YAML merge keys in your Docker Compose version. You can either: + > 1. Update Docker to the latest version + > 2. Modify the docker-compose.yml file to use a different syntax for environment variable inheritance + > 3. Use the Docker CLI directly: `docker run -p 7860:7860 --gpus all spark-tts:latest-full` + + For more customization options, see the environment variables in the docker-compose.yml file. + +3. **Client Example**: + The repository includes an example client script that demonstrates how to interact with the API: + ```sh + # Note: The example client requires librosa, which is not in requirements.txt + pip install librosa + + # Basic usage + python api/example_client.py --text "Text to synthesize" + + # Voice cloning with reference audio + python api/example_client.py --text "This is voice cloning" --prompt_audio example/prompt_audio.wav + + # Voice creation with parameters + python api/example_client.py --text "This is voice creation" --gender female --pitch high --speed moderate + ``` + +For more detailed information about the API service, including all available endpoints and parameters, please refer to the [API README](api/README.md). + ## Runtime diff --git a/api/.env.example b/api/.env.example new file mode 100644 index 0000000..39047bc --- /dev/null +++ b/api/.env.example @@ -0,0 +1,31 @@ +# Spark-TTS API Environment Variable Configuration Example +# Copy this file to .env and modify the configuration as needed + +# === Service Configuration === +SPARK_TTS_API_PORT=7860 +SPARK_TTS_API_HOST=0.0.0.0 +SPARK_TTS_API_DEBUG=False + +# === Security Configuration === +SPARK_TTS_API_KEY_NAME=X-SPARKTTS-API-KEY +SPARK_TTS_API_KEY= + +# === TTS Model Configuration === +SPARK_TTS_MODEL_DIR=pretrained_models/Spark-TTS-0.5B +# Device configuration: +# - cpu: Use CPU for inference +# - gpu: Use default GPU for inference +# - gpu:N: Use specific GPU (N is device ID) for inference +SPARK_TTS_DEVICE=gpu:0 + +# === Default Prompt Configuration === +SPARK_TTS_DEFAULT_PROMPT_TEXT=吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。 +SPARK_TTS_DEFAULT_PROMPT_SPEECH=example/prompt_audio.wav + +# === Output Configuration === +SPARK_TTS_OUTPUT_DIR=api/outputs +SPARK_TTS_OUTPUT_URL_PREFIX=/outputs + +# === Cleanup Configuration === +SPARK_TTS_CLEANUP_INTERVAL=3600 +SPARK_TTS_FILE_EXPIRY_TIME=86400 diff --git a/api/README.md b/api/README.md new file mode 100644 index 0000000..9873378 --- /dev/null +++ b/api/README.md @@ -0,0 +1,223 @@ +# Spark-TTS API + +This is a Web API interface based on FastAPI for accessing the functionality of the Spark-TTS speech synthesis model. Compared to the existing WebUI interface, this API provides more flexible feature selection and supports both voice cloning and voice creation features. + +## Features + +- Supports basic text-to-speech synthesis +- Supports voice cloning based on reference audio +- Supports voice creation based on parameter control +- Supports simultaneous use of voice cloning and voice creation features +- Supports multiple audio input methods: Base64, URL, default audio +- Supports multiple audio output methods: Base64, URL access +- Supports API key authentication (optional) +- Automatically cleans up expired audio files +- Flexible configuration, supports settings through environment variables or .env file + +## Getting Started + +### Configuring Environment Variables + +The API supports configuring environment variables in two ways: + +1. **Using a .env file (recommended)**: + ```bash + # Copy the example configuration file + cp api/.env.example api/.env + + # Edit the configuration file + nano api/.env + ``` + +2. **Setting environment variables directly**: + ```bash + export SPARK_TTS_API_PORT=8080 + export SPARK_TTS_DEVICE=1 + # Other environment variables... + ``` + +### Starting the API Service + +Use the provided script to start the API service: + +```bash +chmod +x api/run_api.sh +./api/run_api.sh +``` + +If you used a .env file, the script will automatically load the configurations from it. You can also override the settings in the .env file with command line arguments: + +```bash +./api/run_api.sh --port 8080 --device 1 --model_dir /path/to/model --debug +``` + +### Custom Startup Parameters + +You can customize the behavior of the service with the following parameters: + +```bash +./api/run_api.sh --port 8080 --device 1 --model_dir /path/to/model --debug --env /path/to/env/file +``` + +Parameter descriptions: +- `--port`: Port for the service to listen on (default: 7860) +- `--device`: GPU device ID to use (default: 0) +- `--model_dir`: Model directory path (default: pretrained_models/Spark-TTS-0.5B) +- `--debug`: Enable debug mode +- `--env`: Specify a custom .env file path (default: api/.env) + +## API Endpoints + +### 1. Text to Speech (POST /tts) + +Convert text to speech, supporting voice cloning and voice creation. + +**Request Parameters**: + +```json +{ + "text": "Text to synthesize", + + // Voice cloning parameters (optional) + "prompt_text": "Text content of the reference audio", + "prompt_audio_base64": "Base64 encoded reference audio", + "prompt_audio_url": "URL of the reference audio", + + // Voice creation parameters (optional) + "gender": "male or female", + "pitch": "very_low, low, moderate, high, or very_high", + "speed": "very_low, low, moderate, high, or very_high", + + // Other parameters + "temperature": 0.8, + "top_k": 50, + "top_p": 0.95, + "return_audio_data": true +} +``` + +**Note**: `prompt_audio_base64` and `prompt_audio_url` are mutually exclusive parameters and cannot be provided simultaneously. + +**Response**: + +```json +{ + "text": "Input text", + "audio_url": "/outputs/file_id.wav", + "audio_base64": "Base64 encoded audio (when return_audio_data=true)", + "duration": 3.5, + "sample_rate": 16000, + "file_id": "unique_file_id", + "created_at": "2023-05-20T12:34:56.789" +} +``` + +### 2. Get Audio File (GET /outputs/{file_id}) + +Retrieve the generated audio file using the file ID. + +**Request**: + +``` +GET /outputs/{file_id} +``` + +**Response**: +Audio file (WAV format) + +### 3. Health Check (GET /health) + +Check if the API service is running normally. + +**Request**: + +``` +GET /health +``` + +**Response**: + +```json +{ + "status": "ok", + "timestamp": "2023-05-20T12:34:56.789", + "device": { + "configured": "gpu:0", + "actual": "cuda:0" + }, + "model_loaded": true +} +``` + +## Example Client + +An example client script is provided to demonstrate how to use the API: + +```bash +# Note: The example client requires librosa, which is not in requirements.txt +# Install it before running the client: +pip install librosa + +# Basic usage +python api/example_client.py --text "This is a test" +``` + +You can use different parameter combinations for different features: + +- **For voice cloning** (using reference audio): +```bash +python api/example_client.py --text "This is an example of voice cloning" --prompt_audio example/prompt_audio.wav +``` + +- **For voice creation** (using control parameters): +```bash +python api/example_client.py --text "This is an example of voice creation" --gender female --pitch high --speed moderate +``` + +- **Combined features** (both voice cloning and creation): +```bash +python api/example_client.py --text "This is an example of combined features" --prompt_audio example/prompt_audio.wav --gender male --pitch low +``` + +- **Using with API key**: +```bash +python api/example_client.py --text "This is a test" --api_key YOUR_API_KEY +``` + +## Environment Variable Configuration + +The API service can be configured using the following environment variables: + +| Environment Variable | Description | Default Value | +|----------|------|--------| +| SPARK_TTS_API_PORT | API service port | 7860 | +| SPARK_TTS_API_HOST | API service host | 0.0.0.0 | +| SPARK_TTS_API_DEBUG | Whether to enable debug mode | False | +| SPARK_TTS_API_KEY | API key (authentication not enabled if not set) | None | +| SPARK_TTS_API_KEY_NAME | API key request header name | X-SPARKTTS-API-KEY | +| SPARK_TTS_MODEL_DIR | Model directory path | pretrained_models/Spark-TTS-0.5B | +| SPARK_TTS_DEVICE | GPU device ID | gpu:0 | +| SPARK_TTS_DEFAULT_PROMPT_TEXT | Default reference text | "吃燕窝就选燕之屋..." | +| SPARK_TTS_DEFAULT_PROMPT_SPEECH | Default reference audio path | example/prompt_audio.wav | +| SPARK_TTS_OUTPUT_DIR | Output audio file directory | api/outputs | +| SPARK_TTS_OUTPUT_URL_PREFIX | Output audio URL prefix | /outputs | +| SPARK_TTS_CLEANUP_INTERVAL | Cleanup task interval (seconds) | 3600 | +| SPARK_TTS_FILE_EXPIRY_TIME | File expiration time (seconds) | 86400 | + +## Docker Support + +The API design considers operation in a Docker environment and can be flexibly configured through environment variables or by mounting the api/.env file. Dedicated Docker support will be provided in the future. + +Example Docker mount command: +```bash +docker run -p 7860:7860 -v /local/path/api/.env:/app/api/.env -v /local/path/pretrained_models:/app/pretrained_models spark-tts:latest +``` + +## Notes + +- Please ensure that the model files have been correctly downloaded and placed in the specified directory. +- If API key authentication is enabled, all requests must include the correct API key header. +- Generated audio files will be automatically deleted after the set expiration time, default is 24 hours. +- Static file service has been configured, and generated audio files can be accessed directly via URL. +- **The server only accepts audio files in WAV format**. If you need to use other formats (such as MP3), please convert them to WAV format on the client side before uploading. The example client includes automatic conversion functionality. +- When using the `prompt_audio_url` parameter to point to an audio file on the same server (such as `http://localhost:7860/outputs/xxx.wav`), the server will read the local file directly rather than downloading it via HTTP to avoid circular reference issues. \ No newline at end of file diff --git a/api/example_client.py b/api/example_client.py new file mode 100644 index 0000000..26c17c9 --- /dev/null +++ b/api/example_client.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Spark-TTS API Client Example + +This script demonstrates how to use the Spark-TTS API service through HTTP requests. +All features are integrated, and you can set the appropriate parameter combinations as needed. + +Basic usage: + python example_client.py --text "Text to synthesize" + +Using reference audio for voice cloning: + python example_client.py --text "Text to synthesize" --prompt_audio example/prompt_audio.wav + +Using reference audio and text for more accurate voice cloning: + python example_client.py --text "Text to synthesize" --prompt_audio example/prompt_audio.wav --prompt_text "Text content of the reference audio" + +Using voice parameters for control: + python example_client.py --text "Text to synthesize" --gender female --pitch high --speed moderate + +Using both reference audio and voice parameters: + python example_client.py --text "Text to synthesize" --prompt_audio example/prompt_audio.wav --gender male + +Using API key: + python example_client.py --text "Text to synthesize" --api_key YOUR_API_KEY + +Output configuration: + python example_client.py --text "Text to synthesize" --output_dir custom_outputs + +Note: The server only accepts audio files in WAV format. If other formats are provided (such as MP3), the client will automatically convert them to WAV format. +""" + +import os +import sys +import base64 +import argparse +import requests +import tempfile +from urllib.parse import urljoin +import json + + +def convert_audio_to_wav(input_file, output_file=None): + """Convert audio file to WAV format""" + # If output file is not specified, create a temporary file + if output_file is None: + fd, output_file = tempfile.mkstemp(suffix=".wav") + os.close(fd) + + # Check if input file exists + if not os.path.exists(input_file): + print(f"Error: Input file does not exist: {input_file}") + return None + + # Check input file extension + _, ext = os.path.splitext(input_file.lower()) + + # If already in WAV format, copy directly + if ext == '.wav': + try: + # Verify if it's a valid WAV file + import soundfile as sf + audio_data, sample_rate = sf.read(input_file) + print(f"Input file is already in WAV format, sample rate: {sample_rate}Hz") + + # If output file is different from input file, copy the file + if input_file != output_file: + import shutil + shutil.copy(input_file, output_file) + + return output_file + except Exception as e: + print(f"Warning: Input file has WAV extension but format is invalid: {str(e)}") + # Continue trying to convert + + # Try different methods for conversion + + # Method 1: Try using ffmpeg (if installed on the system) + try: + import subprocess + print(f"Attempting to convert using ffmpeg: {input_file}") + result = subprocess.run( + ["ffmpeg", "-i", input_file, "-ar", "16000", "-ac", "1", "-y", output_file], + capture_output=True, + text=True + ) + if result.returncode == 0: + print(f"Conversion with ffmpeg successful: {output_file}") + return output_file + else: + print(f"ffmpeg conversion failed: {result.stderr}") + except Exception as e: + print(f"Conversion with ffmpeg failed: {str(e)}") + + # Method 2: Using librosa (supports multiple formats) - as fallback + try: + import librosa + import soundfile as sf + print(f"Loading audio using librosa: {input_file}") + audio_data, sample_rate = librosa.load(input_file, sr=None) + print(f"Converting to WAV format, sample rate: {sample_rate}Hz") + sf.write(output_file, audio_data, sample_rate) + print(f"Audio conversion successful: {output_file}") + return output_file + except ImportError: + print("Warning: librosa library not installed, cannot use this method for conversion") + print("Tip: Install librosa library to support more audio formats: pip install librosa") + except Exception as e: + print(f"Conversion with librosa failed: {str(e)}") + + print("Error: All conversion methods failed, unable to convert audio to WAV format") + print("Please ensure ffmpeg is installed or install librosa library (pip install librosa)") + return None + + +def read_audio_file(file_path): + """Read audio file and convert to Base64 encoding""" + # First ensure the file is in WAV format + wav_file = convert_audio_to_wav(file_path) + if not wav_file: + raise ValueError(f"Unable to convert audio file to WAV format: {file_path}") + + # Read and encode WAV file + with open(wav_file, "rb") as f: + audio_data = f.read() + + # If it's a temporary file, delete it + if wav_file != file_path: + os.remove(wav_file) + + return base64.b64encode(audio_data).decode("utf-8") + + +def save_audio_file(base64_data, output_path): + """Save Base64 encoded audio data as a file""" + audio_data = base64.b64decode(base64_data) + with open(output_path, "wb") as f: + f.write(audio_data) + + +def tts_request( + api_url, + text, + prompt_text=None, + prompt_audio_path=None, + prompt_audio_url=None, + gender=None, + pitch=None, + speed=None, + return_audio_data=True, + api_key=None, + output_dir="example_client_outputs", + timeout=60, # Add timeout parameter, default 60 seconds +): + """ + Send TTS request to API + + Args: + api_url: Base URL of the API service + text: Text to synthesize + prompt_text: Text content of the reference audio + prompt_audio_path: Path to reference audio file + prompt_audio_url: URL of reference audio + gender: Voice gender + pitch: Pitch + speed: Speech rate + return_audio_data: Whether to return Base64 encoded audio data + api_key: API key + output_dir: Output directory for client to locally save audio + timeout: Request timeout in seconds + + Returns: + Response dictionary or None + """ + # Prepare URL + endpoint = urljoin(api_url, "tts") + + # Check if URL pointing to the same service is used + if prompt_audio_url: + api_base = api_url.rstrip('/') + if prompt_audio_url.startswith(api_base): + print(f"Note: You are using a URL pointing to the same API service as reference audio: {prompt_audio_url}") + print("The server will read the local file directly instead of downloading via HTTP") + + # Prepare request data + payload = {"text": text, "return_audio_data": return_audio_data} + + # Add voice cloning parameters + if prompt_text: + payload["prompt_text"] = prompt_text + + if prompt_audio_path: + try: + # Read and encode audio file + print(f"Reading audio file: {prompt_audio_path}") + payload["prompt_audio_base64"] = read_audio_file(prompt_audio_path) + print(f"Audio file encoding complete, size approximately {len(payload['prompt_audio_base64'])//1024} KB") + except Exception as e: + print(f"Failed to read audio file: {str(e)}") + print("Please ensure the audio file exists and is in the correct format, or install librosa library to support more formats") + return None + elif prompt_audio_url: + payload["prompt_audio_url"] = prompt_audio_url + + # Add voice creation parameters + if gender: + payload["gender"] = gender + if pitch: + payload["pitch"] = pitch + if speed: + payload["speed"] = speed + + # Prepare request headers + headers = {"Content-Type": "application/json"} + if api_key: + headers["X-SPARKTTS-API-KEY"] = api_key + + # Send request + print(f"Sending request to {endpoint}") + print("Request processing, this may take some time...") + + try: + response = requests.post(endpoint, json=payload, headers=headers, timeout=timeout) + + # Check response + if response.status_code != 200: + print(f"Error: {response.status_code} - {response.text}") + + # Special handling for audio format errors + if response.status_code == 400 and "WAV format" in response.text: + print("\nAudio format error: The server only accepts audio files in WAV format") + print("The client attempted to automatically convert the audio format, but it may have failed") + print("Suggestions:") + print("1. Install librosa library: pip install librosa") + print("2. Or manually convert the audio to WAV format before uploading") + print("3. Or use ffmpeg to manually convert: ffmpeg -i input.mp3 -ar 16000 -ac 1 output.wav") + + return None + + # Parse response + result = response.json() + print("Request successful!") + + # Note: This example client saves audio files locally, which is different from the files saved by the server in the api/outputs directory: + # - Server-side: Saves audio in the api/outputs directory, provides access via API URL + # - Client-side: Saves a local copy of the audio in the example_client_outputs directory for local use + + # If Base64 audio was returned, save to file + if return_audio_data and result.get("audio_base64"): + # Ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + + # Save audio file + output_path = os.path.join(output_dir, f"{result.get('file_id')}.wav") + save_audio_file(result["audio_base64"], output_path) + print(f"Client locally saved audio to: {output_path} (Note: The server also saves a copy in the api/outputs directory)") + result["local_path"] = output_path + elif result.get("audio_url"): + print(f"Audio URL: {api_url.rstrip('/')}{result['audio_url']}") + + return result + except requests.exceptions.Timeout: + print(f"Request timeout, server processing time exceeded {timeout} seconds") + print("This may be due to processing large audio files or high server load") + print("You can try the following:") + print("1. Use a smaller audio file") + print("2. Increase timeout: --timeout 120") + print("3. Check server logs for detailed errors") + return None + except requests.exceptions.ConnectionError: + print("Connection error, unable to connect to server") + print("Please ensure the API service is running and the port settings are correct") + return None + except Exception as e: + print(f"Error sending request: {str(e)}") + return None + + +def main(): + parser = argparse.ArgumentParser(description="Spark-TTS API Client Example") + parser.add_argument("--api_url", default="http://localhost:7860/", help="Base URL of the API service") + parser.add_argument("--api_key", default=None, help="API key") + parser.add_argument("--text", default="Welcome to the Spark-TTS speech synthesis system.", help="Text to synthesize") + parser.add_argument("--output_dir", default="example_client_outputs", help="Directory for client to locally save audio, separate from server-side storage directory") + + # Voice cloning parameters + parser.add_argument("--prompt_text", default=None, help="Text content of the reference audio") + parser.add_argument("--prompt_audio", default=None, help="Path to reference audio file") + parser.add_argument("--prompt_audio_url", default=None, help="URL of reference audio") + + # Voice creation parameters + parser.add_argument("--gender", choices=["male", "female"], default=None, help="Voice gender") + parser.add_argument("--pitch", choices=["very_low", "low", "moderate", "high", "very_high"], default=None, help="Pitch") + parser.add_argument("--speed", choices=["very_low", "low", "moderate", "high", "very_high"], default=None, help="Speech rate") + + # Other parameters + parser.add_argument("--timeout", type=int, default=120, help="Request timeout in seconds") + + args = parser.parse_args() + + # Print audio save location note + print(f"\nNote: The client will save a local copy of the audio file in the {args.output_dir} directory") + print(f"At the same time, the server will also save the same audio file in the api/outputs directory\n") + + # Prepare feature description + features = [] + if args.prompt_audio or args.prompt_audio_url: + features.append("Voice Cloning") + if args.gender or args.pitch or args.speed: + features.append("Voice Creation") + + # Indicate which features are being used + if features: + print(f"Using features: {', '.join(features)}") + else: + print("Using default settings for TTS") + + # Make TTS request + result = tts_request( + api_url=args.api_url, + text=args.text, + prompt_text=args.prompt_text, + prompt_audio_path=args.prompt_audio, + prompt_audio_url=args.prompt_audio_url, + gender=args.gender, + pitch=args.pitch, + speed=args.speed, + api_key=args.api_key, + output_dir=args.output_dir, + timeout=args.timeout, + ) + + # Print result summary + if result: + print("\nResult Summary:") + print(f"Text: {result['text']}") + print(f"Duration: {result['duration']:.2f} seconds") + print(f"Sample Rate: {result['sample_rate']} Hz") + print(f"File ID: {result['file_id']}") + print(f"Created At: {result['created_at']}") + if "local_path" in result: + print(f"Local File Path: {result['local_path']}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/api/main.py b/api/main.py new file mode 100644 index 0000000..5f46ced --- /dev/null +++ b/api/main.py @@ -0,0 +1,828 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Spark-TTS Web API +A FastAPI-based Spark-TTS Web API interface supporting speech synthesis, voice cloning, and voice creation features + +Latest updates: +- Enhanced robustness of audio data processing, supporting multiple data types +- Improved error handling, providing more detailed log information +- Fixed data type mismatch issues +- Server only accepts audio in WAV format, other formats need to be converted on the client side +""" + +import os +import base64 +import shutil +import logging +import asyncio +import tempfile +import uuid +from datetime import datetime, timedelta +from typing import Optional, List, Dict, Any, Union, TypeVar +from pathlib import Path +from functools import lru_cache + +import torch +import uvicorn +import requests +import soundfile as sf +from fastapi import FastAPI, Depends, HTTPException, Security, BackgroundTasks, UploadFile, File, Form, Query +from fastapi.security.api_key import APIKeyHeader, APIKey +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel, Field, HttpUrl, validator +import numpy as np +from dotenv import load_dotenv + +from cli.SparkTTS import SparkTTS + +# Load .env file +project_root = Path(__file__).parent +env_file = project_root / '.env' +if env_file.exists(): + load_dotenv(env_file) + logging.info(f"Environment variables file loaded: {env_file}") +else: + logging.info(f"Environment variables file not found: {env_file}, using environment variables or default configuration") + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler() # Ensure logs go to console + ] +) +logger = logging.getLogger(__name__) + +# Force logging level to INFO for this module +logger.setLevel(logging.INFO) + +# Add a direct print of key configurations (will show even if logging is filtered) +def print_config_info(): + """Print configuration information directly to stdout, bypassing logging system""" + settings = get_settings() + print("\n" + "="*80) + print("SPARK-TTS CONFIGURATION SUMMARY") + print("="*80) + + # Print environment variables + print("\nENVIRONMENT VARIABLES:") + print(f"SPARK_TTS_DEFAULT_PROMPT_SPEECH = {os.getenv('SPARK_TTS_DEFAULT_PROMPT_SPEECH', 'not set')}") + print(f"SPARK_TTS_MODEL_DIR = {os.getenv('SPARK_TTS_MODEL_DIR', 'not set')}") + print(f"SPARK_TTS_OUTPUT_DIR = {os.getenv('SPARK_TTS_OUTPUT_DIR', 'not set')}") + print(f"SPARK_TTS_DEVICE = {os.getenv('SPARK_TTS_DEVICE', 'not set')}") + + # Print calculated paths + print("\nCALCULATED PATHS:") + # Project root + print(f"PROJECT_ROOT = {settings.PROJECT_ROOT}") + print(f"Current directory = {os.getcwd()}") + + # Default prompt speech + prompt_speech_path = settings.get_absolute_path(settings.DEFAULT_PROMPT_SPEECH) + print(f"DEFAULT_PROMPT_SPEECH = {settings.DEFAULT_PROMPT_SPEECH}") + print(f" Absolute path = {prompt_speech_path}") + print(f" File exists = {os.path.exists(prompt_speech_path)}") + + # Model directory + model_dir_path = settings.get_absolute_path(settings.MODEL_DIR) + print(f"MODEL_DIR = {settings.MODEL_DIR}") + print(f" Absolute path = {model_dir_path}") + print(f" Directory exists = {os.path.exists(model_dir_path)}") + + # Output directory + output_dir_path = settings.get_absolute_path(settings.OUTPUT_DIR) + print(f"OUTPUT_DIR = {settings.OUTPUT_DIR}") + print(f" Absolute path = {output_dir_path}") + print(f" Directory exists = {os.path.exists(output_dir_path)}") + + print("="*80 + "\n") + +# === Configuration Items === +class Settings: + # Service configuration + API_PORT: int = int(os.getenv('SPARK_TTS_API_PORT', 7860)) + API_HOST: str = os.getenv('SPARK_TTS_API_HOST', '0.0.0.0') + API_DEBUG: bool = os.getenv('SPARK_TTS_API_DEBUG', 'False').lower() == 'true' + + # Security configuration + API_KEY_NAME: str = os.getenv('SPARK_TTS_API_KEY_NAME', 'X-SPARKTTS-API-KEY') + API_KEY: Optional[str] = os.getenv('SPARK_TTS_API_KEY', None) + + # TTS model configuration + MODEL_DIR: str = os.getenv('SPARK_TTS_MODEL_DIR', 'pretrained_models/Spark-TTS-0.5B') + DEVICE: str = os.getenv('SPARK_TTS_DEVICE', 'gpu:0') + + # Default prompt audio and text + DEFAULT_PROMPT_TEXT: str = os.getenv('SPARK_TTS_DEFAULT_PROMPT_TEXT', + "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。") + DEFAULT_PROMPT_SPEECH: str = os.getenv('SPARK_TTS_DEFAULT_PROMPT_SPEECH', + "example/prompt_audio.wav") + logger.info(f"Environment variable SPARK_TTS_DEFAULT_PROMPT_SPEECH value: {os.getenv('SPARK_TTS_DEFAULT_PROMPT_SPEECH', 'not set')}") + logger.info(f"Configured DEFAULT_PROMPT_SPEECH value: {DEFAULT_PROMPT_SPEECH}") + + # Output configuration + OUTPUT_DIR: str = os.getenv('SPARK_TTS_OUTPUT_DIR', 'api/outputs') + OUTPUT_URL_PREFIX: str = os.getenv('SPARK_TTS_OUTPUT_URL_PREFIX', '/outputs') + + # Cleanup configuration + CLEANUP_INTERVAL: int = int(os.getenv('SPARK_TTS_CLEANUP_INTERVAL', 3600)) # seconds + FILE_EXPIRY_TIME: int = int(os.getenv('SPARK_TTS_FILE_EXPIRY_TIME', 86400)) # seconds + + # Project root path + @property + def PROJECT_ROOT(self): + # The parent directory of the api directory is the project root directory + return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + # Get absolute path + def get_absolute_path(self, path, log=False): + """Get absolute path, ensuring all relative paths are relative to the project root directory + + Args: + path: The path to convert + log: Whether to log the path conversion (default: False) + + Returns: + Absolute path + """ + if os.path.isabs(path): + if log: + logger.info(f"Path is already absolute: {path}") + return path + + # If it's a path relative to the project root directory (starting with ../) + if path.startswith("../"): + abs_path = os.path.join(self.PROJECT_ROOT, path[3:]) + if log: + logger.info(f"Converting relative path (../) to absolute: {path} -> {abs_path}") + return abs_path + + # General relative path, considered relative to the project root directory + # This ensures that the path resolution is consistent regardless of where the script is run from + abs_path = os.path.join(self.PROJECT_ROOT, path) + if log: + logger.info(f"Converting relative path to absolute: {path} -> {abs_path}") + return abs_path + +@lru_cache +def get_settings(): + return Settings() + +# === API Model === +class TTSRequest(BaseModel): + text: str = Field(..., description="Text to be synthesized") + + # Voice cloning parameters (all are optional) + prompt_text: Optional[str] = Field(None, description="Text content of reference audio") + prompt_audio_base64: Optional[str] = Field(None, description="Base64 encoded reference audio data") + prompt_audio_url: Optional[HttpUrl] = Field(None, description="Reference audio URL") + + # Voice creation parameters (all are optional) + gender: Optional[str] = Field(None, description="Voice gender (male/female)") + pitch: Optional[str] = Field(None, description="Pitch (very_low/low/moderate/high/very_high)") + speed: Optional[str] = Field(None, description="Speech speed (very_low/low/moderate/high/very_high)") + + # Advanced parameters + temperature: float = Field(0.8, description="Sampling temperature") + top_k: int = Field(50, description="Top K sampling") + top_p: float = Field(0.95, description="Top P sampling") + return_audio_data: bool = Field(False, description="Whether to include audio data in the response") + + @validator('prompt_audio_url') + def validate_audio_sources(cls, v, values): + if v is not None and values.get('prompt_audio_base64') is not None: + raise ValueError("Cannot provide both prompt_audio_base64 and prompt_audio_url, please choose one method") + return v + + @validator('text') + def validate_text(cls, v): + if not v or len(v.strip()) < 2: + raise ValueError("Input text too short. Please provide at least 2 characters of text.") + return v + +class TTSResponse(BaseModel): + text: str = Field(..., description="Input text") + audio_url: Optional[str] = Field(None, description="Generated audio URL") + audio_base64: Optional[str] = Field(None, description="Base64 encoded generated audio") + duration: float = Field(..., description="Audio duration (seconds)") + sample_rate: int = Field(..., description="Sample rate") + file_id: str = Field(..., description="File ID") + created_at: str = Field(..., description="Creation time") + +# === Security === +api_key_header = APIKeyHeader(name=get_settings().API_KEY_NAME, auto_error=False) + +async def get_api_key(api_key: str = Security(api_key_header)): + settings = get_settings() + + # If API key is not set, no validation is performed + if not settings.API_KEY: + return None + + if api_key == settings.API_KEY: + return api_key + + raise HTTPException( + status_code=403, + detail="Invalid API key" + ) + +# === Application Initialization === +app = FastAPI( + title="Spark-TTS API", + description="A speech synthesis API based on Spark-TTS, supporting voice cloning and voice creation features", + version="1.0.0" +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# TTS model instance +tts_model = None + +# Cleanup task +cleanup_task = None + +# === Helper Functions === +def initialize_model(): + """Initialize the TTS model + + Returns: + SparkTTS: Initialized SparkTTS model + """ + settings = get_settings() + + # Get model path + model_dir = settings.get_absolute_path(settings.MODEL_DIR, log=True) + logger.info(f"Initializing Spark-TTS model, path: {model_dir}") + + # Process device parameter + device_param = settings.DEVICE.lower().strip() + + # Convert 'gpu' to appropriate device format for PyTorch + if device_param == "gpu": + device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"Mapped 'gpu' to PyTorch device: {device}") + elif device_param.startswith("gpu:"): + gpu_id = device_param.split(":")[-1] + if torch.cuda.is_available(): + device = f"cuda:{gpu_id}" + logger.info(f"Using CUDA device {gpu_id}") + else: + device = "cpu" + logger.warning("No CUDA available, falling back to CPU") + else: + # Use the device as-is + device = device_param + + logger.info(f"Using inference device: {device}") + + # Initialize model + try: + model = SparkTTS(model_dir=model_dir, device=device) + logger.info(f"Spark-TTS model initialization completed, actual used device: {model.device}") + return model + except Exception as e: + logger.error(f"Model initialization failed: {str(e)}") + raise + +async def process_audio_source(request: TTSRequest) -> tuple: + """Process audio source, return temporary file path and whether to clean up""" + settings = get_settings() + prompt_speech_path = None + need_cleanup = False + + logger.info(f"Processing audio source - request.prompt_audio_base64 exists: {request.prompt_audio_base64 is not None}") + logger.info(f"Processing audio source - request.prompt_audio_url exists: {request.prompt_audio_url is not None}") + + # If provided Base64 encoded audio + if request.prompt_audio_base64: + try: + # Create temporary file + fd, prompt_speech_path = tempfile.mkstemp(suffix=".wav") + os.close(fd) + logger.info(f"Created temporary file from Base64: {prompt_speech_path}") + + # Decode Base64 and write to temporary file + audio_data = base64.b64decode(request.prompt_audio_base64) + logger.info(f"Decoded audio data size: {len(audio_data)} bytes") + + # Write to temporary file + with open(prompt_speech_path, "wb") as f: + f.write(audio_data) + + # Verify whether it's a valid WAV file + try: + import soundfile as sf + audio_data_sf, sample_rate = sf.read(prompt_speech_path) + logger.info(f"Audio verification successful, sample rate: {sample_rate}") + except Exception as e: + logger.error(f"Invalid WAV audio file: {str(e)}") + raise HTTPException( + status_code=400, + detail="Provided audio is not a valid WAV format. Please convert audio to WAV format on the client side before uploading." + ) + + need_cleanup = True + logger.info(f"Audio processing completed: {prompt_speech_path}") + + except HTTPException: + # Re-raise HTTP exception + raise + except Exception as e: + logger.error(f"Failed to process Base64 audio: {str(e)}", exc_info=True) + raise HTTPException(status_code=400, detail=f"Invalid Base64 audio: {str(e)}") + + # If provided audio URL + elif request.prompt_audio_url: + try: + # Create temporary file + fd, prompt_speech_path = tempfile.mkstemp(suffix=".wav") + os.close(fd) + logger.info(f"Created temporary file from URL: {prompt_speech_path}") + + # Check if URL points to this service + settings = get_settings() + url_str = str(request.prompt_audio_url) + server_host = f"http://{settings.API_HOST}:{settings.API_PORT}" + local_urls = [ + f"http://localhost:{settings.API_PORT}", + f"http://127.0.0.1:{settings.API_PORT}", + server_host + ] + + is_self_reference = False + for local_url in local_urls: + if url_str.startswith(local_url): + is_self_reference = True + # Extract file path + file_path = url_str.replace(f"{local_url}{settings.OUTPUT_URL_PREFIX}/", "") + logger.info(f"Detected self-reference, directly reading file: {file_path}") + + # Build local file path + local_file_path = os.path.join(settings.get_absolute_path(settings.OUTPUT_DIR, log=True), file_path) + logger.info(f"Local file path: {local_file_path}") + + if os.path.exists(local_file_path): + # Directly copy file instead of downloading via HTTP + import shutil + shutil.copy(local_file_path, prompt_speech_path) + logger.info(f"Direct copy of local file successful: {local_file_path} -> {prompt_speech_path}") + else: + logger.error(f"Local file does not exist: {local_file_path}") + raise HTTPException(status_code=404, detail=f"Local referenced file does not exist: {file_path}") + break + + # If not a self-reference, download URL audio and write to temporary file + if not is_self_reference: + # Download URL audio and write to temporary file + logger.info(f"Downloading audio from URL: {request.prompt_audio_url}") + response = requests.get(str(request.prompt_audio_url), stream=True) + if response.status_code == 200: + with open(prompt_speech_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + logger.info(f"URL audio written to temporary file successfully") + else: + logger.error(f"Failed to download audio, HTTP status code: {response.status_code}") + raise HTTPException(status_code=400, detail=f"Failed to download audio: HTTP {response.status_code}") + + # Verify whether it's a valid WAV file + try: + import soundfile as sf + audio_data_sf, sample_rate = sf.read(prompt_speech_path) + logger.info(f"Audio verification successful, sample rate: {sample_rate}") + except Exception as e: + logger.error(f"Invalid WAV audio file: {str(e)}") + raise HTTPException( + status_code=400, + detail="Provided URL audio is not a valid WAV format. Please convert audio to WAV format on the client side before uploading." + ) + + need_cleanup = True + + except HTTPException: + # Re-raise HTTP exception + raise + except Exception as e: + logger.error(f"Failed to download audio: {str(e)}", exc_info=True) + raise HTTPException(status_code=400, detail=f"Failed to download audio: {str(e)}") + + # If no audio is provided, use default audio + else: + # Get absolute path of default prompt audio + prompt_speech_path = settings.get_absolute_path(settings.DEFAULT_PROMPT_SPEECH, log=True) + logger.info(f"Using default prompt audio: {prompt_speech_path}") + + if not os.path.exists(prompt_speech_path): + logger.warning(f"Default prompt audio does not exist: {prompt_speech_path}") + # Try to find in different locations + alt_paths = [ + os.path.join(settings.PROJECT_ROOT, "example/prompt_audio.wav"), + os.path.join(os.getcwd(), "example/prompt_audio.wav"), + os.path.join(os.path.dirname(os.path.abspath(__file__)), "../example/prompt_audio.wav") + ] + + logger.info(f"Searching for alternative prompt audio files...") + for path in alt_paths: + logger.info(f"Checking alternative path: {path}") + if os.path.exists(path): + logger.info(f"✅ Found alternative prompt audio: {path}") + settings.DEFAULT_PROMPT_SPEECH = path + logger.info(f"Updated DEFAULT_PROMPT_SPEECH value: {path}") + break + else: + logger.warning("❌ No alternative prompt audio files found in any location, service may not work properly") + + # Verify default audio is a valid WAV file + try: + import soundfile as sf + audio_data_sf, sample_rate = sf.read(prompt_speech_path) + logger.info(f"Default audio verification successful, sample rate: {sample_rate}") + except Exception as e: + logger.error(f"Default audio is not a valid WAV file: {str(e)}") + raise HTTPException(status_code=500, detail="Default audio is not a valid WAV format") + + logger.info(f"Default prompt audio exists, size: {os.path.getsize(prompt_speech_path)} bytes") + + return prompt_speech_path, need_cleanup + +def get_prompt_text(request: TTSRequest) -> str: + """Get prompt text""" + settings = get_settings() + + # Record input at call time + logger.info(f"Getting prompt text - request.prompt_text: {request.prompt_text}") + logger.info(f"Getting prompt text - Default prompt text: {settings.DEFAULT_PROMPT_TEXT}") + + if request.prompt_text: + return request.prompt_text + else: + # Ensure default prompt text is not empty + default_text = settings.DEFAULT_PROMPT_TEXT + if not default_text or len(default_text.strip()) < 2: + logger.warning("Default prompt text is empty or too short, using backup prompt text") + default_text = "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡。" + + return default_text + +AudioDataType = TypeVar('AudioDataType') + +async def save_output_audio(audio_data, sample_rate: int = 16000) -> tuple: + """Save output audio to file + + Args: + audio_data: Audio data to save + sample_rate: Sample rate of audio data + + Returns: + tuple: (file_id, output_path, output_url, duration) + """ + settings = get_settings() + + # Convert to numpy array if needed + if isinstance(audio_data, torch.Tensor): + audio_np = audio_data.cpu().numpy() + else: + audio_np = audio_data + + # Generate output file path + file_id = str(uuid.uuid4()) + file_name = f"{file_id}.wav" + + # Ensure output directory exists + output_dir = settings.get_absolute_path(settings.OUTPUT_DIR, log=True) + os.makedirs(output_dir, exist_ok=True) + + # Create output file path + output_path = os.path.join(output_dir, file_name) + output_url = f"{settings.OUTPUT_URL_PREFIX}/{file_name}" + + # Calculate audio duration + duration = len(audio_np) / sample_rate + + # Save audio to file + try: + sf.write(output_path, audio_np, sample_rate) + logger.info(f"Saved audio to {output_path}, size: {os.path.getsize(output_path)} bytes") + return file_id, output_path, output_url, duration + except Exception as e: + logger.error(f"Failed to save audio: {str(e)}") + raise + +def get_audio_base64(file_path: str) -> str: + """Convert audio file to Base64 encoding""" + with open(file_path, "rb") as f: + audio_data = f.read() + return base64.b64encode(audio_data).decode("utf-8") + +async def cleanup_old_files(): + """Clean up expired output files""" + settings = get_settings() + + while True: + try: + logger.info("Starting to clean up expired files...") + now = datetime.now() + expiry_time = now - timedelta(seconds=settings.FILE_EXPIRY_TIME) + + # Ensure output directory is absolute path (disable logging here to avoid duplication) + output_dir = settings.get_absolute_path(settings.OUTPUT_DIR, log=False) + + # Ensure output directory exists + if not os.path.exists(output_dir): + logger.warning(f"Output directory does not exist: {output_dir}, skipping cleanup") + await asyncio.sleep(settings.CLEANUP_INTERVAL) + continue + + # Iterate through files in the output directory + for filename in os.listdir(output_dir): + file_path = os.path.join(output_dir, filename) + + # Check if the file is a regular file + if not os.path.isfile(file_path): + continue + + # Get file modification time + file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path)) + + # Delete expired files + if file_mod_time < expiry_time: + try: + os.remove(file_path) + logger.info(f"Deleted expired file: {file_path}") + except Exception as e: + logger.error(f"Failed to delete file {file_path}: {str(e)}") + + deleted_count = 0 + logger.info(f"Cleanup complete, deleted {deleted_count} expired files") + + except Exception as e: + logger.error(f"Error during cleanup: {str(e)}") + + # Wait for next cleanup interval + await asyncio.sleep(settings.CLEANUP_INTERVAL) + +# === API Endpoints === +# Mount static files directory for audio output access +@app.post("/tts", response_model=TTSResponse) +async def text_to_speech(request: TTSRequest, api_key: APIKey = Depends(get_api_key)): + """ + Text-to-speech endpoint supporting voice cloning and voice creation. + + This endpoint accepts text and optional parameters for voice cloning and/or voice creation, + then generates audio using the Spark-TTS model. + """ + settings = get_settings() + logger.info(f"Received TTS request: {request.text[:100]}{'...' if len(request.text) > 100 else ''}") + + # Initialize model if not already initialized + model = initialize_model() + if model is None: + raise HTTPException(status_code=500, detail="Failed to initialize TTS model") + + # Process audio source (for voice cloning) + prompt_speech_path = None + need_cleanup_audio = False + + if request.prompt_audio_base64 is not None or request.prompt_audio_url is not None: + logger.info("Voice cloning mode detected") + prompt_speech_path, need_cleanup_audio = await process_audio_source(request) + else: + logger.info("No voice cloning parameters provided, using default if available") + # When no prompt audio is provided, we'll use the default if voice cloning is needed + if any([param is not None for param in [request.gender, request.pitch, request.speed]]): + logger.info("Voice creation mode detected") + else: + logger.info("Basic TTS mode, using default prompt") + # In basic mode, we always use default prompt for better quality + prompt_speech_path, need_cleanup_audio = await process_audio_source(request) + + # Get prompt text (for voice cloning) + prompt_text = get_prompt_text(request) + + try: + # Define the voice generation parameters + tts_params = {} + + # Voice cloning parameters + if prompt_speech_path: + logger.info(f"Using prompt audio: {prompt_speech_path}") + tts_params["prompt_audio"] = prompt_speech_path + + if prompt_text: + logger.info(f"Using prompt text: {prompt_text[:100]}{'...' if len(prompt_text) > 100 else ''}") + tts_params["prompt_text"] = prompt_text + + # Voice creation parameters + if request.gender: + logger.info(f"Setting gender: {request.gender}") + tts_params["gender"] = request.gender + + if request.pitch: + logger.info(f"Setting pitch: {request.pitch}") + tts_params["pitch"] = request.pitch + + if request.speed: + logger.info(f"Setting speed: {request.speed}") + tts_params["speed"] = request.speed + + # Execute TTS inference + logger.info(f"Starting TTS inference, text length: {len(request.text)}, text first 30 characters: {request.text[:30]}") + try: + # Add environment variable setting, which may help resolve CUDA errors + os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + logger.info("CUDA_LAUNCH_BLOCKING=1 set") + + # Convert prompt audio path to str type + prompt_speech_path_str = str(prompt_speech_path) + logger.info(f"Prompt audio path (str): {prompt_speech_path_str}") + + # Use asynchronous implementation with timeout handling + async def run_inference(): + try: + if request.gender is not None: + # Voice creation mode + logger.info("Using voice creation mode") + return model.inference( + text=request.text, + gender=request.gender, + pitch=request.pitch or "moderate", + speed=request.speed or "moderate", + temperature=request.temperature, + top_k=request.top_k, + top_p=request.top_p + ) + else: + # Voice cloning mode or basic mode + logger.info("Using voice cloning or basic mode") + logger.info(f"Parameter check - text: {request.text}") + logger.info(f"Parameter check - prompt_speech_path: {prompt_speech_path_str}") + logger.info(f"Parameter check - prompt_text: {prompt_text}") + + logger.info("Starting audio tokenization...") + # Use executor in thread pool to run this part + loop = asyncio.get_event_loop() + audio = await loop.run_in_executor( + None, + lambda: model.inference( + text=request.text, + prompt_speech_path=prompt_speech_path_str, + prompt_text=prompt_text, + temperature=request.temperature, + top_k=request.top_k, + top_p=request.top_p + ) + ) + logger.info("Audio synthesis completed") + return audio + except Exception as e: + logger.error(f"Inference failed: {str(e)}", exc_info=True) + raise e + + # Set timeout time (120 seconds) + try: + logger.info("Starting TTS inference, setting timeout time to 120 seconds...") + audio = await asyncio.wait_for(run_inference(), timeout=120) + logger.info(f"TTS inference completed, audio shape: {audio.shape}") + except asyncio.TimeoutError: + logger.error("TTS inference timeout (120 seconds)") + raise HTTPException(status_code=504, detail="TTS processing timeout, possibly due to reference audio too large or format incompatibility") + except Exception as e: + logger.error(f"TTS processing failed: {str(e)}", exc_info=True) + + # For debugging purposes, try to record CUDA device status + try: + logger.info(f"CUDA available: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + logger.info(f"CUDA device count: {torch.cuda.device_count()}") + logger.info(f"Current CUDA device: {torch.cuda.current_device()}") + logger.info(f"CUDA device name: {torch.cuda.get_device_name(0)}") + logger.info(f"CUDA memory allocation: {torch.cuda.memory_allocated(0)}") + logger.info(f"CUDA memory cache: {torch.cuda.memory_reserved(0)}") + except Exception as cuda_error: + logger.error(f"Failed to get CUDA information: {str(cuda_error)}") + + raise HTTPException(status_code=500, detail=f"TTS processing failed: {str(e)}") + + # Clean up temporary audio file + if need_cleanup_audio and prompt_speech_path and os.path.exists(prompt_speech_path): + os.remove(prompt_speech_path) + + # Save output audio + file_id, output_path, output_url, duration = await save_output_audio(audio) + + # Build response + response = TTSResponse( + text=request.text, + audio_url=output_url, + audio_base64=get_audio_base64(output_path) if request.return_audio_data else None, + duration=duration, + sample_rate=16000, + file_id=file_id, + created_at=datetime.now().isoformat() + ) + + return response + + except Exception as e: + logger.error(f"TTS processing failed: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail=f"TTS processing failed: {str(e)}") + +@app.get("/health") +async def health_check(): + """Health check endpoint""" + global tts_model + settings = get_settings() + + # Get device information + device_info = { + "configured": settings.DEVICE, + "actual": str(tts_model.device) if tts_model is not None else "Not initialized" + } + + return { + "status": "ok", + "timestamp": datetime.now().isoformat(), + "device": device_info, + "model_loaded": tts_model is not None + } + +# === Application startup and shutdown events === +@app.on_event("startup") +async def startup_event(): + """Initialize the TTS model on startup""" + # Load environment variables + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + + global cleanup_task + settings = get_settings() + + # Print essential configuration information (bypassing logging system) + # This direct print ensures critical config is always visible even if logging is filtered + print_config_info() + + # Initialize model + try: + model = initialize_model() + except Exception as e: + logger.error(f"Failed to initialize model: {str(e)}") + model = None + + # Start scheduled cleanup task + cleanup_task = asyncio.create_task(cleanup_old_files()) + + # Mount static files directory + try: + # Ensure output directory is absolute path (log only once during startup) + output_dir = settings.get_absolute_path(settings.OUTPUT_DIR, log=True) + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Using output directory: {output_dir}") + + # Remove URL prefix leading slash, if any + static_url_path = settings.OUTPUT_URL_PREFIX + if static_url_path.startswith('/'): + static_url_path = static_url_path[1:] + + # Mount static files directory + app.mount(f"/{static_url_path}", StaticFiles(directory=output_dir), name="audio_files") + logger.info(f"Mounted static files directory: {output_dir} to /{static_url_path}") + except Exception as e: + logger.error(f"Mounting static files directory failed: {str(e)}") + + logger.info("Spark-TTS API service started successfully") + +@app.on_event("shutdown") +async def shutdown_event(): + global cleanup_task + + # Cancel cleanup task + if cleanup_task: + cleanup_task.cancel() + try: + await cleanup_task + except asyncio.CancelledError: + pass + + logger.info("Spark-TTS API service shut down") + +# === Main program === +if __name__ == "__main__": + settings = get_settings() + uvicorn.run( + app, + host=settings.API_HOST, + port=settings.API_PORT, + reload=settings.API_DEBUG + ) \ No newline at end of file diff --git a/api/run_api.sh b/api/run_api.sh new file mode 100755 index 0000000..231aadf --- /dev/null +++ b/api/run_api.sh @@ -0,0 +1,106 @@ +#!/bin/bash + +# Get the absolute path of the script +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" + +# Get the project root directory +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" + +# Set essential environment variables +export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH" +export CUDA_LAUNCH_BLOCKING=1 + +# 1. Load environment variables from .env file (lowest priority) +ENV_FILE="$SCRIPT_DIR/.env" +if [ -f "$ENV_FILE" ]; then + echo "Loading environment variables from: $ENV_FILE" + # Export environment variables from .env file + while IFS='=' read -r key value || [ -n "$key" ]; do + # Skip comment lines and empty lines + [[ $key == \#* ]] && continue + [[ -z "$key" ]] && continue + + # Remove surrounding quotes + value=$(echo "$value" | sed -e 's/^"//' -e 's/"$//' -e "s/^'//" -e "s/'$//") + + # Only export if environment variable is not already set + if [ -z "${!key}" ]; then + export "$key=$value" + fi + done < "$ENV_FILE" +else + echo "No environment file detected, using system environment variables or defaults" +fi + +# 2. Set defaults or read from environment variables (middle priority) +# Using environment variables if available, otherwise use defaults +HOST="${SPARK_TTS_API_HOST:-0.0.0.0}" +PORT="${SPARK_TTS_API_PORT:-7860}" +DEBUG="${SPARK_TTS_API_DEBUG:-false}" +RELOAD=false + +# 3. Process command line arguments (highest priority) +while [[ $# -gt 0 ]]; do + case $1 in + --host) + HOST="$2" + shift 2 + ;; + --port) + PORT="$2" + shift 2 + ;; + --debug) + DEBUG=true + RELOAD=true + shift + ;; + --reload) + RELOAD=true + shift + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Convert string boolean to actual boolean for DEBUG if needed +if [[ "$DEBUG" == "true" || "$DEBUG" == "True" || "$DEBUG" == "TRUE" || "$DEBUG" == "1" ]]; then + RELOAD=true +fi + +# Create output directory if it doesn't exist +OUTPUT_DIR_PATH="${SPARK_TTS_OUTPUT_DIR:-api/outputs}" +if [[ "$OUTPUT_DIR_PATH" = /* ]]; then + # Absolute path + FINAL_OUTPUT_DIR="$OUTPUT_DIR_PATH" +else + # Relative path, convert to absolute path + FINAL_OUTPUT_DIR="$PROJECT_ROOT/$OUTPUT_DIR_PATH" +fi +mkdir -p "$FINAL_OUTPUT_DIR" + +# Start API service +cd "$PROJECT_ROOT" +echo "Starting Spark-TTS API service..." +echo "Host: $HOST, Port: $PORT, Debug: $DEBUG, Reload: $RELOAD" + +# Set RELOAD parameter +if [ "$RELOAD" = true ]; then + RELOAD_ARG="--reload" +else + RELOAD_ARG="" +fi + +# Start API service +python -m uvicorn api.main:app --host "$HOST" --port "$PORT" $RELOAD_ARG + +# Check exit status +if [ $? -ne 0 ]; then + echo "Failed to start API service!" + exit 1 +fi + +echo "API service has stopped" \ No newline at end of file diff --git a/cli/__init__.py b/cli/__init__.py new file mode 100644 index 0000000..8000f80 --- /dev/null +++ b/cli/__init__.py @@ -0,0 +1,7 @@ +""" +Spark-TTS CLI 模块 +""" + +from .SparkTTS import SparkTTS + +__all__ = ['SparkTTS'] \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..8c35ffc --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,142 @@ +name: 'Spark-TTS' + +# Spark-TTS Docker Compose Configuration +# This file provides different service configurations for Spark-TTS + +# === Environment Variables === +x-spark-tts-env: &spark-tts-env + # Service type - will be overridden for each service + SERVICE_TYPE: api + + # === Service Configuration === + # API service listening port + SPARK_TTS_API_PORT: 7860 + + # API service host address + SPARK_TTS_API_HOST: 0.0.0.0 + + # Enable debug mode + SPARK_TTS_API_DEBUG: False + + # === Security Configuration === + # Request header name for API key + # SPARK_TTS_API_KEY_NAME: X-SPARKTTS-API-KEY + + # API key (authentication disabled if not set) + # SPARK_TTS_API_KEY: your_secret_api_key + + # === TTS Model Configuration === + # Model directory path + # SPARK_TTS_MODEL_DIR: pretrained_models/Spark-TTS-0.5B + + # GPU device ID + # Options: cpu (CPU inference), gpu (default GPU), gpu:N (specific GPU where N is device ID) + # SPARK_TTS_DEVICE: gpu:0 + + # === Default Prompt Configuration === + # Default reference text + # SPARK_TTS_DEFAULT_PROMPT_TEXT: 吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。 + + # Default reference audio path + # SPARK_TTS_DEFAULT_PROMPT_SPEECH: example/prompt_audio.wav + + # === Output Configuration === + # Output audio file directory + # SPARK_TTS_OUTPUT_DIR: api/outputs + + # Output audio URL prefix + # SPARK_TTS_OUTPUT_URL_PREFIX: /outputs + + # === Cleanup Configuration === + # Cleanup task interval (seconds) + # SPARK_TTS_CLEANUP_INTERVAL: 3600 + + # File expiration time (seconds) + # SPARK_TTS_FILE_EXPIRY_TIME: 86400 + +# === Usage Instructions === +# +# Start API with models included: +# docker compose up -d api +# +# Start API with mounted models: +# docker compose up -d api-lite +# +# Start WebUI with models included: +# docker compose up -d webui +# +# Start WebUI with mounted models: +# docker compose up -d webui-lite +# + +services: + # API service with full image (includes models) + api: + image: spark-tts:latest-full + environment: + <<: *spark-tts-env + SERVICE_TYPE: api + ports: + - "7860:7860" + volumes: + - ./api_output:/app/api/outputs + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + + # API service with lite image (mount models) + api-lite: + image: spark-tts:latest-lite + environment: + <<: *spark-tts-env + SERVICE_TYPE: api + ports: + - "7860:7860" + volumes: + - ./pretrained_models:/app/pretrained_models + - ./api_output:/app/api/outputs + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + + # WebUI service with full image (includes models) + webui: + image: spark-tts:latest-full + environment: + <<: *spark-tts-env + SERVICE_TYPE: webui + ports: + - "7860:7860" + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + + # WebUI service with lite image (mount models) + webui-lite: + image: spark-tts:latest-lite + environment: + <<: *spark-tts-env + SERVICE_TYPE: webui + ports: + - "7860:7860" + volumes: + - ./pretrained_models:/app/pretrained_models + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] diff --git a/docker_builder.sh b/docker_builder.sh new file mode 100755 index 0000000..4735e5f --- /dev/null +++ b/docker_builder.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# Spark-TTS Docker Image Builder +# This script builds different versions of the Spark-TTS Docker image + +# Set colors for output +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +RED='\033[0;31m' +NC='\033[0m' # No Color + +# Get the absolute path of the script +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +cd "$SCRIPT_DIR" + +# Define image names +IMAGE_NAME="spark-tts" +BASE_TAG="latest" +FULL_TAG="latest-full" +LITE_TAG="latest-lite" + +# Print header +echo -e "${GREEN}====================================${NC}" +echo -e "${GREEN} Spark-TTS Docker Image Builder ${NC}" +echo -e "${GREEN}====================================${NC}" +echo + +# Check if Docker is installed +if ! command -v docker &> /dev/null; then + echo -e "${RED}Error: Docker is not installed or not in PATH${NC}" + exit 1 +fi + +# Check if pretrained_models directory exists +if [ ! -d "./pretrained_models" ]; then + echo -e "${YELLOW}Warning: pretrained_models directory not found${NC}" + echo -e "${YELLOW}Models will not be included in the 'full' image${NC}" + echo -e "${YELLOW}You can download models later or mount them when running the container${NC}" + read -p "Do you want to continue? (y/n): " -n 1 -r + echo + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + echo -e "${RED}Build canceled${NC}" + exit 1 + fi +fi + +# Build lite version (without models) +echo -e "${GREEN}Building ${IMAGE_NAME}:${LITE_TAG} (without models)...${NC}" +docker build -t ${IMAGE_NAME}:${LITE_TAG} . + +if [ $? -ne 0 ]; then + echo -e "${RED}Failed to build ${IMAGE_NAME}:${LITE_TAG}${NC}" + exit 1 +fi + +# Set as latest tag +echo -e "${GREEN}Tagging ${IMAGE_NAME}:${LITE_TAG} as ${IMAGE_NAME}:${BASE_TAG}...${NC}" +docker tag ${IMAGE_NAME}:${LITE_TAG} ${IMAGE_NAME}:${BASE_TAG} + +# Build full version (with models) +echo -e "${GREEN}Building ${IMAGE_NAME}:${FULL_TAG} (with models)...${NC}" +docker build --build-arg INCLUDE_MODELS=true -t ${IMAGE_NAME}:${FULL_TAG} . + +if [ $? -ne 0 ]; then + echo -e "${RED}Failed to build ${IMAGE_NAME}:${FULL_TAG}${NC}" + echo -e "${YELLOW}Note: The lite version was built successfully and is available${NC}" + exit 1 +fi + +# Summary +echo +echo -e "${GREEN}====================================${NC}" +echo -e "${GREEN} Build Completed Successfully ${NC}" +echo -e "${GREEN}====================================${NC}" +echo +echo -e "Image tags created:" +echo -e " - ${IMAGE_NAME}:${BASE_TAG} (alias of ${LITE_TAG})" +echo -e " - ${IMAGE_NAME}:${LITE_TAG} (without models)" +echo -e " - ${IMAGE_NAME}:${FULL_TAG} (with models)" +echo +echo -e "To run API (default):" +echo -e " docker run -p 7860:7860 --gpus all ${IMAGE_NAME}:${FULL_TAG}" +echo +echo -e "To run WebUI:" +echo -e " docker run -p 7860:7860 --gpus all -e SERVICE_TYPE=webui ${IMAGE_NAME}:${FULL_TAG}" +echo +echo -e "To use the lite version, you must mount the models directory:" +echo -e " docker run -p 7860:7860 --gpus all -v /local/path/to/models:/app/pretrained_models ${IMAGE_NAME}:${LITE_TAG}" +echo \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 0e05ebb..b82ad1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,7 @@ torch==2.5.1 torchaudio==2.5.1 tqdm==4.66.5 transformers==4.46.2 -gradio==5.18.0 \ No newline at end of file +gradio==5.18.0 +fastapi==0.115.11 +uvicorn==0.34.0 +python-dotenv==1.0.1 \ No newline at end of file