Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from shared import RealtimeSpeakerDiarization | |
| import numpy as np | |
| import uvicorn | |
| import logging | |
| import asyncio | |
| import json | |
| import time | |
| from typing import Set, Dict, Any | |
| import traceback | |
| # Check for RealtimeSTT and install if needed | |
| try: | |
| from RealtimeSTT import AudioToTextRecorder | |
| except ImportError: | |
| import subprocess | |
| import sys | |
| print("Installing RealtimeSTT dependency...") | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "RealtimeSTT"]) | |
| from RealtimeSTT import AudioToTextRecorder | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Real-time Speaker Diarization API", version="1.0.0") | |
| # Add CORS middleware for browser compatibility | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global state management | |
| diart = None | |
| active_connections: Set[WebSocket] = set() | |
| connection_stats: Dict[str, Any] = { | |
| "total_connections": 0, | |
| "current_connections": 0, | |
| "last_audio_received": None, | |
| "total_audio_chunks": 0 | |
| } | |
| class ConnectionManager: | |
| """Manages WebSocket connections and broadcasting""" | |
| def __init__(self): | |
| self.active_connections: Set[WebSocket] = set() | |
| self.connection_metadata: Dict[WebSocket, Dict] = {} | |
| async def connect(self, websocket: WebSocket, client_id: str = None): | |
| """Add a new WebSocket connection""" | |
| await websocket.accept() | |
| self.active_connections.add(websocket) | |
| self.connection_metadata[websocket] = { | |
| "client_id": client_id or f"client_{int(time.time())}", | |
| "connected_at": time.time(), | |
| "messages_sent": 0 | |
| } | |
| connection_stats["current_connections"] = len(self.active_connections) | |
| connection_stats["total_connections"] += 1 | |
| # Start recording if this is the first connection and system is ready | |
| if len(self.active_connections) == 1 and diart and not diart.is_running: | |
| logger.info("First connection established, starting recording") | |
| diart.start_recording() | |
| logger.info(f"WebSocket connected: {self.connection_metadata[websocket]['client_id']}. " | |
| f"Total connections: {len(self.active_connections)}") | |
| def disconnect(self, websocket: WebSocket): | |
| """Remove a WebSocket connection""" | |
| if websocket in self.active_connections: | |
| client_info = self.connection_metadata.get(websocket, {}) | |
| client_id = client_info.get("client_id", "unknown") | |
| self.active_connections.discard(websocket) | |
| self.connection_metadata.pop(websocket, None) | |
| connection_stats["current_connections"] = len(self.active_connections) | |
| # If no more connections, stop recording to save resources | |
| if len(self.active_connections) == 0 and diart and diart.is_running: | |
| logger.info("No active connections, stopping recording") | |
| diart.stop_recording() | |
| logger.info(f"WebSocket disconnected: {client_id}. " | |
| f"Remaining connections: {len(self.active_connections)}") | |
| async def broadcast(self, message: str): | |
| """Broadcast message to all active connections""" | |
| if not self.active_connections: | |
| return | |
| disconnected = set() | |
| for websocket in self.active_connections.copy(): | |
| try: | |
| await websocket.send_text(message) | |
| if websocket in self.connection_metadata: | |
| self.connection_metadata[websocket]["messages_sent"] += 1 | |
| except Exception as e: | |
| logger.warning(f"Failed to send message to client: {e}") | |
| disconnected.add(websocket) | |
| # Clean up disconnected clients | |
| for ws in disconnected: | |
| self.disconnect(ws) | |
| def get_stats(self): | |
| """Get connection statistics""" | |
| return { | |
| "active_connections": len(self.active_connections), | |
| "connection_metadata": { | |
| ws_id: meta for ws_id, (ws, meta) in | |
| enumerate(self.connection_metadata.items()) | |
| } | |
| } | |
| # Initialize connection manager | |
| manager = ConnectionManager() | |
| async def initialize_diarization_system(): | |
| """Initialize the diarization system with proper error handling""" | |
| global diart | |
| try: | |
| logger.info("Initializing diarization system...") | |
| diart = RealtimeSpeakerDiarization() | |
| success = diart.initialize_models() | |
| if success: | |
| logger.info("Models initialized successfully") | |
| # Don't start recording yet - wait for an actual connection | |
| # diart.start_recording() | |
| logger.info("System ready for connections") | |
| return True | |
| else: | |
| logger.error("Failed to initialize models") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error initializing diarization system: {e}") | |
| logger.error(traceback.format_exc()) | |
| return False | |
| async def send_conversation_updates(): | |
| """Periodically send conversation updates to all connected clients""" | |
| update_interval = 0.5 # 500ms update intervals | |
| last_conversation_hash = None | |
| while True: | |
| try: | |
| if diart and diart.is_running and manager.active_connections: | |
| # Get current conversation | |
| conversation_html = diart.get_formatted_conversation() | |
| # Only send if conversation has changed (to reduce bandwidth) | |
| conversation_hash = hash(conversation_html) | |
| if conversation_hash != last_conversation_hash: | |
| # Create structured message | |
| update_message = json.dumps({ | |
| "type": "conversation_update", | |
| "timestamp": time.time(), | |
| "conversation_html": conversation_html, | |
| "status": diart.get_status_info() if hasattr(diart, 'get_status_info') else {} | |
| }) | |
| await manager.broadcast(update_message) | |
| last_conversation_hash = conversation_hash | |
| except Exception as e: | |
| logger.error(f"Error in conversation update: {e}") | |
| await asyncio.sleep(update_interval) | |
| async def startup_event(): | |
| """Initialize system on startup""" | |
| logger.info("Starting Real-time Speaker Diarization Service") | |
| # Initialize diarization system | |
| success = await initialize_diarization_system() | |
| if not success: | |
| logger.error("Failed to initialize diarization system!") | |
| # Start background update task | |
| asyncio.create_task(send_conversation_updates()) | |
| logger.info("Background tasks started") | |
| async def shutdown_event(): | |
| """Clean up on shutdown""" | |
| logger.info("Shutting down...") | |
| if diart: | |
| try: | |
| diart.stop_recording() | |
| logger.info("Recording stopped") | |
| # Shutdown RealtimeSTT properly if available | |
| if hasattr(diart, 'recorder') and diart.recorder: | |
| try: | |
| diart.recorder.shutdown() | |
| logger.info("Transcription model shut down") | |
| except Exception as e: | |
| logger.error(f"Error shutting down transcription model: {e}") | |
| except Exception as e: | |
| logger.error(f"Error stopping recording: {e}") | |
| async def root(): | |
| """Root endpoint with service information""" | |
| return { | |
| "service": "Real-time Speaker Diarization API", | |
| "version": "1.0.0", | |
| "status": "running" if diart and diart.is_running else "initializing", | |
| "endpoints": { | |
| "websocket": "/ws_inference", | |
| "health": "/health", | |
| "conversation": "/conversation", | |
| "status": "/status" | |
| } | |
| } | |
| async def health_check(): | |
| """Comprehensive health check endpoint""" | |
| system_healthy = diart and diart.is_running | |
| return { | |
| "status": "healthy" if system_healthy else "unhealthy", | |
| "system_running": system_healthy, | |
| "active_connections": len(manager.active_connections), | |
| "connection_stats": connection_stats, | |
| "diarization_status": diart.get_status_info() if diart and hasattr(diart, 'get_status_info') else {} | |
| } | |
| async def ws_inference(websocket: WebSocket): | |
| """WebSocket endpoint for real-time audio processing""" | |
| client_id = f"client_{int(time.time())}" | |
| try: | |
| await manager.connect(websocket, client_id) | |
| # Send initial connection confirmation | |
| initial_message = json.dumps({ | |
| "type": "connection_established", | |
| "client_id": client_id, | |
| "system_status": "ready" if diart and diart.is_running else "initializing", | |
| "conversation": diart.get_formatted_conversation() if diart else "" | |
| }) | |
| await websocket.send_text(initial_message) | |
| # Process incoming audio data | |
| async for data in websocket.iter_bytes(): | |
| try: | |
| if data and diart and diart.is_running: | |
| # Update statistics | |
| connection_stats["last_audio_received"] = time.time() | |
| connection_stats["total_audio_chunks"] += 1 | |
| # Process audio chunk | |
| result = diart.process_audio_chunk(data) | |
| # Send processing result back to client | |
| if result: | |
| # Ensure all numeric values are JSON serializable | |
| for key in result: | |
| if isinstance(result[key], np.number): | |
| result[key] = result[key].item() | |
| result_message = json.dumps({ | |
| "type": "processing_result", | |
| "timestamp": time.time(), | |
| "data": result | |
| }) | |
| await websocket.send_text(result_message) | |
| # Log processing result (optional) | |
| if connection_stats["total_audio_chunks"] % 100 == 0: # Log every 100 chunks | |
| logger.debug(f"Processed {connection_stats['total_audio_chunks']} audio chunks") | |
| elif not diart: | |
| logger.warning("Received audio data but diarization system not initialized") | |
| error_message = json.dumps({ | |
| "type": "error", | |
| "message": "Diarization system not initialized", | |
| "timestamp": time.time() | |
| }) | |
| await websocket.send_text(error_message) | |
| except Exception as e: | |
| logger.error(f"Error processing audio chunk: {e}") | |
| # Send error message to client | |
| error_message = json.dumps({ | |
| "type": "error", | |
| "message": "Error processing audio", | |
| "details": str(e), | |
| "timestamp": time.time() | |
| }) | |
| await websocket.send_text(error_message) | |
| except WebSocketDisconnect: | |
| logger.info(f"WebSocket {client_id} disconnected normally") | |
| except Exception as e: | |
| logger.error(f"WebSocket {client_id} error: {e}") | |
| finally: | |
| manager.disconnect(websocket) | |
| async def get_conversation(): | |
| """Get the current conversation as HTML""" | |
| if not diart: | |
| raise HTTPException(status_code=503, detail="Diarization system not initialized") | |
| try: | |
| conversation = diart.get_formatted_conversation() | |
| return { | |
| "conversation": conversation, | |
| "timestamp": time.time(), | |
| "system_status": diart.get_status_info() if hasattr(diart, 'get_status_info') else {} | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting conversation: {e}") | |
| raise HTTPException(status_code=500, detail="Error retrieving conversation") | |
| async def get_status(): | |
| """Get comprehensive system status information""" | |
| if not diart: | |
| return {"status": "system_not_initialized"} | |
| try: | |
| base_status = diart.get_status_info() if hasattr(diart, 'get_status_info') else {} | |
| return { | |
| **base_status, | |
| "connection_stats": connection_stats, | |
| "active_connections": len(manager.active_connections), | |
| "system_uptime": time.time() - connection_stats.get("system_start_time", time.time()) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting status: {e}") | |
| return {"status": "error", "message": str(e)} | |
| async def update_settings(threshold: float = None, max_speakers: int = None): | |
| """Update speaker detection settings""" | |
| if not diart: | |
| raise HTTPException(status_code=503, detail="Diarization system not initialized") | |
| try: | |
| # Validate parameters | |
| if threshold is not None and (threshold < 0 or threshold > 1): | |
| raise HTTPException(status_code=400, detail="Threshold must be between 0 and 1") | |
| if max_speakers is not None and (max_speakers < 1 or max_speakers > 20): | |
| raise HTTPException(status_code=400, detail="Max speakers must be between 1 and 20") | |
| result = diart.update_settings(threshold, max_speakers) | |
| return { | |
| "result": result, | |
| "updated_settings": { | |
| "threshold": threshold, | |
| "max_speakers": max_speakers | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Error updating settings: {e}") | |
| raise HTTPException(status_code=500, detail="Error updating settings") | |
| async def clear_conversation(): | |
| """Clear the conversation history""" | |
| if not diart: | |
| raise HTTPException(status_code=503, detail="Diarization system not initialized") | |
| try: | |
| result = diart.clear_conversation() | |
| # Notify all connected clients about the clear | |
| clear_message = json.dumps({ | |
| "type": "conversation_cleared", | |
| "timestamp": time.time() | |
| }) | |
| await manager.broadcast(clear_message) | |
| return {"result": result, "message": "Conversation cleared successfully"} | |
| except Exception as e: | |
| logger.error(f"Error clearing conversation: {e}") | |
| raise HTTPException(status_code=500, detail="Error clearing conversation") | |
| async def get_connection_stats(): | |
| """Get detailed connection statistics""" | |
| return { | |
| "connection_stats": connection_stats, | |
| "manager_stats": manager.get_stats(), | |
| "system_info": { | |
| "diarization_running": diart.is_running if diart else False, | |
| "total_active_connections": len(manager.active_connections) | |
| } | |
| } | |
| # Mount UI if available | |
| try: | |
| import ui | |
| ui.mount_ui(app) | |
| logger.info("Gradio UI mounted successfully") | |
| except ImportError: | |
| logger.warning("UI module not found, running in API-only mode") | |
| except Exception as e: | |
| logger.error(f"Error mounting UI: {e}") | |
| # Initialize system start time | |
| connection_stats["system_start_time"] = time.time() | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info", | |
| access_log=True | |
| ) |