Spaces:
Running
on
Zero
Running
on
Zero
| """Tax calculation models. | |
| This module defines Pydantic models for tax calculations, transactions, | |
| cost basis tracking, and wash sale detection. | |
| """ | |
| from datetime import datetime, date | |
| from decimal import Decimal | |
| from typing import Optional, List, Dict, Any | |
| from enum import Enum | |
| from pydantic import BaseModel, Field, field_validator, ConfigDict | |
| class TaxFilingStatus(str, Enum): | |
| """Tax filing status options.""" | |
| SINGLE = "single" | |
| MARRIED_JOINT = "married_joint" | |
| MARRIED_SEPARATE = "married_separate" | |
| HEAD_OF_HOUSEHOLD = "head_of_household" | |
| class CostBasisMethod(str, Enum): | |
| """Cost basis calculation methods.""" | |
| FIFO = "fifo" # First In, First Out | |
| LIFO = "lifo" # Last In, First Out | |
| HIFO = "hifo" # Highest In, First Out | |
| SPECIFIC_ID = "specific_id" # Specific Lot Identification | |
| AVERAGE = "average" # Average Cost | |
| class TransactionType(str, Enum): | |
| """Transaction type classification.""" | |
| BUY = "buy" | |
| SELL = "sell" | |
| DIVIDEND = "dividend" | |
| SPLIT = "split" | |
| TRANSFER_IN = "transfer_in" | |
| TRANSFER_OUT = "transfer_out" | |
| class Transaction(BaseModel): | |
| """Individual portfolio transaction.""" | |
| model_config = ConfigDict( | |
| str_strip_whitespace=True, | |
| validate_assignment=True, | |
| ) | |
| transaction_id: Optional[str] = Field(None, description="Unique transaction ID") | |
| ticker: str = Field(..., description="Stock ticker symbol") | |
| transaction_type: TransactionType = Field(..., description="Type of transaction") | |
| transaction_date: date = Field(..., description="Transaction date") | |
| quantity: Decimal = Field(..., description="Number of shares") | |
| price: Decimal = Field(..., ge=0, description="Price per share") | |
| fees: Decimal = Field(default=Decimal("0"), ge=0, description="Transaction fees") | |
| lot_id: Optional[str] = Field(None, description="Specific lot identifier for tracking") | |
| def validate_ticker(cls, v: str) -> str: | |
| """Validate ticker format.""" | |
| v = v.strip().upper() | |
| if not v: | |
| raise ValueError("Ticker cannot be empty") | |
| return v | |
| def total_amount(self) -> Decimal: | |
| """Calculate total transaction amount including fees.""" | |
| return (self.quantity * self.price) + self.fees | |
| def adjusted_cost_basis(self) -> Decimal: | |
| """Calculate adjusted cost basis per share including fees.""" | |
| if self.quantity == 0: | |
| return Decimal("0") | |
| return (self.quantity * self.price + self.fees) / self.quantity | |
| class TaxLot(BaseModel): | |
| """Tax lot representing a specific purchase of securities.""" | |
| model_config = ConfigDict(validate_assignment=True) | |
| lot_id: str = Field(..., description="Unique lot identifier") | |
| ticker: str = Field(..., description="Stock ticker symbol") | |
| acquisition_date: date = Field(..., description="Date of acquisition") | |
| quantity: Decimal = Field(..., ge=0, description="Remaining quantity in lot") | |
| original_quantity: Decimal = Field(..., gt=0, description="Original quantity purchased") | |
| cost_basis_per_share: Decimal = Field(..., ge=0, description="Cost basis per share") | |
| total_cost_basis: Decimal = Field(..., ge=0, description="Total cost basis for lot") | |
| def holding_period_days(self) -> int: | |
| """Calculate holding period in days from acquisition to today.""" | |
| return (date.today() - self.acquisition_date).days | |
| def is_long_term(self) -> bool: | |
| """Determine if holding qualifies for long-term capital gains.""" | |
| # Long-term is more than 365 days (next day after acquisition to sale date inclusive) | |
| return self.holding_period_days > 365 | |
| class CapitalGain(BaseModel): | |
| """Capital gain/loss from a sale transaction.""" | |
| model_config = ConfigDict(validate_assignment=True) | |
| ticker: str = Field(..., description="Stock ticker symbol") | |
| sale_date: date = Field(..., description="Date of sale") | |
| acquisition_date: date = Field(..., description="Date of acquisition") | |
| quantity: Decimal = Field(..., gt=0, description="Quantity sold") | |
| sale_price: Decimal = Field(..., ge=0, description="Sale price per share") | |
| cost_basis: Decimal = Field(..., ge=0, description="Cost basis per share") | |
| proceeds: Decimal = Field(..., ge=0, description="Total sale proceeds") | |
| total_cost_basis: Decimal = Field(..., ge=0, description="Total cost basis") | |
| gain_loss: Decimal = Field(..., description="Capital gain (positive) or loss (negative)") | |
| is_long_term: bool = Field(..., description="Whether gain qualifies as long-term") | |
| holding_period_days: int = Field(..., ge=0, description="Holding period in days") | |
| lot_id: Optional[str] = Field(None, description="Associated tax lot ID") | |
| is_wash_sale: bool = Field(default=False, description="Whether this is a wash sale") | |
| wash_sale_deferred_loss: Decimal = Field( | |
| default=Decimal("0"), description="Loss amount deferred due to wash sale" | |
| ) | |
| class WashSale(BaseModel): | |
| """Wash sale violation and deferred loss tracking.""" | |
| model_config = ConfigDict(validate_assignment=True) | |
| ticker: str = Field(..., description="Stock ticker symbol") | |
| sale_date: date = Field(..., description="Date of loss sale") | |
| sale_quantity: Decimal = Field(..., gt=0, description="Quantity sold at loss") | |
| loss_amount: Decimal = Field(..., lt=0, description="Loss amount (negative)") | |
| replacement_date: date = Field(..., description="Date of replacement purchase") | |
| replacement_quantity: Decimal = Field(..., gt=0, description="Replacement quantity") | |
| deferred_loss: Decimal = Field(..., lt=0, description="Amount of loss deferred") | |
| adjusted_basis_increase: Decimal = Field( | |
| ..., gt=0, description="Increase to replacement lot's basis" | |
| ) | |
| def validate_wash_sale_period(cls, v: date, info) -> date: | |
| """Validate that replacement is within 61-day wash sale window.""" | |
| sale_date = info.data.get("sale_date") | |
| if sale_date: | |
| days_diff = abs((v - sale_date).days) | |
| if days_diff > 30: | |
| # 30 days before and 30 days after = 61 day window | |
| raise ValueError("Replacement must be within 61-day wash sale window") | |
| return v | |
| class TaxBracket(BaseModel): | |
| """Tax bracket information for capital gains.""" | |
| model_config = ConfigDict(validate_assignment=True) | |
| filing_status: TaxFilingStatus = Field(..., description="Tax filing status") | |
| income: Decimal = Field(..., ge=0, description="Taxable income") | |
| short_term_rate: Decimal = Field(..., ge=0, le=1, description="Short-term gains tax rate") | |
| long_term_rate: Decimal = Field(..., ge=0, le=1, description="Long-term gains tax rate") | |
| class TaxImpactSummary(BaseModel): | |
| """Summary of tax impact for portfolio operations.""" | |
| model_config = ConfigDict(validate_assignment=True) | |
| total_short_term_gains: Decimal = Field(default=Decimal("0"), description="Total ST gains") | |
| total_short_term_losses: Decimal = Field(default=Decimal("0"), description="Total ST losses") | |
| total_long_term_gains: Decimal = Field(default=Decimal("0"), description="Total LT gains") | |
| total_long_term_losses: Decimal = Field(default=Decimal("0"), description="Total LT losses") | |
| net_short_term: Decimal = Field(default=Decimal("0"), description="Net short-term gain/loss") | |
| net_long_term: Decimal = Field(default=Decimal("0"), description="Net long-term gain/loss") | |
| total_tax_liability: Decimal = Field(default=Decimal("0"), ge=0, description="Estimated tax") | |
| wash_sales_count: int = Field(default=0, ge=0, description="Number of wash sales") | |
| total_deferred_losses: Decimal = Field( | |
| default=Decimal("0"), description="Total losses deferred from wash sales" | |
| ) | |
| capital_gains: List[CapitalGain] = Field(default_factory=list) | |
| wash_sales: List[WashSale] = Field(default_factory=list) | |
| class TaxLossHarvestingOpportunity(BaseModel): | |
| """Tax-loss harvesting opportunity identification.""" | |
| model_config = ConfigDict(validate_assignment=True) | |
| ticker: str = Field(..., description="Stock ticker symbol") | |
| current_price: Decimal = Field(..., ge=0, description="Current market price") | |
| unrealised_loss: Decimal = Field(..., lt=0, description="Unrealised loss amount") | |
| loss_percentage: Decimal = Field(..., lt=0, description="Loss percentage") | |
| quantity: Decimal = Field(..., gt=0, description="Quantity to sell") | |
| holding_period_days: int = Field(..., ge=0, description="Current holding period") | |
| is_long_term: bool = Field(..., description="Whether holding is long-term") | |
| estimated_tax_savings: Decimal = Field(..., ge=0, description="Estimated tax savings") | |
| wash_sale_risk: bool = Field( | |
| ..., description="Whether there's recent activity creating wash sale risk" | |
| ) | |
| recommended_action: str = Field(..., description="Recommended action to take") | |
| lot_ids: List[str] = Field(default_factory=list, description="Tax lot IDs to sell") | |
| class TaxOptimizedSale(BaseModel): | |
| """Tax-optimised sale recommendation.""" | |
| model_config = ConfigDict(validate_assignment=True) | |
| ticker: str = Field(..., description="Stock ticker symbol") | |
| quantity_to_sell: Decimal = Field(..., gt=0, description="Total quantity to sell") | |
| lots_to_sell: List[TaxLot] = Field(..., description="Specific lots to sell in order") | |
| cost_basis_method: CostBasisMethod = Field(..., description="Cost basis method used") | |
| estimated_gain_loss: Decimal = Field(..., description="Estimated gain or loss") | |
| estimated_tax: Decimal = Field(..., ge=0, description="Estimated tax liability") | |
| is_long_term: bool = Field(..., description="Whether majority is long-term") | |
| rationale: str = Field(..., description="Explanation of optimisation strategy") | |
| class PortfolioTaxAnalysis(BaseModel): | |
| """Complete tax analysis for portfolio.""" | |
| model_config = ConfigDict(validate_assignment=True) | |
| portfolio_id: str = Field(..., description="Portfolio identifier") | |
| analysis_date: date = Field(default_factory=date.today) | |
| tax_year: int = Field(..., description="Tax year for analysis") | |
| filing_status: TaxFilingStatus = Field(..., description="Tax filing status") | |
| taxable_income: Decimal = Field(..., ge=0, description="Annual taxable income") | |
| # Current positions | |
| total_unrealised_gains: Decimal = Field(default=Decimal("0"), description="Total unrealised gains") | |
| total_unrealised_losses: Decimal = Field(default=Decimal("0"), description="Total unrealised losses") | |
| # Realised gains (YTD) | |
| ytd_realised_gains: TaxImpactSummary = Field( | |
| default_factory=TaxImpactSummary, description="Year-to-date realised gains" | |
| ) | |
| # Opportunities | |
| harvesting_opportunities: List[TaxLossHarvestingOpportunity] = Field( | |
| default_factory=list, description="Tax-loss harvesting opportunities" | |
| ) | |
| estimated_harvesting_savings: Decimal = Field( | |
| default=Decimal("0"), ge=0, description="Total potential tax savings from harvesting" | |
| ) | |
| # Recommendations | |
| recommendations: List[str] = Field(default_factory=list, description="Tax optimisation recommendations") | |