436 lines
No EOL
15 KiB
Python
436 lines
No EOL
15 KiB
Python
import random
|
||
import numpy as np
|
||
import torch
|
||
import io
|
||
import os
|
||
import uuid
|
||
import asyncio
|
||
import traceback
|
||
from dataclasses import dataclass, field
|
||
from typing import Dict, Optional
|
||
from datetime import datetime
|
||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
|
||
import gradio as gr
|
||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
|
||
from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
import uvicorn
|
||
import soundfile as sf
|
||
|
||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||
print(f"🚀 Running on device: {DEVICE}")
|
||
|
||
# --- Global Model Initialization ---
|
||
MODEL = None
|
||
gpu_lock = asyncio.Lock()
|
||
task_storage: Dict[str, 'QueueTask'] = {}
|
||
MAX_QUEUE_SIZE = 10
|
||
TASK_TTL_HOURS = 1
|
||
|
||
LANGUAGE_CONFIG = {
|
||
"en": {
|
||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac",
|
||
"text": "Last month, we reached a new milestone with two billion views on our YouTube channel."
|
||
},
|
||
"ru": {
|
||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ru_m.flac",
|
||
"text": "В прошлом месяце мы достигли нового рубежа: два миллиарда просмотров на нашем YouTube-канале."
|
||
},
|
||
}
|
||
|
||
@dataclass
|
||
class QueueTask:
|
||
id: str
|
||
temp_path: str
|
||
text: str
|
||
language_id: str
|
||
params: dict
|
||
created_at: datetime = field(default_factory=datetime.now)
|
||
status: str = "pending"
|
||
result: Optional[bytes] = None
|
||
error: Optional[str] = None
|
||
completed_at: Optional[datetime] = None
|
||
future: asyncio.Future = field(default_factory=lambda: asyncio.get_event_loop().create_future())
|
||
|
||
def default_audio_for_ui(lang: str) -> str | None:
|
||
return LANGUAGE_CONFIG.get(lang, {}).get("audio")
|
||
|
||
def default_text_for_ui(lang: str) -> str:
|
||
return LANGUAGE_CONFIG.get(lang, {}).get("text", "")
|
||
|
||
def get_supported_languages_display() -> str:
|
||
language_items = []
|
||
for code, name in sorted(SUPPORTED_LANGUAGES.items()):
|
||
language_items.append(f"**{name}** (`{code}`)")
|
||
mid = len(language_items) // 2
|
||
line1 = " • ".join(language_items[:mid])
|
||
line2 = " • ".join(language_items[mid:])
|
||
return f"""
|
||
### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total)
|
||
{line1}
|
||
{line2}
|
||
"""
|
||
|
||
def get_or_load_model():
|
||
global MODEL
|
||
if MODEL is None:
|
||
print("Model not loaded, initializing...")
|
||
try:
|
||
MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
|
||
if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
|
||
MODEL.to(DEVICE)
|
||
print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
|
||
except Exception as e:
|
||
print(f"Error loading model: {e}")
|
||
raise
|
||
return MODEL
|
||
|
||
try:
|
||
get_or_load_model()
|
||
except Exception as e:
|
||
print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
|
||
|
||
def set_seed(seed: int):
|
||
torch.manual_seed(seed)
|
||
if DEVICE == "cuda":
|
||
torch.cuda.manual_seed(seed)
|
||
torch.cuda.manual_seed_all(seed)
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
|
||
def generate_tts_audio(
|
||
text_input: str,
|
||
language_id: str,
|
||
audio_prompt_path_input: str = None,
|
||
exaggeration_input: float = 0.5,
|
||
temperature_input: float = 0.8,
|
||
seed_num_input: int = 0,
|
||
cfgw_input: float = 0.5
|
||
) -> tuple[int, np.ndarray]:
|
||
current_model = get_or_load_model()
|
||
if current_model is None:
|
||
raise RuntimeError("TTS model is not loaded.")
|
||
|
||
if seed_num_input != 0:
|
||
set_seed(int(seed_num_input))
|
||
|
||
print(f"[GPU] Processing: '{text_input[:50]}...'")
|
||
chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id)
|
||
|
||
generate_kwargs = {
|
||
"exaggeration": exaggeration_input,
|
||
"temperature": temperature_input,
|
||
"cfg_weight": cfgw_input,
|
||
}
|
||
if chosen_prompt:
|
||
generate_kwargs["audio_prompt_path"] = chosen_prompt
|
||
|
||
wav = current_model.generate(
|
||
text_input[:300],
|
||
language_id=language_id,
|
||
**generate_kwargs
|
||
)
|
||
print(f"[GPU] Done: '{text_input[:50]}...'")
|
||
return (current_model.sr, wav.squeeze(0).numpy())
|
||
|
||
async def cleanup_old_tasks():
|
||
while True:
|
||
await asyncio.sleep(600)
|
||
now = datetime.now()
|
||
to_remove = []
|
||
for task_id, task in task_storage.items():
|
||
if task.status in ["completed", "failed"]:
|
||
if task.completed_at and (now - task.completed_at).total_seconds() > 3600 * TASK_TTL_HOURS:
|
||
to_remove.append(task_id)
|
||
for tid in to_remove:
|
||
del task_storage[tid]
|
||
print(f"[Cleanup] Removed old task {tid}")
|
||
|
||
# Замени функцию process_task на эту:
|
||
|
||
async def process_task(task: QueueTask):
|
||
"""Обработка задачи с блокировкой GPU"""
|
||
async with gpu_lock:
|
||
task.status = "processing"
|
||
try:
|
||
loop = asyncio.get_event_loop()
|
||
sr, wav = await loop.run_in_executor(
|
||
None,
|
||
generate_tts_audio,
|
||
task.text,
|
||
task.language_id,
|
||
task.temp_path,
|
||
task.params["exaggeration_input"],
|
||
task.params["temperature_input"],
|
||
task.params["seed_num_input"],
|
||
task.params["cfgw_input"]
|
||
)
|
||
|
||
# 🔥 ИСПРАВЛЕНИЕ: правильная запись в BytesIO с явным форматом
|
||
buffer = io.BytesIO()
|
||
# Используем lambda чтобы передать именованные аргументы в sf.write
|
||
await loop.run_in_executor(
|
||
None,
|
||
lambda: sf.write(buffer, wav, sr, format='WAV', subtype='PCM_16')
|
||
)
|
||
buffer.seek(0) # 🔥 Важно: перемотать в начало!
|
||
task.result = buffer.getvalue()
|
||
task.status = "completed"
|
||
task.completed_at = datetime.now()
|
||
task.future.set_result(True)
|
||
|
||
except Exception as e:
|
||
task.error = str(e)
|
||
task.status = "failed"
|
||
task.completed_at = datetime.now()
|
||
task.future.set_exception(e)
|
||
finally:
|
||
if os.path.exists(task.temp_path):
|
||
os.remove(task.temp_path)
|
||
|
||
# ============ FASTAPI SETUP ============
|
||
|
||
class CORSHeaderMiddleware(BaseHTTPMiddleware):
|
||
"""Добавляет CORS заголовки к ЛЮБОМУ ответу, включая ошибки 500"""
|
||
async def dispatch(self, request, call_next):
|
||
try:
|
||
response = await call_next(request)
|
||
except Exception as e:
|
||
# Если произошла ошибка, создаем ответ с CORS заголовками
|
||
error_msg = f"Internal Server Error: {str(e)}\n{traceback.format_exc()}"
|
||
print(error_msg)
|
||
response = JSONResponse(
|
||
status_code=500,
|
||
content={"detail": str(e), "traceback": traceback.format_exc()}
|
||
)
|
||
|
||
# Добавляем CORS заголовки к ЛЮБОМУ ответу
|
||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS, PATCH"
|
||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||
return response
|
||
|
||
app = FastAPI(title="Chatterbox TTS API")
|
||
|
||
# 🔥 Первым делом добавляем наш middleware для CORS (перед стандартным)
|
||
app.add_middleware(CORSHeaderMiddleware)
|
||
|
||
# 🔥 Стандартный CORS middleware для preflight запросов
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
expose_headers=["*"],
|
||
max_age=3600,
|
||
)
|
||
|
||
@app.exception_handler(Exception)
|
||
async def global_exception_handler(request: Request, exc: Exception):
|
||
"""Глобальный обработчик исключений с CORS заголовками"""
|
||
error_detail = f"{str(exc)}\n{traceback.format_exc()}"
|
||
print(f"[ERROR] {error_detail}")
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content={"detail": str(exc), "type": type(exc).__name__},
|
||
headers={
|
||
"Access-Control-Allow-Origin": "*",
|
||
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
|
||
"Access-Control-Allow-Headers": "*",
|
||
}
|
||
)
|
||
|
||
@app.on_event("startup")
|
||
async def startup():
|
||
asyncio.create_task(cleanup_old_tasks())
|
||
|
||
@app.options("/{path:path}")
|
||
async def options_handler(path: str):
|
||
"""Обрабатывает OPTIONS для любого пути"""
|
||
return PlainTextResponse(
|
||
"OK",
|
||
headers={
|
||
"Access-Control-Allow-Origin": "*",
|
||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH",
|
||
"Access-Control-Allow-Headers": "*",
|
||
"Access-Control-Allow-Credentials": "true",
|
||
"Access-Control-Max-Age": "3600",
|
||
}
|
||
)
|
||
|
||
@app.post("/process-audio")
|
||
async def process_audio(
|
||
audio_file: UploadFile = File(...),
|
||
text: str = Form(...),
|
||
language_id: str = Form("en"),
|
||
exaggeration: float = Form(0.5),
|
||
temperature: float = Form(0.8),
|
||
seed_num: int = Form(0),
|
||
cfg_weight: float = Form(0.5)
|
||
):
|
||
try:
|
||
pending_count = sum(1 for t in task_storage.values() if t.status == "pending")
|
||
if pending_count >= MAX_QUEUE_SIZE:
|
||
raise HTTPException(status_code=503, detail="Queue is full, try again later")
|
||
|
||
task_id = str(uuid.uuid4())
|
||
temp_path = f"/tmp/input_{task_id}.wav"
|
||
|
||
# Сохраняем файл
|
||
content = await audio_file.read()
|
||
with open(temp_path, "wb") as f:
|
||
f.write(content)
|
||
|
||
task = QueueTask(
|
||
id=task_id,
|
||
temp_path=temp_path,
|
||
text=text,
|
||
language_id=language_id,
|
||
params={
|
||
"exaggeration_input": exaggeration,
|
||
"temperature_input": temperature,
|
||
"seed_num_input": seed_num,
|
||
"cfgw_input": cfg_weight
|
||
}
|
||
)
|
||
task_storage[task_id] = task
|
||
|
||
asyncio.create_task(process_task(task))
|
||
|
||
try:
|
||
await asyncio.wait_for(task.future, timeout=300.0)
|
||
|
||
if task.error:
|
||
raise HTTPException(status_code=500, detail=task.error)
|
||
|
||
return StreamingResponse(
|
||
io.BytesIO(task.result),
|
||
media_type="audio/wav",
|
||
headers={"Content-Disposition": f"attachment; filename=output_{task_id}.wav"}
|
||
)
|
||
|
||
except asyncio.TimeoutError:
|
||
return JSONResponse(
|
||
status_code=202,
|
||
content={
|
||
"task_id": task_id,
|
||
"status": "processing",
|
||
"message": "Processing takes longer than expected."
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR in process_audio]: {traceback.format_exc()}")
|
||
raise
|
||
|
||
@app.get("/status/{task_id}")
|
||
async def get_status(task_id: str):
|
||
task = task_storage.get(task_id)
|
||
if not task:
|
||
raise HTTPException(status_code=404, detail="Task not found")
|
||
|
||
if task.status == "completed":
|
||
return StreamingResponse(
|
||
io.BytesIO(task.result),
|
||
media_type="audio/wav",
|
||
headers={"Content-Disposition": f"attachment; filename=output_{task_id}.wav"}
|
||
)
|
||
|
||
return JSONResponse(content={
|
||
"task_id": task_id,
|
||
"status": task.status,
|
||
"error": task.error,
|
||
"created_at": task.created_at.isoformat()
|
||
})
|
||
|
||
@app.get("/queue-status")
|
||
async def queue_status():
|
||
pending = sum(1 for t in task_storage.values() if t.status == "pending")
|
||
processing = sum(1 for t in task_storage.values() if t.status == "processing")
|
||
|
||
return JSONResponse(content={
|
||
"queue_size": pending,
|
||
"processing_now": processing,
|
||
"completed_total": sum(1 for t in task_storage.values() if t.status == "completed"),
|
||
"max_queue_size": MAX_QUEUE_SIZE,
|
||
"gpu_busy": gpu_lock.locked(),
|
||
"available_slots": MAX_QUEUE_SIZE - pending
|
||
})
|
||
|
||
# ============ GRADIO SETUP ============
|
||
|
||
with gr.Blocks() as demo:
|
||
gr.Markdown(
|
||
"""
|
||
# Chatterbox Multilingual Demo
|
||
Generate high-quality multilingual speech from text with reference audio styling, supporting 23 languages.
|
||
"""
|
||
)
|
||
gr.Markdown(get_supported_languages_display())
|
||
|
||
with gr.Row():
|
||
with gr.Column():
|
||
initial_lang = "ru"
|
||
text = gr.Textbox(
|
||
value=default_text_for_ui(initial_lang),
|
||
label="Text to synthesize (max chars 300)",
|
||
max_lines=5
|
||
)
|
||
language_id = gr.Dropdown(
|
||
choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()),
|
||
value=initial_lang,
|
||
label="Language"
|
||
)
|
||
ref_wav = gr.Audio(
|
||
sources=["upload", "microphone"],
|
||
type="filepath",
|
||
label="Reference Audio File (Optional)",
|
||
value=default_audio_for_ui(initial_lang)
|
||
)
|
||
exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration", value=.5)
|
||
cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.5)
|
||
with gr.Accordion("More options", open=False):
|
||
seed_num = gr.Number(value=0, label="Random seed")
|
||
temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
|
||
run_btn = gr.Button("Generate", variant="primary")
|
||
queue_info = gr.Textbox(label="Queue Status", value="Ready", interactive=False)
|
||
|
||
with gr.Column():
|
||
audio_output = gr.Audio(label="Output Audio")
|
||
|
||
def on_language_change(lang):
|
||
return default_audio_for_ui(lang), default_text_for_ui(lang)
|
||
|
||
language_id.change(
|
||
fn=on_language_change,
|
||
inputs=[language_id],
|
||
outputs=[ref_wav, text],
|
||
show_progress=False
|
||
)
|
||
|
||
def generate_with_queue(*args):
|
||
pending = sum(1 for t in task_storage.values() if t.status == "pending")
|
||
if pending >= MAX_QUEUE_SIZE:
|
||
return None, f"⚠️ Queue is full ({pending}/{MAX_QUEUE_SIZE})"
|
||
try:
|
||
result = generate_tts_audio(*args)
|
||
return result, f"✅ Ready"
|
||
except Exception as e:
|
||
return None, f"❌ Error: {str(e)}"
|
||
|
||
run_btn.click(
|
||
fn=generate_with_queue,
|
||
inputs=[text, language_id, ref_wav, exaggeration, temp, seed_num, cfg_weight],
|
||
outputs=[audio_output, queue_info],
|
||
)
|
||
|
||
app = gr.mount_gradio_app(app, demo, path="/gradio")
|
||
|
||
if __name__ == "__main__":
|
||
print("🚀 Starting server with CORS fix...")
|
||
print(f"📊 Gradio UI: http://localhost:8000/gradio")
|
||
print(f"🔌 API endpoint: POST http://localhost:8000/process-audio")
|
||
uvicorn.run(app, host="0.0.0.0", port=8000) |