BrianIsaac commited on
Commit
4cc5533
·
1 Parent(s): 797be0a

feat: integrate ensemble predictor into workflow Phase 2.5

Browse files

Add ML-based price forecasting to portfolio analysis workflow:
- Add Phase 2.5 between computation and LLM synthesis
- Update MCP router with ensemble predictor integration
- Extend AgentState to include ensemble_forecasts field
- Update PortfolioAnalystAgent to consume ML forecasts
- Generate 30-day forecasts for all portfolio holdings
- Graceful degradation if forecasts fail (workflow continues)
- Pass forecast data to LLM for enhanced analysis

backend/agents/portfolio_analyst.py CHANGED
@@ -280,6 +280,7 @@ class PortfolioAnalystAgent(BasePortfolioAgent[PortfolioAnalysisOutput]):
280
  economic_data: Dict[str, Any],
281
  optimization_results: Dict[str, Any],
282
  risk_analysis: Dict[str, Any],
 
283
  risk_tolerance: str = "moderate",
284
  ) -> "AgentResult[PortfolioAnalysisOutput]":
285
  """Analyze a complete portfolio with all available data.
@@ -292,6 +293,7 @@ class PortfolioAnalystAgent(BasePortfolioAgent[PortfolioAnalysisOutput]):
292
  economic_data: Macroeconomic indicators
293
  optimization_results: Portfolio optimization outputs
294
  risk_analysis: VaR, CVaR, and risk metrics
 
295
  risk_tolerance: Investor's risk tolerance
296
 
297
  Returns:
@@ -302,6 +304,11 @@ class PortfolioAnalystAgent(BasePortfolioAgent[PortfolioAnalysisOutput]):
302
  risk_xml = format_risk_analysis_xml(risk_analysis)
303
  optimisation_xml = format_optimisation_results_xml(optimization_results)
304
 
 
 
 
 
 
305
  prompt = f"""Analyze this investment portfolio:
306
 
307
  PORTFOLIO:
@@ -323,7 +330,7 @@ OPTIMIZATION ANALYSIS:
323
  {optimisation_xml}
324
 
325
  RISK ANALYSIS:
326
- {risk_xml}
327
 
328
  INVESTOR RISK TOLERANCE: {risk_tolerance}
329
 
 
280
  economic_data: Dict[str, Any],
281
  optimization_results: Dict[str, Any],
282
  risk_analysis: Dict[str, Any],
283
+ ensemble_forecasts: Optional[Dict[str, Any]] = None,
284
  risk_tolerance: str = "moderate",
285
  ) -> "AgentResult[PortfolioAnalysisOutput]":
286
  """Analyze a complete portfolio with all available data.
 
293
  economic_data: Macroeconomic indicators
294
  optimization_results: Portfolio optimization outputs
295
  risk_analysis: VaR, CVaR, and risk metrics
296
+ ensemble_forecasts: ML-based price forecasts (Chronos + statistical models)
297
  risk_tolerance: Investor's risk tolerance
298
 
299
  Returns:
 
304
  risk_xml = format_risk_analysis_xml(risk_analysis)
305
  optimisation_xml = format_optimisation_results_xml(optimization_results)
306
 
307
+ # Format ensemble forecasts if available
308
+ forecasts_section = ""
309
+ if ensemble_forecasts:
310
+ forecasts_section = f"\n\nML FORECASTS (30-day predictions):\n{ensemble_forecasts}"
311
+
312
  prompt = f"""Analyze this investment portfolio:
313
 
314
  PORTFOLIO:
 
330
  {optimisation_xml}
331
 
332
  RISK ANALYSIS:
333
+ {risk_xml}{forecasts_section}
334
 
335
  INVESTOR RISK TOLERANCE: {risk_tolerance}
336
 
backend/agents/workflow.py CHANGED
@@ -1,8 +1,9 @@
1
  """LangGraph workflow for multi-agent portfolio analysis.
2
 
3
- This implements the three-phase architecture:
4
  Phase 1: Data Layer MCPs (Yahoo Finance, FMP, Trading-MCP, FRED)
5
  Phase 2: Computation Layer MCPs (Portfolio Optimizer, Risk Analyzer)
 
6
  Phase 3: LLM Synthesis (Portfolio Analyst Agent)
7
  """
8
 
@@ -80,12 +81,14 @@ class PortfolioAnalysisWorkflow:
80
  # Add nodes for each phase
81
  workflow.add_node("phase_1_data_layer", self._phase_1_data_layer)
82
  workflow.add_node("phase_2_computation", self._phase_2_computation)
 
83
  workflow.add_node("phase_3_synthesis", self._phase_3_synthesis)
84
 
85
  # Define the flow
86
  workflow.set_entry_point("phase_1_data_layer")
87
  workflow.add_edge("phase_1_data_layer", "phase_2_computation")
88
- workflow.add_edge("phase_2_computation", "phase_3_synthesis")
 
89
  workflow.add_edge("phase_3_synthesis", END)
90
 
91
  return workflow.compile()
@@ -330,6 +333,79 @@ class PortfolioAnalysisWorkflow:
330
 
331
  return state
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  async def _phase_3_synthesis(self, state: AgentState) -> AgentState:
334
  """Phase 3: LLM synthesis of all data into actionable insights."""
335
  logger.info("PHASE 3: LLM Synthesis")
@@ -352,6 +428,7 @@ class PortfolioAnalysisWorkflow:
352
  economic_data=state.get("economic_data", {}),
353
  optimization_results=state.get("optimisation_results", {}),
354
  risk_analysis=state.get("risk_analysis", {}),
 
355
  risk_tolerance=state["risk_tolerance"],
356
  )
357
 
 
1
  """LangGraph workflow for multi-agent portfolio analysis.
2
 
3
+ This implements the multi-phase architecture:
4
  Phase 1: Data Layer MCPs (Yahoo Finance, FMP, Trading-MCP, FRED)
5
  Phase 2: Computation Layer MCPs (Portfolio Optimizer, Risk Analyzer)
6
+ Phase 2.5: ML Predictions (Ensemble Predictor with Chronos)
7
  Phase 3: LLM Synthesis (Portfolio Analyst Agent)
8
  """
9
 
 
81
  # Add nodes for each phase
82
  workflow.add_node("phase_1_data_layer", self._phase_1_data_layer)
83
  workflow.add_node("phase_2_computation", self._phase_2_computation)
84
+ workflow.add_node("phase_2_5_ml_predictions", self._phase_2_5_ml_predictions)
85
  workflow.add_node("phase_3_synthesis", self._phase_3_synthesis)
86
 
87
  # Define the flow
88
  workflow.set_entry_point("phase_1_data_layer")
89
  workflow.add_edge("phase_1_data_layer", "phase_2_computation")
90
+ workflow.add_edge("phase_2_computation", "phase_2_5_ml_predictions")
91
+ workflow.add_edge("phase_2_5_ml_predictions", "phase_3_synthesis")
92
  workflow.add_edge("phase_3_synthesis", END)
93
 
94
  return workflow.compile()
 
333
 
334
  return state
335
 
336
+ async def _phase_2_5_ml_predictions(self, state: AgentState) -> AgentState:
337
+ """Phase 2.5: Generate ML-based price forecasts using Ensemble Predictor.
338
+
339
+ MCP called:
340
+ - Ensemble Predictor: Chronos + statistical models for price forecasting
341
+ """
342
+ logger.info("PHASE 2.5: Generating ML predictions")
343
+ phase_start = time.perf_counter()
344
+
345
+ try:
346
+ # Generate forecasts for each holding
347
+ logger.debug("Running ensemble forecasts for portfolio holdings")
348
+ ensemble_forecasts = {}
349
+
350
+ for holding in state["holdings"]:
351
+ ticker = holding["ticker"]
352
+
353
+ # Get historical prices from Phase 1 data
354
+ hist_data = state["historical_prices"].get(ticker, {})
355
+ prices = hist_data.get("close_prices", [])
356
+
357
+ if not prices or len(prices) < 10:
358
+ logger.warning(f"Insufficient price data for {ticker}, skipping forecast")
359
+ continue
360
+
361
+ try:
362
+ # Call ensemble predictor
363
+ forecast_result = await self.mcp_router.call_ensemble_predictor_mcp(
364
+ "forecast_ensemble",
365
+ {
366
+ "ticker": ticker,
367
+ "prices": prices,
368
+ "forecast_horizon": 30, # 30-day forecast
369
+ "confidence_level": 0.95,
370
+ "use_returns": True, # Forecast returns for stability
371
+ "ensemble_method": "mean", # Simple averaging
372
+ }
373
+ )
374
+
375
+ ensemble_forecasts[ticker] = forecast_result
376
+ logger.debug(f"Generated forecast for {ticker} using {len(forecast_result.get('models_used', []))} models")
377
+
378
+ except Exception as e:
379
+ logger.warning(f"Forecast failed for {ticker}: {e}")
380
+ continue
381
+
382
+ # Update state
383
+ state["ensemble_forecasts"] = ensemble_forecasts
384
+ state["current_step"] = "phase_2_5_complete"
385
+
386
+ # Log MCP calls
387
+ state["mcp_calls"].extend([
388
+ MCPCall.model_validate({
389
+ "mcp": "ensemble_predictor",
390
+ "tool": "forecast_ensemble"
391
+ }).model_dump(),
392
+ ])
393
+
394
+ # Track phase duration
395
+ phase_duration_ms = int((time.perf_counter() - phase_start) * 1000)
396
+
397
+ logger.info(
398
+ f"PHASE 2.5 COMPLETE: Generated forecasts for {len(ensemble_forecasts)} assets ({phase_duration_ms}ms)"
399
+ )
400
+
401
+ except Exception as e:
402
+ logger.error(f"Error in Phase 2.5: {e}")
403
+ state["errors"].append(f"Phase 2.5 error: {str(e)}")
404
+ # Set empty forecasts to allow workflow to continue
405
+ state["ensemble_forecasts"] = {}
406
+
407
+ return state
408
+
409
  async def _phase_3_synthesis(self, state: AgentState) -> AgentState:
410
  """Phase 3: LLM synthesis of all data into actionable insights."""
411
  logger.info("PHASE 3: LLM Synthesis")
 
428
  economic_data=state.get("economic_data", {}),
429
  optimization_results=state.get("optimisation_results", {}),
430
  risk_analysis=state.get("risk_analysis", {}),
431
+ ensemble_forecasts=state.get("ensemble_forecasts", {}),
432
  risk_tolerance=state["risk_tolerance"],
433
  )
434
 
backend/mcp_router.py CHANGED
@@ -13,7 +13,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
13
 
14
  # Import all MCP servers
15
  from backend.mcp_servers import yahoo_finance_mcp, fmp_mcp, trading_mcp, fred_mcp
16
- from backend.mcp_servers import portfolio_optimizer_mcp, risk_analyzer_mcp
17
 
18
  logger = logging.getLogger(__name__)
19
 
@@ -23,7 +23,7 @@ class MCPRouter:
23
 
24
  Manages connections to:
25
  - P0 (Week 1): Yahoo Finance, FMP, Trading-MCP, FRED, Portfolio Optimizer, Risk Analyzer
26
- - P1 (Week 2): Ensemble Predictor (if time permits)
27
  """
28
 
29
  def __init__(self):
@@ -43,6 +43,7 @@ class MCPRouter:
43
  "fred": fred_mcp,
44
  "portfolio_optimizer": portfolio_optimizer_mcp,
45
  "risk_analyzer": risk_analyzer_mcp,
 
46
  }
47
 
48
  logger.info(f"Initialised {len(self.servers)} MCP servers")
@@ -229,6 +230,30 @@ class MCPRouter:
229
  return result.model_dump()
230
  return result
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  # High-level helper methods
233
  async def fetch_market_data(self, tickers: List[str]) -> Dict[str, Any]:
234
  """Fetch market data for given tickers.
 
13
 
14
  # Import all MCP servers
15
  from backend.mcp_servers import yahoo_finance_mcp, fmp_mcp, trading_mcp, fred_mcp
16
+ from backend.mcp_servers import portfolio_optimizer_mcp, risk_analyzer_mcp, ensemble_predictor_mcp
17
 
18
  logger = logging.getLogger(__name__)
19
 
 
23
 
24
  Manages connections to:
25
  - P0 (Week 1): Yahoo Finance, FMP, Trading-MCP, FRED, Portfolio Optimizer, Risk Analyzer
26
+ - P1 (Week 2): Ensemble Predictor (Chronos + statistical models)
27
  """
28
 
29
  def __init__(self):
 
43
  "fred": fred_mcp,
44
  "portfolio_optimizer": portfolio_optimizer_mcp,
45
  "risk_analyzer": risk_analyzer_mcp,
46
+ "ensemble_predictor": ensemble_predictor_mcp,
47
  }
48
 
49
  logger.info(f"Initialised {len(self.servers)} MCP servers")
 
230
  return result.model_dump()
231
  return result
232
 
233
+ # Ensemble Predictor MCP methods
234
+ async def call_ensemble_predictor_mcp(self, tool: str, params: Dict[str, Any]) -> Dict[str, Any]:
235
+ """Call Ensemble Predictor MCP tool.
236
+
237
+ Args:
238
+ tool: Tool name
239
+ params: Tool parameters
240
+
241
+ Returns:
242
+ Tool result
243
+ """
244
+ logger.debug(f"Calling Ensemble Predictor MCP: {tool}")
245
+
246
+ if tool == "forecast_ensemble":
247
+ from backend.mcp_servers.ensemble_predictor_mcp import forecast_ensemble, ForecastRequest
248
+ request = ForecastRequest(**params)
249
+ result = await forecast_ensemble.fn(request)
250
+ else:
251
+ raise ValueError(f"Unknown Ensemble Predictor tool: {tool}")
252
+
253
+ if hasattr(result, 'model_dump'):
254
+ return result.model_dump()
255
+ return result
256
+
257
  # High-level helper methods
258
  async def fetch_market_data(self, tickers: List[str]) -> Dict[str, Any]:
259
  """Fetch market data for given tickers.
backend/models/agent_state.py CHANGED
@@ -51,6 +51,9 @@ class AgentState(TypedDict):
51
  optimisation_results: Annotated[Dict[str, Any], merge_dicts]
52
  risk_analysis: Annotated[Dict[str, Any], merge_dicts]
53
 
 
 
 
54
  # Phase 3: LLM Synthesis
55
  ai_synthesis: str
56
  recommendations: List[str]
 
51
  optimisation_results: Annotated[Dict[str, Any], merge_dicts]
52
  risk_analysis: Annotated[Dict[str, Any], merge_dicts]
53
 
54
+ # Phase 2.5: ML Predictions (P1)
55
+ ensemble_forecasts: Annotated[Dict[str, Any], merge_dicts]
56
+
57
  # Phase 3: LLM Synthesis
58
  ai_synthesis: str
59
  recommendations: List[str]