Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,505 Bytes
9b88b42 6752363 9b88b42 a081162 3109d12 9b88b42 3109d12 9b88b42 a081162 9f411df 9b88b42 d038452 9b88b42 a081162 9b88b42 4cc5533 9b88b42 3109d12 9b88b42 3109d12 9b88b42 3109d12 d72878a 9b88b42 f2c29a4 9b88b42 f2c29a4 9b88b42 6752363 9b88b42 6752363 9b88b42 6752363 9b88b42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
"""LangGraph agent state models.
This module defines the state schema for multi-agent orchestration using LangGraph.
"""
from typing import TypedDict, Annotated, Sequence, Dict, Any, Optional, List
from datetime import datetime, timezone
import operator
from pydantic import BaseModel, Field
def merge_dicts(current: Dict[str, Any], new: Dict[str, Any]) -> Dict[str, Any]:
"""Merge two dictionaries for LangGraph state updates.
Args:
current: Current dictionary value
new: New dictionary to merge in
Returns:
Merged dictionary with new values overwriting current
"""
if current is None:
return new or {}
if new is None:
return current
return {**current, **new}
def last_value(current: Any, new: Any) -> Any:
"""Reducer that keeps the last non-None value.
Args:
current: Current value
new: New value
Returns:
New value if not None, otherwise current
"""
return new if new is not None else current
def sum_or_last(current: Optional[int], new: Optional[int]) -> Optional[int]:
"""Reducer for metrics that sums values or keeps last.
Args:
current: Current value
new: New value
Returns:
Sum if both are set, otherwise the non-None value
"""
if current is None:
return new
if new is None:
return current
return current + new
class AgentState(TypedDict):
"""LangGraph state for multi-agent portfolio analysis workflow.
This state is passed between agents in the workflow and accumulates
results from each phase of the analysis.
"""
# Input (use last_value reducer for parallel safety)
portfolio_id: Annotated[str, last_value]
user_query: Annotated[str, last_value]
risk_tolerance: Annotated[str, last_value]
holdings: Annotated[List[Dict[str, Any]], last_value]
# Phase 1: Data Layer Results (from MCPs)
historical_prices: Annotated[Dict[str, Any], merge_dicts]
fundamentals: Annotated[Dict[str, Any], merge_dicts]
economic_data: Annotated[Dict[str, Any], merge_dicts]
realtime_data: Annotated[Dict[str, Any], merge_dicts]
technical_indicators: Annotated[Dict[str, Any], merge_dicts]
sentiment_data: Annotated[Dict[str, Any], merge_dicts] # Enhancement #3: News Sentiment MCP
# Phase 1.5: Feature Engineering
feature_vectors: Annotated[Dict[str, Any], merge_dicts]
# Phase 2: Computation Layer Results
optimisation_results: Annotated[Dict[str, Any], merge_dicts]
risk_analysis: Annotated[Dict[str, Any], merge_dicts]
# Phase 2.5: ML Predictions (P1)
ensemble_forecasts: Annotated[Dict[str, Any], merge_dicts]
# Phase 3: LLM Synthesis
ai_synthesis: Annotated[str, last_value]
recommendations: Annotated[List[str], last_value]
reasoning_steps: Annotated[List[str], operator.add]
# Metadata
current_step: Annotated[str, last_value]
errors: Annotated[List[str], operator.add]
mcp_calls: Annotated[List[Dict[str, Any]], operator.add]
# Performance Metrics (sum for parallel branches)
phase_1_duration_ms: Annotated[Optional[int], last_value]
phase_1_5_duration_ms: Annotated[Optional[int], last_value]
phase_2_duration_ms: Annotated[Optional[int], last_value]
phase_2_5_duration_ms: Annotated[Optional[int], last_value]
phase_3_duration_ms: Annotated[Optional[int], last_value]
llm_input_tokens: Annotated[Optional[int], sum_or_last]
llm_output_tokens: Annotated[Optional[int], sum_or_last]
llm_total_tokens: Annotated[Optional[int], sum_or_last]
llm_request_count: Annotated[Optional[int], sum_or_last]
class MCPCall(BaseModel):
"""Record of an MCP tool call.
Accepts both 'mcp_server' and 'mcp' field names for backward compatibility.
"""
model_config = {"populate_by_name": True}
mcp_server: str = Field(..., validation_alias="mcp", description="MCP server name")
tool_name: str = Field(..., validation_alias="tool", description="Tool called")
parameters: Dict[str, Any] = Field(default_factory=dict)
result: Optional[Dict[str, Any]] = None
error: Optional[str] = None
duration_ms: Optional[int] = Field(None, ge=0)
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class AgentMessage(BaseModel):
"""Message from an agent with metadata."""
role: str = Field(..., description="Agent role (user/assistant/system)")
content: str = Field(..., min_length=1)
agent_name: Optional[str] = Field(None, description="Name of agent that generated message")
thinking: Optional[str] = Field(None, description="Agent reasoning")
tools_used: Optional[List[MCPCall]] = Field(default_factory=list)
confidence: Optional[float] = Field(None, ge=0.0, le=1.0)
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class WorkflowStatus(BaseModel):
"""Status of the multi-agent workflow."""
session_id: str
current_phase: str = Field(..., description="Current execution phase")
phase_1_complete: bool = Field(default=False)
phase_2_complete: bool = Field(default=False)
phase_3_complete: bool = Field(default=False)
errors: List[str] = Field(default_factory=list)
mcp_calls: List[MCPCall] = Field(default_factory=list)
started_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
completed_at: Optional[datetime] = None
execution_time_ms: Optional[int] = None
|