""" DungeonMaster AI - MCP Connection Manager Manages MCP connection lifecycle with health checks, automatic reconnection, circuit breaker pattern, and graceful degradation. """ from __future__ import annotations import asyncio import contextlib import logging import random import time from collections.abc import Sequence from datetime import datetime from typing import TYPE_CHECKING from llama_index.core.tools import FunctionTool from src.config.settings import get_settings from .exceptions import ( MCPCircuitBreakerOpenError, MCPUnavailableError, ) from .fallbacks import FallbackHandler from .models import ( CircuitBreakerState, ConnectionState, MCPConnectionStatus, ) from .toolkit_client import TTRPGToolkitClient if TYPE_CHECKING: from typing import Any logger = logging.getLogger(__name__) class CircuitBreaker: """ Circuit breaker pattern implementation. Prevents repeated calls to a failing service by tracking failures and temporarily rejecting requests when failure threshold is reached. States: - CLOSED: Normal operation, requests allowed - OPEN: Too many failures, requests rejected for reset_timeout - HALF_OPEN: Testing if service recovered with a single request """ def __init__( self, failure_threshold: int = 5, reset_timeout: float = 30.0, half_open_max_calls: int = 1, ) -> None: """ Initialize circuit breaker. Args: failure_threshold: Failures before opening circuit reset_timeout: Seconds to wait before trying again half_open_max_calls: Max calls allowed in half-open state """ self.failure_threshold = failure_threshold self.reset_timeout = reset_timeout self.half_open_max_calls = half_open_max_calls self._state = CircuitBreakerState.CLOSED self._failure_count = 0 self._last_failure_time: float | None = None self._half_open_calls = 0 @property def state(self) -> CircuitBreakerState: """Get current circuit breaker state.""" # Check if we should transition from OPEN to HALF_OPEN if ( self._state == CircuitBreakerState.OPEN and self._last_failure_time is not None ): elapsed = time.time() - self._last_failure_time if elapsed >= self.reset_timeout: self._state = CircuitBreakerState.HALF_OPEN self._half_open_calls = 0 logger.info("Circuit breaker transitioned to HALF_OPEN") return self._state @property def is_open(self) -> bool: """Check if circuit is open (rejecting requests).""" return self.state == CircuitBreakerState.OPEN @property def time_until_retry(self) -> float | None: """Seconds until retry is allowed, or None if not in OPEN state.""" if self._state != CircuitBreakerState.OPEN: return None if self._last_failure_time is None: return None elapsed = time.time() - self._last_failure_time remaining = self.reset_timeout - elapsed return max(0.0, remaining) def record_success(self) -> None: """Record a successful call.""" if self._state == CircuitBreakerState.HALF_OPEN: # Successful call in half-open means we can close self._state = CircuitBreakerState.CLOSED logger.info("Circuit breaker closed after successful recovery") self._failure_count = 0 self._last_failure_time = None def record_failure(self) -> None: """Record a failed call.""" self._failure_count += 1 self._last_failure_time = time.time() if self._state == CircuitBreakerState.HALF_OPEN: # Failure in half-open means service still failing self._state = CircuitBreakerState.OPEN logger.warning("Circuit breaker reopened after half-open failure") elif self._failure_count >= self.failure_threshold: self._state = CircuitBreakerState.OPEN logger.warning( f"Circuit breaker opened after {self._failure_count} failures" ) def allow_request(self) -> bool: """ Check if a request should be allowed. Returns: True if request is allowed, False if should be rejected. """ state = self.state # This may transition OPEN -> HALF_OPEN if state == CircuitBreakerState.CLOSED: return True if state == CircuitBreakerState.HALF_OPEN: if self._half_open_calls < self.half_open_max_calls: self._half_open_calls += 1 return True return False # OPEN state return False def reset(self) -> None: """Reset circuit breaker to initial state.""" self._state = CircuitBreakerState.CLOSED self._failure_count = 0 self._last_failure_time = None self._half_open_calls = 0 class ConnectionManager: """ Manages MCP connection lifecycle with health checks and reconnection. Features: - Automatic reconnection with exponential backoff - Circuit breaker to prevent hammering failed server - Health check monitoring - Graceful degradation via FallbackHandler - Connection status tracking Example: ```python manager = ConnectionManager() connected = await manager.connect() if manager.is_available: tools = await manager.get_tools() result = await manager.execute_tool("roll", {"notation": "1d20"}) else: # Fallback handling result = await manager.execute_with_fallback("roll", {"notation": "1d20"}) ``` """ def __init__( self, toolkit_client: TTRPGToolkitClient | None = None, max_retries: int | None = None, retry_delay: float | None = None, fallback_handler: FallbackHandler | None = None, ) -> None: """ Initialize connection manager. Args: toolkit_client: Pre-configured client, or None to create new one max_retries: Max reconnection attempts (default from settings) retry_delay: Base delay between retries (default from settings) fallback_handler: Handler for graceful degradation """ settings = get_settings() self._client = toolkit_client or TTRPGToolkitClient() self._max_retries: int = max_retries or settings.mcp.mcp_retry_attempts self._retry_delay: float = retry_delay or settings.mcp.mcp_retry_delay self._fallback_handler = fallback_handler or FallbackHandler() # Connection state self._state = ConnectionState.DISCONNECTED self._last_successful_call: datetime | None = None self._consecutive_failures = 0 self._last_error: str | None = None # Circuit breaker self._circuit_breaker = CircuitBreaker( failure_threshold=5, reset_timeout=30.0, ) # Health check task self._health_check_task: asyncio.Task[None] | None = None self._health_check_interval = 60.0 # seconds @property def state(self) -> ConnectionState: """Get current connection state.""" return self._state @property def is_available(self) -> bool: """Check if MCP is available for use.""" return ( self._state == ConnectionState.CONNECTED and not self._circuit_breaker.is_open ) @property def client(self) -> TTRPGToolkitClient: """Get the underlying toolkit client.""" return self._client def get_status(self) -> MCPConnectionStatus: """Get detailed connection status.""" return MCPConnectionStatus( state=self._state, is_available=self.is_available, url=self._client.url, last_successful_call=self._last_successful_call, consecutive_failures=self._consecutive_failures, circuit_breaker_state=self._circuit_breaker.state, tools_count=self._client.tools_count, error_message=self._last_error, ) async def connect(self) -> bool: """ Connect to MCP server with retry logic. Returns: True if connection successful, False otherwise. """ self._state = ConnectionState.CONNECTING logger.info("Attempting to connect to MCP server...") for attempt in range(self._max_retries): try: await self._client.connect() self._state = ConnectionState.CONNECTED self._consecutive_failures = 0 self._last_successful_call = datetime.now() self._last_error = None self._circuit_breaker.reset() logger.info("Successfully connected to MCP server") return True except Exception as e: self._consecutive_failures += 1 self._last_error = str(e) logger.warning( f"Connection attempt {attempt + 1}/{self._max_retries} failed: {e}" ) if attempt < self._max_retries - 1: delay = self._calculate_backoff_delay(attempt) logger.info(f"Retrying in {delay:.2f} seconds...") await asyncio.sleep(delay) self._state = ConnectionState.ERROR logger.error(f"Failed to connect after {self._max_retries} attempts") return False async def disconnect(self) -> None: """Disconnect from MCP server.""" # Stop health check task if running if self._health_check_task and not self._health_check_task.done(): self._health_check_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._health_check_task await self._client.disconnect() self._state = ConnectionState.DISCONNECTED logger.info("Disconnected from MCP server") async def health_check(self) -> bool: """ Perform health check by listing tools. Returns: True if healthy, False otherwise. """ try: await self._client.list_tool_names() self._last_successful_call = datetime.now() self._consecutive_failures = 0 self._circuit_breaker.record_success() return True except Exception as e: self._consecutive_failures += 1 self._circuit_breaker.record_failure() logger.warning(f"Health check failed: {e}") return False async def get_tools( self, categories: Sequence[str] | None = None, ) -> Sequence[FunctionTool]: """ Get tools with automatic reconnection on failure. Args: categories: Optional list of categories to filter. Returns: Sequence of FunctionTool objects. Raises: MCPUnavailableError: If MCP is unavailable after reconnection attempt. """ if not self.is_available: await self._attempt_reconnect() if not self.is_available: raise MCPUnavailableError( "MCP server is unavailable", reason=self._last_error, ) try: tools: Sequence[FunctionTool] if categories: tools = await self._client.get_tools_by_category(categories) else: tools = await self._client.get_all_tools() self._last_successful_call = datetime.now() self._circuit_breaker.record_success() return tools except Exception as e: self._circuit_breaker.record_failure() self._last_error = str(e) logger.error(f"Failed to get tools: {e}") # Try reconnection if await self._attempt_reconnect(): # Retry after reconnection if categories: return await self._client.get_tools_by_category(categories) return await self._client.get_all_tools() raise MCPUnavailableError( "Unable to get tools after reconnection", reason=str(e), ) from e async def execute_tool( self, tool_name: str, arguments: dict[str, Any], ) -> Any: """ Execute a tool with connection management. Args: tool_name: Name of the tool to call. arguments: Tool arguments. Returns: Tool result. Raises: MCPCircuitBreakerOpenError: If circuit breaker is open. MCPUnavailableError: If MCP is unavailable. """ # Check circuit breaker if not self._circuit_breaker.allow_request(): retry_after = self._circuit_breaker.time_until_retry raise MCPCircuitBreakerOpenError(retry_after_seconds=retry_after) if not self.is_available: await self._attempt_reconnect() if not self.is_available: raise MCPUnavailableError( "MCP server is unavailable", reason=self._last_error, ) try: result = await self._client.call_tool(tool_name, arguments) self._last_successful_call = datetime.now() self._consecutive_failures = 0 self._circuit_breaker.record_success() return result except Exception as e: self._consecutive_failures += 1 self._circuit_breaker.record_failure() self._last_error = str(e) raise async def execute_with_fallback( self, tool_name: str, arguments: dict[str, Any], ) -> Any: """ Execute tool with automatic fallback on failure. If MCP fails and a fallback handler can handle the tool, uses the fallback. Otherwise, raises the original error. Args: tool_name: Name of the tool to call. arguments: Tool arguments. Returns: Tool result (from MCP or fallback). Raises: MCPUnavailableError: If MCP fails and no fallback available. """ try: return await self.execute_tool(tool_name, arguments) except (MCPUnavailableError, MCPCircuitBreakerOpenError) as e: # Try fallback if self._fallback_handler.can_handle(tool_name): logger.info(f"Using fallback for tool '{tool_name}'") return await self._fallback_handler.handle(tool_name, arguments) # No fallback available raise MCPUnavailableError( f"MCP unavailable and no fallback for '{tool_name}'", reason=str(e), ) from e async def _attempt_reconnect(self) -> bool: """ Attempt reconnection with exponential backoff. Returns: True if reconnection successful, False otherwise. """ if self._state == ConnectionState.RECONNECTING: # Already reconnecting return False self._state = ConnectionState.RECONNECTING logger.info("Attempting to reconnect to MCP server...") for attempt in range(self._max_retries): try: await self._client.disconnect() await self._client.connect() self._state = ConnectionState.CONNECTED self._consecutive_failures = 0 self._last_successful_call = datetime.now() self._circuit_breaker.reset() logger.info("Reconnection successful") return True except Exception as e: logger.warning(f"Reconnection attempt {attempt + 1} failed: {e}") delay = self._calculate_backoff_delay(attempt) await asyncio.sleep(delay) self._state = ConnectionState.ERROR logger.error("All reconnection attempts failed") return False def _calculate_backoff_delay(self, attempt: int) -> float: """ Calculate delay with exponential backoff and jitter. Args: attempt: Current attempt number (0-indexed). Returns: Delay in seconds. """ # Exponential backoff delay: float = self._retry_delay * (2**attempt) # Cap at 30 seconds delay = min(delay, 30.0) # Add jitter (10% random variation) jitter: float = delay * 0.1 * random.random() delay += jitter return delay async def start_health_monitoring(self) -> None: """Start background health check monitoring.""" if self._health_check_task and not self._health_check_task.done(): logger.warning("Health monitoring already running") return self._health_check_task = asyncio.create_task(self._health_check_loop()) logger.info("Started health check monitoring") async def stop_health_monitoring(self) -> None: """Stop background health check monitoring.""" if self._health_check_task and not self._health_check_task.done(): self._health_check_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._health_check_task logger.info("Stopped health check monitoring") async def _health_check_loop(self) -> None: """Background task for periodic health checks.""" while True: try: await asyncio.sleep(self._health_check_interval) if self._state == ConnectionState.CONNECTED: healthy = await self.health_check() if not healthy: logger.warning("Health check failed, attempting reconnection") await self._attempt_reconnect() except asyncio.CancelledError: break except Exception as e: logger.error(f"Health check loop error: {e}") def get_unavailable_message(self) -> str: """Get user-friendly message when MCP is unavailable.""" return self._fallback_handler.get_unavailable_message() def __repr__(self) -> str: """String representation.""" return ( f"ConnectionManager(state={self._state.value}, " f"available={self.is_available}, " f"failures={self._consecutive_failures})" )