File size: 5,505 Bytes
9b88b42
 
 
 
 
 
6752363
9b88b42
 
 
 
 
a081162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3109d12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b88b42
 
 
 
 
 
 
3109d12
 
 
 
 
9b88b42
 
a081162
 
 
 
 
9f411df
9b88b42
d038452
 
 
9b88b42
a081162
 
9b88b42
4cc5533
 
 
9b88b42
3109d12
 
9b88b42
 
 
3109d12
9b88b42
 
 
3109d12
 
 
 
 
 
 
 
 
 
d72878a
9b88b42
 
f2c29a4
9b88b42
f2c29a4
 
 
 
 
 
 
9b88b42
 
 
 
6752363
9b88b42
 
 
 
 
 
 
 
 
 
 
6752363
9b88b42
 
 
 
 
 
 
 
 
 
 
 
6752363
9b88b42
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""LangGraph agent state models.

This module defines the state schema for multi-agent orchestration using LangGraph.
"""

from typing import TypedDict, Annotated, Sequence, Dict, Any, Optional, List
from datetime import datetime, timezone
import operator

from pydantic import BaseModel, Field


def merge_dicts(current: Dict[str, Any], new: Dict[str, Any]) -> Dict[str, Any]:
    """Merge two dictionaries for LangGraph state updates.

    Args:
        current: Current dictionary value
        new: New dictionary to merge in

    Returns:
        Merged dictionary with new values overwriting current
    """
    if current is None:
        return new or {}
    if new is None:
        return current
    return {**current, **new}


def last_value(current: Any, new: Any) -> Any:
    """Reducer that keeps the last non-None value.

    Args:
        current: Current value
        new: New value

    Returns:
        New value if not None, otherwise current
    """
    return new if new is not None else current


def sum_or_last(current: Optional[int], new: Optional[int]) -> Optional[int]:
    """Reducer for metrics that sums values or keeps last.

    Args:
        current: Current value
        new: New value

    Returns:
        Sum if both are set, otherwise the non-None value
    """
    if current is None:
        return new
    if new is None:
        return current
    return current + new


class AgentState(TypedDict):
    """LangGraph state for multi-agent portfolio analysis workflow.

    This state is passed between agents in the workflow and accumulates
    results from each phase of the analysis.
    """

    # Input (use last_value reducer for parallel safety)
    portfolio_id: Annotated[str, last_value]
    user_query: Annotated[str, last_value]
    risk_tolerance: Annotated[str, last_value]
    holdings: Annotated[List[Dict[str, Any]], last_value]

    # Phase 1: Data Layer Results (from MCPs)
    historical_prices: Annotated[Dict[str, Any], merge_dicts]
    fundamentals: Annotated[Dict[str, Any], merge_dicts]
    economic_data: Annotated[Dict[str, Any], merge_dicts]
    realtime_data: Annotated[Dict[str, Any], merge_dicts]
    technical_indicators: Annotated[Dict[str, Any], merge_dicts]
    sentiment_data: Annotated[Dict[str, Any], merge_dicts]  # Enhancement #3: News Sentiment MCP

    # Phase 1.5: Feature Engineering
    feature_vectors: Annotated[Dict[str, Any], merge_dicts]

    # Phase 2: Computation Layer Results
    optimisation_results: Annotated[Dict[str, Any], merge_dicts]
    risk_analysis: Annotated[Dict[str, Any], merge_dicts]

    # Phase 2.5: ML Predictions (P1)
    ensemble_forecasts: Annotated[Dict[str, Any], merge_dicts]

    # Phase 3: LLM Synthesis
    ai_synthesis: Annotated[str, last_value]
    recommendations: Annotated[List[str], last_value]
    reasoning_steps: Annotated[List[str], operator.add]

    # Metadata
    current_step: Annotated[str, last_value]
    errors: Annotated[List[str], operator.add]
    mcp_calls: Annotated[List[Dict[str, Any]], operator.add]

    # Performance Metrics (sum for parallel branches)
    phase_1_duration_ms: Annotated[Optional[int], last_value]
    phase_1_5_duration_ms: Annotated[Optional[int], last_value]
    phase_2_duration_ms: Annotated[Optional[int], last_value]
    phase_2_5_duration_ms: Annotated[Optional[int], last_value]
    phase_3_duration_ms: Annotated[Optional[int], last_value]
    llm_input_tokens: Annotated[Optional[int], sum_or_last]
    llm_output_tokens: Annotated[Optional[int], sum_or_last]
    llm_total_tokens: Annotated[Optional[int], sum_or_last]
    llm_request_count: Annotated[Optional[int], sum_or_last]


class MCPCall(BaseModel):
    """Record of an MCP tool call.

    Accepts both 'mcp_server' and 'mcp' field names for backward compatibility.
    """

    model_config = {"populate_by_name": True}

    mcp_server: str = Field(..., validation_alias="mcp", description="MCP server name")
    tool_name: str = Field(..., validation_alias="tool", description="Tool called")
    parameters: Dict[str, Any] = Field(default_factory=dict)
    result: Optional[Dict[str, Any]] = None
    error: Optional[str] = None
    duration_ms: Optional[int] = Field(None, ge=0)
    timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))


class AgentMessage(BaseModel):
    """Message from an agent with metadata."""

    role: str = Field(..., description="Agent role (user/assistant/system)")
    content: str = Field(..., min_length=1)
    agent_name: Optional[str] = Field(None, description="Name of agent that generated message")
    thinking: Optional[str] = Field(None, description="Agent reasoning")
    tools_used: Optional[List[MCPCall]] = Field(default_factory=list)
    confidence: Optional[float] = Field(None, ge=0.0, le=1.0)
    timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))


class WorkflowStatus(BaseModel):
    """Status of the multi-agent workflow."""

    session_id: str
    current_phase: str = Field(..., description="Current execution phase")
    phase_1_complete: bool = Field(default=False)
    phase_2_complete: bool = Field(default=False)
    phase_3_complete: bool = Field(default=False)
    errors: List[str] = Field(default_factory=list)
    mcp_calls: List[MCPCall] = Field(default_factory=list)
    started_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
    completed_at: Optional[datetime] = None
    execution_time_ms: Optional[int] = None