| """ |
| Base Tool for SPARKNET |
| Defines the interface for all tools that agents can use |
| """ |
|
|
| from abc import ABC, abstractmethod |
| from typing import Any, Dict, Optional |
| from pydantic import BaseModel, Field |
| from loguru import logger |
| import json |
|
|
|
|
| class ToolParameter(BaseModel): |
| """Definition of a tool parameter.""" |
| name: str = Field(..., description="Parameter name") |
| type: str = Field(..., description="Parameter type (str, int, float, bool, list, dict)") |
| description: str = Field(..., description="Parameter description") |
| required: bool = Field(default=True, description="Whether parameter is required") |
| default: Optional[Any] = Field(default=None, description="Default value if not required") |
|
|
|
|
| class ToolResult(BaseModel): |
| """Result from tool execution.""" |
| success: bool = Field(..., description="Whether execution was successful") |
| output: Any = Field(..., description="Tool output") |
| error: Optional[str] = Field(default=None, description="Error message if failed") |
| metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata") |
|
|
|
|
| class BaseTool(ABC): |
| """Base class for all tools.""" |
|
|
| def __init__(self, name: str, description: str): |
| """ |
| Initialize tool. |
| |
| Args: |
| name: Tool name |
| description: Tool description |
| """ |
| self.name = name |
| self.description = description |
| self.parameters: list[ToolParameter] = [] |
|
|
| @abstractmethod |
| async def execute(self, **kwargs) -> ToolResult: |
| """ |
| Execute the tool with given parameters. |
| |
| Args: |
| **kwargs: Tool parameters |
| |
| Returns: |
| ToolResult with execution results |
| """ |
| pass |
|
|
| def add_parameter( |
| self, |
| name: str, |
| param_type: str, |
| description: str, |
| required: bool = True, |
| default: Optional[Any] = None, |
| ): |
| """ |
| Add a parameter definition to the tool. |
| |
| Args: |
| name: Parameter name |
| param_type: Parameter type |
| description: Parameter description |
| required: Whether parameter is required |
| default: Default value |
| """ |
| param = ToolParameter( |
| name=name, |
| type=param_type, |
| description=description, |
| required=required, |
| default=default, |
| ) |
| self.parameters.append(param) |
|
|
| def validate_parameters(self, **kwargs) -> tuple[bool, Optional[str]]: |
| """ |
| Validate provided parameters against tool definition. |
| |
| Args: |
| **kwargs: Provided parameters |
| |
| Returns: |
| Tuple of (is_valid, error_message) |
| """ |
| |
| for param in self.parameters: |
| if param.required and param.name not in kwargs: |
| return False, f"Missing required parameter: {param.name}" |
|
|
| |
| for param in self.parameters: |
| if param.name in kwargs: |
| value = kwargs[param.name] |
| expected_type = param.type |
|
|
| |
| type_map = { |
| "str": str, |
| "int": int, |
| "float": float, |
| "bool": bool, |
| "list": list, |
| "dict": dict, |
| } |
|
|
| if expected_type in type_map: |
| if not isinstance(value, type_map[expected_type]): |
| return False, f"Parameter {param.name} must be of type {expected_type}" |
|
|
| return True, None |
|
|
| def get_schema(self) -> Dict[str, Any]: |
| """ |
| Get tool schema for LLM function calling. |
| |
| Returns: |
| Tool schema dictionary |
| """ |
| return { |
| "name": self.name, |
| "description": self.description, |
| "parameters": { |
| "type": "object", |
| "properties": { |
| param.name: { |
| "type": param.type, |
| "description": param.description, |
| } |
| for param in self.parameters |
| }, |
| "required": [param.name for param in self.parameters if param.required], |
| }, |
| } |
|
|
| async def safe_execute(self, **kwargs) -> ToolResult: |
| """ |
| Execute tool with parameter validation and error handling. |
| |
| Args: |
| **kwargs: Tool parameters |
| |
| Returns: |
| ToolResult with execution results |
| """ |
| |
| is_valid, error_msg = self.validate_parameters(**kwargs) |
| if not is_valid: |
| logger.error(f"Tool {self.name} parameter validation failed: {error_msg}") |
| return ToolResult(success=False, output=None, error=error_msg) |
|
|
| |
| for param in self.parameters: |
| if not param.required and param.name not in kwargs: |
| kwargs[param.name] = param.default |
|
|
| |
| try: |
| logger.info(f"Executing tool: {self.name}") |
| result = await self.execute(**kwargs) |
| logger.info(f"Tool {self.name} executed successfully") |
| return result |
| except Exception as e: |
| logger.error(f"Tool {self.name} execution failed: {e}") |
| return ToolResult( |
| success=False, |
| output=None, |
| error=str(e), |
| ) |
|
|
| def __repr__(self) -> str: |
| return f"<Tool: {self.name}>" |
|
|
|
|
| class ToolRegistry: |
| """Registry for managing available tools.""" |
|
|
| def __init__(self): |
| """Initialize tool registry.""" |
| self.tools: Dict[str, BaseTool] = {} |
| logger.info("Tool registry initialized") |
|
|
| def register(self, tool: BaseTool): |
| """ |
| Register a tool. |
| |
| Args: |
| tool: Tool instance to register |
| """ |
| self.tools[tool.name] = tool |
| logger.info(f"Registered tool: {tool.name}") |
|
|
| def unregister(self, tool_name: str): |
| """ |
| Unregister a tool. |
| |
| Args: |
| tool_name: Name of tool to unregister |
| """ |
| if tool_name in self.tools: |
| del self.tools[tool_name] |
| logger.info(f"Unregistered tool: {tool_name}") |
|
|
| def get_tool(self, tool_name: str) -> Optional[BaseTool]: |
| """ |
| Get a tool by name. |
| |
| Args: |
| tool_name: Name of tool |
| |
| Returns: |
| Tool instance or None |
| """ |
| return self.tools.get(tool_name) |
|
|
| def list_tools(self) -> list[str]: |
| """ |
| List all registered tools. |
| |
| Returns: |
| List of tool names |
| """ |
| return list(self.tools.keys()) |
|
|
| def get_schemas(self) -> list[Dict[str, Any]]: |
| """ |
| Get schemas for all tools. |
| |
| Returns: |
| List of tool schemas |
| """ |
| return [tool.get_schema() for tool in self.tools.values()] |
|
|
|
|
| |
| _tool_registry: Optional[ToolRegistry] = None |
|
|
|
|
| def get_tool_registry() -> ToolRegistry: |
| """Get or create the global tool registry.""" |
| global _tool_registry |
| if _tool_registry is None: |
| _tool_registry = ToolRegistry() |
| return _tool_registry |
|
|