import random import numpy as np import torch import io import os import uuid import asyncio import traceback from dataclasses import dataclass, field from typing import Dict, Optional from datetime import datetime from chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES import gradio as gr from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.base import BaseHTTPMiddleware import uvicorn import soundfile as sf DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"🚀 Running on device: {DEVICE}") # --- Global Model Initialization --- MODEL = None gpu_lock = asyncio.Lock() task_storage: Dict[str, 'QueueTask'] = {} MAX_QUEUE_SIZE = 10 TASK_TTL_HOURS = 1 LANGUAGE_CONFIG = { "en": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac", "text": "Last month, we reached a new milestone with two billion views on our YouTube channel." }, "ru": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ru_m.flac", "text": "В прошлом месяце мы достигли нового рубежа: два миллиарда просмотров на нашем YouTube-канале." }, } @dataclass class QueueTask: id: str temp_path: str text: str language_id: str params: dict created_at: datetime = field(default_factory=datetime.now) status: str = "pending" result: Optional[bytes] = None error: Optional[str] = None completed_at: Optional[datetime] = None future: asyncio.Future = field(default_factory=lambda: asyncio.get_event_loop().create_future()) def default_audio_for_ui(lang: str) -> str | None: return LANGUAGE_CONFIG.get(lang, {}).get("audio") def default_text_for_ui(lang: str) -> str: return LANGUAGE_CONFIG.get(lang, {}).get("text", "") def get_supported_languages_display() -> str: language_items = [] for code, name in sorted(SUPPORTED_LANGUAGES.items()): language_items.append(f"**{name}** (`{code}`)") mid = len(language_items) // 2 line1 = " • ".join(language_items[:mid]) line2 = " • ".join(language_items[mid:]) return f""" ### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total) {line1} {line2} """ def get_or_load_model(): global MODEL if MODEL is None: print("Model not loaded, initializing...") try: MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE) if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE: MODEL.to(DEVICE) print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}") except Exception as e: print(f"Error loading model: {e}") raise return MODEL try: get_or_load_model() except Exception as e: print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}") def set_seed(seed: int): torch.manual_seed(seed) if DEVICE == "cuda": torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) def generate_tts_audio( text_input: str, language_id: str, audio_prompt_path_input: str = None, exaggeration_input: float = 0.5, temperature_input: float = 0.8, seed_num_input: int = 0, cfgw_input: float = 0.5 ) -> tuple[int, np.ndarray]: current_model = get_or_load_model() if current_model is None: raise RuntimeError("TTS model is not loaded.") if seed_num_input != 0: set_seed(int(seed_num_input)) print(f"[GPU] Processing: '{text_input[:50]}...'") chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id) generate_kwargs = { "exaggeration": exaggeration_input, "temperature": temperature_input, "cfg_weight": cfgw_input, } if chosen_prompt: generate_kwargs["audio_prompt_path"] = chosen_prompt wav = current_model.generate( text_input[:300], language_id=language_id, **generate_kwargs ) print(f"[GPU] Done: '{text_input[:50]}...'") return (current_model.sr, wav.squeeze(0).numpy()) async def cleanup_old_tasks(): while True: await asyncio.sleep(600) now = datetime.now() to_remove = [] for task_id, task in task_storage.items(): if task.status in ["completed", "failed"]: if task.completed_at and (now - task.completed_at).total_seconds() > 3600 * TASK_TTL_HOURS: to_remove.append(task_id) for tid in to_remove: del task_storage[tid] print(f"[Cleanup] Removed old task {tid}") # Замени функцию process_task на эту: async def process_task(task: QueueTask): """Обработка задачи с блокировкой GPU""" async with gpu_lock: task.status = "processing" try: loop = asyncio.get_event_loop() sr, wav = await loop.run_in_executor( None, generate_tts_audio, task.text, task.language_id, task.temp_path, task.params["exaggeration_input"], task.params["temperature_input"], task.params["seed_num_input"], task.params["cfgw_input"] ) # 🔥 ИСПРАВЛЕНИЕ: правильная запись в BytesIO с явным форматом buffer = io.BytesIO() # Используем lambda чтобы передать именованные аргументы в sf.write await loop.run_in_executor( None, lambda: sf.write(buffer, wav, sr, format='WAV', subtype='PCM_16') ) buffer.seek(0) # 🔥 Важно: перемотать в начало! task.result = buffer.getvalue() task.status = "completed" task.completed_at = datetime.now() task.future.set_result(True) except Exception as e: task.error = str(e) task.status = "failed" task.completed_at = datetime.now() task.future.set_exception(e) finally: if os.path.exists(task.temp_path): os.remove(task.temp_path) # ============ FASTAPI SETUP ============ class CORSHeaderMiddleware(BaseHTTPMiddleware): """Добавляет CORS заголовки к ЛЮБОМУ ответу, включая ошибки 500""" async def dispatch(self, request, call_next): try: response = await call_next(request) except Exception as e: # Если произошла ошибка, создаем ответ с CORS заголовками error_msg = f"Internal Server Error: {str(e)}\n{traceback.format_exc()}" print(error_msg) response = JSONResponse( status_code=500, content={"detail": str(e), "traceback": traceback.format_exc()} ) # Добавляем CORS заголовки к ЛЮБОМУ ответу response.headers["Access-Control-Allow-Origin"] = "*" response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS, PATCH" response.headers["Access-Control-Allow-Headers"] = "*" response.headers["Access-Control-Allow-Credentials"] = "true" return response app = FastAPI(title="Chatterbox TTS API") # 🔥 Первым делом добавляем наш middleware для CORS (перед стандартным) app.add_middleware(CORSHeaderMiddleware) # 🔥 Стандартный CORS middleware для preflight запросов app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], expose_headers=["*"], max_age=3600, ) @app.exception_handler(Exception) async def global_exception_handler(request: Request, exc: Exception): """Глобальный обработчик исключений с CORS заголовками""" error_detail = f"{str(exc)}\n{traceback.format_exc()}" print(f"[ERROR] {error_detail}") return JSONResponse( status_code=500, content={"detail": str(exc), "type": type(exc).__name__}, headers={ "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "GET, POST, OPTIONS", "Access-Control-Allow-Headers": "*", } ) @app.on_event("startup") async def startup(): asyncio.create_task(cleanup_old_tasks()) @app.options("/{path:path}") async def options_handler(path: str): """Обрабатывает OPTIONS для любого пути""" return PlainTextResponse( "OK", headers={ "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH", "Access-Control-Allow-Headers": "*", "Access-Control-Allow-Credentials": "true", "Access-Control-Max-Age": "3600", } ) @app.post("/process-audio") async def process_audio( audio_file: UploadFile = File(...), text: str = Form(...), language_id: str = Form("en"), exaggeration: float = Form(0.5), temperature: float = Form(0.8), seed_num: int = Form(0), cfg_weight: float = Form(0.5) ): try: pending_count = sum(1 for t in task_storage.values() if t.status == "pending") if pending_count >= MAX_QUEUE_SIZE: raise HTTPException(status_code=503, detail="Queue is full, try again later") task_id = str(uuid.uuid4()) temp_path = f"/tmp/input_{task_id}.wav" # Сохраняем файл content = await audio_file.read() with open(temp_path, "wb") as f: f.write(content) task = QueueTask( id=task_id, temp_path=temp_path, text=text, language_id=language_id, params={ "exaggeration_input": exaggeration, "temperature_input": temperature, "seed_num_input": seed_num, "cfgw_input": cfg_weight } ) task_storage[task_id] = task asyncio.create_task(process_task(task)) try: await asyncio.wait_for(task.future, timeout=300.0) if task.error: raise HTTPException(status_code=500, detail=task.error) return StreamingResponse( io.BytesIO(task.result), media_type="audio/wav", headers={"Content-Disposition": f"attachment; filename=output_{task_id}.wav"} ) except asyncio.TimeoutError: return JSONResponse( status_code=202, content={ "task_id": task_id, "status": "processing", "message": "Processing takes longer than expected." } ) except Exception as e: print(f"[ERROR in process_audio]: {traceback.format_exc()}") raise @app.get("/status/{task_id}") async def get_status(task_id: str): task = task_storage.get(task_id) if not task: raise HTTPException(status_code=404, detail="Task not found") if task.status == "completed": return StreamingResponse( io.BytesIO(task.result), media_type="audio/wav", headers={"Content-Disposition": f"attachment; filename=output_{task_id}.wav"} ) return JSONResponse(content={ "task_id": task_id, "status": task.status, "error": task.error, "created_at": task.created_at.isoformat() }) @app.get("/queue-status") async def queue_status(): pending = sum(1 for t in task_storage.values() if t.status == "pending") processing = sum(1 for t in task_storage.values() if t.status == "processing") return JSONResponse(content={ "queue_size": pending, "processing_now": processing, "completed_total": sum(1 for t in task_storage.values() if t.status == "completed"), "max_queue_size": MAX_QUEUE_SIZE, "gpu_busy": gpu_lock.locked(), "available_slots": MAX_QUEUE_SIZE - pending }) # ============ GRADIO SETUP ============ with gr.Blocks() as demo: gr.Markdown( """ # Chatterbox Multilingual Demo Generate high-quality multilingual speech from text with reference audio styling, supporting 23 languages. """ ) gr.Markdown(get_supported_languages_display()) with gr.Row(): with gr.Column(): initial_lang = "ru" text = gr.Textbox( value=default_text_for_ui(initial_lang), label="Text to synthesize (max chars 300)", max_lines=5 ) language_id = gr.Dropdown( choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()), value=initial_lang, label="Language" ) ref_wav = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Reference Audio File (Optional)", value=default_audio_for_ui(initial_lang) ) exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration", value=.5) cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.5) with gr.Accordion("More options", open=False): seed_num = gr.Number(value=0, label="Random seed") temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8) run_btn = gr.Button("Generate", variant="primary") queue_info = gr.Textbox(label="Queue Status", value="Ready", interactive=False) with gr.Column(): audio_output = gr.Audio(label="Output Audio") def on_language_change(lang): return default_audio_for_ui(lang), default_text_for_ui(lang) language_id.change( fn=on_language_change, inputs=[language_id], outputs=[ref_wav, text], show_progress=False ) def generate_with_queue(*args): pending = sum(1 for t in task_storage.values() if t.status == "pending") if pending >= MAX_QUEUE_SIZE: return None, f"⚠️ Queue is full ({pending}/{MAX_QUEUE_SIZE})" try: result = generate_tts_audio(*args) return result, f"✅ Ready" except Exception as e: return None, f"❌ Error: {str(e)}" run_btn.click( fn=generate_with_queue, inputs=[text, language_id, ref_wav, exaggeration, temp, seed_num, cfg_weight], outputs=[audio_output, queue_info], ) app = gr.mount_gradio_app(app, demo, path="/gradio") if __name__ == "__main__": print("🚀 Starting server with CORS fix...") print(f"📊 Gradio UI: http://localhost:8000/gradio") print(f"🔌 API endpoint: POST http://localhost:8000/process-audio") uvicorn.run(app, host="0.0.0.0", port=8000)