diff --git a/README.md b/README.md index bfa09a4..6a1905c 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,13 @@ ## Решение команды Atekin кейса от компании Центр Инвест! +### Реализованный функционал: +* Deepfake с возможностью использования своего голоса или файла +* Квиз по инфобезу по двум темам +* Информационная страничка как база знаний +* 5 кейсов по отработке типовых ситуаций с сфере информационной безопасности +* Прогресс и мини лента новостей :D + + | Фотки работы | | ----------- | | ![Главный экран](./static/main.png) | @@ -16,11 +24,12 @@ | ----------- | | ![deepfake](./static/quiz.png) | | ![deepfake](./static/quiz_1.png) | -| ![deepfake](./static/deepfake_2.png) | -| ![deepfake](./static/deepfake_3.png) | -| ![deepfake](./static/deepfake_4.png) | -Фронтенд на react, бекенд с нейронкой по генерации deepfake голоса на, *барабанная дробь* питоне +| База знаний со ссылками | +| ----------- | +| ![deepfake](./static/wiki.png) | + +### Фронтенд на react, бекенд с нейронкой по генерации deepfake голоса на, *барабанная дробь* питоне Для сборки статики фронта требуется nodejs @@ -36,3 +45,4 @@ npm run build Выхлоп в папке `dist/` +### Бекенд для создания голоса лежит по пути `backend/chatterbox_app.py` diff --git a/backend/chatterbox_app.py b/backend/chatterbox_app.py new file mode 100644 index 0000000..052e1cd --- /dev/null +++ b/backend/chatterbox_app.py @@ -0,0 +1,436 @@ +import random +import numpy as np +import torch +import io +import os +import uuid +import asyncio +import traceback +from dataclasses import dataclass, field +from typing import Dict, Optional +from datetime import datetime +from chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES +import gradio as gr +from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request +from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse +from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.base import BaseHTTPMiddleware +import uvicorn +import soundfile as sf + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +print(f"🚀 Running on device: {DEVICE}") + +# --- Global Model Initialization --- +MODEL = None +gpu_lock = asyncio.Lock() +task_storage: Dict[str, 'QueueTask'] = {} +MAX_QUEUE_SIZE = 10 +TASK_TTL_HOURS = 1 + +LANGUAGE_CONFIG = { + "en": { + "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac", + "text": "Last month, we reached a new milestone with two billion views on our YouTube channel." + }, + "ru": { + "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ru_m.flac", + "text": "В прошлом месяце мы достигли нового рубежа: два миллиарда просмотров на нашем YouTube-канале." + }, +} + +@dataclass +class QueueTask: + id: str + temp_path: str + text: str + language_id: str + params: dict + created_at: datetime = field(default_factory=datetime.now) + status: str = "pending" + result: Optional[bytes] = None + error: Optional[str] = None + completed_at: Optional[datetime] = None + future: asyncio.Future = field(default_factory=lambda: asyncio.get_event_loop().create_future()) + +def default_audio_for_ui(lang: str) -> str | None: + return LANGUAGE_CONFIG.get(lang, {}).get("audio") + +def default_text_for_ui(lang: str) -> str: + return LANGUAGE_CONFIG.get(lang, {}).get("text", "") + +def get_supported_languages_display() -> str: + language_items = [] + for code, name in sorted(SUPPORTED_LANGUAGES.items()): + language_items.append(f"**{name}** (`{code}`)") + mid = len(language_items) // 2 + line1 = " • ".join(language_items[:mid]) + line2 = " • ".join(language_items[mid:]) + return f""" +### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total) +{line1} +{line2} +""" + +def get_or_load_model(): + global MODEL + if MODEL is None: + print("Model not loaded, initializing...") + try: + MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE) + if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE: + MODEL.to(DEVICE) + print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}") + except Exception as e: + print(f"Error loading model: {e}") + raise + return MODEL + +try: + get_or_load_model() +except Exception as e: + print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}") + +def set_seed(seed: int): + torch.manual_seed(seed) + if DEVICE == "cuda": + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + random.seed(seed) + np.random.seed(seed) + +def generate_tts_audio( + text_input: str, + language_id: str, + audio_prompt_path_input: str = None, + exaggeration_input: float = 0.5, + temperature_input: float = 0.8, + seed_num_input: int = 0, + cfgw_input: float = 0.5 +) -> tuple[int, np.ndarray]: + current_model = get_or_load_model() + if current_model is None: + raise RuntimeError("TTS model is not loaded.") + + if seed_num_input != 0: + set_seed(int(seed_num_input)) + + print(f"[GPU] Processing: '{text_input[:50]}...'") + chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id) + + generate_kwargs = { + "exaggeration": exaggeration_input, + "temperature": temperature_input, + "cfg_weight": cfgw_input, + } + if chosen_prompt: + generate_kwargs["audio_prompt_path"] = chosen_prompt + + wav = current_model.generate( + text_input[:300], + language_id=language_id, + **generate_kwargs + ) + print(f"[GPU] Done: '{text_input[:50]}...'") + return (current_model.sr, wav.squeeze(0).numpy()) + +async def cleanup_old_tasks(): + while True: + await asyncio.sleep(600) + now = datetime.now() + to_remove = [] + for task_id, task in task_storage.items(): + if task.status in ["completed", "failed"]: + if task.completed_at and (now - task.completed_at).total_seconds() > 3600 * TASK_TTL_HOURS: + to_remove.append(task_id) + for tid in to_remove: + del task_storage[tid] + print(f"[Cleanup] Removed old task {tid}") + +# Замени функцию process_task на эту: + +async def process_task(task: QueueTask): + """Обработка задачи с блокировкой GPU""" + async with gpu_lock: + task.status = "processing" + try: + loop = asyncio.get_event_loop() + sr, wav = await loop.run_in_executor( + None, + generate_tts_audio, + task.text, + task.language_id, + task.temp_path, + task.params["exaggeration_input"], + task.params["temperature_input"], + task.params["seed_num_input"], + task.params["cfgw_input"] + ) + + # 🔥 ИСПРАВЛЕНИЕ: правильная запись в BytesIO с явным форматом + buffer = io.BytesIO() + # Используем lambda чтобы передать именованные аргументы в sf.write + await loop.run_in_executor( + None, + lambda: sf.write(buffer, wav, sr, format='WAV', subtype='PCM_16') + ) + buffer.seek(0) # 🔥 Важно: перемотать в начало! + task.result = buffer.getvalue() + task.status = "completed" + task.completed_at = datetime.now() + task.future.set_result(True) + + except Exception as e: + task.error = str(e) + task.status = "failed" + task.completed_at = datetime.now() + task.future.set_exception(e) + finally: + if os.path.exists(task.temp_path): + os.remove(task.temp_path) + +# ============ FASTAPI SETUP ============ + +class CORSHeaderMiddleware(BaseHTTPMiddleware): + """Добавляет CORS заголовки к ЛЮБОМУ ответу, включая ошибки 500""" + async def dispatch(self, request, call_next): + try: + response = await call_next(request) + except Exception as e: + # Если произошла ошибка, создаем ответ с CORS заголовками + error_msg = f"Internal Server Error: {str(e)}\n{traceback.format_exc()}" + print(error_msg) + response = JSONResponse( + status_code=500, + content={"detail": str(e), "traceback": traceback.format_exc()} + ) + + # Добавляем CORS заголовки к ЛЮБОМУ ответу + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS, PATCH" + response.headers["Access-Control-Allow-Headers"] = "*" + response.headers["Access-Control-Allow-Credentials"] = "true" + return response + +app = FastAPI(title="Chatterbox TTS API") + +# 🔥 Первым делом добавляем наш middleware для CORS (перед стандартным) +app.add_middleware(CORSHeaderMiddleware) + +# 🔥 Стандартный CORS middleware для preflight запросов +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["*"], + max_age=3600, +) + +@app.exception_handler(Exception) +async def global_exception_handler(request: Request, exc: Exception): + """Глобальный обработчик исключений с CORS заголовками""" + error_detail = f"{str(exc)}\n{traceback.format_exc()}" + print(f"[ERROR] {error_detail}") + return JSONResponse( + status_code=500, + content={"detail": str(exc), "type": type(exc).__name__}, + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, OPTIONS", + "Access-Control-Allow-Headers": "*", + } + ) + +@app.on_event("startup") +async def startup(): + asyncio.create_task(cleanup_old_tasks()) + +@app.options("/{path:path}") +async def options_handler(path: str): + """Обрабатывает OPTIONS для любого пути""" + return PlainTextResponse( + "OK", + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH", + "Access-Control-Allow-Headers": "*", + "Access-Control-Allow-Credentials": "true", + "Access-Control-Max-Age": "3600", + } + ) + +@app.post("/process-audio") +async def process_audio( + audio_file: UploadFile = File(...), + text: str = Form(...), + language_id: str = Form("en"), + exaggeration: float = Form(0.5), + temperature: float = Form(0.8), + seed_num: int = Form(0), + cfg_weight: float = Form(0.5) +): + try: + pending_count = sum(1 for t in task_storage.values() if t.status == "pending") + if pending_count >= MAX_QUEUE_SIZE: + raise HTTPException(status_code=503, detail="Queue is full, try again later") + + task_id = str(uuid.uuid4()) + temp_path = f"/tmp/input_{task_id}.wav" + + # Сохраняем файл + content = await audio_file.read() + with open(temp_path, "wb") as f: + f.write(content) + + task = QueueTask( + id=task_id, + temp_path=temp_path, + text=text, + language_id=language_id, + params={ + "exaggeration_input": exaggeration, + "temperature_input": temperature, + "seed_num_input": seed_num, + "cfgw_input": cfg_weight + } + ) + task_storage[task_id] = task + + asyncio.create_task(process_task(task)) + + try: + await asyncio.wait_for(task.future, timeout=300.0) + + if task.error: + raise HTTPException(status_code=500, detail=task.error) + + return StreamingResponse( + io.BytesIO(task.result), + media_type="audio/wav", + headers={"Content-Disposition": f"attachment; filename=output_{task_id}.wav"} + ) + + except asyncio.TimeoutError: + return JSONResponse( + status_code=202, + content={ + "task_id": task_id, + "status": "processing", + "message": "Processing takes longer than expected." + } + ) + + except Exception as e: + print(f"[ERROR in process_audio]: {traceback.format_exc()}") + raise + +@app.get("/status/{task_id}") +async def get_status(task_id: str): + task = task_storage.get(task_id) + if not task: + raise HTTPException(status_code=404, detail="Task not found") + + if task.status == "completed": + return StreamingResponse( + io.BytesIO(task.result), + media_type="audio/wav", + headers={"Content-Disposition": f"attachment; filename=output_{task_id}.wav"} + ) + + return JSONResponse(content={ + "task_id": task_id, + "status": task.status, + "error": task.error, + "created_at": task.created_at.isoformat() + }) + +@app.get("/queue-status") +async def queue_status(): + pending = sum(1 for t in task_storage.values() if t.status == "pending") + processing = sum(1 for t in task_storage.values() if t.status == "processing") + + return JSONResponse(content={ + "queue_size": pending, + "processing_now": processing, + "completed_total": sum(1 for t in task_storage.values() if t.status == "completed"), + "max_queue_size": MAX_QUEUE_SIZE, + "gpu_busy": gpu_lock.locked(), + "available_slots": MAX_QUEUE_SIZE - pending + }) + +# ============ GRADIO SETUP ============ + +with gr.Blocks() as demo: + gr.Markdown( + """ + # Chatterbox Multilingual Demo + Generate high-quality multilingual speech from text with reference audio styling, supporting 23 languages. + """ + ) + gr.Markdown(get_supported_languages_display()) + + with gr.Row(): + with gr.Column(): + initial_lang = "ru" + text = gr.Textbox( + value=default_text_for_ui(initial_lang), + label="Text to synthesize (max chars 300)", + max_lines=5 + ) + language_id = gr.Dropdown( + choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()), + value=initial_lang, + label="Language" + ) + ref_wav = gr.Audio( + sources=["upload", "microphone"], + type="filepath", + label="Reference Audio File (Optional)", + value=default_audio_for_ui(initial_lang) + ) + exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration", value=.5) + cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.5) + with gr.Accordion("More options", open=False): + seed_num = gr.Number(value=0, label="Random seed") + temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8) + run_btn = gr.Button("Generate", variant="primary") + queue_info = gr.Textbox(label="Queue Status", value="Ready", interactive=False) + + with gr.Column(): + audio_output = gr.Audio(label="Output Audio") + + def on_language_change(lang): + return default_audio_for_ui(lang), default_text_for_ui(lang) + + language_id.change( + fn=on_language_change, + inputs=[language_id], + outputs=[ref_wav, text], + show_progress=False + ) + + def generate_with_queue(*args): + pending = sum(1 for t in task_storage.values() if t.status == "pending") + if pending >= MAX_QUEUE_SIZE: + return None, f"⚠️ Queue is full ({pending}/{MAX_QUEUE_SIZE})" + try: + result = generate_tts_audio(*args) + return result, f"✅ Ready" + except Exception as e: + return None, f"❌ Error: {str(e)}" + + run_btn.click( + fn=generate_with_queue, + inputs=[text, language_id, ref_wav, exaggeration, temp, seed_num, cfg_weight], + outputs=[audio_output, queue_info], + ) + +app = gr.mount_gradio_app(app, demo, path="/gradio") + +if __name__ == "__main__": + print("🚀 Starting server with CORS fix...") + print(f"📊 Gradio UI: http://localhost:8000/gradio") + print(f"🔌 API endpoint: POST http://localhost:8000/process-audio") + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/src/components/CyberSecurityArticle.tsx b/src/components/CyberSecurityArticle.tsx index 3e8ce65..ba4be23 100644 --- a/src/components/CyberSecurityArticle.tsx +++ b/src/components/CyberSecurityArticle.tsx @@ -306,11 +306,11 @@ const CyberSecurityArticle: React.FC = () => {