DungeonMaster-AI / src /mcp_integration /connection_manager.py
bhupesh-sf's picture
first commit
f8ba6bf verified
"""
DungeonMaster AI - MCP Connection Manager
Manages MCP connection lifecycle with health checks, automatic reconnection,
circuit breaker pattern, and graceful degradation.
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
import random
import time
from collections.abc import Sequence
from datetime import datetime
from typing import TYPE_CHECKING
from llama_index.core.tools import FunctionTool
from src.config.settings import get_settings
from .exceptions import (
MCPCircuitBreakerOpenError,
MCPUnavailableError,
)
from .fallbacks import FallbackHandler
from .models import (
CircuitBreakerState,
ConnectionState,
MCPConnectionStatus,
)
from .toolkit_client import TTRPGToolkitClient
if TYPE_CHECKING:
from typing import Any
logger = logging.getLogger(__name__)
class CircuitBreaker:
"""
Circuit breaker pattern implementation.
Prevents repeated calls to a failing service by tracking failures
and temporarily rejecting requests when failure threshold is reached.
States:
- CLOSED: Normal operation, requests allowed
- OPEN: Too many failures, requests rejected for reset_timeout
- HALF_OPEN: Testing if service recovered with a single request
"""
def __init__(
self,
failure_threshold: int = 5,
reset_timeout: float = 30.0,
half_open_max_calls: int = 1,
) -> None:
"""
Initialize circuit breaker.
Args:
failure_threshold: Failures before opening circuit
reset_timeout: Seconds to wait before trying again
half_open_max_calls: Max calls allowed in half-open state
"""
self.failure_threshold = failure_threshold
self.reset_timeout = reset_timeout
self.half_open_max_calls = half_open_max_calls
self._state = CircuitBreakerState.CLOSED
self._failure_count = 0
self._last_failure_time: float | None = None
self._half_open_calls = 0
@property
def state(self) -> CircuitBreakerState:
"""Get current circuit breaker state."""
# Check if we should transition from OPEN to HALF_OPEN
if (
self._state == CircuitBreakerState.OPEN
and self._last_failure_time is not None
):
elapsed = time.time() - self._last_failure_time
if elapsed >= self.reset_timeout:
self._state = CircuitBreakerState.HALF_OPEN
self._half_open_calls = 0
logger.info("Circuit breaker transitioned to HALF_OPEN")
return self._state
@property
def is_open(self) -> bool:
"""Check if circuit is open (rejecting requests)."""
return self.state == CircuitBreakerState.OPEN
@property
def time_until_retry(self) -> float | None:
"""Seconds until retry is allowed, or None if not in OPEN state."""
if self._state != CircuitBreakerState.OPEN:
return None
if self._last_failure_time is None:
return None
elapsed = time.time() - self._last_failure_time
remaining = self.reset_timeout - elapsed
return max(0.0, remaining)
def record_success(self) -> None:
"""Record a successful call."""
if self._state == CircuitBreakerState.HALF_OPEN:
# Successful call in half-open means we can close
self._state = CircuitBreakerState.CLOSED
logger.info("Circuit breaker closed after successful recovery")
self._failure_count = 0
self._last_failure_time = None
def record_failure(self) -> None:
"""Record a failed call."""
self._failure_count += 1
self._last_failure_time = time.time()
if self._state == CircuitBreakerState.HALF_OPEN:
# Failure in half-open means service still failing
self._state = CircuitBreakerState.OPEN
logger.warning("Circuit breaker reopened after half-open failure")
elif self._failure_count >= self.failure_threshold:
self._state = CircuitBreakerState.OPEN
logger.warning(
f"Circuit breaker opened after {self._failure_count} failures"
)
def allow_request(self) -> bool:
"""
Check if a request should be allowed.
Returns:
True if request is allowed, False if should be rejected.
"""
state = self.state # This may transition OPEN -> HALF_OPEN
if state == CircuitBreakerState.CLOSED:
return True
if state == CircuitBreakerState.HALF_OPEN:
if self._half_open_calls < self.half_open_max_calls:
self._half_open_calls += 1
return True
return False
# OPEN state
return False
def reset(self) -> None:
"""Reset circuit breaker to initial state."""
self._state = CircuitBreakerState.CLOSED
self._failure_count = 0
self._last_failure_time = None
self._half_open_calls = 0
class ConnectionManager:
"""
Manages MCP connection lifecycle with health checks and reconnection.
Features:
- Automatic reconnection with exponential backoff
- Circuit breaker to prevent hammering failed server
- Health check monitoring
- Graceful degradation via FallbackHandler
- Connection status tracking
Example:
```python
manager = ConnectionManager()
connected = await manager.connect()
if manager.is_available:
tools = await manager.get_tools()
result = await manager.execute_tool("roll", {"notation": "1d20"})
else:
# Fallback handling
result = await manager.execute_with_fallback("roll", {"notation": "1d20"})
```
"""
def __init__(
self,
toolkit_client: TTRPGToolkitClient | None = None,
max_retries: int | None = None,
retry_delay: float | None = None,
fallback_handler: FallbackHandler | None = None,
) -> None:
"""
Initialize connection manager.
Args:
toolkit_client: Pre-configured client, or None to create new one
max_retries: Max reconnection attempts (default from settings)
retry_delay: Base delay between retries (default from settings)
fallback_handler: Handler for graceful degradation
"""
settings = get_settings()
self._client = toolkit_client or TTRPGToolkitClient()
self._max_retries: int = max_retries or settings.mcp.mcp_retry_attempts
self._retry_delay: float = retry_delay or settings.mcp.mcp_retry_delay
self._fallback_handler = fallback_handler or FallbackHandler()
# Connection state
self._state = ConnectionState.DISCONNECTED
self._last_successful_call: datetime | None = None
self._consecutive_failures = 0
self._last_error: str | None = None
# Circuit breaker
self._circuit_breaker = CircuitBreaker(
failure_threshold=5,
reset_timeout=30.0,
)
# Health check task
self._health_check_task: asyncio.Task[None] | None = None
self._health_check_interval = 60.0 # seconds
@property
def state(self) -> ConnectionState:
"""Get current connection state."""
return self._state
@property
def is_available(self) -> bool:
"""Check if MCP is available for use."""
return (
self._state == ConnectionState.CONNECTED
and not self._circuit_breaker.is_open
)
@property
def client(self) -> TTRPGToolkitClient:
"""Get the underlying toolkit client."""
return self._client
def get_status(self) -> MCPConnectionStatus:
"""Get detailed connection status."""
return MCPConnectionStatus(
state=self._state,
is_available=self.is_available,
url=self._client.url,
last_successful_call=self._last_successful_call,
consecutive_failures=self._consecutive_failures,
circuit_breaker_state=self._circuit_breaker.state,
tools_count=self._client.tools_count,
error_message=self._last_error,
)
async def connect(self) -> bool:
"""
Connect to MCP server with retry logic.
Returns:
True if connection successful, False otherwise.
"""
self._state = ConnectionState.CONNECTING
logger.info("Attempting to connect to MCP server...")
for attempt in range(self._max_retries):
try:
await self._client.connect()
self._state = ConnectionState.CONNECTED
self._consecutive_failures = 0
self._last_successful_call = datetime.now()
self._last_error = None
self._circuit_breaker.reset()
logger.info("Successfully connected to MCP server")
return True
except Exception as e:
self._consecutive_failures += 1
self._last_error = str(e)
logger.warning(
f"Connection attempt {attempt + 1}/{self._max_retries} failed: {e}"
)
if attempt < self._max_retries - 1:
delay = self._calculate_backoff_delay(attempt)
logger.info(f"Retrying in {delay:.2f} seconds...")
await asyncio.sleep(delay)
self._state = ConnectionState.ERROR
logger.error(f"Failed to connect after {self._max_retries} attempts")
return False
async def disconnect(self) -> None:
"""Disconnect from MCP server."""
# Stop health check task if running
if self._health_check_task and not self._health_check_task.done():
self._health_check_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._health_check_task
await self._client.disconnect()
self._state = ConnectionState.DISCONNECTED
logger.info("Disconnected from MCP server")
async def health_check(self) -> bool:
"""
Perform health check by listing tools.
Returns:
True if healthy, False otherwise.
"""
try:
await self._client.list_tool_names()
self._last_successful_call = datetime.now()
self._consecutive_failures = 0
self._circuit_breaker.record_success()
return True
except Exception as e:
self._consecutive_failures += 1
self._circuit_breaker.record_failure()
logger.warning(f"Health check failed: {e}")
return False
async def get_tools(
self,
categories: Sequence[str] | None = None,
) -> Sequence[FunctionTool]:
"""
Get tools with automatic reconnection on failure.
Args:
categories: Optional list of categories to filter.
Returns:
Sequence of FunctionTool objects.
Raises:
MCPUnavailableError: If MCP is unavailable after reconnection attempt.
"""
if not self.is_available:
await self._attempt_reconnect()
if not self.is_available:
raise MCPUnavailableError(
"MCP server is unavailable",
reason=self._last_error,
)
try:
tools: Sequence[FunctionTool]
if categories:
tools = await self._client.get_tools_by_category(categories)
else:
tools = await self._client.get_all_tools()
self._last_successful_call = datetime.now()
self._circuit_breaker.record_success()
return tools
except Exception as e:
self._circuit_breaker.record_failure()
self._last_error = str(e)
logger.error(f"Failed to get tools: {e}")
# Try reconnection
if await self._attempt_reconnect():
# Retry after reconnection
if categories:
return await self._client.get_tools_by_category(categories)
return await self._client.get_all_tools()
raise MCPUnavailableError(
"Unable to get tools after reconnection",
reason=str(e),
) from e
async def execute_tool(
self,
tool_name: str,
arguments: dict[str, Any],
) -> Any:
"""
Execute a tool with connection management.
Args:
tool_name: Name of the tool to call.
arguments: Tool arguments.
Returns:
Tool result.
Raises:
MCPCircuitBreakerOpenError: If circuit breaker is open.
MCPUnavailableError: If MCP is unavailable.
"""
# Check circuit breaker
if not self._circuit_breaker.allow_request():
retry_after = self._circuit_breaker.time_until_retry
raise MCPCircuitBreakerOpenError(retry_after_seconds=retry_after)
if not self.is_available:
await self._attempt_reconnect()
if not self.is_available:
raise MCPUnavailableError(
"MCP server is unavailable",
reason=self._last_error,
)
try:
result = await self._client.call_tool(tool_name, arguments)
self._last_successful_call = datetime.now()
self._consecutive_failures = 0
self._circuit_breaker.record_success()
return result
except Exception as e:
self._consecutive_failures += 1
self._circuit_breaker.record_failure()
self._last_error = str(e)
raise
async def execute_with_fallback(
self,
tool_name: str,
arguments: dict[str, Any],
) -> Any:
"""
Execute tool with automatic fallback on failure.
If MCP fails and a fallback handler can handle the tool,
uses the fallback. Otherwise, raises the original error.
Args:
tool_name: Name of the tool to call.
arguments: Tool arguments.
Returns:
Tool result (from MCP or fallback).
Raises:
MCPUnavailableError: If MCP fails and no fallback available.
"""
try:
return await self.execute_tool(tool_name, arguments)
except (MCPUnavailableError, MCPCircuitBreakerOpenError) as e:
# Try fallback
if self._fallback_handler.can_handle(tool_name):
logger.info(f"Using fallback for tool '{tool_name}'")
return await self._fallback_handler.handle(tool_name, arguments)
# No fallback available
raise MCPUnavailableError(
f"MCP unavailable and no fallback for '{tool_name}'",
reason=str(e),
) from e
async def _attempt_reconnect(self) -> bool:
"""
Attempt reconnection with exponential backoff.
Returns:
True if reconnection successful, False otherwise.
"""
if self._state == ConnectionState.RECONNECTING:
# Already reconnecting
return False
self._state = ConnectionState.RECONNECTING
logger.info("Attempting to reconnect to MCP server...")
for attempt in range(self._max_retries):
try:
await self._client.disconnect()
await self._client.connect()
self._state = ConnectionState.CONNECTED
self._consecutive_failures = 0
self._last_successful_call = datetime.now()
self._circuit_breaker.reset()
logger.info("Reconnection successful")
return True
except Exception as e:
logger.warning(f"Reconnection attempt {attempt + 1} failed: {e}")
delay = self._calculate_backoff_delay(attempt)
await asyncio.sleep(delay)
self._state = ConnectionState.ERROR
logger.error("All reconnection attempts failed")
return False
def _calculate_backoff_delay(self, attempt: int) -> float:
"""
Calculate delay with exponential backoff and jitter.
Args:
attempt: Current attempt number (0-indexed).
Returns:
Delay in seconds.
"""
# Exponential backoff
delay: float = self._retry_delay * (2**attempt)
# Cap at 30 seconds
delay = min(delay, 30.0)
# Add jitter (10% random variation)
jitter: float = delay * 0.1 * random.random()
delay += jitter
return delay
async def start_health_monitoring(self) -> None:
"""Start background health check monitoring."""
if self._health_check_task and not self._health_check_task.done():
logger.warning("Health monitoring already running")
return
self._health_check_task = asyncio.create_task(self._health_check_loop())
logger.info("Started health check monitoring")
async def stop_health_monitoring(self) -> None:
"""Stop background health check monitoring."""
if self._health_check_task and not self._health_check_task.done():
self._health_check_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._health_check_task
logger.info("Stopped health check monitoring")
async def _health_check_loop(self) -> None:
"""Background task for periodic health checks."""
while True:
try:
await asyncio.sleep(self._health_check_interval)
if self._state == ConnectionState.CONNECTED:
healthy = await self.health_check()
if not healthy:
logger.warning("Health check failed, attempting reconnection")
await self._attempt_reconnect()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Health check loop error: {e}")
def get_unavailable_message(self) -> str:
"""Get user-friendly message when MCP is unavailable."""
return self._fallback_handler.get_unavailable_message()
def __repr__(self) -> str:
"""String representation."""
return (
f"ConnectionManager(state={self._state.value}, "
f"available={self.is_available}, "
f"failures={self._consecutive_failures})"
)