BrianIsaac's picture
fix: resolve portfolio optimization constraints and API deprecations
6752363
"""Portfolio data models.
This module defines Pydantic models for portfolio data structures, holdings,
and analysis results.
"""
from datetime import datetime, timezone
from decimal import Decimal
from typing import Optional, List, Dict, Any
from enum import Enum
from pydantic import BaseModel, Field, field_validator, ConfigDict
class RiskTolerance(str, Enum):
"""Risk tolerance levels."""
CONSERVATIVE = "conservative"
MODERATE = "moderate"
AGGRESSIVE = "aggressive"
class AssetType(str, Enum):
"""Asset type classification."""
STOCK = "stock"
ETF = "etf"
CRYPTO = "crypto"
BOND = "bond"
CASH = "cash"
OTHER = "other"
class Holding(BaseModel):
"""Individual portfolio holding."""
model_config = ConfigDict(
str_strip_whitespace=True,
validate_assignment=True,
)
ticker: str = Field(..., description="Stock ticker symbol")
quantity: Decimal = Field(..., gt=0, description="Number of shares/units")
cost_basis: Optional[Decimal] = Field(
None, ge=0, description="Purchase price per share"
)
asset_type: AssetType = Field(
default=AssetType.STOCK, description="Type of asset"
)
current_price: Optional[Decimal] = Field(
None, ge=0, description="Current market price"
)
current_value: Optional[Decimal] = Field(
None, ge=0, description="Current total value"
)
@field_validator("ticker")
@classmethod
def validate_ticker(cls, v: str) -> str:
"""Validate ticker format."""
v = v.strip().upper()
if not v:
raise ValueError("Ticker cannot be empty")
if len(v) > 10:
raise ValueError("Ticker too long (max 10 characters)")
return v
class Portfolio(BaseModel):
"""Portfolio containing multiple holdings."""
model_config = ConfigDict(
validate_assignment=True,
)
portfolio_id: Optional[str] = Field(None, description="Unique portfolio ID")
user_id: Optional[str] = Field(None, description="Owner user ID")
name: str = Field(..., min_length=1, max_length=200, description="Portfolio name")
description: Optional[str] = Field(None, max_length=1000)
holdings: List[Holding] = Field(default_factory=list, min_length=1)
risk_tolerance: RiskTolerance = Field(default=RiskTolerance.MODERATE)
total_value: Optional[Decimal] = Field(None, ge=0)
created_at: Optional[datetime] = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: Optional[datetime] = Field(default_factory=lambda: datetime.now(timezone.utc))
@field_validator("holdings")
@classmethod
def validate_holdings(cls, v: List[Holding]) -> List[Holding]:
"""Validate holdings list."""
if not v:
raise ValueError("Portfolio must have at least one holding")
if len(v) > 100:
raise ValueError("Portfolio cannot have more than 100 holdings")
return v
class MarketData(BaseModel):
"""Market data for a single security."""
ticker: str
price: Decimal = Field(..., ge=0)
previous_close: Optional[Decimal] = Field(None, ge=0)
open_price: Optional[Decimal] = Field(None, ge=0)
high: Optional[Decimal] = Field(None, ge=0)
low: Optional[Decimal] = Field(None, ge=0)
volume: Optional[int] = Field(None, ge=0)
market_cap: Optional[Decimal] = Field(None, ge=0)
pe_ratio: Optional[Decimal] = None
dividend_yield: Optional[Decimal] = Field(None, ge=0, le=1)
fifty_two_week_high: Optional[Decimal] = Field(None, ge=0)
fifty_two_week_low: Optional[Decimal] = Field(None, ge=0)
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class HistoricalData(BaseModel):
"""Historical price data."""
ticker: str
dates: List[datetime]
prices: List[Decimal]
volumes: Optional[List[int]] = None
returns: Optional[List[Decimal]] = None
class OptimisationResult(BaseModel):
"""Portfolio optimisation result."""
method: str = Field(..., description="Optimisation method used")
weights: Dict[str, Decimal] = Field(..., description="Ticker to weight mapping")
expected_return: Decimal = Field(..., description="Expected annual return")
volatility: Decimal = Field(..., ge=0, description="Expected volatility")
sharpe_ratio: Decimal = Field(..., description="Sharpe ratio")
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
class RiskMetrics(BaseModel):
"""Portfolio risk metrics."""
volatility: Decimal = Field(..., ge=0, description="Annualised volatility")
sharpe_ratio: Decimal = Field(..., description="Sharpe ratio")
sortino_ratio: Optional[Decimal] = Field(None, description="Sortino ratio")
max_drawdown: Decimal = Field(..., le=0, description="Maximum drawdown")
var_95: Decimal = Field(..., le=0, description="Value at Risk (95%)")
var_99: Decimal = Field(..., le=0, description="Value at Risk (99%)")
cvar_95: Decimal = Field(..., le=0, description="Conditional VaR (95%)")
cvar_99: Decimal = Field(..., le=0, description="Conditional VaR (99%)")
beta: Optional[Decimal] = Field(None, description="Beta vs benchmark")
alpha: Optional[Decimal] = Field(None, description="Alpha vs benchmark")
class PortfolioAnalysis(BaseModel):
"""Complete portfolio analysis result."""
portfolio_id: str
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
# Current state
total_value: Decimal = Field(..., ge=0)
holdings_count: int = Field(..., ge=1)
# Market data
market_data: Dict[str, MarketData] = Field(default_factory=dict)
# Risk metrics
risk_metrics: RiskMetrics
# Optimisation results
optimisation_hrp: Optional[OptimisationResult] = None
optimisation_black_litterman: Optional[OptimisationResult] = None
optimisation_mean_variance: Optional[OptimisationResult] = None
# AI-generated insights
summary: str = Field(..., min_length=10)
recommendations: List[str] = Field(default_factory=list)
risk_assessment: str
health_score: int = Field(..., ge=0, le=100)
# Agent reasoning
reasoning_steps: Optional[List[str]] = Field(default_factory=list)
mcp_calls: Optional[List[Dict[str, Any]]] = Field(default_factory=list)
# Metadata
execution_time_ms: Optional[int] = Field(None, ge=0)
model_version: Optional[str] = None
class MCPProvenance(BaseModel):
"""Data provenance tracking for MCP calls."""
source: str = Field(..., description="MCP source identifier")
mcps_used: List[str] = Field(..., description="List of MCP servers called")
fetch_timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
cache_hit: bool = Field(default=False)