|
|
""" |
|
|
DungeonMaster AI - MCP Connection Manager |
|
|
|
|
|
Manages MCP connection lifecycle with health checks, automatic reconnection, |
|
|
circuit breaker pattern, and graceful degradation. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import asyncio |
|
|
import contextlib |
|
|
import logging |
|
|
import random |
|
|
import time |
|
|
from collections.abc import Sequence |
|
|
from datetime import datetime |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
from llama_index.core.tools import FunctionTool |
|
|
|
|
|
from src.config.settings import get_settings |
|
|
|
|
|
from .exceptions import ( |
|
|
MCPCircuitBreakerOpenError, |
|
|
MCPUnavailableError, |
|
|
) |
|
|
from .fallbacks import FallbackHandler |
|
|
from .models import ( |
|
|
CircuitBreakerState, |
|
|
ConnectionState, |
|
|
MCPConnectionStatus, |
|
|
) |
|
|
from .toolkit_client import TTRPGToolkitClient |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from typing import Any |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class CircuitBreaker: |
|
|
""" |
|
|
Circuit breaker pattern implementation. |
|
|
|
|
|
Prevents repeated calls to a failing service by tracking failures |
|
|
and temporarily rejecting requests when failure threshold is reached. |
|
|
|
|
|
States: |
|
|
- CLOSED: Normal operation, requests allowed |
|
|
- OPEN: Too many failures, requests rejected for reset_timeout |
|
|
- HALF_OPEN: Testing if service recovered with a single request |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
failure_threshold: int = 5, |
|
|
reset_timeout: float = 30.0, |
|
|
half_open_max_calls: int = 1, |
|
|
) -> None: |
|
|
""" |
|
|
Initialize circuit breaker. |
|
|
|
|
|
Args: |
|
|
failure_threshold: Failures before opening circuit |
|
|
reset_timeout: Seconds to wait before trying again |
|
|
half_open_max_calls: Max calls allowed in half-open state |
|
|
""" |
|
|
self.failure_threshold = failure_threshold |
|
|
self.reset_timeout = reset_timeout |
|
|
self.half_open_max_calls = half_open_max_calls |
|
|
|
|
|
self._state = CircuitBreakerState.CLOSED |
|
|
self._failure_count = 0 |
|
|
self._last_failure_time: float | None = None |
|
|
self._half_open_calls = 0 |
|
|
|
|
|
@property |
|
|
def state(self) -> CircuitBreakerState: |
|
|
"""Get current circuit breaker state.""" |
|
|
|
|
|
if ( |
|
|
self._state == CircuitBreakerState.OPEN |
|
|
and self._last_failure_time is not None |
|
|
): |
|
|
elapsed = time.time() - self._last_failure_time |
|
|
if elapsed >= self.reset_timeout: |
|
|
self._state = CircuitBreakerState.HALF_OPEN |
|
|
self._half_open_calls = 0 |
|
|
logger.info("Circuit breaker transitioned to HALF_OPEN") |
|
|
|
|
|
return self._state |
|
|
|
|
|
@property |
|
|
def is_open(self) -> bool: |
|
|
"""Check if circuit is open (rejecting requests).""" |
|
|
return self.state == CircuitBreakerState.OPEN |
|
|
|
|
|
@property |
|
|
def time_until_retry(self) -> float | None: |
|
|
"""Seconds until retry is allowed, or None if not in OPEN state.""" |
|
|
if self._state != CircuitBreakerState.OPEN: |
|
|
return None |
|
|
if self._last_failure_time is None: |
|
|
return None |
|
|
elapsed = time.time() - self._last_failure_time |
|
|
remaining = self.reset_timeout - elapsed |
|
|
return max(0.0, remaining) |
|
|
|
|
|
def record_success(self) -> None: |
|
|
"""Record a successful call.""" |
|
|
if self._state == CircuitBreakerState.HALF_OPEN: |
|
|
|
|
|
self._state = CircuitBreakerState.CLOSED |
|
|
logger.info("Circuit breaker closed after successful recovery") |
|
|
|
|
|
self._failure_count = 0 |
|
|
self._last_failure_time = None |
|
|
|
|
|
def record_failure(self) -> None: |
|
|
"""Record a failed call.""" |
|
|
self._failure_count += 1 |
|
|
self._last_failure_time = time.time() |
|
|
|
|
|
if self._state == CircuitBreakerState.HALF_OPEN: |
|
|
|
|
|
self._state = CircuitBreakerState.OPEN |
|
|
logger.warning("Circuit breaker reopened after half-open failure") |
|
|
elif self._failure_count >= self.failure_threshold: |
|
|
self._state = CircuitBreakerState.OPEN |
|
|
logger.warning( |
|
|
f"Circuit breaker opened after {self._failure_count} failures" |
|
|
) |
|
|
|
|
|
def allow_request(self) -> bool: |
|
|
""" |
|
|
Check if a request should be allowed. |
|
|
|
|
|
Returns: |
|
|
True if request is allowed, False if should be rejected. |
|
|
""" |
|
|
state = self.state |
|
|
|
|
|
if state == CircuitBreakerState.CLOSED: |
|
|
return True |
|
|
|
|
|
if state == CircuitBreakerState.HALF_OPEN: |
|
|
if self._half_open_calls < self.half_open_max_calls: |
|
|
self._half_open_calls += 1 |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
return False |
|
|
|
|
|
def reset(self) -> None: |
|
|
"""Reset circuit breaker to initial state.""" |
|
|
self._state = CircuitBreakerState.CLOSED |
|
|
self._failure_count = 0 |
|
|
self._last_failure_time = None |
|
|
self._half_open_calls = 0 |
|
|
|
|
|
|
|
|
class ConnectionManager: |
|
|
""" |
|
|
Manages MCP connection lifecycle with health checks and reconnection. |
|
|
|
|
|
Features: |
|
|
- Automatic reconnection with exponential backoff |
|
|
- Circuit breaker to prevent hammering failed server |
|
|
- Health check monitoring |
|
|
- Graceful degradation via FallbackHandler |
|
|
- Connection status tracking |
|
|
|
|
|
Example: |
|
|
```python |
|
|
manager = ConnectionManager() |
|
|
connected = await manager.connect() |
|
|
|
|
|
if manager.is_available: |
|
|
tools = await manager.get_tools() |
|
|
result = await manager.execute_tool("roll", {"notation": "1d20"}) |
|
|
else: |
|
|
# Fallback handling |
|
|
result = await manager.execute_with_fallback("roll", {"notation": "1d20"}) |
|
|
``` |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
toolkit_client: TTRPGToolkitClient | None = None, |
|
|
max_retries: int | None = None, |
|
|
retry_delay: float | None = None, |
|
|
fallback_handler: FallbackHandler | None = None, |
|
|
) -> None: |
|
|
""" |
|
|
Initialize connection manager. |
|
|
|
|
|
Args: |
|
|
toolkit_client: Pre-configured client, or None to create new one |
|
|
max_retries: Max reconnection attempts (default from settings) |
|
|
retry_delay: Base delay between retries (default from settings) |
|
|
fallback_handler: Handler for graceful degradation |
|
|
""" |
|
|
settings = get_settings() |
|
|
|
|
|
self._client = toolkit_client or TTRPGToolkitClient() |
|
|
self._max_retries: int = max_retries or settings.mcp.mcp_retry_attempts |
|
|
self._retry_delay: float = retry_delay or settings.mcp.mcp_retry_delay |
|
|
self._fallback_handler = fallback_handler or FallbackHandler() |
|
|
|
|
|
|
|
|
self._state = ConnectionState.DISCONNECTED |
|
|
self._last_successful_call: datetime | None = None |
|
|
self._consecutive_failures = 0 |
|
|
self._last_error: str | None = None |
|
|
|
|
|
|
|
|
self._circuit_breaker = CircuitBreaker( |
|
|
failure_threshold=5, |
|
|
reset_timeout=30.0, |
|
|
) |
|
|
|
|
|
|
|
|
self._health_check_task: asyncio.Task[None] | None = None |
|
|
self._health_check_interval = 60.0 |
|
|
|
|
|
@property |
|
|
def state(self) -> ConnectionState: |
|
|
"""Get current connection state.""" |
|
|
return self._state |
|
|
|
|
|
@property |
|
|
def is_available(self) -> bool: |
|
|
"""Check if MCP is available for use.""" |
|
|
return ( |
|
|
self._state == ConnectionState.CONNECTED |
|
|
and not self._circuit_breaker.is_open |
|
|
) |
|
|
|
|
|
@property |
|
|
def client(self) -> TTRPGToolkitClient: |
|
|
"""Get the underlying toolkit client.""" |
|
|
return self._client |
|
|
|
|
|
def get_status(self) -> MCPConnectionStatus: |
|
|
"""Get detailed connection status.""" |
|
|
return MCPConnectionStatus( |
|
|
state=self._state, |
|
|
is_available=self.is_available, |
|
|
url=self._client.url, |
|
|
last_successful_call=self._last_successful_call, |
|
|
consecutive_failures=self._consecutive_failures, |
|
|
circuit_breaker_state=self._circuit_breaker.state, |
|
|
tools_count=self._client.tools_count, |
|
|
error_message=self._last_error, |
|
|
) |
|
|
|
|
|
async def connect(self) -> bool: |
|
|
""" |
|
|
Connect to MCP server with retry logic. |
|
|
|
|
|
Returns: |
|
|
True if connection successful, False otherwise. |
|
|
""" |
|
|
self._state = ConnectionState.CONNECTING |
|
|
logger.info("Attempting to connect to MCP server...") |
|
|
|
|
|
for attempt in range(self._max_retries): |
|
|
try: |
|
|
await self._client.connect() |
|
|
self._state = ConnectionState.CONNECTED |
|
|
self._consecutive_failures = 0 |
|
|
self._last_successful_call = datetime.now() |
|
|
self._last_error = None |
|
|
self._circuit_breaker.reset() |
|
|
|
|
|
logger.info("Successfully connected to MCP server") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
self._consecutive_failures += 1 |
|
|
self._last_error = str(e) |
|
|
logger.warning( |
|
|
f"Connection attempt {attempt + 1}/{self._max_retries} failed: {e}" |
|
|
) |
|
|
|
|
|
if attempt < self._max_retries - 1: |
|
|
delay = self._calculate_backoff_delay(attempt) |
|
|
logger.info(f"Retrying in {delay:.2f} seconds...") |
|
|
await asyncio.sleep(delay) |
|
|
|
|
|
self._state = ConnectionState.ERROR |
|
|
logger.error(f"Failed to connect after {self._max_retries} attempts") |
|
|
return False |
|
|
|
|
|
async def disconnect(self) -> None: |
|
|
"""Disconnect from MCP server.""" |
|
|
|
|
|
if self._health_check_task and not self._health_check_task.done(): |
|
|
self._health_check_task.cancel() |
|
|
with contextlib.suppress(asyncio.CancelledError): |
|
|
await self._health_check_task |
|
|
|
|
|
await self._client.disconnect() |
|
|
self._state = ConnectionState.DISCONNECTED |
|
|
logger.info("Disconnected from MCP server") |
|
|
|
|
|
async def health_check(self) -> bool: |
|
|
""" |
|
|
Perform health check by listing tools. |
|
|
|
|
|
Returns: |
|
|
True if healthy, False otherwise. |
|
|
""" |
|
|
try: |
|
|
await self._client.list_tool_names() |
|
|
self._last_successful_call = datetime.now() |
|
|
self._consecutive_failures = 0 |
|
|
self._circuit_breaker.record_success() |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
self._consecutive_failures += 1 |
|
|
self._circuit_breaker.record_failure() |
|
|
logger.warning(f"Health check failed: {e}") |
|
|
return False |
|
|
|
|
|
async def get_tools( |
|
|
self, |
|
|
categories: Sequence[str] | None = None, |
|
|
) -> Sequence[FunctionTool]: |
|
|
""" |
|
|
Get tools with automatic reconnection on failure. |
|
|
|
|
|
Args: |
|
|
categories: Optional list of categories to filter. |
|
|
|
|
|
Returns: |
|
|
Sequence of FunctionTool objects. |
|
|
|
|
|
Raises: |
|
|
MCPUnavailableError: If MCP is unavailable after reconnection attempt. |
|
|
""" |
|
|
if not self.is_available: |
|
|
await self._attempt_reconnect() |
|
|
|
|
|
if not self.is_available: |
|
|
raise MCPUnavailableError( |
|
|
"MCP server is unavailable", |
|
|
reason=self._last_error, |
|
|
) |
|
|
|
|
|
try: |
|
|
tools: Sequence[FunctionTool] |
|
|
if categories: |
|
|
tools = await self._client.get_tools_by_category(categories) |
|
|
else: |
|
|
tools = await self._client.get_all_tools() |
|
|
|
|
|
self._last_successful_call = datetime.now() |
|
|
self._circuit_breaker.record_success() |
|
|
return tools |
|
|
|
|
|
except Exception as e: |
|
|
self._circuit_breaker.record_failure() |
|
|
self._last_error = str(e) |
|
|
logger.error(f"Failed to get tools: {e}") |
|
|
|
|
|
|
|
|
if await self._attempt_reconnect(): |
|
|
|
|
|
if categories: |
|
|
return await self._client.get_tools_by_category(categories) |
|
|
return await self._client.get_all_tools() |
|
|
|
|
|
raise MCPUnavailableError( |
|
|
"Unable to get tools after reconnection", |
|
|
reason=str(e), |
|
|
) from e |
|
|
|
|
|
async def execute_tool( |
|
|
self, |
|
|
tool_name: str, |
|
|
arguments: dict[str, Any], |
|
|
) -> Any: |
|
|
""" |
|
|
Execute a tool with connection management. |
|
|
|
|
|
Args: |
|
|
tool_name: Name of the tool to call. |
|
|
arguments: Tool arguments. |
|
|
|
|
|
Returns: |
|
|
Tool result. |
|
|
|
|
|
Raises: |
|
|
MCPCircuitBreakerOpenError: If circuit breaker is open. |
|
|
MCPUnavailableError: If MCP is unavailable. |
|
|
""" |
|
|
|
|
|
if not self._circuit_breaker.allow_request(): |
|
|
retry_after = self._circuit_breaker.time_until_retry |
|
|
raise MCPCircuitBreakerOpenError(retry_after_seconds=retry_after) |
|
|
|
|
|
if not self.is_available: |
|
|
await self._attempt_reconnect() |
|
|
|
|
|
if not self.is_available: |
|
|
raise MCPUnavailableError( |
|
|
"MCP server is unavailable", |
|
|
reason=self._last_error, |
|
|
) |
|
|
|
|
|
try: |
|
|
result = await self._client.call_tool(tool_name, arguments) |
|
|
self._last_successful_call = datetime.now() |
|
|
self._consecutive_failures = 0 |
|
|
self._circuit_breaker.record_success() |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
self._consecutive_failures += 1 |
|
|
self._circuit_breaker.record_failure() |
|
|
self._last_error = str(e) |
|
|
raise |
|
|
|
|
|
async def execute_with_fallback( |
|
|
self, |
|
|
tool_name: str, |
|
|
arguments: dict[str, Any], |
|
|
) -> Any: |
|
|
""" |
|
|
Execute tool with automatic fallback on failure. |
|
|
|
|
|
If MCP fails and a fallback handler can handle the tool, |
|
|
uses the fallback. Otherwise, raises the original error. |
|
|
|
|
|
Args: |
|
|
tool_name: Name of the tool to call. |
|
|
arguments: Tool arguments. |
|
|
|
|
|
Returns: |
|
|
Tool result (from MCP or fallback). |
|
|
|
|
|
Raises: |
|
|
MCPUnavailableError: If MCP fails and no fallback available. |
|
|
""" |
|
|
try: |
|
|
return await self.execute_tool(tool_name, arguments) |
|
|
|
|
|
except (MCPUnavailableError, MCPCircuitBreakerOpenError) as e: |
|
|
|
|
|
if self._fallback_handler.can_handle(tool_name): |
|
|
logger.info(f"Using fallback for tool '{tool_name}'") |
|
|
return await self._fallback_handler.handle(tool_name, arguments) |
|
|
|
|
|
|
|
|
raise MCPUnavailableError( |
|
|
f"MCP unavailable and no fallback for '{tool_name}'", |
|
|
reason=str(e), |
|
|
) from e |
|
|
|
|
|
async def _attempt_reconnect(self) -> bool: |
|
|
""" |
|
|
Attempt reconnection with exponential backoff. |
|
|
|
|
|
Returns: |
|
|
True if reconnection successful, False otherwise. |
|
|
""" |
|
|
if self._state == ConnectionState.RECONNECTING: |
|
|
|
|
|
return False |
|
|
|
|
|
self._state = ConnectionState.RECONNECTING |
|
|
logger.info("Attempting to reconnect to MCP server...") |
|
|
|
|
|
for attempt in range(self._max_retries): |
|
|
try: |
|
|
await self._client.disconnect() |
|
|
await self._client.connect() |
|
|
|
|
|
self._state = ConnectionState.CONNECTED |
|
|
self._consecutive_failures = 0 |
|
|
self._last_successful_call = datetime.now() |
|
|
self._circuit_breaker.reset() |
|
|
|
|
|
logger.info("Reconnection successful") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Reconnection attempt {attempt + 1} failed: {e}") |
|
|
delay = self._calculate_backoff_delay(attempt) |
|
|
await asyncio.sleep(delay) |
|
|
|
|
|
self._state = ConnectionState.ERROR |
|
|
logger.error("All reconnection attempts failed") |
|
|
return False |
|
|
|
|
|
def _calculate_backoff_delay(self, attempt: int) -> float: |
|
|
""" |
|
|
Calculate delay with exponential backoff and jitter. |
|
|
|
|
|
Args: |
|
|
attempt: Current attempt number (0-indexed). |
|
|
|
|
|
Returns: |
|
|
Delay in seconds. |
|
|
""" |
|
|
|
|
|
delay: float = self._retry_delay * (2**attempt) |
|
|
|
|
|
|
|
|
delay = min(delay, 30.0) |
|
|
|
|
|
|
|
|
jitter: float = delay * 0.1 * random.random() |
|
|
delay += jitter |
|
|
|
|
|
return delay |
|
|
|
|
|
async def start_health_monitoring(self) -> None: |
|
|
"""Start background health check monitoring.""" |
|
|
if self._health_check_task and not self._health_check_task.done(): |
|
|
logger.warning("Health monitoring already running") |
|
|
return |
|
|
|
|
|
self._health_check_task = asyncio.create_task(self._health_check_loop()) |
|
|
logger.info("Started health check monitoring") |
|
|
|
|
|
async def stop_health_monitoring(self) -> None: |
|
|
"""Stop background health check monitoring.""" |
|
|
if self._health_check_task and not self._health_check_task.done(): |
|
|
self._health_check_task.cancel() |
|
|
with contextlib.suppress(asyncio.CancelledError): |
|
|
await self._health_check_task |
|
|
logger.info("Stopped health check monitoring") |
|
|
|
|
|
async def _health_check_loop(self) -> None: |
|
|
"""Background task for periodic health checks.""" |
|
|
while True: |
|
|
try: |
|
|
await asyncio.sleep(self._health_check_interval) |
|
|
|
|
|
if self._state == ConnectionState.CONNECTED: |
|
|
healthy = await self.health_check() |
|
|
if not healthy: |
|
|
logger.warning("Health check failed, attempting reconnection") |
|
|
await self._attempt_reconnect() |
|
|
|
|
|
except asyncio.CancelledError: |
|
|
break |
|
|
except Exception as e: |
|
|
logger.error(f"Health check loop error: {e}") |
|
|
|
|
|
def get_unavailable_message(self) -> str: |
|
|
"""Get user-friendly message when MCP is unavailable.""" |
|
|
return self._fallback_handler.get_unavailable_message() |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
"""String representation.""" |
|
|
return ( |
|
|
f"ConnectionManager(state={self._state.value}, " |
|
|
f"available={self.is_available}, " |
|
|
f"failures={self._consecutive_failures})" |
|
|
) |
|
|
|