1
0
Fork 0
security-lab/backend/chatterbox_app.py
2026-04-05 12:13:58 +03:00

436 lines
No EOL
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
import numpy as np
import torch
import io
import os
import uuid
import asyncio
import traceback
from dataclasses import dataclass, field
from typing import Dict, Optional
from datetime import datetime
from chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
import gradio as gr
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
import uvicorn
import soundfile as sf
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {DEVICE}")
# --- Global Model Initialization ---
MODEL = None
gpu_lock = asyncio.Lock()
task_storage: Dict[str, 'QueueTask'] = {}
MAX_QUEUE_SIZE = 10
TASK_TTL_HOURS = 1
LANGUAGE_CONFIG = {
"en": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac",
"text": "Last month, we reached a new milestone with two billion views on our YouTube channel."
},
"ru": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ru_m.flac",
"text": "В прошлом месяце мы достигли нового рубежа: два миллиарда просмотров на нашем YouTube-канале."
},
}
@dataclass
class QueueTask:
id: str
temp_path: str
text: str
language_id: str
params: dict
created_at: datetime = field(default_factory=datetime.now)
status: str = "pending"
result: Optional[bytes] = None
error: Optional[str] = None
completed_at: Optional[datetime] = None
future: asyncio.Future = field(default_factory=lambda: asyncio.get_event_loop().create_future())
def default_audio_for_ui(lang: str) -> str | None:
return LANGUAGE_CONFIG.get(lang, {}).get("audio")
def default_text_for_ui(lang: str) -> str:
return LANGUAGE_CONFIG.get(lang, {}).get("text", "")
def get_supported_languages_display() -> str:
language_items = []
for code, name in sorted(SUPPORTED_LANGUAGES.items()):
language_items.append(f"**{name}** (`{code}`)")
mid = len(language_items) // 2
line1 = "".join(language_items[:mid])
line2 = "".join(language_items[mid:])
return f"""
### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total)
{line1}
{line2}
"""
def get_or_load_model():
global MODEL
if MODEL is None:
print("Model not loaded, initializing...")
try:
MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
MODEL.to(DEVICE)
print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
except Exception as e:
print(f"Error loading model: {e}")
raise
return MODEL
try:
get_or_load_model()
except Exception as e:
print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
def set_seed(seed: int):
torch.manual_seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
def generate_tts_audio(
text_input: str,
language_id: str,
audio_prompt_path_input: str = None,
exaggeration_input: float = 0.5,
temperature_input: float = 0.8,
seed_num_input: int = 0,
cfgw_input: float = 0.5
) -> tuple[int, np.ndarray]:
current_model = get_or_load_model()
if current_model is None:
raise RuntimeError("TTS model is not loaded.")
if seed_num_input != 0:
set_seed(int(seed_num_input))
print(f"[GPU] Processing: '{text_input[:50]}...'")
chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id)
generate_kwargs = {
"exaggeration": exaggeration_input,
"temperature": temperature_input,
"cfg_weight": cfgw_input,
}
if chosen_prompt:
generate_kwargs["audio_prompt_path"] = chosen_prompt
wav = current_model.generate(
text_input[:300],
language_id=language_id,
**generate_kwargs
)
print(f"[GPU] Done: '{text_input[:50]}...'")
return (current_model.sr, wav.squeeze(0).numpy())
async def cleanup_old_tasks():
while True:
await asyncio.sleep(600)
now = datetime.now()
to_remove = []
for task_id, task in task_storage.items():
if task.status in ["completed", "failed"]:
if task.completed_at and (now - task.completed_at).total_seconds() > 3600 * TASK_TTL_HOURS:
to_remove.append(task_id)
for tid in to_remove:
del task_storage[tid]
print(f"[Cleanup] Removed old task {tid}")
# Замени функцию process_task на эту:
async def process_task(task: QueueTask):
"""Обработка задачи с блокировкой GPU"""
async with gpu_lock:
task.status = "processing"
try:
loop = asyncio.get_event_loop()
sr, wav = await loop.run_in_executor(
None,
generate_tts_audio,
task.text,
task.language_id,
task.temp_path,
task.params["exaggeration_input"],
task.params["temperature_input"],
task.params["seed_num_input"],
task.params["cfgw_input"]
)
# 🔥 ИСПРАВЛЕНИЕ: правильная запись в BytesIO с явным форматом
buffer = io.BytesIO()
# Используем lambda чтобы передать именованные аргументы в sf.write
await loop.run_in_executor(
None,
lambda: sf.write(buffer, wav, sr, format='WAV', subtype='PCM_16')
)
buffer.seek(0) # 🔥 Важно: перемотать в начало!
task.result = buffer.getvalue()
task.status = "completed"
task.completed_at = datetime.now()
task.future.set_result(True)
except Exception as e:
task.error = str(e)
task.status = "failed"
task.completed_at = datetime.now()
task.future.set_exception(e)
finally:
if os.path.exists(task.temp_path):
os.remove(task.temp_path)
# ============ FASTAPI SETUP ============
class CORSHeaderMiddleware(BaseHTTPMiddleware):
"""Добавляет CORS заголовки к ЛЮБОМУ ответу, включая ошибки 500"""
async def dispatch(self, request, call_next):
try:
response = await call_next(request)
except Exception as e:
# Если произошла ошибка, создаем ответ с CORS заголовками
error_msg = f"Internal Server Error: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
response = JSONResponse(
status_code=500,
content={"detail": str(e), "traceback": traceback.format_exc()}
)
# Добавляем CORS заголовки к ЛЮБОМУ ответу
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS, PATCH"
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Allow-Credentials"] = "true"
return response
app = FastAPI(title="Chatterbox TTS API")
# 🔥 Первым делом добавляем наш middleware для CORS (перед стандартным)
app.add_middleware(CORSHeaderMiddleware)
# 🔥 Стандартный CORS middleware для preflight запросов
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
max_age=3600,
)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""Глобальный обработчик исключений с CORS заголовками"""
error_detail = f"{str(exc)}\n{traceback.format_exc()}"
print(f"[ERROR] {error_detail}")
return JSONResponse(
status_code=500,
content={"detail": str(exc), "type": type(exc).__name__},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "*",
}
)
@app.on_event("startup")
async def startup():
asyncio.create_task(cleanup_old_tasks())
@app.options("/{path:path}")
async def options_handler(path: str):
"""Обрабатывает OPTIONS для любого пути"""
return PlainTextResponse(
"OK",
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH",
"Access-Control-Allow-Headers": "*",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "3600",
}
)
@app.post("/process-audio")
async def process_audio(
audio_file: UploadFile = File(...),
text: str = Form(...),
language_id: str = Form("en"),
exaggeration: float = Form(0.5),
temperature: float = Form(0.8),
seed_num: int = Form(0),
cfg_weight: float = Form(0.5)
):
try:
pending_count = sum(1 for t in task_storage.values() if t.status == "pending")
if pending_count >= MAX_QUEUE_SIZE:
raise HTTPException(status_code=503, detail="Queue is full, try again later")
task_id = str(uuid.uuid4())
temp_path = f"/tmp/input_{task_id}.wav"
# Сохраняем файл
content = await audio_file.read()
with open(temp_path, "wb") as f:
f.write(content)
task = QueueTask(
id=task_id,
temp_path=temp_path,
text=text,
language_id=language_id,
params={
"exaggeration_input": exaggeration,
"temperature_input": temperature,
"seed_num_input": seed_num,
"cfgw_input": cfg_weight
}
)
task_storage[task_id] = task
asyncio.create_task(process_task(task))
try:
await asyncio.wait_for(task.future, timeout=300.0)
if task.error:
raise HTTPException(status_code=500, detail=task.error)
return StreamingResponse(
io.BytesIO(task.result),
media_type="audio/wav",
headers={"Content-Disposition": f"attachment; filename=output_{task_id}.wav"}
)
except asyncio.TimeoutError:
return JSONResponse(
status_code=202,
content={
"task_id": task_id,
"status": "processing",
"message": "Processing takes longer than expected."
}
)
except Exception as e:
print(f"[ERROR in process_audio]: {traceback.format_exc()}")
raise
@app.get("/status/{task_id}")
async def get_status(task_id: str):
task = task_storage.get(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
if task.status == "completed":
return StreamingResponse(
io.BytesIO(task.result),
media_type="audio/wav",
headers={"Content-Disposition": f"attachment; filename=output_{task_id}.wav"}
)
return JSONResponse(content={
"task_id": task_id,
"status": task.status,
"error": task.error,
"created_at": task.created_at.isoformat()
})
@app.get("/queue-status")
async def queue_status():
pending = sum(1 for t in task_storage.values() if t.status == "pending")
processing = sum(1 for t in task_storage.values() if t.status == "processing")
return JSONResponse(content={
"queue_size": pending,
"processing_now": processing,
"completed_total": sum(1 for t in task_storage.values() if t.status == "completed"),
"max_queue_size": MAX_QUEUE_SIZE,
"gpu_busy": gpu_lock.locked(),
"available_slots": MAX_QUEUE_SIZE - pending
})
# ============ GRADIO SETUP ============
with gr.Blocks() as demo:
gr.Markdown(
"""
# Chatterbox Multilingual Demo
Generate high-quality multilingual speech from text with reference audio styling, supporting 23 languages.
"""
)
gr.Markdown(get_supported_languages_display())
with gr.Row():
with gr.Column():
initial_lang = "ru"
text = gr.Textbox(
value=default_text_for_ui(initial_lang),
label="Text to synthesize (max chars 300)",
max_lines=5
)
language_id = gr.Dropdown(
choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()),
value=initial_lang,
label="Language"
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio File (Optional)",
value=default_audio_for_ui(initial_lang)
)
exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration", value=.5)
cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.5)
with gr.Accordion("More options", open=False):
seed_num = gr.Number(value=0, label="Random seed")
temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
run_btn = gr.Button("Generate", variant="primary")
queue_info = gr.Textbox(label="Queue Status", value="Ready", interactive=False)
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
def on_language_change(lang):
return default_audio_for_ui(lang), default_text_for_ui(lang)
language_id.change(
fn=on_language_change,
inputs=[language_id],
outputs=[ref_wav, text],
show_progress=False
)
def generate_with_queue(*args):
pending = sum(1 for t in task_storage.values() if t.status == "pending")
if pending >= MAX_QUEUE_SIZE:
return None, f"⚠️ Queue is full ({pending}/{MAX_QUEUE_SIZE})"
try:
result = generate_tts_audio(*args)
return result, f"✅ Ready"
except Exception as e:
return None, f"❌ Error: {str(e)}"
run_btn.click(
fn=generate_with_queue,
inputs=[text, language_id, ref_wav, exaggeration, temp, seed_num, cfg_weight],
outputs=[audio_output, queue_info],
)
app = gr.mount_gradio_app(app, demo, path="/gradio")
if __name__ == "__main__":
print("🚀 Starting server with CORS fix...")
print(f"📊 Gradio UI: http://localhost:8000/gradio")
print(f"🔌 API endpoint: POST http://localhost:8000/process-audio")
uvicorn.run(app, host="0.0.0.0", port=8000)