|
|
""" |
|
|
DungeonMaster AI - TTRPG Toolkit MCP Client |
|
|
|
|
|
Core client for connecting to the TTRPG-Toolkit MCP server. |
|
|
Uses llama-index-tools-mcp for connection management and tool conversion. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import logging |
|
|
from collections.abc import Sequence |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
from llama_index.core.tools import FunctionTool |
|
|
from llama_index.tools.mcp import BasicMCPClient, McpToolSpec |
|
|
|
|
|
from src.config.settings import get_settings |
|
|
|
|
|
from .exceptions import ( |
|
|
MCPConnectionError, |
|
|
MCPToolExecutionError, |
|
|
MCPToolNotFoundError, |
|
|
) |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from typing import Any |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TOOL_CATEGORIES: dict[str, list[str]] = { |
|
|
"dice": [ |
|
|
"roll", |
|
|
"roll_check", |
|
|
"roll_table", |
|
|
"get_roll_statistics", |
|
|
|
|
|
"mcp_roll", |
|
|
"mcp_roll_check", |
|
|
"mcp_roll_table", |
|
|
"mcp_get_roll_statistics", |
|
|
], |
|
|
"character": [ |
|
|
"create_character", |
|
|
"get_character", |
|
|
"modify_hp", |
|
|
"rest", |
|
|
"generate_ability_scores", |
|
|
"add_condition", |
|
|
"remove_condition", |
|
|
"update_character", |
|
|
"level_up", |
|
|
"add_item", |
|
|
"remove_item", |
|
|
|
|
|
"mcp_create_character", |
|
|
"mcp_get_character", |
|
|
"mcp_modify_hp", |
|
|
"mcp_rest", |
|
|
"mcp_generate_ability_scores", |
|
|
"mcp_add_condition", |
|
|
"mcp_remove_condition", |
|
|
], |
|
|
"rules": [ |
|
|
"search_rules", |
|
|
"get_monster", |
|
|
"search_monsters", |
|
|
"get_spell", |
|
|
"search_spells", |
|
|
"get_class_info", |
|
|
"get_race_info", |
|
|
"get_item", |
|
|
"get_condition", |
|
|
|
|
|
"mcp_search_rules", |
|
|
"mcp_get_monster", |
|
|
"mcp_search_monsters", |
|
|
"mcp_get_spell", |
|
|
"mcp_search_spells", |
|
|
"mcp_get_class_info", |
|
|
"mcp_get_race_info", |
|
|
], |
|
|
"generators": [ |
|
|
"generate_name", |
|
|
"generate_npc", |
|
|
"generate_encounter", |
|
|
"generate_loot", |
|
|
"generate_location", |
|
|
|
|
|
"mcp_generate_name", |
|
|
"mcp_generate_npc", |
|
|
"mcp_generate_encounter", |
|
|
"mcp_generate_loot", |
|
|
"mcp_generate_location", |
|
|
], |
|
|
"combat": [ |
|
|
"start_combat", |
|
|
"roll_all_initiatives", |
|
|
"get_turn_order", |
|
|
"next_turn", |
|
|
"apply_damage", |
|
|
"apply_condition", |
|
|
"remove_combat_condition", |
|
|
"end_combat", |
|
|
"get_combat_status", |
|
|
|
|
|
"mcp_start_combat", |
|
|
"mcp_roll_all_initiatives", |
|
|
"mcp_get_turn_order", |
|
|
"mcp_next_turn", |
|
|
"mcp_apply_damage", |
|
|
"mcp_end_combat", |
|
|
"mcp_get_combat_status", |
|
|
], |
|
|
"session": [ |
|
|
"start_session", |
|
|
"end_session", |
|
|
"log_event", |
|
|
"get_session_summary", |
|
|
"add_session_note", |
|
|
"get_session_history", |
|
|
"list_sessions", |
|
|
|
|
|
"mcp_start_session", |
|
|
"mcp_end_session", |
|
|
"mcp_log_event", |
|
|
"mcp_get_session_summary", |
|
|
"mcp_add_session_note", |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
class TTRPGToolkitClient: |
|
|
""" |
|
|
Client for connecting to TTRPG-Toolkit MCP server. |
|
|
|
|
|
Uses llama-index-tools-mcp's BasicMCPClient for connection |
|
|
and McpToolSpec for converting MCP tools to LlamaIndex FunctionTool objects. |
|
|
|
|
|
Example: |
|
|
```python |
|
|
client = TTRPGToolkitClient() |
|
|
await client.connect() |
|
|
|
|
|
# Get all tools for DM agent |
|
|
all_tools = await client.get_all_tools() |
|
|
|
|
|
# Get only rules tools for Rules agent |
|
|
rules_tools = await client.get_tools_by_category(["rules"]) |
|
|
|
|
|
# Direct tool call |
|
|
result = await client.call_tool("roll", {"notation": "2d6+3"}) |
|
|
``` |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
mcp_url: str | None = None, |
|
|
timeout: int | None = None, |
|
|
) -> None: |
|
|
""" |
|
|
Initialize the TTRPG Toolkit client. |
|
|
|
|
|
Args: |
|
|
mcp_url: MCP server URL. Defaults to settings.mcp.ttrpg_toolkit_mcp_url |
|
|
timeout: Connection timeout in seconds. Defaults to settings.mcp.mcp_connection_timeout |
|
|
""" |
|
|
settings = get_settings() |
|
|
|
|
|
self._url = mcp_url or settings.mcp.ttrpg_toolkit_mcp_url |
|
|
self._timeout = timeout or settings.mcp.mcp_connection_timeout |
|
|
|
|
|
|
|
|
self._client: BasicMCPClient | None = None |
|
|
self._tool_spec: McpToolSpec | None = None |
|
|
self._tools_cache: list[FunctionTool] | None = None |
|
|
self._tools_by_name: dict[str, FunctionTool] | None = None |
|
|
self._connected = False |
|
|
|
|
|
logger.debug(f"TTRPGToolkitClient initialized with URL: {self._url}") |
|
|
|
|
|
@property |
|
|
def url(self) -> str: |
|
|
"""Get the MCP server URL.""" |
|
|
return self._url |
|
|
|
|
|
@property |
|
|
def is_connected(self) -> bool: |
|
|
"""Check if client is connected.""" |
|
|
return self._connected |
|
|
|
|
|
@property |
|
|
def tools_count(self) -> int: |
|
|
"""Get number of cached tools.""" |
|
|
return len(self._tools_cache) if self._tools_cache else 0 |
|
|
|
|
|
async def connect(self) -> TTRPGToolkitClient: |
|
|
""" |
|
|
Establish connection to MCP server. |
|
|
|
|
|
Creates BasicMCPClient and McpToolSpec, then fetches and caches |
|
|
all available tools to verify the connection. |
|
|
|
|
|
Returns: |
|
|
Self for method chaining. |
|
|
|
|
|
Raises: |
|
|
MCPConnectionError: If connection fails. |
|
|
""" |
|
|
|
|
|
if self._connected and self._client is not None: |
|
|
logger.debug("Already connected to MCP server, skipping reconnection") |
|
|
return self |
|
|
|
|
|
logger.info(f"Connecting to MCP server at {self._url}...") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
self._client = BasicMCPClient( |
|
|
command_or_url=self._url, |
|
|
timeout=self._timeout, |
|
|
) |
|
|
|
|
|
|
|
|
self._tool_spec = McpToolSpec( |
|
|
client=self._client, |
|
|
allowed_tools=None, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self._tools_cache = await self._tool_spec.to_tool_list_async() |
|
|
|
|
|
|
|
|
self._tools_by_name = { |
|
|
tool.metadata.name: tool for tool in self._tools_cache |
|
|
if tool.metadata.name is not None |
|
|
} |
|
|
|
|
|
self._connected = True |
|
|
logger.info( |
|
|
f"Connected to MCP server with {len(self._tools_cache)} tools available" |
|
|
) |
|
|
|
|
|
return self |
|
|
|
|
|
except Exception as e: |
|
|
self._connected = False |
|
|
self._client = None |
|
|
self._tool_spec = None |
|
|
self._tools_cache = None |
|
|
self._tools_by_name = None |
|
|
|
|
|
logger.error(f"Failed to connect to MCP server: {e}") |
|
|
raise MCPConnectionError( |
|
|
f"Failed to connect to MCP server: {e}", |
|
|
url=self._url, |
|
|
) from e |
|
|
|
|
|
async def disconnect(self) -> None: |
|
|
""" |
|
|
Gracefully close connection and clear cache. |
|
|
|
|
|
Safe to call even if not connected. |
|
|
""" |
|
|
logger.info("Disconnecting from MCP server...") |
|
|
|
|
|
self._client = None |
|
|
self._tool_spec = None |
|
|
self._tools_cache = None |
|
|
self._tools_by_name = None |
|
|
self._connected = False |
|
|
|
|
|
logger.info("Disconnected from MCP server") |
|
|
|
|
|
async def get_all_tools(self) -> Sequence[FunctionTool]: |
|
|
""" |
|
|
Get all available tools as LlamaIndex FunctionTool objects. |
|
|
|
|
|
Uses cached tools if available. |
|
|
|
|
|
Returns: |
|
|
Sequence of FunctionTool objects. |
|
|
|
|
|
Raises: |
|
|
MCPConnectionError: If not connected. |
|
|
""" |
|
|
if not self._connected or self._tools_cache is None: |
|
|
raise MCPConnectionError( |
|
|
"Not connected to MCP server. Call connect() first." |
|
|
) |
|
|
|
|
|
return self._tools_cache |
|
|
|
|
|
async def get_tools_by_category( |
|
|
self, |
|
|
categories: Sequence[str], |
|
|
) -> list[FunctionTool]: |
|
|
""" |
|
|
Get tools filtered by category names. |
|
|
|
|
|
Args: |
|
|
categories: List of category names from TOOL_CATEGORIES. |
|
|
Valid categories: dice, character, rules, generators, combat, session |
|
|
|
|
|
Returns: |
|
|
Filtered list of FunctionTool objects. |
|
|
|
|
|
Raises: |
|
|
MCPConnectionError: If not connected. |
|
|
""" |
|
|
if not self._connected or self._tools_cache is None: |
|
|
raise MCPConnectionError( |
|
|
"Not connected to MCP server. Call connect() first." |
|
|
) |
|
|
|
|
|
|
|
|
allowed_names: set[str] = set() |
|
|
for category in categories: |
|
|
if category in TOOL_CATEGORIES: |
|
|
allowed_names.update(TOOL_CATEGORIES[category]) |
|
|
else: |
|
|
logger.warning(f"Unknown tool category: {category}") |
|
|
|
|
|
|
|
|
filtered_tools = [ |
|
|
tool |
|
|
for tool in self._tools_cache |
|
|
if tool.metadata.name is not None and ( |
|
|
tool.metadata.name in allowed_names |
|
|
or tool.metadata.name.replace("mcp_", "") in allowed_names |
|
|
) |
|
|
] |
|
|
|
|
|
logger.debug( |
|
|
f"Filtered {len(filtered_tools)} tools from categories: {categories}" |
|
|
) |
|
|
|
|
|
return filtered_tools |
|
|
|
|
|
async def get_tool_by_name(self, tool_name: str) -> FunctionTool | None: |
|
|
""" |
|
|
Get a specific tool by name. |
|
|
|
|
|
Args: |
|
|
tool_name: Name of the tool to find. |
|
|
|
|
|
Returns: |
|
|
FunctionTool if found, None otherwise. |
|
|
""" |
|
|
if not self._connected or self._tools_by_name is None: |
|
|
return None |
|
|
|
|
|
|
|
|
tool = self._tools_by_name.get(tool_name) |
|
|
if tool: |
|
|
return tool |
|
|
|
|
|
|
|
|
tool = self._tools_by_name.get(f"mcp_{tool_name}") |
|
|
if tool: |
|
|
return tool |
|
|
|
|
|
|
|
|
if tool_name.startswith("mcp_"): |
|
|
tool = self._tools_by_name.get(tool_name[4:]) |
|
|
|
|
|
return tool |
|
|
|
|
|
async def call_tool( |
|
|
self, |
|
|
tool_name: str, |
|
|
arguments: dict[str, Any], |
|
|
) -> Any: |
|
|
""" |
|
|
Call a tool directly without going through an agent. |
|
|
|
|
|
Useful for direct API access and internal operations. |
|
|
|
|
|
Args: |
|
|
tool_name: Name of the tool to call. |
|
|
arguments: Arguments to pass to the tool. |
|
|
|
|
|
Returns: |
|
|
Tool result (type varies by tool). |
|
|
|
|
|
Raises: |
|
|
MCPConnectionError: If not connected. |
|
|
MCPToolNotFoundError: If tool doesn't exist. |
|
|
MCPToolExecutionError: If tool execution fails. |
|
|
""" |
|
|
if not self._connected: |
|
|
raise MCPConnectionError( |
|
|
"Not connected to MCP server. Call connect() first." |
|
|
) |
|
|
|
|
|
tool = await self.get_tool_by_name(tool_name) |
|
|
if tool is None: |
|
|
raise MCPToolNotFoundError(tool_name) |
|
|
|
|
|
try: |
|
|
logger.debug(f"Calling tool '{tool_name}' with args: {arguments}") |
|
|
result = await tool.acall(**arguments) |
|
|
|
|
|
|
|
|
if hasattr(result, "raw_output"): |
|
|
return result.raw_output |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Tool '{tool_name}' execution failed: {e}") |
|
|
raise MCPToolExecutionError(tool_name, e) from e |
|
|
|
|
|
def call_tool_sync( |
|
|
self, |
|
|
tool_name: str, |
|
|
arguments: dict[str, Any], |
|
|
) -> Any: |
|
|
""" |
|
|
Synchronous wrapper for call_tool. |
|
|
|
|
|
Uses asyncio.run() for synchronous contexts. |
|
|
Prefer call_tool() in async code. |
|
|
|
|
|
Args: |
|
|
tool_name: Name of the tool to call. |
|
|
arguments: Arguments to pass to the tool. |
|
|
|
|
|
Returns: |
|
|
Tool result. |
|
|
""" |
|
|
import asyncio |
|
|
|
|
|
return asyncio.run(self.call_tool(tool_name, arguments)) |
|
|
|
|
|
async def list_tool_names(self) -> list[str]: |
|
|
""" |
|
|
Get list of all available tool names. |
|
|
|
|
|
Returns: |
|
|
List of tool names. |
|
|
|
|
|
Raises: |
|
|
MCPConnectionError: If not connected. |
|
|
""" |
|
|
if not self._connected or self._tools_by_name is None: |
|
|
raise MCPConnectionError("Not connected to MCP server.") |
|
|
|
|
|
return list(self._tools_by_name.keys()) |
|
|
|
|
|
async def has_tool(self, tool_name: str) -> bool: |
|
|
""" |
|
|
Check if a tool is available. |
|
|
|
|
|
Args: |
|
|
tool_name: Name of the tool to check. |
|
|
|
|
|
Returns: |
|
|
True if tool exists, False otherwise. |
|
|
""" |
|
|
tool = await self.get_tool_by_name(tool_name) |
|
|
return tool is not None |
|
|
|
|
|
def get_category_for_tool(self, tool_name: str) -> str | None: |
|
|
""" |
|
|
Get the category for a tool name. |
|
|
|
|
|
Args: |
|
|
tool_name: Name of the tool. |
|
|
|
|
|
Returns: |
|
|
Category name or None if not found. |
|
|
""" |
|
|
|
|
|
normalized = tool_name.replace("mcp_", "") |
|
|
|
|
|
for category, tools in TOOL_CATEGORIES.items(): |
|
|
if normalized in tools or tool_name in tools: |
|
|
return category |
|
|
|
|
|
return None |
|
|
|
|
|
async def refresh_tools(self) -> int: |
|
|
""" |
|
|
Refresh the cached tool list from the server. |
|
|
|
|
|
Useful after server updates or if tools might have changed. |
|
|
|
|
|
Returns: |
|
|
Number of tools now available. |
|
|
|
|
|
Raises: |
|
|
MCPConnectionError: If not connected or refresh fails. |
|
|
""" |
|
|
if not self._connected or self._tool_spec is None: |
|
|
raise MCPConnectionError("Not connected to MCP server.") |
|
|
|
|
|
try: |
|
|
self._tools_cache = await self._tool_spec.to_tool_list_async() |
|
|
self._tools_by_name = { |
|
|
tool.metadata.name: tool for tool in self._tools_cache |
|
|
if tool.metadata.name is not None |
|
|
} |
|
|
logger.info(f"Refreshed tool cache: {len(self._tools_cache)} tools") |
|
|
return len(self._tools_cache) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to refresh tools: {e}") |
|
|
raise MCPConnectionError(f"Failed to refresh tools: {e}") from e |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
"""String representation.""" |
|
|
status = "connected" if self._connected else "disconnected" |
|
|
tools = self.tools_count |
|
|
return f"TTRPGToolkitClient(url={self._url!r}, status={status}, tools={tools})" |
|
|
|