| """ |
| Base Agent for SPARKNET |
| Defines the core agent interface and functionality |
| """ |
|
|
| from abc import ABC, abstractmethod |
| from typing import List, Dict, Optional, Any |
| from dataclasses import dataclass |
| from datetime import datetime |
| from loguru import logger |
| import json |
|
|
| from ..llm.ollama_client import OllamaClient |
| from ..tools.base_tool import BaseTool, ToolRegistry, ToolResult |
|
|
|
|
| @dataclass |
| class Message: |
| """Message for agent communication.""" |
| role: str |
| content: str |
| sender: Optional[str] = None |
| timestamp: Optional[datetime] = None |
| metadata: Optional[Dict[str, Any]] = None |
|
|
| def __post_init__(self): |
| if self.timestamp is None: |
| self.timestamp = datetime.now() |
|
|
| def to_dict(self) -> Dict[str, str]: |
| """Convert to dictionary for Ollama API.""" |
| return { |
| "role": "user" if self.role == "agent" else self.role, |
| "content": self.content, |
| } |
|
|
|
|
| @dataclass |
| class Task: |
| """Task for agent execution.""" |
| id: str |
| description: str |
| priority: int = 0 |
| status: str = "pending" |
| result: Optional[Any] = None |
| error: Optional[str] = None |
| metadata: Optional[Dict[str, Any]] = None |
|
|
| def __post_init__(self): |
| if self.metadata is None: |
| self.metadata = {} |
|
|
|
|
| class BaseAgent(ABC): |
| """Base class for all SPARKNET agents.""" |
|
|
| def __init__( |
| self, |
| name: str, |
| description: str, |
| llm_client: OllamaClient, |
| model: str, |
| system_prompt: str, |
| tools: Optional[List[BaseTool]] = None, |
| temperature: float = 0.7, |
| max_tokens: Optional[int] = None, |
| ): |
| """ |
| Initialize agent. |
| |
| Args: |
| name: Agent name |
| description: Agent description |
| llm_client: Ollama client instance |
| model: Model to use |
| system_prompt: System prompt for the agent |
| tools: List of available tools |
| temperature: LLM temperature |
| max_tokens: Max tokens to generate |
| """ |
| self.name = name |
| self.description = description |
| self.llm_client = llm_client |
| self.model = model |
| self.system_prompt = system_prompt |
| self.tools = {tool.name: tool for tool in (tools or [])} |
| self.temperature = temperature |
| self.max_tokens = max_tokens |
|
|
| |
| self.messages: List[Message] = [] |
|
|
| |
| self.tool_registry: Optional[ToolRegistry] = None |
|
|
| logger.info(f"Initialized agent: {self.name} with model {self.model}") |
|
|
| def add_tool(self, tool: BaseTool): |
| """ |
| Add a tool to the agent's toolbox. |
| |
| Args: |
| tool: Tool to add |
| """ |
| self.tools[tool.name] = tool |
| logger.info(f"Agent {self.name} added tool: {tool.name}") |
|
|
| def remove_tool(self, tool_name: str): |
| """ |
| Remove a tool from the agent's toolbox. |
| |
| Args: |
| tool_name: Name of tool to remove |
| """ |
| if tool_name in self.tools: |
| del self.tools[tool_name] |
| logger.info(f"Agent {self.name} removed tool: {tool_name}") |
|
|
| def set_tool_registry(self, registry: ToolRegistry): |
| """ |
| Set the tool registry for accessing shared tools. |
| |
| Args: |
| registry: Tool registry instance |
| """ |
| self.tool_registry = registry |
|
|
| async def call_llm( |
| self, |
| prompt: Optional[str] = None, |
| messages: Optional[List[Message]] = None, |
| temperature: Optional[float] = None, |
| ) -> str: |
| """ |
| Call the LLM with a prompt or messages. |
| |
| Args: |
| prompt: Single prompt string |
| messages: List of messages |
| temperature: Override temperature |
| |
| Returns: |
| LLM response |
| """ |
| temp = temperature if temperature is not None else self.temperature |
|
|
| if prompt: |
| |
| response = self.llm_client.generate( |
| prompt=prompt, |
| model=self.model, |
| system=self.system_prompt, |
| temperature=temp, |
| max_tokens=self.max_tokens, |
| ) |
| elif messages: |
| |
| |
| chat_messages = [ |
| {"role": "system", "content": self.system_prompt} |
| ] |
| |
| chat_messages.extend([msg.to_dict() for msg in messages]) |
|
|
| response = self.llm_client.chat( |
| messages=chat_messages, |
| model=self.model, |
| temperature=temp, |
| ) |
| else: |
| raise ValueError("Either prompt or messages must be provided") |
|
|
| logger.debug(f"Agent {self.name} received LLM response: {len(response)} chars") |
| return response |
|
|
| async def execute_tool(self, tool_name: str, **kwargs) -> ToolResult: |
| """ |
| Execute a tool by name. |
| |
| Args: |
| tool_name: Name of tool to execute |
| **kwargs: Tool parameters |
| |
| Returns: |
| ToolResult from tool execution |
| """ |
| |
| tool = self.tools.get(tool_name) |
|
|
| |
| if tool is None and self.tool_registry: |
| tool = self.tool_registry.get_tool(tool_name) |
|
|
| if tool is None: |
| logger.error(f"Tool not found: {tool_name}") |
| return ToolResult( |
| success=False, |
| output=None, |
| error=f"Tool not found: {tool_name}", |
| ) |
|
|
| logger.info(f"Agent {self.name} executing tool: {tool_name}") |
| result = await tool.safe_execute(**kwargs) |
|
|
| return result |
|
|
| def add_message(self, message: Message): |
| """ |
| Add a message to the agent's history. |
| |
| Args: |
| message: Message to add |
| """ |
| self.messages.append(message) |
|
|
| async def receive_message(self, message: Message) -> Optional[str]: |
| """ |
| Receive and process a message from another agent or user. |
| |
| Args: |
| message: Incoming message |
| |
| Returns: |
| Response or None |
| """ |
| logger.info(f"Agent {self.name} received message from {message.sender}") |
| self.add_message(message) |
|
|
| |
| return await self.process_message(message) |
|
|
| async def process_message(self, message: Message) -> Optional[str]: |
| """ |
| Process an incoming message. Can be overridden by subclasses. |
| |
| Args: |
| message: Message to process |
| |
| Returns: |
| Response or None |
| """ |
| |
| response = await self.call_llm(messages=self.messages) |
|
|
| |
| self.add_message( |
| Message( |
| role="assistant", |
| content=response, |
| sender=self.name, |
| ) |
| ) |
|
|
| return response |
|
|
| @abstractmethod |
| async def process_task(self, task: Task) -> Task: |
| """ |
| Process a task. Must be implemented by subclasses. |
| |
| Args: |
| task: Task to process |
| |
| Returns: |
| Updated task with results |
| """ |
| pass |
|
|
| async def send_message(self, recipient: "BaseAgent", content: str) -> Optional[str]: |
| """ |
| Send a message to another agent. |
| |
| Args: |
| recipient: Recipient agent |
| content: Message content |
| |
| Returns: |
| Response from recipient |
| """ |
| message = Message( |
| role="agent", |
| content=content, |
| sender=self.name, |
| ) |
|
|
| logger.info(f"Agent {self.name} sending message to {recipient.name}") |
| response = await recipient.receive_message(message) |
|
|
| return response |
|
|
| def get_available_tools(self) -> List[str]: |
| """ |
| Get list of available tool names. |
| |
| Returns: |
| List of tool names |
| """ |
| tool_names = list(self.tools.keys()) |
|
|
| if self.tool_registry: |
| tool_names.extend(self.tool_registry.list_tools()) |
|
|
| return list(set(tool_names)) |
|
|
| def get_tool_schemas(self) -> List[Dict[str, Any]]: |
| """ |
| Get schemas for all available tools. |
| |
| Returns: |
| List of tool schemas |
| """ |
| schemas = [tool.get_schema() for tool in self.tools.values()] |
|
|
| if self.tool_registry: |
| schemas.extend(self.tool_registry.get_schemas()) |
|
|
| return schemas |
|
|
| def clear_history(self): |
| """Clear message history.""" |
| self.messages.clear() |
| logger.info(f"Agent {self.name} cleared message history") |
|
|
| def get_stats(self) -> Dict[str, Any]: |
| """ |
| Get agent statistics. |
| |
| Returns: |
| Dictionary with agent stats |
| """ |
| return { |
| "name": self.name, |
| "model": self.model, |
| "messages_count": len(self.messages), |
| "tools_count": len(self.tools), |
| } |
|
|
| def __repr__(self) -> str: |
| return f"<Agent: {self.name} (model={self.model}, tools={len(self.tools)})>" |
|
|