Spaces:
Running
on
Zero
Running
on
Zero
| """LangGraph agent state models. | |
| This module defines the state schema for multi-agent orchestration using LangGraph. | |
| """ | |
| from typing import TypedDict, Annotated, Sequence, Dict, Any, Optional, List | |
| from datetime import datetime, timezone | |
| import operator | |
| from pydantic import BaseModel, Field | |
| def merge_dicts(current: Dict[str, Any], new: Dict[str, Any]) -> Dict[str, Any]: | |
| """Merge two dictionaries for LangGraph state updates. | |
| Args: | |
| current: Current dictionary value | |
| new: New dictionary to merge in | |
| Returns: | |
| Merged dictionary with new values overwriting current | |
| """ | |
| if current is None: | |
| return new or {} | |
| if new is None: | |
| return current | |
| return {**current, **new} | |
| def last_value(current: Any, new: Any) -> Any: | |
| """Reducer that keeps the last non-None value. | |
| Args: | |
| current: Current value | |
| new: New value | |
| Returns: | |
| New value if not None, otherwise current | |
| """ | |
| return new if new is not None else current | |
| def sum_or_last(current: Optional[int], new: Optional[int]) -> Optional[int]: | |
| """Reducer for metrics that sums values or keeps last. | |
| Args: | |
| current: Current value | |
| new: New value | |
| Returns: | |
| Sum if both are set, otherwise the non-None value | |
| """ | |
| if current is None: | |
| return new | |
| if new is None: | |
| return current | |
| return current + new | |
| class AgentState(TypedDict): | |
| """LangGraph state for multi-agent portfolio analysis workflow. | |
| This state is passed between agents in the workflow and accumulates | |
| results from each phase of the analysis. | |
| """ | |
| # Input (use last_value reducer for parallel safety) | |
| portfolio_id: Annotated[str, last_value] | |
| user_query: Annotated[str, last_value] | |
| risk_tolerance: Annotated[str, last_value] | |
| holdings: Annotated[List[Dict[str, Any]], last_value] | |
| # Phase 1: Data Layer Results (from MCPs) | |
| historical_prices: Annotated[Dict[str, Any], merge_dicts] | |
| fundamentals: Annotated[Dict[str, Any], merge_dicts] | |
| economic_data: Annotated[Dict[str, Any], merge_dicts] | |
| realtime_data: Annotated[Dict[str, Any], merge_dicts] | |
| technical_indicators: Annotated[Dict[str, Any], merge_dicts] | |
| sentiment_data: Annotated[Dict[str, Any], merge_dicts] # Enhancement #3: News Sentiment MCP | |
| # Phase 1.5: Feature Engineering | |
| feature_vectors: Annotated[Dict[str, Any], merge_dicts] | |
| # Phase 2: Computation Layer Results | |
| optimisation_results: Annotated[Dict[str, Any], merge_dicts] | |
| risk_analysis: Annotated[Dict[str, Any], merge_dicts] | |
| # Phase 2.5: ML Predictions (P1) | |
| ensemble_forecasts: Annotated[Dict[str, Any], merge_dicts] | |
| # Phase 3: LLM Synthesis | |
| ai_synthesis: Annotated[str, last_value] | |
| recommendations: Annotated[List[str], last_value] | |
| reasoning_steps: Annotated[List[str], operator.add] | |
| # Metadata | |
| current_step: Annotated[str, last_value] | |
| errors: Annotated[List[str], operator.add] | |
| mcp_calls: Annotated[List[Dict[str, Any]], operator.add] | |
| # Performance Metrics (sum for parallel branches) | |
| phase_1_duration_ms: Annotated[Optional[int], last_value] | |
| phase_1_5_duration_ms: Annotated[Optional[int], last_value] | |
| phase_2_duration_ms: Annotated[Optional[int], last_value] | |
| phase_2_5_duration_ms: Annotated[Optional[int], last_value] | |
| phase_3_duration_ms: Annotated[Optional[int], last_value] | |
| llm_input_tokens: Annotated[Optional[int], sum_or_last] | |
| llm_output_tokens: Annotated[Optional[int], sum_or_last] | |
| llm_total_tokens: Annotated[Optional[int], sum_or_last] | |
| llm_request_count: Annotated[Optional[int], sum_or_last] | |
| class MCPCall(BaseModel): | |
| """Record of an MCP tool call. | |
| Accepts both 'mcp_server' and 'mcp' field names for backward compatibility. | |
| """ | |
| model_config = {"populate_by_name": True} | |
| mcp_server: str = Field(..., validation_alias="mcp", description="MCP server name") | |
| tool_name: str = Field(..., validation_alias="tool", description="Tool called") | |
| parameters: Dict[str, Any] = Field(default_factory=dict) | |
| result: Optional[Dict[str, Any]] = None | |
| error: Optional[str] = None | |
| duration_ms: Optional[int] = Field(None, ge=0) | |
| timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) | |
| class AgentMessage(BaseModel): | |
| """Message from an agent with metadata.""" | |
| role: str = Field(..., description="Agent role (user/assistant/system)") | |
| content: str = Field(..., min_length=1) | |
| agent_name: Optional[str] = Field(None, description="Name of agent that generated message") | |
| thinking: Optional[str] = Field(None, description="Agent reasoning") | |
| tools_used: Optional[List[MCPCall]] = Field(default_factory=list) | |
| confidence: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) | |
| class WorkflowStatus(BaseModel): | |
| """Status of the multi-agent workflow.""" | |
| session_id: str | |
| current_phase: str = Field(..., description="Current execution phase") | |
| phase_1_complete: bool = Field(default=False) | |
| phase_2_complete: bool = Field(default=False) | |
| phase_3_complete: bool = Field(default=False) | |
| errors: List[str] = Field(default_factory=list) | |
| mcp_calls: List[MCPCall] = Field(default_factory=list) | |
| started_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) | |
| completed_at: Optional[datetime] = None | |
| execution_time_ms: Optional[int] = None | |