| """ |
| MemoryAgent for SPARKNET |
| Provides vector memory system using ChromaDB and LangChain |
| Supports episodic, semantic, and stakeholder memory |
| """ |
|
|
| from typing import Optional, Dict, Any, List, Literal |
| from datetime import datetime |
| from loguru import logger |
| import json |
|
|
| from langchain_chroma import Chroma |
| from langchain_core.documents import Document |
|
|
| from .base_agent import BaseAgent, Task, Message |
| from ..llm.langchain_ollama_client import LangChainOllamaClient |
| from ..workflow.langgraph_state import ScenarioType, TaskStatus |
|
|
|
|
| MemoryType = Literal["episodic", "semantic", "stakeholders", "all"] |
|
|
|
|
| class MemoryAgent(BaseAgent): |
| """ |
| Vector memory system using ChromaDB and LangChain. |
| Stores and retrieves context for agent decision-making. |
| |
| Three collections: |
| - episodic_memory: Past workflow executions, outcomes, lessons learned |
| - semantic_memory: Domain knowledge (patents, legal frameworks, market data) |
| - stakeholder_profiles: Researcher and industry partner profiles |
| """ |
|
|
| def __init__( |
| self, |
| llm_client: LangChainOllamaClient, |
| persist_directory: str = "data/vector_store", |
| memory_agent: Optional['MemoryAgent'] = None, |
| ): |
| """ |
| Initialize MemoryAgent with ChromaDB collections. |
| |
| Args: |
| llm_client: LangChain Ollama client for embeddings |
| persist_directory: Directory to persist ChromaDB data |
| memory_agent: Not used (for interface compatibility) |
| """ |
| self.llm_client = llm_client |
| self.persist_directory = persist_directory |
|
|
| |
| self.embeddings = llm_client.get_embeddings() |
|
|
| |
| self._initialize_collections() |
|
|
| |
| self.name = "MemoryAgent" |
| self.description = "Vector memory and context retrieval" |
|
|
| logger.info(f"Initialized MemoryAgent with ChromaDB at {persist_directory}") |
|
|
| def _initialize_collections(self): |
| """Initialize three ChromaDB collections.""" |
| try: |
| |
| self.episodic_memory = Chroma( |
| collection_name="episodic_memory", |
| embedding_function=self.embeddings, |
| persist_directory=f"{self.persist_directory}/episodic" |
| ) |
| logger.debug("Initialized episodic_memory collection") |
|
|
| |
| self.semantic_memory = Chroma( |
| collection_name="semantic_memory", |
| embedding_function=self.embeddings, |
| persist_directory=f"{self.persist_directory}/semantic" |
| ) |
| logger.debug("Initialized semantic_memory collection") |
|
|
| |
| self.stakeholder_profiles = Chroma( |
| collection_name="stakeholder_profiles", |
| embedding_function=self.embeddings, |
| persist_directory=f"{self.persist_directory}/stakeholders" |
| ) |
| logger.debug("Initialized stakeholder_profiles collection") |
|
|
| except Exception as e: |
| logger.error(f"Failed to initialize ChromaDB collections: {e}") |
| raise |
|
|
| async def process_task(self, task: Task) -> Task: |
| """ |
| Process memory-related task. |
| |
| Args: |
| task: Task with memory operation |
| |
| Returns: |
| Updated task with results |
| """ |
| logger.info(f"MemoryAgent processing task: {task.id}") |
| task.status = "in_progress" |
|
|
| try: |
| operation = task.metadata.get('operation') if task.metadata else None |
|
|
| if operation == 'store_episode': |
| |
| episode_data = task.metadata.get('episode_data', {}) |
| await self.store_episode(**episode_data) |
| task.result = {"stored": True} |
|
|
| elif operation == 'retrieve_context': |
| |
| query = task.metadata.get('query', '') |
| context_type = task.metadata.get('context_type', 'all') |
| top_k = task.metadata.get('top_k', 3) |
| |
| results = await self.retrieve_relevant_context( |
| query=query, |
| context_type=context_type, |
| top_k=top_k |
| ) |
| task.result = {"contexts": results} |
|
|
| elif operation == 'store_knowledge': |
| |
| documents = task.metadata.get('documents', []) |
| metadatas = task.metadata.get('metadatas', []) |
| category = task.metadata.get('category', 'general') |
| |
| await self.store_knowledge(documents, metadatas, category) |
| task.result = {"stored": len(documents)} |
|
|
| else: |
| raise ValueError(f"Unknown memory operation: {operation}") |
|
|
| task.status = "completed" |
| logger.info(f"Memory operation completed: {operation}") |
|
|
| except Exception as e: |
| logger.error(f"Memory operation failed: {e}") |
| task.status = "failed" |
| task.error = str(e) |
|
|
| return task |
|
|
| async def store_episode( |
| self, |
| task_id: str, |
| task_description: str, |
| scenario: ScenarioType, |
| workflow_steps: List[Dict], |
| outcome: Dict, |
| quality_score: float, |
| execution_time: Optional[float] = None, |
| iterations_used: Optional[int] = None, |
| ) -> None: |
| """ |
| Store a completed workflow execution for learning. |
| |
| Args: |
| task_id: Unique task identifier |
| task_description: Natural language task description |
| scenario: VISTA scenario type |
| workflow_steps: List of subtasks executed |
| outcome: Final output and results |
| quality_score: Quality score from validation (0.0-1.0) |
| execution_time: Total execution time in seconds |
| iterations_used: Number of refinement iterations |
| """ |
| try: |
| |
| content = f""" |
| Task: {task_description} |
| Scenario: {scenario.value if hasattr(scenario, 'value') else scenario} |
| Quality Score: {quality_score:.2f} |
| Steps: {len(workflow_steps)} |
| Outcome: {json.dumps(outcome, indent=2)[:500]} |
| """ |
|
|
| |
| metadata = { |
| "task_id": task_id, |
| "scenario": scenario.value if hasattr(scenario, 'value') else str(scenario), |
| "quality_score": float(quality_score), |
| "timestamp": datetime.now().isoformat(), |
| "num_steps": len(workflow_steps), |
| "execution_time": execution_time or 0.0, |
| "iterations": iterations_used or 0, |
| "success": quality_score >= 0.85, |
| } |
|
|
| |
| document = Document( |
| page_content=content, |
| metadata=metadata |
| ) |
|
|
| |
| self.episodic_memory.add_documents([document]) |
|
|
| logger.info(f"Stored episode: {task_id} (score: {quality_score:.2f})") |
|
|
| except Exception as e: |
| logger.error(f"Failed to store episode: {e}") |
| raise |
|
|
| async def retrieve_relevant_context( |
| self, |
| query: str, |
| context_type: MemoryType = "episodic", |
| top_k: int = 3, |
| scenario_filter: Optional[ScenarioType] = None, |
| min_quality_score: Optional[float] = None, |
| ) -> List[Document]: |
| """ |
| Semantic search across specified memory type. |
| |
| Args: |
| query: Natural language query |
| context_type: Memory type to search |
| top_k: Number of results to return |
| scenario_filter: Filter by VISTA scenario |
| min_quality_score: Minimum quality score for episodes |
| |
| Returns: |
| List of Document objects with content and metadata |
| """ |
| try: |
| results = [] |
|
|
| |
| |
| where_filter = None |
| if scenario_filter and min_quality_score is not None: |
| where_filter = { |
| "$and": [ |
| {"scenario": scenario_filter.value if hasattr(scenario_filter, 'value') else str(scenario_filter)}, |
| {"quality_score": {"$gte": min_quality_score}} |
| ] |
| } |
| elif scenario_filter: |
| where_filter = {"scenario": scenario_filter.value if hasattr(scenario_filter, 'value') else str(scenario_filter)} |
| elif min_quality_score is not None: |
| where_filter = {"quality_score": {"$gte": min_quality_score}} |
|
|
| |
| if context_type == "episodic" or context_type == "all": |
| episodic_results = self.episodic_memory.similarity_search( |
| query=query, |
| k=top_k, |
| filter=where_filter if where_filter else None |
| ) |
| results.extend(episodic_results) |
| logger.debug(f"Found {len(episodic_results)} episodic memories") |
|
|
| if context_type == "semantic" or context_type == "all": |
| semantic_results = self.semantic_memory.similarity_search( |
| query=query, |
| k=top_k |
| ) |
| results.extend(semantic_results) |
| logger.debug(f"Found {len(semantic_results)} semantic memories") |
|
|
| if context_type == "stakeholders" or context_type == "all": |
| stakeholder_results = self.stakeholder_profiles.similarity_search( |
| query=query, |
| k=top_k |
| ) |
| results.extend(stakeholder_results) |
| logger.debug(f"Found {len(stakeholder_results)} stakeholder profiles") |
|
|
| |
| unique_results = list({doc.page_content: doc for doc in results}.values()) |
| return unique_results[:top_k] |
|
|
| except Exception as e: |
| logger.error(f"Failed to retrieve context: {e}") |
| return [] |
|
|
| async def store_knowledge( |
| self, |
| documents: List[str], |
| metadatas: List[Dict], |
| category: str, |
| ) -> None: |
| """ |
| Store domain knowledge in semantic memory. |
| |
| Args: |
| documents: List of knowledge documents (text) |
| metadatas: List of metadata dicts |
| category: Knowledge category |
| |
| Categories: |
| - "patent_templates": Common patent structures |
| - "legal_frameworks": GDPR, Law 25 regulations |
| - "market_data": Industry sectors, trends |
| - "best_practices": Successful valorization strategies |
| """ |
| try: |
| |
| docs = [] |
| for i, (text, metadata) in enumerate(zip(documents, metadatas)): |
| |
| metadata['category'] = category |
| metadata['timestamp'] = datetime.now().isoformat() |
| metadata['doc_id'] = f"{category}_{i}" |
|
|
| doc = Document( |
| page_content=text, |
| metadata=metadata |
| ) |
| docs.append(doc) |
|
|
| |
| self.semantic_memory.add_documents(docs) |
|
|
| logger.info(f"Stored {len(docs)} knowledge documents in category: {category}") |
|
|
| except Exception as e: |
| logger.error(f"Failed to store knowledge: {e}") |
| raise |
|
|
| async def store_stakeholder_profile( |
| self, |
| name: str, |
| profile: Dict, |
| categories: List[str], |
| ) -> None: |
| """ |
| Store researcher or industry partner profile. |
| |
| Args: |
| name: Stakeholder name |
| profile: Profile data |
| categories: List of categories (expertise areas) |
| |
| Profile includes: |
| - expertise: List of expertise areas |
| - interests: Research interests |
| - collaborations: Past collaborations |
| - technologies: Technology domains |
| - location: Geographic location |
| - contact: Contact information |
| """ |
| try: |
| |
| content = f""" |
| Name: {name} |
| Expertise: {', '.join(profile.get('expertise', []))} |
| Interests: {', '.join(profile.get('interests', []))} |
| Technologies: {', '.join(profile.get('technologies', []))} |
| Location: {profile.get('location', 'Unknown')} |
| Past Collaborations: {profile.get('collaborations', 'None listed')} |
| """ |
|
|
| |
| metadata = { |
| "name": name, |
| "categories": ", ".join(categories), |
| "timestamp": datetime.now().isoformat(), |
| "location": profile.get('location', 'Unknown'), |
| "num_expertise": len(profile.get('expertise', [])), |
| } |
|
|
| |
| metadata['profile'] = json.dumps(profile) |
|
|
| |
| document = Document( |
| page_content=content, |
| metadata=metadata |
| ) |
|
|
| |
| self.stakeholder_profiles.add_documents([document]) |
|
|
| logger.info(f"Stored stakeholder profile: {name}") |
|
|
| except Exception as e: |
| logger.error(f"Failed to store stakeholder profile: {e}") |
| raise |
|
|
| async def learn_from_feedback( |
| self, |
| task_id: str, |
| feedback: str, |
| updated_score: Optional[float] = None, |
| ) -> None: |
| """ |
| Update episodic memory with user feedback. |
| Mark successful strategies for reuse. |
| |
| Args: |
| task_id: Task identifier |
| feedback: User feedback text |
| updated_score: Updated quality score after feedback |
| """ |
| try: |
| |
| results = self.episodic_memory.similarity_search( |
| query=task_id, |
| k=1, |
| filter={"task_id": task_id} |
| ) |
|
|
| if results: |
| logger.info(f"Found episode {task_id} for feedback update") |
| |
| |
| original = results[0] |
| content = f"{original.page_content}\n\nUser Feedback: {feedback}" |
| |
| metadata = original.metadata.copy() |
| if updated_score is not None: |
| metadata['quality_score'] = updated_score |
| metadata['has_feedback'] = True |
| metadata['feedback_timestamp'] = datetime.now().isoformat() |
|
|
| |
| doc = Document(page_content=content, metadata=metadata) |
| self.episodic_memory.add_documents([doc]) |
|
|
| logger.info(f"Updated episode {task_id} with feedback") |
| else: |
| logger.warning(f"Episode {task_id} not found for feedback") |
|
|
| except Exception as e: |
| logger.error(f"Failed to learn from feedback: {e}") |
|
|
| async def get_similar_episodes( |
| self, |
| task_description: str, |
| scenario: Optional[ScenarioType] = None, |
| min_quality_score: float = 0.8, |
| top_k: int = 3, |
| ) -> List[Dict]: |
| """ |
| Find similar past episodes for learning. |
| |
| Args: |
| task_description: Current task description |
| scenario: Optional scenario filter |
| min_quality_score: Minimum quality threshold |
| top_k: Number of results |
| |
| Returns: |
| List of episode dictionaries with metadata |
| """ |
| results = await self.retrieve_relevant_context( |
| query=task_description, |
| context_type="episodic", |
| top_k=top_k, |
| scenario_filter=scenario, |
| min_quality_score=min_quality_score |
| ) |
|
|
| episodes = [] |
| for doc in results: |
| episodes.append({ |
| "content": doc.page_content, |
| "metadata": doc.metadata |
| }) |
|
|
| return episodes |
|
|
| async def get_domain_knowledge( |
| self, |
| query: str, |
| category: Optional[str] = None, |
| top_k: int = 3, |
| ) -> List[Document]: |
| """ |
| Retrieve domain knowledge from semantic memory. |
| |
| Args: |
| query: Knowledge query |
| category: Optional category filter |
| top_k: Number of results |
| |
| Returns: |
| List of knowledge documents |
| """ |
| where_filter = {"category": category} if category else None |
|
|
| results = self.semantic_memory.similarity_search( |
| query=query, |
| k=top_k, |
| filter=where_filter |
| ) |
|
|
| return results |
|
|
| async def find_matching_stakeholders( |
| self, |
| requirements: str, |
| categories: Optional[List[str]] = None, |
| location: Optional[str] = None, |
| top_k: int = 5, |
| ) -> List[Dict]: |
| """ |
| Find stakeholders matching requirements. |
| |
| Args: |
| requirements: Description of needed expertise/capabilities |
| categories: Optional category filters |
| location: Optional location filter |
| top_k: Number of matches |
| |
| Returns: |
| List of matching stakeholder profiles |
| """ |
| |
| where_filter = {} |
| if location: |
| where_filter["location"] = location |
|
|
| results = self.stakeholder_profiles.similarity_search( |
| query=requirements, |
| k=top_k, |
| filter=where_filter if where_filter else None |
| ) |
|
|
| stakeholders = [] |
| for doc in results: |
| profile_data = json.loads(doc.metadata.get('profile', '{}')) |
| stakeholders.append({ |
| "name": doc.metadata.get('name'), |
| "profile": profile_data, |
| "match_text": doc.page_content, |
| "metadata": doc.metadata |
| }) |
|
|
| return stakeholders |
|
|
| def get_collection_stats(self) -> Dict[str, int]: |
| """ |
| Get statistics about memory collections. |
| |
| Returns: |
| Dictionary with collection counts |
| """ |
| try: |
| stats = { |
| "episodic_count": self.episodic_memory._collection.count(), |
| "semantic_count": self.semantic_memory._collection.count(), |
| "stakeholders_count": self.stakeholder_profiles._collection.count(), |
| } |
| return stats |
| except Exception as e: |
| logger.error(f"Failed to get collection stats: {e}") |
| return {"episodic_count": 0, "semantic_count": 0, "stakeholders_count": 0} |
|
|
|
|
| |
| def create_memory_agent( |
| llm_client: LangChainOllamaClient, |
| persist_directory: str = "data/vector_store", |
| ) -> MemoryAgent: |
| """ |
| Create a MemoryAgent instance. |
| |
| Args: |
| llm_client: LangChain Ollama client |
| persist_directory: Directory for ChromaDB persistence |
| |
| Returns: |
| MemoryAgent instance |
| """ |
| return MemoryAgent( |
| llm_client=llm_client, |
| persist_directory=persist_directory |
| ) |
|
|