bhupesh-sf's picture
first commit
f8ba6bf verified
"""
DungeonMaster AI - TTRPG Toolkit MCP Client
Core client for connecting to the TTRPG-Toolkit MCP server.
Uses llama-index-tools-mcp for connection management and tool conversion.
"""
from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import TYPE_CHECKING
from llama_index.core.tools import FunctionTool
from llama_index.tools.mcp import BasicMCPClient, McpToolSpec
from src.config.settings import get_settings
from .exceptions import (
MCPConnectionError,
MCPToolExecutionError,
MCPToolNotFoundError,
)
if TYPE_CHECKING:
from typing import Any
logger = logging.getLogger(__name__)
# =============================================================================
# Tool Category Mapping
# =============================================================================
TOOL_CATEGORIES: dict[str, list[str]] = {
"dice": [
"roll",
"roll_check",
"roll_table",
"get_roll_statistics",
# MCP prefixed versions
"mcp_roll",
"mcp_roll_check",
"mcp_roll_table",
"mcp_get_roll_statistics",
],
"character": [
"create_character",
"get_character",
"modify_hp",
"rest",
"generate_ability_scores",
"add_condition",
"remove_condition",
"update_character",
"level_up",
"add_item",
"remove_item",
# MCP prefixed versions
"mcp_create_character",
"mcp_get_character",
"mcp_modify_hp",
"mcp_rest",
"mcp_generate_ability_scores",
"mcp_add_condition",
"mcp_remove_condition",
],
"rules": [
"search_rules",
"get_monster",
"search_monsters",
"get_spell",
"search_spells",
"get_class_info",
"get_race_info",
"get_item",
"get_condition",
# MCP prefixed versions
"mcp_search_rules",
"mcp_get_monster",
"mcp_search_monsters",
"mcp_get_spell",
"mcp_search_spells",
"mcp_get_class_info",
"mcp_get_race_info",
],
"generators": [
"generate_name",
"generate_npc",
"generate_encounter",
"generate_loot",
"generate_location",
# MCP prefixed versions
"mcp_generate_name",
"mcp_generate_npc",
"mcp_generate_encounter",
"mcp_generate_loot",
"mcp_generate_location",
],
"combat": [
"start_combat",
"roll_all_initiatives",
"get_turn_order",
"next_turn",
"apply_damage",
"apply_condition",
"remove_combat_condition",
"end_combat",
"get_combat_status",
# MCP prefixed versions
"mcp_start_combat",
"mcp_roll_all_initiatives",
"mcp_get_turn_order",
"mcp_next_turn",
"mcp_apply_damage",
"mcp_end_combat",
"mcp_get_combat_status",
],
"session": [
"start_session",
"end_session",
"log_event",
"get_session_summary",
"add_session_note",
"get_session_history",
"list_sessions",
# MCP prefixed versions
"mcp_start_session",
"mcp_end_session",
"mcp_log_event",
"mcp_get_session_summary",
"mcp_add_session_note",
],
}
class TTRPGToolkitClient:
"""
Client for connecting to TTRPG-Toolkit MCP server.
Uses llama-index-tools-mcp's BasicMCPClient for connection
and McpToolSpec for converting MCP tools to LlamaIndex FunctionTool objects.
Example:
```python
client = TTRPGToolkitClient()
await client.connect()
# Get all tools for DM agent
all_tools = await client.get_all_tools()
# Get only rules tools for Rules agent
rules_tools = await client.get_tools_by_category(["rules"])
# Direct tool call
result = await client.call_tool("roll", {"notation": "2d6+3"})
```
"""
def __init__(
self,
mcp_url: str | None = None,
timeout: int | None = None,
) -> None:
"""
Initialize the TTRPG Toolkit client.
Args:
mcp_url: MCP server URL. Defaults to settings.mcp.ttrpg_toolkit_mcp_url
timeout: Connection timeout in seconds. Defaults to settings.mcp.mcp_connection_timeout
"""
settings = get_settings()
self._url = mcp_url or settings.mcp.ttrpg_toolkit_mcp_url
self._timeout = timeout or settings.mcp.mcp_connection_timeout
# Internal state
self._client: BasicMCPClient | None = None
self._tool_spec: McpToolSpec | None = None
self._tools_cache: list[FunctionTool] | None = None
self._tools_by_name: dict[str, FunctionTool] | None = None
self._connected = False
logger.debug(f"TTRPGToolkitClient initialized with URL: {self._url}")
@property
def url(self) -> str:
"""Get the MCP server URL."""
return self._url
@property
def is_connected(self) -> bool:
"""Check if client is connected."""
return self._connected
@property
def tools_count(self) -> int:
"""Get number of cached tools."""
return len(self._tools_cache) if self._tools_cache else 0
async def connect(self) -> TTRPGToolkitClient:
"""
Establish connection to MCP server.
Creates BasicMCPClient and McpToolSpec, then fetches and caches
all available tools to verify the connection.
Returns:
Self for method chaining.
Raises:
MCPConnectionError: If connection fails.
"""
# Return early if already connected
if self._connected and self._client is not None:
logger.debug("Already connected to MCP server, skipping reconnection")
return self
logger.info(f"Connecting to MCP server at {self._url}...")
try:
# Create the MCP client
# BasicMCPClient auto-detects SSE vs streamable-http based on URL
self._client = BasicMCPClient(
command_or_url=self._url,
timeout=self._timeout,
)
# Create the tool spec for tool conversion
self._tool_spec = McpToolSpec(
client=self._client,
allowed_tools=None, # Get all tools
)
# Fetch tools to verify connection
# This also caches the tools
self._tools_cache = await self._tool_spec.to_tool_list_async()
# Build name lookup dict
self._tools_by_name = {
tool.metadata.name: tool for tool in self._tools_cache
if tool.metadata.name is not None
}
self._connected = True
logger.info(
f"Connected to MCP server with {len(self._tools_cache)} tools available"
)
return self
except Exception as e:
self._connected = False
self._client = None
self._tool_spec = None
self._tools_cache = None
self._tools_by_name = None
logger.error(f"Failed to connect to MCP server: {e}")
raise MCPConnectionError(
f"Failed to connect to MCP server: {e}",
url=self._url,
) from e
async def disconnect(self) -> None:
"""
Gracefully close connection and clear cache.
Safe to call even if not connected.
"""
logger.info("Disconnecting from MCP server...")
self._client = None
self._tool_spec = None
self._tools_cache = None
self._tools_by_name = None
self._connected = False
logger.info("Disconnected from MCP server")
async def get_all_tools(self) -> Sequence[FunctionTool]:
"""
Get all available tools as LlamaIndex FunctionTool objects.
Uses cached tools if available.
Returns:
Sequence of FunctionTool objects.
Raises:
MCPConnectionError: If not connected.
"""
if not self._connected or self._tools_cache is None:
raise MCPConnectionError(
"Not connected to MCP server. Call connect() first."
)
return self._tools_cache
async def get_tools_by_category(
self,
categories: Sequence[str],
) -> list[FunctionTool]:
"""
Get tools filtered by category names.
Args:
categories: List of category names from TOOL_CATEGORIES.
Valid categories: dice, character, rules, generators, combat, session
Returns:
Filtered list of FunctionTool objects.
Raises:
MCPConnectionError: If not connected.
"""
if not self._connected or self._tools_cache is None:
raise MCPConnectionError(
"Not connected to MCP server. Call connect() first."
)
# Build set of allowed tool names from categories
allowed_names: set[str] = set()
for category in categories:
if category in TOOL_CATEGORIES:
allowed_names.update(TOOL_CATEGORIES[category])
else:
logger.warning(f"Unknown tool category: {category}")
# Filter tools by name
filtered_tools = [
tool
for tool in self._tools_cache
if tool.metadata.name is not None and (
tool.metadata.name in allowed_names
or tool.metadata.name.replace("mcp_", "") in allowed_names
)
]
logger.debug(
f"Filtered {len(filtered_tools)} tools from categories: {categories}"
)
return filtered_tools
async def get_tool_by_name(self, tool_name: str) -> FunctionTool | None:
"""
Get a specific tool by name.
Args:
tool_name: Name of the tool to find.
Returns:
FunctionTool if found, None otherwise.
"""
if not self._connected or self._tools_by_name is None:
return None
# Try exact match first
tool = self._tools_by_name.get(tool_name)
if tool:
return tool
# Try with mcp_ prefix
tool = self._tools_by_name.get(f"mcp_{tool_name}")
if tool:
return tool
# Try without mcp_ prefix
if tool_name.startswith("mcp_"):
tool = self._tools_by_name.get(tool_name[4:])
return tool
async def call_tool(
self,
tool_name: str,
arguments: dict[str, Any],
) -> Any:
"""
Call a tool directly without going through an agent.
Useful for direct API access and internal operations.
Args:
tool_name: Name of the tool to call.
arguments: Arguments to pass to the tool.
Returns:
Tool result (type varies by tool).
Raises:
MCPConnectionError: If not connected.
MCPToolNotFoundError: If tool doesn't exist.
MCPToolExecutionError: If tool execution fails.
"""
if not self._connected:
raise MCPConnectionError(
"Not connected to MCP server. Call connect() first."
)
tool = await self.get_tool_by_name(tool_name)
if tool is None:
raise MCPToolNotFoundError(tool_name)
try:
logger.debug(f"Calling tool '{tool_name}' with args: {arguments}")
result = await tool.acall(**arguments)
# FunctionTool.acall returns ToolOutput, extract raw_output
if hasattr(result, "raw_output"):
return result.raw_output
return result
except Exception as e:
logger.error(f"Tool '{tool_name}' execution failed: {e}")
raise MCPToolExecutionError(tool_name, e) from e
def call_tool_sync(
self,
tool_name: str,
arguments: dict[str, Any],
) -> Any:
"""
Synchronous wrapper for call_tool.
Uses asyncio.run() for synchronous contexts.
Prefer call_tool() in async code.
Args:
tool_name: Name of the tool to call.
arguments: Arguments to pass to the tool.
Returns:
Tool result.
"""
import asyncio
return asyncio.run(self.call_tool(tool_name, arguments))
async def list_tool_names(self) -> list[str]:
"""
Get list of all available tool names.
Returns:
List of tool names.
Raises:
MCPConnectionError: If not connected.
"""
if not self._connected or self._tools_by_name is None:
raise MCPConnectionError("Not connected to MCP server.")
return list(self._tools_by_name.keys())
async def has_tool(self, tool_name: str) -> bool:
"""
Check if a tool is available.
Args:
tool_name: Name of the tool to check.
Returns:
True if tool exists, False otherwise.
"""
tool = await self.get_tool_by_name(tool_name)
return tool is not None
def get_category_for_tool(self, tool_name: str) -> str | None:
"""
Get the category for a tool name.
Args:
tool_name: Name of the tool.
Returns:
Category name or None if not found.
"""
# Normalize name (remove mcp_ prefix)
normalized = tool_name.replace("mcp_", "")
for category, tools in TOOL_CATEGORIES.items():
if normalized in tools or tool_name in tools:
return category
return None
async def refresh_tools(self) -> int:
"""
Refresh the cached tool list from the server.
Useful after server updates or if tools might have changed.
Returns:
Number of tools now available.
Raises:
MCPConnectionError: If not connected or refresh fails.
"""
if not self._connected or self._tool_spec is None:
raise MCPConnectionError("Not connected to MCP server.")
try:
self._tools_cache = await self._tool_spec.to_tool_list_async()
self._tools_by_name = {
tool.metadata.name: tool for tool in self._tools_cache
if tool.metadata.name is not None
}
logger.info(f"Refreshed tool cache: {len(self._tools_cache)} tools")
return len(self._tools_cache)
except Exception as e:
logger.error(f"Failed to refresh tools: {e}")
raise MCPConnectionError(f"Failed to refresh tools: {e}") from e
def __repr__(self) -> str:
"""String representation."""
status = "connected" if self._connected else "disconnected"
tools = self.tools_count
return f"TTRPGToolkitClient(url={self._url!r}, status={status}, tools={tools})"