BrianIsaac's picture
feat: implement P1 features and production infrastructure
76897aa
"""Tax impact calculator for portfolio operations.
This module implements comprehensive tax calculations including:
- Holding period determination
- Capital gains tax rate calculation
- Wash sale detection
- Cost basis methods (FIFO, LIFO, HIFO, Specific ID, Average)
- Tax liability estimation
All calculations use 2024-2025 US federal tax rates.
"""
from datetime import date, timedelta
from decimal import Decimal
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
import logging
from backend.tax.models import (
Transaction,
TaxLot,
CapitalGain,
WashSale,
TaxBracket,
TaxImpactSummary,
TaxFilingStatus,
CostBasisMethod,
TransactionType,
)
logger = logging.getLogger(__name__)
# 2024-2025 US Federal Long-Term Capital Gains Tax Rates
# https://www.irs.gov/newsroom/irs-provides-tax-inflation-adjustments-for-tax-year-2024
LONG_TERM_CAPITAL_GAINS_BRACKETS_2024 = {
TaxFilingStatus.SINGLE: [
(Decimal("0"), Decimal("47025"), Decimal("0.00")), # 0%
(Decimal("47025"), Decimal("518900"), Decimal("0.15")), # 15%
(Decimal("518900"), Decimal("inf"), Decimal("0.20")), # 20%
],
TaxFilingStatus.MARRIED_JOINT: [
(Decimal("0"), Decimal("94050"), Decimal("0.00")), # 0%
(Decimal("94050"), Decimal("583750"), Decimal("0.15")), # 15%
(Decimal("583750"), Decimal("inf"), Decimal("0.20")), # 20%
],
TaxFilingStatus.MARRIED_SEPARATE: [
(Decimal("0"), Decimal("47025"), Decimal("0.00")), # 0%
(Decimal("47025"), Decimal("291850"), Decimal("0.15")), # 15%
(Decimal("291850"), Decimal("inf"), Decimal("0.20")), # 20%
],
TaxFilingStatus.HEAD_OF_HOUSEHOLD: [
(Decimal("0"), Decimal("63000"), Decimal("0.00")), # 0%
(Decimal("63000"), Decimal("551350"), Decimal("0.15")), # 15%
(Decimal("551350"), Decimal("inf"), Decimal("0.20")), # 20%
],
}
# 2024-2025 US Federal Ordinary Income Tax Brackets (for short-term gains)
# https://www.irs.gov/newsroom/irs-provides-tax-inflation-adjustments-for-tax-year-2024
ORDINARY_INCOME_BRACKETS_2024 = {
TaxFilingStatus.SINGLE: [
(Decimal("0"), Decimal("11600"), Decimal("0.10")),
(Decimal("11600"), Decimal("47150"), Decimal("0.12")),
(Decimal("47150"), Decimal("100525"), Decimal("0.22")),
(Decimal("100525"), Decimal("191950"), Decimal("0.24")),
(Decimal("191950"), Decimal("243725"), Decimal("0.32")),
(Decimal("243725"), Decimal("609350"), Decimal("0.35")),
(Decimal("609350"), Decimal("inf"), Decimal("0.37")),
],
TaxFilingStatus.MARRIED_JOINT: [
(Decimal("0"), Decimal("23200"), Decimal("0.10")),
(Decimal("23200"), Decimal("94300"), Decimal("0.12")),
(Decimal("94300"), Decimal("201050"), Decimal("0.22")),
(Decimal("201050"), Decimal("383900"), Decimal("0.24")),
(Decimal("383900"), Decimal("487450"), Decimal("0.32")),
(Decimal("487450"), Decimal("731200"), Decimal("0.35")),
(Decimal("731200"), Decimal("inf"), Decimal("0.37")),
],
TaxFilingStatus.MARRIED_SEPARATE: [
(Decimal("0"), Decimal("11600"), Decimal("0.10")),
(Decimal("11600"), Decimal("47150"), Decimal("0.12")),
(Decimal("47150"), Decimal("100525"), Decimal("0.22")),
(Decimal("100525"), Decimal("191950"), Decimal("0.24")),
(Decimal("191950"), Decimal("243725"), Decimal("0.32")),
(Decimal("243725"), Decimal("365600"), Decimal("0.35")),
(Decimal("365600"), Decimal("inf"), Decimal("0.37")),
],
TaxFilingStatus.HEAD_OF_HOUSEHOLD: [
(Decimal("0"), Decimal("16550"), Decimal("0.10")),
(Decimal("16550"), Decimal("63100"), Decimal("0.12")),
(Decimal("63100"), Decimal("100500"), Decimal("0.22")),
(Decimal("100500"), Decimal("191950"), Decimal("0.24")),
(Decimal("191950"), Decimal("243700"), Decimal("0.32")),
(Decimal("243700"), Decimal("609350"), Decimal("0.35")),
(Decimal("609350"), Decimal("inf"), Decimal("0.37")),
],
}
class TaxCalculator:
"""Calculate tax impact for portfolio transactions."""
def __init__(
self,
filing_status: TaxFilingStatus = TaxFilingStatus.SINGLE,
taxable_income: Decimal = Decimal("0"),
tax_year: int = 2024,
):
"""Initialise tax calculator.
Args:
filing_status: Tax filing status
taxable_income: Annual taxable income before capital gains
tax_year: Tax year for calculations (currently supports 2024-2025)
"""
self.filing_status = filing_status
self.taxable_income = taxable_income
self.tax_year = tax_year
# Tax lots tracking by ticker
self.tax_lots: Dict[str, List[TaxLot]] = defaultdict(list)
def calculate_holding_period_days(
self, acquisition_date: date, sale_date: date
) -> int:
"""Calculate holding period in days.
Per IRS rules, holding period is from day after acquisition to sale date inclusive.
Args:
acquisition_date: Date of acquisition
sale_date: Date of sale
Returns:
Number of days held
"""
# Holding period starts the day after acquisition
start_date = acquisition_date + timedelta(days=1)
return (sale_date - start_date).days + 1 # Inclusive of sale date
def is_long_term_holding(
self, acquisition_date: date, sale_date: date
) -> bool:
"""Determine if holding qualifies for long-term capital gains.
Long-term capital gains require holding period of more than 365 days.
Args:
acquisition_date: Date of acquisition
sale_date: Date of sale
Returns:
True if holding period exceeds 365 days
"""
holding_days = self.calculate_holding_period_days(acquisition_date, sale_date)
return holding_days > 365
def get_capital_gains_rate(
self, is_long_term: bool, income: Optional[Decimal] = None
) -> Decimal:
"""Get applicable capital gains tax rate.
Args:
is_long_term: Whether gain is long-term (> 365 days)
income: Taxable income (uses instance income if not provided)
Returns:
Tax rate as decimal (e.g., 0.15 for 15%)
"""
income_to_use = income if income is not None else self.taxable_income
if is_long_term:
# Use long-term capital gains brackets
brackets = LONG_TERM_CAPITAL_GAINS_BRACKETS_2024[self.filing_status]
else:
# Short-term gains taxed as ordinary income
brackets = ORDINARY_INCOME_BRACKETS_2024[self.filing_status]
# Find applicable bracket
for min_income, max_income, rate in brackets:
if min_income <= income_to_use < max_income:
return rate
# Default to highest bracket if not found
return brackets[-1][2]
def calculate_tax_liability(
self,
capital_gains: List[CapitalGain],
filing_status: Optional[TaxFilingStatus] = None,
income: Optional[Decimal] = None,
) -> Tuple[Decimal, TaxImpactSummary]:
"""Calculate total tax liability from capital gains.
Args:
capital_gains: List of capital gains/losses
filing_status: Override default filing status
income: Override default taxable income
Returns:
Tuple of (total tax liability, detailed summary)
"""
filing_status = filing_status or self.filing_status
income = income if income is not None else self.taxable_income
summary = TaxImpactSummary(capital_gains=capital_gains)
# Separate gains and losses by term
for gain in capital_gains:
if gain.is_long_term:
if gain.gain_loss > 0:
summary.total_long_term_gains += gain.gain_loss
else:
summary.total_long_term_losses += abs(gain.gain_loss)
else:
if gain.gain_loss > 0:
summary.total_short_term_gains += gain.gain_loss
else:
summary.total_short_term_losses += abs(gain.gain_loss)
if gain.is_wash_sale:
summary.wash_sales_count += 1
summary.total_deferred_losses += abs(gain.wash_sale_deferred_loss)
# Calculate net gains/losses
summary.net_short_term = (
summary.total_short_term_gains - summary.total_short_term_losses
)
summary.net_long_term = (
summary.total_long_term_gains - summary.total_long_term_losses
)
# Calculate tax liability
total_tax = Decimal("0")
# Tax on net short-term gains (if positive)
if summary.net_short_term > 0:
st_rate = self.get_capital_gains_rate(is_long_term=False, income=income)
total_tax += summary.net_short_term * st_rate
# Tax on net long-term gains (if positive)
if summary.net_long_term > 0:
lt_rate = self.get_capital_gains_rate(is_long_term=True, income=income)
total_tax += summary.net_long_term * lt_rate
# Note: Losses can offset gains, but excess losses are limited to $3,000/year
# This is a simplification; actual tax planning requires more complex rules
summary.total_tax_liability = total_tax
return total_tax, summary
def detect_wash_sales(
self,
transactions: List[Transaction],
lookback_days: int = 30,
lookforward_days: int = 30,
) -> List[WashSale]:
"""Detect wash sale violations.
A wash sale occurs when:
1. You sell stock at a loss
2. Within 30 days before or after (61-day window), you buy substantially identical stock
Args:
transactions: List of all transactions
lookback_days: Days to look back for replacement purchases (default 30)
lookforward_days: Days to look forward for replacement purchases (default 30)
Returns:
List of detected wash sales
"""
wash_sales = []
# Group transactions by ticker
by_ticker: Dict[str, List[Transaction]] = defaultdict(list)
for txn in transactions:
by_ticker[txn.ticker].append(txn)
# Sort each ticker's transactions by transaction_date
for ticker in by_ticker:
by_ticker[ticker].sort(key=lambda t: t.transaction_date)
# Check each sale for wash sale violations
for ticker, txns in by_ticker.items():
for i, sale_txn in enumerate(txns):
if sale_txn.transaction_type != TransactionType.SELL:
continue
# Skip if sale was at a gain or break-even
if sale_txn.price >= sale_txn.adjusted_cost_basis:
continue
loss_amount = (sale_txn.price - sale_txn.adjusted_cost_basis) * sale_txn.quantity
# Check for replacement purchases within 61-day window
wash_sale_start = sale_txn.transaction_date - timedelta(days=lookback_days)
wash_sale_end = sale_txn.transaction_date + timedelta(days=lookforward_days)
for buy_txn in txns:
if buy_txn.transaction_type != TransactionType.BUY:
continue
if wash_sale_start <= buy_txn.transaction_date <= wash_sale_end:
# Wash sale detected!
replacement_qty = min(sale_txn.quantity, buy_txn.quantity)
deferred_loss = (
loss_amount * replacement_qty / sale_txn.quantity
)
adjusted_basis_increase = abs(deferred_loss)
wash_sale = WashSale(
ticker=ticker,
sale_date=sale_txn.transaction_date,
sale_quantity=sale_txn.quantity,
loss_amount=loss_amount,
replacement_date=buy_txn.transaction_date,
replacement_quantity=replacement_qty,
deferred_loss=deferred_loss,
adjusted_basis_increase=adjusted_basis_increase,
)
wash_sales.append(wash_sale)
break # One wash sale per loss sale
return wash_sales
def add_tax_lot(
self,
ticker: str,
acquisition_date: date,
quantity: Decimal,
cost_basis_per_share: Decimal,
lot_id: Optional[str] = None,
) -> TaxLot:
"""Add a tax lot to tracking.
Args:
ticker: Stock ticker symbol
acquisition_date: Date of acquisition
quantity: Number of shares
cost_basis_per_share: Cost basis per share
lot_id: Optional lot identifier
Returns:
Created TaxLot
"""
if lot_id is None:
lot_id = f"{ticker}_{acquisition_date.isoformat()}_{len(self.tax_lots[ticker])}"
total_cost = quantity * cost_basis_per_share
lot = TaxLot(
lot_id=lot_id,
ticker=ticker,
acquisition_date=acquisition_date,
quantity=quantity,
original_quantity=quantity,
cost_basis_per_share=cost_basis_per_share,
total_cost_basis=total_cost,
)
self.tax_lots[ticker].append(lot)
return lot
def calculate_sale_fifo(
self,
ticker: str,
sale_date: date,
quantity: Decimal,
sale_price: Decimal,
) -> List[CapitalGain]:
"""Calculate capital gains using FIFO method.
Args:
ticker: Stock ticker symbol
sale_date: Date of sale
quantity: Quantity to sell
sale_price: Sale price per share
Returns:
List of capital gains (one per lot sold)
"""
return self._calculate_sale_with_method(
ticker, sale_date, quantity, sale_price, CostBasisMethod.FIFO
)
def calculate_sale_lifo(
self,
ticker: str,
sale_date: date,
quantity: Decimal,
sale_price: Decimal,
) -> List[CapitalGain]:
"""Calculate capital gains using LIFO method.
Args:
ticker: Stock ticker symbol
sale_date: Date of sale
quantity: Quantity to sell
sale_price: Sale price per share
Returns:
List of capital gains (one per lot sold)
"""
return self._calculate_sale_with_method(
ticker, sale_date, quantity, sale_price, CostBasisMethod.LIFO
)
def calculate_sale_hifo(
self,
ticker: str,
sale_date: date,
quantity: Decimal,
sale_price: Decimal,
) -> List[CapitalGain]:
"""Calculate capital gains using HIFO method.
Args:
ticker: Stock ticker symbol
sale_date: Date of sale
quantity: Quantity to sell
sale_price: Sale price per share
Returns:
List of capital gains (one per lot sold)
"""
return self._calculate_sale_with_method(
ticker, sale_date, quantity, sale_price, CostBasisMethod.HIFO
)
def calculate_sale_average(
self,
ticker: str,
sale_date: date,
quantity: Decimal,
sale_price: Decimal,
) -> List[CapitalGain]:
"""Calculate capital gains using average cost method.
Args:
ticker: Stock ticker symbol
sale_date: Date of sale
quantity: Quantity to sell
sale_price: Sale price per share
Returns:
List with single capital gain using average cost
"""
lots = self.tax_lots.get(ticker, [])
if not lots:
raise ValueError(f"No tax lots available for {ticker}")
# Calculate average cost basis
total_shares = sum(lot.quantity for lot in lots)
total_cost = sum(lot.total_cost_basis for lot in lots)
avg_cost_basis = total_cost / total_shares if total_shares > 0 else Decimal("0")
# Use earliest acquisition date for holding period
earliest_date = min(lot.acquisition_date for lot in lots)
holding_days = self.calculate_holding_period_days(earliest_date, sale_date)
is_long_term = holding_days > 365
proceeds = quantity * sale_price
cost_basis_total = quantity * avg_cost_basis
gain_loss = proceeds - cost_basis_total
gain = CapitalGain(
ticker=ticker,
sale_date=sale_date,
acquisition_date=earliest_date,
quantity=quantity,
sale_price=sale_price,
cost_basis=avg_cost_basis,
proceeds=proceeds,
total_cost_basis=cost_basis_total,
gain_loss=gain_loss,
is_long_term=is_long_term,
holding_period_days=holding_days,
)
# Reduce quantities from lots proportionally
reduction_ratio = quantity / total_shares
for lot in lots:
lot.quantity -= lot.quantity * reduction_ratio
# Remove empty lots
self.tax_lots[ticker] = [lot for lot in lots if lot.quantity > 0]
return [gain]
def _calculate_sale_with_method(
self,
ticker: str,
sale_date: date,
quantity: Decimal,
sale_price: Decimal,
method: CostBasisMethod,
) -> List[CapitalGain]:
"""Internal method to calculate sale with specific cost basis method.
Args:
ticker: Stock ticker symbol
sale_date: Date of sale
quantity: Quantity to sell
sale_price: Sale price per share
method: Cost basis method to use
Returns:
List of capital gains
"""
lots = self.tax_lots.get(ticker, [])
if not lots:
raise ValueError(f"No tax lots available for {ticker}")
# Sort lots based on method
if method == CostBasisMethod.FIFO:
lots.sort(key=lambda x: x.acquisition_date)
elif method == CostBasisMethod.LIFO:
lots.sort(key=lambda x: x.acquisition_date, reverse=True)
elif method == CostBasisMethod.HIFO:
lots.sort(key=lambda x: x.cost_basis_per_share, reverse=True)
else:
raise ValueError(f"Method {method} not supported in this function")
capital_gains = []
remaining_qty = quantity
for lot in lots:
if remaining_qty <= 0:
break
# Determine how much to sell from this lot
qty_from_lot = min(remaining_qty, lot.quantity)
# Calculate gain/loss
proceeds = qty_from_lot * sale_price
cost_basis = qty_from_lot * lot.cost_basis_per_share
gain_loss = proceeds - cost_basis
holding_days = self.calculate_holding_period_days(
lot.acquisition_date, sale_date
)
is_long_term = holding_days > 365
gain = CapitalGain(
ticker=ticker,
sale_date=sale_date,
acquisition_date=lot.acquisition_date,
quantity=qty_from_lot,
sale_price=sale_price,
cost_basis=lot.cost_basis_per_share,
proceeds=proceeds,
total_cost_basis=cost_basis,
gain_loss=gain_loss,
is_long_term=is_long_term,
holding_period_days=holding_days,
lot_id=lot.lot_id,
)
capital_gains.append(gain)
# Reduce lot quantity
lot.quantity -= qty_from_lot
remaining_qty -= qty_from_lot
# Remove empty lots
self.tax_lots[ticker] = [lot for lot in self.tax_lots[ticker] if lot.quantity > 0]
if remaining_qty > 0:
logger.warning(
f"Insufficient shares for sale: {ticker}, "
f"requested {quantity}, sold {quantity - remaining_qty}"
)
return capital_gains
def get_unrealised_gains(
self, ticker: str, current_price: Decimal
) -> Tuple[Decimal, Decimal]:
"""Calculate unrealised gains/losses for current holdings.
Args:
ticker: Stock ticker symbol
current_price: Current market price
Returns:
Tuple of (unrealised gain/loss, total quantity)
"""
lots = self.tax_lots.get(ticker, [])
if not lots:
return Decimal("0"), Decimal("0")
total_quantity = sum(lot.quantity for lot in lots)
total_cost = sum(lot.total_cost_basis for lot in lots)
current_value = total_quantity * current_price
unrealised_gain = current_value - total_cost
return unrealised_gain, total_quantity