"""Redis caching implementation with Upstash and in-memory fallback. This module provides a production-ready caching solution with: - Hybrid caching (Redis primary, in-memory fallback) - Domain-specific TTL strategies for financial data - Event-driven cache invalidation - Cache warming for frequently accessed data - Thread-safe operations with comprehensive error handling - Cache statistics tracking and monitoring """ import json import logging import random import threading import time from abc import ABC, abstractmethod from datetime import datetime, timedelta, timezone from enum import Enum from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union from pydantic import BaseModel, Field from upstash_redis import Redis logger = logging.getLogger(__name__) T = TypeVar("T") class CacheBackend(str, Enum): """Cache backend type.""" REDIS = "redis" MEMORY = "memory" NONE = "none" class CacheDataType(str, Enum): """Data type categories for TTL strategy.""" MARKET_DATA = "market_data" PORTFOLIO_METRICS = "portfolio_metrics" ANALYSIS_RESULTS = "analysis_results" HISTORICAL_DATA = "historical_data" MCP_RESPONSE = "mcp_response" USER_DATA = "user_data" class CacheStats(BaseModel): """Cache statistics for monitoring and optimisation.""" hits: int = 0 misses: int = 0 sets: int = 0 deletes: int = 0 errors: int = 0 fallback_hits: int = 0 total_size_bytes: int = 0 backend: CacheBackend = CacheBackend.NONE last_reset: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) @property def hit_rate(self) -> float: """Calculate cache hit rate. Returns: Hit rate as a percentage (0-100). """ total = self.hits + self.misses if total == 0: return 0.0 return (self.hits / total) * 100 @property def fallback_rate(self) -> float: """Calculate fallback usage rate. Returns: Fallback rate as a percentage (0-100). """ if self.hits == 0: return 0.0 return (self.fallback_hits / self.hits) * 100 class TTLStrategy: """TTL (Time To Live) strategy for different data types. Implements jitter to prevent cache stampede by adding randomness to TTL values. """ # Base TTL values in seconds TTL_CONFIG: Dict[CacheDataType, int] = { CacheDataType.MARKET_DATA: 60, # 1 minute (reduced for real-time data) CacheDataType.PORTFOLIO_METRICS: 1800, # 30 minutes CacheDataType.ANALYSIS_RESULTS: 14400, # 4 hours (extended from 1hr for cost savings) CacheDataType.HISTORICAL_DATA: 43200, # 12 hours (extended for older data) CacheDataType.MCP_RESPONSE: 300, # 5 minutes (extended from 10min) CacheDataType.USER_DATA: 7200, # 2 hours } # Jitter percentage (0-1) to add randomness JITTER_FACTOR = 0.1 @classmethod def get_ttl( cls, data_type: CacheDataType, custom_ttl: Optional[int] = None ) -> int: """Get TTL with jitter for a data type. Args: data_type: Type of data being cached. custom_ttl: Optional custom TTL override. Returns: TTL in seconds with jitter applied. """ base_ttl = custom_ttl if custom_ttl is not None else cls.TTL_CONFIG[data_type] jitter = int(base_ttl * cls.JITTER_FACTOR * random.random()) return base_ttl + jitter @classmethod def get_expiry_time( cls, data_type: CacheDataType, custom_ttl: Optional[int] = None ) -> datetime: """Get expiry timestamp for a data type. Args: data_type: Type of data being cached. custom_ttl: Optional custom TTL override. Returns: Expiry datetime in UTC. """ ttl = cls.get_ttl(data_type, custom_ttl) return datetime.now(timezone.utc) + timedelta(seconds=ttl) class CacheKey: """Standardised cache key generator with namespace support.""" SEPARATOR = ":" @classmethod def build( cls, namespace: str, identifier: str, *parts: str, tags: Optional[List[str]] = None, ) -> str: """Build a standardised cache key. Args: namespace: Primary namespace (e.g., 'portfolio', 'market'). identifier: Unique identifier (e.g., portfolio_id, ticker). *parts: Additional key components. tags: Optional tags for invalidation grouping. Returns: Formatted cache key. """ key_parts = [namespace, identifier] + list(parts) key = cls.SEPARATOR.join(str(p) for p in key_parts) if tags: tag_str = cls.SEPARATOR.join(sorted(tags)) key = f"{key}{cls.SEPARATOR}tags{cls.SEPARATOR}{tag_str}" return key @classmethod def parse_namespace(cls, key: str) -> Optional[str]: """Extract namespace from cache key. Args: key: Cache key to parse. Returns: Namespace if found, None otherwise. """ parts = key.split(cls.SEPARATOR) return parts[0] if parts else None class InMemoryCache: """Thread-safe in-memory cache with TTL support. Used as fallback when Redis is unavailable or for local development. """ def __init__(self, max_size: int = 1000): """Initialise in-memory cache. Args: max_size: Maximum number of items to store. """ self._cache: Dict[str, Dict[str, Any]] = {} self._lock = threading.RLock() self._max_size = max_size self.stats = CacheStats(backend=CacheBackend.MEMORY) def get(self, key: str) -> Optional[Any]: """Get value from cache. Args: key: Cache key. Returns: Cached value if found and not expired, None otherwise. """ with self._lock: if key not in self._cache: self.stats.misses += 1 return None entry = self._cache[key] expiry = entry.get("expiry") if expiry and datetime.now(timezone.utc) > expiry: del self._cache[key] self.stats.misses += 1 return None self.stats.hits += 1 return entry["value"] def set( self, key: str, value: Any, ttl: Optional[int] = None, expiry: Optional[datetime] = None, ) -> bool: """Set value in cache. Args: key: Cache key. value: Value to cache. ttl: Time to live in seconds. expiry: Explicit expiry datetime. Returns: True if successful. """ with self._lock: if len(self._cache) >= self._max_size: self._evict_oldest() if expiry is None and ttl is not None: expiry = datetime.now(timezone.utc) + timedelta(seconds=ttl) self._cache[key] = {"value": value, "expiry": expiry} self.stats.sets += 1 return True def delete(self, key: str) -> bool: """Delete key from cache. Args: key: Cache key. Returns: True if key was deleted. """ with self._lock: if key in self._cache: del self._cache[key] self.stats.deletes += 1 return True return False def delete_pattern(self, pattern: str) -> int: """Delete keys matching pattern. Args: pattern: Pattern to match (supports wildcards). Returns: Number of keys deleted. """ with self._lock: # Simple wildcard matching import fnmatch matching_keys = [ k for k in self._cache.keys() if fnmatch.fnmatch(k, pattern) ] for key in matching_keys: del self._cache[key] deleted_count = len(matching_keys) self.stats.deletes += deleted_count return deleted_count def clear(self) -> None: """Clear all cache entries.""" with self._lock: self._cache.clear() self.stats = CacheStats(backend=CacheBackend.MEMORY) def _evict_oldest(self) -> None: """Evict oldest entry based on insertion order.""" if self._cache: oldest_key = next(iter(self._cache)) del self._cache[oldest_key] def get_size(self) -> int: """Get current cache size. Returns: Number of items in cache. """ with self._lock: return len(self._cache) class HybridCache: """Hybrid cache with Redis primary and in-memory fallback. Implements cache-aside pattern with automatic fallback to in-memory storage when Redis is unavailable. """ def __init__( self, redis_url: Optional[str] = None, redis_token: Optional[str] = None, enable_fallback: bool = True, fallback_size: int = 1000, ): """Initialise hybrid cache. Args: redis_url: Upstash Redis REST URL. redis_token: Upstash Redis REST token. enable_fallback: Enable in-memory fallback. fallback_size: Maximum size of fallback cache. """ self.redis_client: Optional[Redis] = None self.fallback_cache: Optional[InMemoryCache] = None self.stats = CacheStats() if redis_url and redis_token: try: self.redis_client = Redis(url=redis_url, token=redis_token) self.redis_client.ping() self.stats.backend = CacheBackend.REDIS logger.info("Redis cache initialised successfully") except Exception as e: logger.warning(f"Failed to initialise Redis: {e}") self.redis_client = None if enable_fallback or self.redis_client is None: self.fallback_cache = InMemoryCache(max_size=fallback_size) logger.info("In-memory fallback cache enabled") if self.redis_client is None and self.fallback_cache is None: logger.warning("No cache backend available") self.stats.backend = CacheBackend.NONE def get(self, key: str, deserialiser: Optional[Callable[[str], T]] = None) -> Optional[T]: """Get value from cache with automatic fallback. Args: key: Cache key. deserialiser: Optional function to deserialise cached value. Returns: Cached value if found, None otherwise. """ try: if self.redis_client: value = self.redis_client.get(key) if value is not None: self.stats.hits += 1 if deserialiser: return deserialiser(value) return json.loads(value) if isinstance(value, str) else value self.stats.misses += 1 if self.fallback_cache: value = self.fallback_cache.get(key) if value is not None: self.stats.fallback_hits += 1 return value except Exception as e: logger.error(f"Cache get error for key {key}: {e}") self.stats.errors += 1 if self.fallback_cache: value = self.fallback_cache.get(key) if value is not None: self.stats.fallback_hits += 1 return value return None def set( self, key: str, value: Any, ttl: Optional[int] = None, serialiser: Optional[Callable[[Any], str]] = None, ) -> bool: """Set value in cache with automatic fallback. Args: key: Cache key. value: Value to cache. ttl: Time to live in seconds. serialiser: Optional function to serialise value. Returns: True if successful in at least one backend. """ success = False try: if self.redis_client: # Use custom serialiser if provided, otherwise use orjson with Decimal support if serialiser: serialised = serialiser(value) else: # Import here to avoid circular dependencies from backend.utils.serialisation import dumps serialised_bytes = dumps(value) # Upstash Redis uses REST API, needs string not bytes serialised = serialised_bytes.decode('utf-8') if ttl: self.redis_client.setex(key, ttl, serialised) else: self.redis_client.set(key, serialised) self.stats.sets += 1 success = True if self.fallback_cache: self.fallback_cache.set(key, value, ttl=ttl) success = True except Exception as e: logger.error(f"Cache set error for key {key}: {e}") self.stats.errors += 1 if self.fallback_cache and not success: try: self.fallback_cache.set(key, value, ttl=ttl) success = True except Exception as fallback_error: logger.error(f"Fallback cache set error: {fallback_error}") return success def delete(self, key: str) -> bool: """Delete key from cache. Args: key: Cache key. Returns: True if deleted from at least one backend. """ success = False try: if self.redis_client: deleted = self.redis_client.delete(key) if deleted: self.stats.deletes += 1 success = True if self.fallback_cache: if self.fallback_cache.delete(key): success = True except Exception as e: logger.error(f"Cache delete error for key {key}: {e}") self.stats.errors += 1 return success def delete_pattern(self, pattern: str) -> int: """Delete keys matching pattern. Args: pattern: Pattern to match (Redis glob-style: *, ?, []). Returns: Number of keys deleted. """ deleted_count = 0 try: if self.redis_client: cursor = 0 while True: cursor, keys = self.redis_client.scan(cursor, match=pattern, count=100) if keys: deleted = self.redis_client.delete(*keys) deleted_count += deleted self.stats.deletes += deleted if cursor == 0: break if self.fallback_cache: deleted_count += self.fallback_cache.delete_pattern(pattern) except Exception as e: logger.error(f"Cache delete pattern error for {pattern}: {e}") self.stats.errors += 1 return deleted_count def clear(self, namespace: Optional[str] = None) -> None: """Clear cache entries. Args: namespace: Optional namespace to clear. If None, clears all. """ try: if namespace: pattern = f"{namespace}:*" self.delete_pattern(pattern) else: if self.redis_client: self.redis_client.flushdb() if self.fallback_cache: self.fallback_cache.clear() self.stats = CacheStats(backend=self.stats.backend) except Exception as e: logger.error(f"Cache clear error: {e}") self.stats.errors += 1 def get_stats(self) -> CacheStats: """Get cache statistics. Returns: Current cache statistics. """ if self.fallback_cache and self.stats.backend == CacheBackend.MEMORY: return self.fallback_cache.stats return self.stats class PortfolioCachingStrategy: """Caching strategy for portfolio-specific data. Provides domain-specific methods for caching portfolio data with appropriate TTL strategies and key generation. """ def __init__(self, cache: HybridCache): """Initialise portfolio caching strategy. Args: cache: HybridCache instance. """ self.cache = cache def cache_portfolio_analysis( self, portfolio_id: str, analysis: Dict[str, Any], tags: Optional[List[str]] = None ) -> bool: """Cache portfolio analysis results. Args: portfolio_id: Portfolio identifier. analysis: Analysis data to cache. tags: Optional tags for invalidation. Returns: True if cached successfully. """ key = CacheKey.build("portfolio", portfolio_id, "analysis", tags=tags) ttl = TTLStrategy.get_ttl(CacheDataType.ANALYSIS_RESULTS) return self.cache.set(key, analysis, ttl=ttl) def get_portfolio_analysis( self, portfolio_id: str, tags: Optional[List[str]] = None ) -> Optional[Dict[str, Any]]: """Get cached portfolio analysis. Args: portfolio_id: Portfolio identifier. tags: Optional tags for key matching. Returns: Cached analysis if found. """ key = CacheKey.build("portfolio", portfolio_id, "analysis", tags=tags) return self.cache.get(key) def cache_portfolio_metrics( self, portfolio_id: str, metrics: Dict[str, Any] ) -> bool: """Cache portfolio metrics. Args: portfolio_id: Portfolio identifier. metrics: Metrics data to cache. Returns: True if cached successfully. """ key = CacheKey.build("portfolio", portfolio_id, "metrics") ttl = TTLStrategy.get_ttl(CacheDataType.PORTFOLIO_METRICS) return self.cache.set(key, metrics, ttl=ttl) def get_portfolio_metrics(self, portfolio_id: str) -> Optional[Dict[str, Any]]: """Get cached portfolio metrics. Args: portfolio_id: Portfolio identifier. Returns: Cached metrics if found. """ key = CacheKey.build("portfolio", portfolio_id, "metrics") return self.cache.get(key) def invalidate_portfolio(self, portfolio_id: str) -> int: """Invalidate all cache entries for a portfolio. Args: portfolio_id: Portfolio identifier. Returns: Number of keys invalidated. """ pattern = f"portfolio:{portfolio_id}:*" return self.cache.delete_pattern(pattern) class CacheInvalidationManager: """Event-driven cache invalidation manager. Manages cache invalidation based on events and tags, supporting batch invalidation and cascading updates. """ def __init__(self, cache: HybridCache): """Initialise invalidation manager. Args: cache: HybridCache instance. """ self.cache = cache self._tag_registry: Dict[str, Set[str]] = {} self._lock = threading.RLock() def register_tag(self, key: str, tags: List[str]) -> None: """Register tags for a cache key. Args: key: Cache key. tags: List of tags. """ with self._lock: for tag in tags: if tag not in self._tag_registry: self._tag_registry[tag] = set() self._tag_registry[tag].add(key) def invalidate_by_tag(self, tag: str) -> int: """Invalidate all keys with a specific tag. Args: tag: Tag to match. Returns: Number of keys invalidated. """ with self._lock: if tag not in self._tag_registry: return 0 keys = self._tag_registry[tag] count = 0 for key in keys: if self.cache.delete(key): count += 1 del self._tag_registry[tag] return count def invalidate_by_event(self, event_type: str, identifier: str) -> int: """Invalidate cache based on event. Args: event_type: Type of event (e.g., 'portfolio_update', 'market_data'). identifier: Event identifier (e.g., portfolio_id, ticker). Returns: Number of keys invalidated. """ pattern_map = { "portfolio_update": f"portfolio:{identifier}:*", "market_data_update": f"market:{identifier}:*", "user_update": f"user:{identifier}:*", } pattern = pattern_map.get(event_type) if not pattern: logger.warning(f"Unknown event type: {event_type}") return 0 return self.cache.delete_pattern(pattern) class FinancialDataCacheManager: """Domain-specific cache manager for financial data. Provides high-level caching operations for market data, historical data, and MCP responses with appropriate TTL strategies. """ def __init__(self, cache: HybridCache): """Initialise financial data cache manager. Args: cache: HybridCache instance. """ self.cache = cache self.portfolio_strategy = PortfolioCachingStrategy(cache) self.invalidation = CacheInvalidationManager(cache) def cache_market_data( self, ticker: str, data: Dict[str, Any], source: str = "default" ) -> bool: """Cache market data for a ticker. Args: ticker: Stock ticker symbol. data: Market data to cache. source: Data source identifier. Returns: True if cached successfully. """ key = CacheKey.build("market", ticker, source) ttl = TTLStrategy.get_ttl(CacheDataType.MARKET_DATA) return self.cache.set(key, data, ttl=ttl) def get_market_data( self, ticker: str, source: str = "default" ) -> Optional[Dict[str, Any]]: """Get cached market data. Args: ticker: Stock ticker symbol. source: Data source identifier. Returns: Cached market data if found. """ key = CacheKey.build("market", ticker, source) return self.cache.get(key) def cache_historical_data( self, ticker: str, period: str, data: Dict[str, Any] ) -> bool: """Cache historical price data. Args: ticker: Stock ticker symbol. period: Time period (e.g., '1y', '5y'). data: Historical data to cache. Returns: True if cached successfully. """ key = CacheKey.build("historical", ticker, period) ttl = TTLStrategy.get_ttl(CacheDataType.HISTORICAL_DATA) return self.cache.set(key, data, ttl=ttl) def get_historical_data( self, ticker: str, period: str ) -> Optional[Dict[str, Any]]: """Get cached historical data. Args: ticker: Stock ticker symbol. period: Time period. Returns: Cached historical data if found. """ key = CacheKey.build("historical", ticker, period) return self.cache.get(key) def cache_mcp_response( self, mcp_server: str, method: str, params_hash: str, response: Any ) -> bool: """Cache MCP server response. Args: mcp_server: MCP server identifier. method: Method name. params_hash: Hash of request parameters. response: MCP response data. Returns: True if cached successfully. """ key = CacheKey.build("mcp", mcp_server, method, params_hash) ttl = TTLStrategy.get_ttl(CacheDataType.MCP_RESPONSE) return self.cache.set(key, response, ttl=ttl) def get_mcp_response( self, mcp_server: str, method: str, params_hash: str ) -> Optional[Any]: """Get cached MCP response. Args: mcp_server: MCP server identifier. method: Method name. params_hash: Hash of request parameters. Returns: Cached response if found. """ key = CacheKey.build("mcp", mcp_server, method, params_hash) return self.cache.get(key) def warm_cache( self, tickers: List[str], data_fetcher: Callable[[str], Dict[str, Any]], ) -> int: """Warm cache with frequently accessed data. Args: tickers: List of tickers to warm. data_fetcher: Function to fetch data for a ticker. Returns: Number of items cached. """ cached_count = 0 for ticker in tickers: try: data = data_fetcher(ticker) if self.cache_market_data(ticker, data): cached_count += 1 except Exception as e: logger.error(f"Failed to warm cache for {ticker}: {e}") logger.info(f"Cache warming completed: {cached_count}/{len(tickers)} items cached") return cached_count def get_stats(self) -> CacheStats: """Get cache statistics. Returns: Current cache statistics. """ return self.cache.get_stats()