BrianIsaac's picture
feat: add streaming progress and fix ReAct agent tool compatibility
3109d12
"""LangGraph agent state models.
This module defines the state schema for multi-agent orchestration using LangGraph.
"""
from typing import TypedDict, Annotated, Sequence, Dict, Any, Optional, List
from datetime import datetime, timezone
import operator
from pydantic import BaseModel, Field
def merge_dicts(current: Dict[str, Any], new: Dict[str, Any]) -> Dict[str, Any]:
"""Merge two dictionaries for LangGraph state updates.
Args:
current: Current dictionary value
new: New dictionary to merge in
Returns:
Merged dictionary with new values overwriting current
"""
if current is None:
return new or {}
if new is None:
return current
return {**current, **new}
def last_value(current: Any, new: Any) -> Any:
"""Reducer that keeps the last non-None value.
Args:
current: Current value
new: New value
Returns:
New value if not None, otherwise current
"""
return new if new is not None else current
def sum_or_last(current: Optional[int], new: Optional[int]) -> Optional[int]:
"""Reducer for metrics that sums values or keeps last.
Args:
current: Current value
new: New value
Returns:
Sum if both are set, otherwise the non-None value
"""
if current is None:
return new
if new is None:
return current
return current + new
class AgentState(TypedDict):
"""LangGraph state for multi-agent portfolio analysis workflow.
This state is passed between agents in the workflow and accumulates
results from each phase of the analysis.
"""
# Input (use last_value reducer for parallel safety)
portfolio_id: Annotated[str, last_value]
user_query: Annotated[str, last_value]
risk_tolerance: Annotated[str, last_value]
holdings: Annotated[List[Dict[str, Any]], last_value]
# Phase 1: Data Layer Results (from MCPs)
historical_prices: Annotated[Dict[str, Any], merge_dicts]
fundamentals: Annotated[Dict[str, Any], merge_dicts]
economic_data: Annotated[Dict[str, Any], merge_dicts]
realtime_data: Annotated[Dict[str, Any], merge_dicts]
technical_indicators: Annotated[Dict[str, Any], merge_dicts]
sentiment_data: Annotated[Dict[str, Any], merge_dicts] # Enhancement #3: News Sentiment MCP
# Phase 1.5: Feature Engineering
feature_vectors: Annotated[Dict[str, Any], merge_dicts]
# Phase 2: Computation Layer Results
optimisation_results: Annotated[Dict[str, Any], merge_dicts]
risk_analysis: Annotated[Dict[str, Any], merge_dicts]
# Phase 2.5: ML Predictions (P1)
ensemble_forecasts: Annotated[Dict[str, Any], merge_dicts]
# Phase 3: LLM Synthesis
ai_synthesis: Annotated[str, last_value]
recommendations: Annotated[List[str], last_value]
reasoning_steps: Annotated[List[str], operator.add]
# Metadata
current_step: Annotated[str, last_value]
errors: Annotated[List[str], operator.add]
mcp_calls: Annotated[List[Dict[str, Any]], operator.add]
# Performance Metrics (sum for parallel branches)
phase_1_duration_ms: Annotated[Optional[int], last_value]
phase_1_5_duration_ms: Annotated[Optional[int], last_value]
phase_2_duration_ms: Annotated[Optional[int], last_value]
phase_2_5_duration_ms: Annotated[Optional[int], last_value]
phase_3_duration_ms: Annotated[Optional[int], last_value]
llm_input_tokens: Annotated[Optional[int], sum_or_last]
llm_output_tokens: Annotated[Optional[int], sum_or_last]
llm_total_tokens: Annotated[Optional[int], sum_or_last]
llm_request_count: Annotated[Optional[int], sum_or_last]
class MCPCall(BaseModel):
"""Record of an MCP tool call.
Accepts both 'mcp_server' and 'mcp' field names for backward compatibility.
"""
model_config = {"populate_by_name": True}
mcp_server: str = Field(..., validation_alias="mcp", description="MCP server name")
tool_name: str = Field(..., validation_alias="tool", description="Tool called")
parameters: Dict[str, Any] = Field(default_factory=dict)
result: Optional[Dict[str, Any]] = None
error: Optional[str] = None
duration_ms: Optional[int] = Field(None, ge=0)
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class AgentMessage(BaseModel):
"""Message from an agent with metadata."""
role: str = Field(..., description="Agent role (user/assistant/system)")
content: str = Field(..., min_length=1)
agent_name: Optional[str] = Field(None, description="Name of agent that generated message")
thinking: Optional[str] = Field(None, description="Agent reasoning")
tools_used: Optional[List[MCPCall]] = Field(default_factory=list)
confidence: Optional[float] = Field(None, ge=0.0, le=1.0)
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class WorkflowStatus(BaseModel):
"""Status of the multi-agent workflow."""
session_id: str
current_phase: str = Field(..., description="Current execution phase")
phase_1_complete: bool = Field(default=False)
phase_2_complete: bool = Field(default=False)
phase_3_complete: bool = Field(default=False)
errors: List[str] = Field(default_factory=list)
mcp_calls: List[MCPCall] = Field(default_factory=list)
started_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
completed_at: Optional[datetime] = None
execution_time_ms: Optional[int] = None