BrianIsaac's picture
feat: implement P1 features and production infrastructure
76897aa
"""Tax calculation models.
This module defines Pydantic models for tax calculations, transactions,
cost basis tracking, and wash sale detection.
"""
from datetime import datetime, date
from decimal import Decimal
from typing import Optional, List, Dict, Any
from enum import Enum
from pydantic import BaseModel, Field, field_validator, ConfigDict
class TaxFilingStatus(str, Enum):
"""Tax filing status options."""
SINGLE = "single"
MARRIED_JOINT = "married_joint"
MARRIED_SEPARATE = "married_separate"
HEAD_OF_HOUSEHOLD = "head_of_household"
class CostBasisMethod(str, Enum):
"""Cost basis calculation methods."""
FIFO = "fifo" # First In, First Out
LIFO = "lifo" # Last In, First Out
HIFO = "hifo" # Highest In, First Out
SPECIFIC_ID = "specific_id" # Specific Lot Identification
AVERAGE = "average" # Average Cost
class TransactionType(str, Enum):
"""Transaction type classification."""
BUY = "buy"
SELL = "sell"
DIVIDEND = "dividend"
SPLIT = "split"
TRANSFER_IN = "transfer_in"
TRANSFER_OUT = "transfer_out"
class Transaction(BaseModel):
"""Individual portfolio transaction."""
model_config = ConfigDict(
str_strip_whitespace=True,
validate_assignment=True,
)
transaction_id: Optional[str] = Field(None, description="Unique transaction ID")
ticker: str = Field(..., description="Stock ticker symbol")
transaction_type: TransactionType = Field(..., description="Type of transaction")
transaction_date: date = Field(..., description="Transaction date")
quantity: Decimal = Field(..., description="Number of shares")
price: Decimal = Field(..., ge=0, description="Price per share")
fees: Decimal = Field(default=Decimal("0"), ge=0, description="Transaction fees")
lot_id: Optional[str] = Field(None, description="Specific lot identifier for tracking")
@field_validator("ticker")
@classmethod
def validate_ticker(cls, v: str) -> str:
"""Validate ticker format."""
v = v.strip().upper()
if not v:
raise ValueError("Ticker cannot be empty")
return v
@property
def total_amount(self) -> Decimal:
"""Calculate total transaction amount including fees."""
return (self.quantity * self.price) + self.fees
@property
def adjusted_cost_basis(self) -> Decimal:
"""Calculate adjusted cost basis per share including fees."""
if self.quantity == 0:
return Decimal("0")
return (self.quantity * self.price + self.fees) / self.quantity
class TaxLot(BaseModel):
"""Tax lot representing a specific purchase of securities."""
model_config = ConfigDict(validate_assignment=True)
lot_id: str = Field(..., description="Unique lot identifier")
ticker: str = Field(..., description="Stock ticker symbol")
acquisition_date: date = Field(..., description="Date of acquisition")
quantity: Decimal = Field(..., ge=0, description="Remaining quantity in lot")
original_quantity: Decimal = Field(..., gt=0, description="Original quantity purchased")
cost_basis_per_share: Decimal = Field(..., ge=0, description="Cost basis per share")
total_cost_basis: Decimal = Field(..., ge=0, description="Total cost basis for lot")
@property
def holding_period_days(self) -> int:
"""Calculate holding period in days from acquisition to today."""
return (date.today() - self.acquisition_date).days
@property
def is_long_term(self) -> bool:
"""Determine if holding qualifies for long-term capital gains."""
# Long-term is more than 365 days (next day after acquisition to sale date inclusive)
return self.holding_period_days > 365
class CapitalGain(BaseModel):
"""Capital gain/loss from a sale transaction."""
model_config = ConfigDict(validate_assignment=True)
ticker: str = Field(..., description="Stock ticker symbol")
sale_date: date = Field(..., description="Date of sale")
acquisition_date: date = Field(..., description="Date of acquisition")
quantity: Decimal = Field(..., gt=0, description="Quantity sold")
sale_price: Decimal = Field(..., ge=0, description="Sale price per share")
cost_basis: Decimal = Field(..., ge=0, description="Cost basis per share")
proceeds: Decimal = Field(..., ge=0, description="Total sale proceeds")
total_cost_basis: Decimal = Field(..., ge=0, description="Total cost basis")
gain_loss: Decimal = Field(..., description="Capital gain (positive) or loss (negative)")
is_long_term: bool = Field(..., description="Whether gain qualifies as long-term")
holding_period_days: int = Field(..., ge=0, description="Holding period in days")
lot_id: Optional[str] = Field(None, description="Associated tax lot ID")
is_wash_sale: bool = Field(default=False, description="Whether this is a wash sale")
wash_sale_deferred_loss: Decimal = Field(
default=Decimal("0"), description="Loss amount deferred due to wash sale"
)
class WashSale(BaseModel):
"""Wash sale violation and deferred loss tracking."""
model_config = ConfigDict(validate_assignment=True)
ticker: str = Field(..., description="Stock ticker symbol")
sale_date: date = Field(..., description="Date of loss sale")
sale_quantity: Decimal = Field(..., gt=0, description="Quantity sold at loss")
loss_amount: Decimal = Field(..., lt=0, description="Loss amount (negative)")
replacement_date: date = Field(..., description="Date of replacement purchase")
replacement_quantity: Decimal = Field(..., gt=0, description="Replacement quantity")
deferred_loss: Decimal = Field(..., lt=0, description="Amount of loss deferred")
adjusted_basis_increase: Decimal = Field(
..., gt=0, description="Increase to replacement lot's basis"
)
@field_validator("replacement_date")
@classmethod
def validate_wash_sale_period(cls, v: date, info) -> date:
"""Validate that replacement is within 61-day wash sale window."""
sale_date = info.data.get("sale_date")
if sale_date:
days_diff = abs((v - sale_date).days)
if days_diff > 30:
# 30 days before and 30 days after = 61 day window
raise ValueError("Replacement must be within 61-day wash sale window")
return v
class TaxBracket(BaseModel):
"""Tax bracket information for capital gains."""
model_config = ConfigDict(validate_assignment=True)
filing_status: TaxFilingStatus = Field(..., description="Tax filing status")
income: Decimal = Field(..., ge=0, description="Taxable income")
short_term_rate: Decimal = Field(..., ge=0, le=1, description="Short-term gains tax rate")
long_term_rate: Decimal = Field(..., ge=0, le=1, description="Long-term gains tax rate")
class TaxImpactSummary(BaseModel):
"""Summary of tax impact for portfolio operations."""
model_config = ConfigDict(validate_assignment=True)
total_short_term_gains: Decimal = Field(default=Decimal("0"), description="Total ST gains")
total_short_term_losses: Decimal = Field(default=Decimal("0"), description="Total ST losses")
total_long_term_gains: Decimal = Field(default=Decimal("0"), description="Total LT gains")
total_long_term_losses: Decimal = Field(default=Decimal("0"), description="Total LT losses")
net_short_term: Decimal = Field(default=Decimal("0"), description="Net short-term gain/loss")
net_long_term: Decimal = Field(default=Decimal("0"), description="Net long-term gain/loss")
total_tax_liability: Decimal = Field(default=Decimal("0"), ge=0, description="Estimated tax")
wash_sales_count: int = Field(default=0, ge=0, description="Number of wash sales")
total_deferred_losses: Decimal = Field(
default=Decimal("0"), description="Total losses deferred from wash sales"
)
capital_gains: List[CapitalGain] = Field(default_factory=list)
wash_sales: List[WashSale] = Field(default_factory=list)
class TaxLossHarvestingOpportunity(BaseModel):
"""Tax-loss harvesting opportunity identification."""
model_config = ConfigDict(validate_assignment=True)
ticker: str = Field(..., description="Stock ticker symbol")
current_price: Decimal = Field(..., ge=0, description="Current market price")
unrealised_loss: Decimal = Field(..., lt=0, description="Unrealised loss amount")
loss_percentage: Decimal = Field(..., lt=0, description="Loss percentage")
quantity: Decimal = Field(..., gt=0, description="Quantity to sell")
holding_period_days: int = Field(..., ge=0, description="Current holding period")
is_long_term: bool = Field(..., description="Whether holding is long-term")
estimated_tax_savings: Decimal = Field(..., ge=0, description="Estimated tax savings")
wash_sale_risk: bool = Field(
..., description="Whether there's recent activity creating wash sale risk"
)
recommended_action: str = Field(..., description="Recommended action to take")
lot_ids: List[str] = Field(default_factory=list, description="Tax lot IDs to sell")
class TaxOptimizedSale(BaseModel):
"""Tax-optimised sale recommendation."""
model_config = ConfigDict(validate_assignment=True)
ticker: str = Field(..., description="Stock ticker symbol")
quantity_to_sell: Decimal = Field(..., gt=0, description="Total quantity to sell")
lots_to_sell: List[TaxLot] = Field(..., description="Specific lots to sell in order")
cost_basis_method: CostBasisMethod = Field(..., description="Cost basis method used")
estimated_gain_loss: Decimal = Field(..., description="Estimated gain or loss")
estimated_tax: Decimal = Field(..., ge=0, description="Estimated tax liability")
is_long_term: bool = Field(..., description="Whether majority is long-term")
rationale: str = Field(..., description="Explanation of optimisation strategy")
class PortfolioTaxAnalysis(BaseModel):
"""Complete tax analysis for portfolio."""
model_config = ConfigDict(validate_assignment=True)
portfolio_id: str = Field(..., description="Portfolio identifier")
analysis_date: date = Field(default_factory=date.today)
tax_year: int = Field(..., description="Tax year for analysis")
filing_status: TaxFilingStatus = Field(..., description="Tax filing status")
taxable_income: Decimal = Field(..., ge=0, description="Annual taxable income")
# Current positions
total_unrealised_gains: Decimal = Field(default=Decimal("0"), description="Total unrealised gains")
total_unrealised_losses: Decimal = Field(default=Decimal("0"), description="Total unrealised losses")
# Realised gains (YTD)
ytd_realised_gains: TaxImpactSummary = Field(
default_factory=TaxImpactSummary, description="Year-to-date realised gains"
)
# Opportunities
harvesting_opportunities: List[TaxLossHarvestingOpportunity] = Field(
default_factory=list, description="Tax-loss harvesting opportunities"
)
estimated_harvesting_savings: Decimal = Field(
default=Decimal("0"), ge=0, description="Total potential tax savings from harvesting"
)
# Recommendations
recommendations: List[str] = Field(default_factory=list, description="Tax optimisation recommendations")