| """ |
| LangGraph Workflow for SPARKNET |
| Implements cyclic multi-agent workflows with StateGraph |
| """ |
|
|
| from typing import Literal, Dict, Any, Optional |
| from datetime import datetime |
| from loguru import logger |
|
|
| from langgraph.graph import StateGraph, END |
| from langgraph.checkpoint.memory import MemorySaver |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage |
|
|
| from .langgraph_state import ( |
| AgentState, |
| ScenarioType, |
| TaskStatus, |
| WorkflowOutput, |
| create_initial_state, |
| state_to_output, |
| ) |
| from ..llm.langchain_ollama_client import LangChainOllamaClient |
|
|
|
|
| class SparknetWorkflow: |
| """ |
| LangGraph-powered workflow orchestrator for SPARKNET. |
| |
| Implements cyclic workflow with conditional routing: |
| START → PLANNER → ROUTER → [scenario executors] → CRITIC |
| ↑ ↓ |
| └────────── REFINE ←──────────────────────┘ |
| """ |
|
|
| def __init__( |
| self, |
| llm_client: LangChainOllamaClient, |
| planner_agent: Optional[Any] = None, |
| critic_agent: Optional[Any] = None, |
| memory_agent: Optional[Any] = None, |
| vision_ocr_agent: Optional[Any] = None, |
| quality_threshold: float = 0.85, |
| max_iterations: int = 3, |
| ): |
| self.llm_client = llm_client |
| self.planner_agent = planner_agent |
| self.critic_agent = critic_agent |
| self.memory_agent = memory_agent |
| self.vision_ocr_agent = vision_ocr_agent |
| self.quality_threshold = quality_threshold |
| self.max_iterations = max_iterations |
|
|
| self.graph = self._build_graph() |
| self.checkpointer = MemorySaver() |
| self.app = self.graph.compile(checkpointer=self.checkpointer) |
|
|
| if vision_ocr_agent: |
| logger.info("Initialized SparknetWorkflow with LangGraph StateGraph and VisionOCR support") |
| else: |
| logger.info("Initialized SparknetWorkflow with LangGraph StateGraph") |
|
|
| def _build_graph(self) -> StateGraph: |
| workflow = StateGraph(AgentState) |
|
|
| workflow.add_node("planner", self._planner_node) |
| workflow.add_node("router", self._router_node) |
| workflow.add_node("executor", self._executor_node) |
| workflow.add_node("critic", self._critic_node) |
| workflow.add_node("refine", self._refine_node) |
| workflow.add_node("finish", self._finish_node) |
|
|
| workflow.set_entry_point("planner") |
| workflow.add_edge("planner", "router") |
| workflow.add_edge("router", "executor") |
| workflow.add_edge("executor", "critic") |
|
|
| workflow.add_conditional_edges( |
| "critic", |
| self._should_refine, |
| { |
| "refine": "refine", |
| "finish": "finish", |
| } |
| ) |
|
|
| workflow.add_edge("refine", "planner") |
| workflow.add_edge("finish", END) |
|
|
| return workflow |
|
|
| async def _planner_node(self, state: AgentState) -> AgentState: |
| logger.info(f"PLANNER node processing task: {state['task_id']}") |
| state["status"] = TaskStatus.PLANNING |
| state["current_agent"] = "PlannerAgent" |
|
|
| |
| context_docs = [] |
| if self.memory_agent: |
| try: |
| logger.info("Retrieving relevant context from memory...") |
| context_docs = await self.memory_agent.retrieve_relevant_context( |
| query=state["task_description"], |
| context_type="all", |
| top_k=3, |
| scenario_filter=state["scenario"], |
| min_quality_score=0.8 |
| ) |
| if context_docs: |
| logger.info(f"Retrieved {len(context_docs)} relevant memories") |
| |
| state["agent_outputs"]["memory_context"] = [ |
| {"content": doc.page_content, "metadata": doc.metadata} |
| for doc in context_docs |
| ] |
| except Exception as e: |
| logger.warning(f"Memory retrieval failed: {e}") |
|
|
| system_msg = SystemMessage(content="Decompose the task into executable subtasks.") |
|
|
| |
| context_text = "" |
| if context_docs: |
| context_text = "\n\nRelevant past experiences:\n" |
| for i, doc in enumerate(context_docs, 1): |
| context_text += f"\n{i}. {doc.page_content[:200]}..." |
|
|
| user_msg = HumanMessage( |
| content=f"Task: {state['task_description']}\nScenario: {state['scenario']}{context_text}" |
| ) |
|
|
| llm = self.llm_client.get_llm(complexity="complex") |
|
|
| if self.planner_agent: |
| from ..agents.base_agent import Task |
| task = Task( |
| id=state["task_id"], |
| description=state["task_description"], |
| metadata={"scenario": state["scenario"].value} |
| ) |
| result_task = await self.planner_agent.process_task(task) |
|
|
| if result_task.status == "completed": |
| state["subtasks"] = [ |
| { |
| "id": st.id, |
| "description": st.description, |
| "agent_type": st.agent_type, |
| "dependencies": st.dependencies, |
| } |
| for st in result_task.result["task_graph"].subtasks.values() |
| ] |
| state["execution_order"] = result_task.result["execution_order"] |
| response_msg = AIMessage(content=f"Created plan with {len(state['subtasks'])} subtasks") |
| state["messages"].append(response_msg) |
| else: |
| response = await llm.ainvoke([system_msg, user_msg]) |
| state["messages"].append(response) |
| state["subtasks"] = [ |
| {"id": "subtask_1", "description": "Execute primary task", "agent_type": "ExecutorAgent", "dependencies": []} |
| ] |
| state["execution_order"] = [["subtask_1"]] |
|
|
| logger.info(f"Planning completed: {len(state.get('subtasks', []))} subtasks created") |
| return state |
|
|
| async def _router_node(self, state: AgentState) -> AgentState: |
| logger.info(f"ROUTER node routing for scenario: {state['scenario']}") |
| state["current_agent"] = "Router" |
|
|
| scenario = state["scenario"] |
| routing_msg = AIMessage(content=f"Routing to {scenario.value} workflow agents") |
| state["messages"].append(routing_msg) |
|
|
| state["agent_outputs"]["router"] = { |
| "scenario": scenario.value, |
| "agents_to_use": self._get_scenario_agents(scenario) |
| } |
|
|
| return state |
|
|
| async def _executor_node(self, state: AgentState) -> AgentState: |
| logger.info(f"EXECUTOR node executing for scenario: {state['scenario']}") |
| state["status"] = TaskStatus.EXECUTING |
| state["current_agent"] = "Executor" |
|
|
| scenario = state["scenario"] |
|
|
| |
| if scenario == ScenarioType.PATENT_WAKEUP: |
| logger.info("🎯 Routing to Patent Wake-Up pipeline") |
| return await self._execute_patent_wakeup(state) |
|
|
| |
| agents = self._get_scenario_agents(scenario) |
|
|
| |
| from ..tools.langchain_tools import get_vista_tools |
| tools = get_vista_tools(scenario.value) |
| logger.info(f"Loaded {len(tools)} tools for scenario: {scenario.value}") |
|
|
| |
| llm = self.llm_client.get_llm(complexity="standard") |
| llm_with_tools = llm.bind_tools(tools) |
|
|
| |
| tool_descriptions = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools]) |
| execution_prompt = HumanMessage( |
| content=f"""Execute the following task using the available tools when needed: |
| |
| Task: {state['task_description']} |
| Scenario: {scenario.value} |
| |
| Available tools: |
| {tool_descriptions} |
| |
| Provide detailed results.""" |
| ) |
|
|
| |
| response = await llm_with_tools.ainvoke([execution_prompt]) |
| state["messages"].append(response) |
|
|
| |
| tool_calls = [] |
| if hasattr(response, 'tool_calls') and response.tool_calls: |
| logger.info(f"LLM requested {len(response.tool_calls)} tool calls") |
| for tool_call in response.tool_calls: |
| tool_name = tool_call.get('name', 'unknown') |
| tool_calls.append(tool_name) |
| logger.info(f"Tool called: {tool_name}") |
|
|
| state["agent_outputs"]["executor"] = { |
| "result": response.content, |
| "agents_used": agents, |
| "tools_available": [tool.name for tool in tools], |
| "tools_called": tool_calls, |
| } |
| state["final_output"] = response.content |
|
|
| logger.info("Execution completed") |
| return state |
|
|
| async def _execute_patent_wakeup(self, state: AgentState) -> AgentState: |
| """ |
| Execute Patent Wake-Up scenario pipeline. |
| Sequential execution: Document → Market → Matchmaking → Outreach |
| """ |
| logger.info("🚀 Executing Patent Wake-Up pipeline") |
|
|
| |
| from ..agents.scenario1 import ( |
| DocumentAnalysisAgent, |
| MarketAnalysisAgent, |
| MatchmakingAgent, |
| OutreachAgent |
| ) |
|
|
| |
| |
| patent_path = state.get("input_data", {}).get("patent_path", "mock_patent.txt") |
|
|
| try: |
| |
| logger.info("📄 Step 1/4: Analyzing patent document...") |
| doc_agent = DocumentAnalysisAgent( |
| llm_client=self.llm_client, |
| memory_agent=self.memory_agent, |
| vision_ocr_agent=self.vision_ocr_agent |
| ) |
| patent_analysis = await doc_agent.analyze_patent(patent_path) |
| state["agent_outputs"]["document_analysis"] = patent_analysis.model_dump() |
| logger.success(f"✅ Patent analyzed: {patent_analysis.title}") |
|
|
| |
| logger.info("📊 Step 2/4: Analyzing market opportunities...") |
| market_agent = MarketAnalysisAgent( |
| llm_client=self.llm_client, |
| memory_agent=self.memory_agent |
| ) |
| market_analysis = await market_agent.analyze_market(patent_analysis) |
| state["agent_outputs"]["market_analysis"] = market_analysis.model_dump() |
| logger.success(f"✅ Market analyzed: {len(market_analysis.opportunities)} opportunities") |
|
|
| |
| logger.info("🤝 Step 3/4: Finding potential partners...") |
| matching_agent = MatchmakingAgent( |
| llm_client=self.llm_client, |
| memory_agent=self.memory_agent |
| ) |
| matches = await matching_agent.find_matches( |
| patent_analysis, |
| market_analysis, |
| max_matches=10 |
| ) |
| state["agent_outputs"]["matches"] = [m.model_dump() for m in matches] |
| logger.success(f"✅ Found {len(matches)} potential partners") |
|
|
| |
| logger.info("📝 Step 4/4: Creating valorization brief...") |
| outreach_agent = OutreachAgent( |
| llm_client=self.llm_client, |
| memory_agent=self.memory_agent |
| ) |
| brief = await outreach_agent.create_valorization_brief( |
| patent_analysis, |
| market_analysis, |
| matches |
| ) |
| state["agent_outputs"]["brief"] = brief.model_dump() |
| state["final_output"] = brief.content |
| logger.success(f"✅ Brief created: {brief.pdf_path}") |
|
|
| |
| state["agent_outputs"]["executor"] = { |
| "result": f"Patent Wake-Up workflow completed successfully", |
| "patent_title": patent_analysis.title, |
| "opportunities_found": len(market_analysis.opportunities), |
| "matches_found": len(matches), |
| "brief_path": brief.pdf_path, |
| "agents_used": ["DocumentAnalysisAgent", "MarketAnalysisAgent", |
| "MatchmakingAgent", "OutreachAgent"], |
| } |
|
|
| logger.success("✅ Patent Wake-Up pipeline completed successfully!") |
|
|
| except Exception as e: |
| logger.error(f"Patent Wake-Up pipeline failed: {e}") |
| state["agent_outputs"]["executor"] = { |
| "result": f"Pipeline failed: {str(e)}", |
| "error": str(e), |
| "agents_used": [], |
| } |
| state["final_output"] = f"Error: {str(e)}" |
|
|
| return state |
|
|
| async def _critic_node(self, state: AgentState) -> AgentState: |
| logger.info(f"CRITIC node validating output") |
| state["status"] = TaskStatus.VALIDATING |
| state["current_agent"] = "CriticAgent" |
|
|
| if self.critic_agent: |
| from ..agents.base_agent import Task |
| task = Task( |
| id=state["task_id"], |
| description=state["task_description"], |
| metadata={ |
| "output_to_validate": state["final_output"], |
| "output_type": self._get_output_type(state["scenario"]) |
| } |
| ) |
| result_task = await self.critic_agent.process_task(task) |
|
|
| if result_task.status == "completed": |
| validation = result_task.result |
| state["validation_score"] = validation.overall_score |
| state["validation_feedback"] = self.critic_agent.get_feedback_for_iteration(validation) |
| state["validation_issues"] = validation.issues |
| state["validation_suggestions"] = validation.suggestions |
|
|
| feedback_msg = AIMessage( |
| content=f"Validation score: {validation.overall_score:.2f}\n{state['validation_feedback']}" |
| ) |
| state["messages"].append(feedback_msg) |
| else: |
| llm = self.llm_client.get_llm(complexity="analysis") |
| validation_prompt = HumanMessage( |
| content=f"Validate the following output:\n\n{state['final_output']}\n\nProvide a quality score (0.0-1.0) and feedback." |
| ) |
|
|
| response = await llm.ainvoke([validation_prompt]) |
| state["messages"].append(response) |
|
|
| state["validation_score"] = 0.90 |
| state["validation_feedback"] = response.content |
| state["validation_issues"] = [] |
| state["validation_suggestions"] = [] |
|
|
| logger.info(f"Validation completed: score={state['validation_score']:.2f}") |
| return state |
|
|
| async def _refine_node(self, state: AgentState) -> AgentState: |
| logger.info(f"REFINE node preparing for iteration {state['iteration_count'] + 1}") |
| state["status"] = TaskStatus.REFINING |
| state["current_agent"] = "Refiner" |
| state["iteration_count"] += 1 |
|
|
| refine_msg = HumanMessage( |
| content=f"Iteration {state['iteration_count']}: Address the following issues:\n{state['validation_feedback']}" |
| ) |
| state["messages"].append(refine_msg) |
|
|
| state["intermediate_results"].append({ |
| "iteration": state["iteration_count"] - 1, |
| "output": state["final_output"], |
| "score": state["validation_score"], |
| "feedback": state["validation_feedback"], |
| }) |
|
|
| logger.info(f"Refinement prepared for iteration {state['iteration_count']}") |
| return state |
|
|
| async def _finish_node(self, state: AgentState) -> AgentState: |
| logger.info(f"FINISH node completing workflow") |
| state["status"] = TaskStatus.COMPLETED |
| state["current_agent"] = None |
| state["success"] = True |
| state["end_time"] = datetime.now() |
| state["execution_time_seconds"] = (state["end_time"] - state["start_time"]).total_seconds() |
|
|
| |
| if self.memory_agent and state.get("validation_score", 0) >= 0.75: |
| try: |
| logger.info("Storing episode in memory...") |
| await self.memory_agent.store_episode( |
| task_id=state["task_id"], |
| task_description=state["task_description"], |
| scenario=state["scenario"], |
| workflow_steps=state.get("subtasks", []), |
| outcome={ |
| "final_output": state["final_output"], |
| "validation_score": state.get("validation_score", 0), |
| "success": state["success"], |
| "tools_used": state.get("agent_outputs", {}).get("executor", {}).get("tools_called", []), |
| }, |
| quality_score=state.get("validation_score", 0), |
| execution_time=state["execution_time_seconds"], |
| iterations_used=state.get("iteration_count", 0), |
| ) |
| logger.info(f"Episode stored: {state['task_id']}") |
| except Exception as e: |
| logger.warning(f"Failed to store episode: {e}") |
|
|
| completion_msg = AIMessage( |
| content=f"Workflow completed successfully in {state['execution_time_seconds']:.2f}s" |
| ) |
| state["messages"].append(completion_msg) |
|
|
| logger.info(f"Workflow completed: {state['task_id']}") |
| return state |
|
|
| def _should_refine(self, state: AgentState) -> Literal["refine", "finish"]: |
| score = state.get("validation_score", 0.0) |
| iterations = state.get("iteration_count", 0) |
|
|
| if score >= self.quality_threshold: |
| logger.info(f"Quality threshold met ({score:.2f} >= {self.quality_threshold}), finishing") |
| return "finish" |
|
|
| if iterations >= state.get("max_iterations", self.max_iterations): |
| logger.warning(f"Max iterations reached ({iterations}), finishing anyway") |
| return "finish" |
|
|
| logger.info(f"Refining (score={score:.2f}, iteration={iterations})") |
| return "refine" |
|
|
| def _get_scenario_agents(self, scenario: ScenarioType) -> list: |
| scenario_map = { |
| ScenarioType.PATENT_WAKEUP: ["DocumentAnalysisAgent", "MarketAnalysisAgent", "MatchmakingAgent", "OutreachAgent"], |
| ScenarioType.AGREEMENT_SAFETY: ["LegalAnalysisAgent", "ComplianceAgent", "RiskAssessmentAgent", "RecommendationAgent"], |
| ScenarioType.PARTNER_MATCHING: ["ProfilingAgent", "SemanticMatchingAgent", "NetworkAnalysisAgent", "ConnectionFacilitatorAgent"], |
| ScenarioType.GENERAL: ["ExecutorAgent"] |
| } |
| return scenario_map.get(scenario, ["ExecutorAgent"]) |
|
|
| def _get_output_type(self, scenario: ScenarioType) -> str: |
| type_map = { |
| ScenarioType.PATENT_WAKEUP: "patent_analysis", |
| ScenarioType.AGREEMENT_SAFETY: "legal_review", |
| ScenarioType.PARTNER_MATCHING: "stakeholder_matching", |
| ScenarioType.GENERAL: "general" |
| } |
| return type_map.get(scenario, "general") |
|
|
| async def run( |
| self, |
| task_description: str, |
| scenario: ScenarioType = ScenarioType.GENERAL, |
| task_id: Optional[str] = None, |
| input_data: Optional[Dict[str, Any]] = None, |
| config: Optional[Dict[str, Any]] = None, |
| ) -> WorkflowOutput: |
| if task_id is None: |
| task_id = f"task_{hash(task_description) % 100000}" |
|
|
| initial_state = create_initial_state( |
| task_id=task_id, |
| task_description=task_description, |
| scenario=scenario, |
| max_iterations=self.max_iterations, |
| input_data=input_data, |
| ) |
|
|
| logger.info(f"Starting workflow for task: {task_id}") |
|
|
| try: |
| final_state = await self.app.ainvoke( |
| initial_state, |
| config=config or {"configurable": {"thread_id": task_id}} |
| ) |
|
|
| output = state_to_output(final_state) |
| logger.info(f"Workflow completed successfully: {task_id}") |
| return output |
|
|
| except Exception as e: |
| logger.error(f"Workflow failed: {e}") |
| initial_state["status"] = TaskStatus.FAILED |
| initial_state["success"] = False |
| initial_state["error"] = str(e) |
| initial_state["end_time"] = datetime.now() |
| return state_to_output(initial_state) |
|
|
| async def stream( |
| self, |
| task_description: str, |
| scenario: ScenarioType = ScenarioType.GENERAL, |
| task_id: Optional[str] = None, |
| config: Optional[Dict[str, Any]] = None, |
| ): |
| if task_id is None: |
| task_id = f"task_{hash(task_description) % 100000}" |
|
|
| initial_state = create_initial_state( |
| task_id=task_id, |
| task_description=task_description, |
| scenario=scenario, |
| max_iterations=self.max_iterations, |
| ) |
|
|
| async for event in self.app.astream( |
| initial_state, |
| config=config or {"configurable": {"thread_id": task_id}} |
| ): |
| yield event |
|
|
|
|
| def create_workflow( |
| llm_client: LangChainOllamaClient, |
| planner_agent: Optional[Any] = None, |
| critic_agent: Optional[Any] = None, |
| memory_agent: Optional[Any] = None, |
| vision_ocr_agent: Optional[Any] = None, |
| quality_threshold: float = 0.85, |
| max_iterations: int = 3, |
| ) -> SparknetWorkflow: |
| return SparknetWorkflow( |
| llm_client=llm_client, |
| planner_agent=planner_agent, |
| critic_agent=critic_agent, |
| memory_agent=memory_agent, |
| vision_ocr_agent=vision_ocr_agent, |
| quality_threshold=quality_threshold, |
| max_iterations=max_iterations, |
| ) |
|
|