| """ |
| SPARKNET Cache Manager |
| Redis-based caching for RAG queries and embeddings. |
| """ |
|
|
| from typing import Optional, Any, List, Dict |
| from datetime import timedelta |
| import hashlib |
| import json |
| import os |
| from loguru import logger |
|
|
| |
| _redis_client = None |
|
|
|
|
| def get_redis_client(): |
| """Get or create Redis client.""" |
| global _redis_client |
| if _redis_client is None: |
| try: |
| import redis |
| redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") |
| _redis_client = redis.from_url(redis_url, decode_responses=True) |
| |
| _redis_client.ping() |
| logger.info(f"Redis connected: {redis_url}") |
| except Exception as e: |
| logger.warning(f"Redis not available: {e}. Using in-memory cache.") |
| _redis_client = None |
| return _redis_client |
|
|
|
|
| class CacheManager: |
| """ |
| Unified cache manager supporting Redis and in-memory fallback. |
| """ |
|
|
| def __init__(self, prefix: str = "sparknet", default_ttl: int = 3600): |
| """ |
| Initialize cache manager. |
| |
| Args: |
| prefix: Key prefix for namespacing |
| default_ttl: Default TTL in seconds (1 hour) |
| """ |
| self.prefix = prefix |
| self.default_ttl = default_ttl |
| self._memory_cache: Dict[str, Dict[str, Any]] = {} |
| self._redis = get_redis_client() |
|
|
| def _make_key(self, key: str) -> str: |
| """Create namespaced cache key.""" |
| return f"{self.prefix}:{key}" |
|
|
| def _hash_key(self, *args, **kwargs) -> str: |
| """Create hash key from arguments.""" |
| content = json.dumps({"args": args, "kwargs": kwargs}, sort_keys=True) |
| return hashlib.md5(content.encode()).hexdigest() |
|
|
| def get(self, key: str) -> Optional[Any]: |
| """ |
| Get value from cache. |
| |
| Args: |
| key: Cache key |
| |
| Returns: |
| Cached value or None |
| """ |
| full_key = self._make_key(key) |
|
|
| |
| if self._redis: |
| try: |
| value = self._redis.get(full_key) |
| if value: |
| return json.loads(value) |
| except Exception as e: |
| logger.warning(f"Redis get failed: {e}") |
|
|
| |
| if full_key in self._memory_cache: |
| entry = self._memory_cache[full_key] |
| import time |
| if entry.get("expires_at", 0) > time.time(): |
| return entry.get("value") |
| else: |
| del self._memory_cache[full_key] |
|
|
| return None |
|
|
| def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: |
| """ |
| Set value in cache. |
| |
| Args: |
| key: Cache key |
| value: Value to cache |
| ttl: Time-to-live in seconds (default: self.default_ttl) |
| |
| Returns: |
| True if successful |
| """ |
| full_key = self._make_key(key) |
| ttl = ttl or self.default_ttl |
|
|
| |
| if self._redis: |
| try: |
| self._redis.setex(full_key, ttl, json.dumps(value)) |
| return True |
| except Exception as e: |
| logger.warning(f"Redis set failed: {e}") |
|
|
| |
| import time |
| self._memory_cache[full_key] = { |
| "value": value, |
| "expires_at": time.time() + ttl |
| } |
|
|
| |
| if len(self._memory_cache) > 10000: |
| self._cleanup_memory_cache() |
|
|
| return True |
|
|
| def delete(self, key: str) -> bool: |
| """Delete a cache entry.""" |
| full_key = self._make_key(key) |
|
|
| if self._redis: |
| try: |
| self._redis.delete(full_key) |
| except Exception as e: |
| logger.warning(f"Redis delete failed: {e}") |
|
|
| if full_key in self._memory_cache: |
| del self._memory_cache[full_key] |
|
|
| return True |
|
|
| def clear_prefix(self, prefix: str) -> int: |
| """Clear all keys matching a prefix.""" |
| pattern = self._make_key(f"{prefix}:*") |
| count = 0 |
|
|
| if self._redis: |
| try: |
| keys = self._redis.keys(pattern) |
| if keys: |
| count = self._redis.delete(*keys) |
| except Exception as e: |
| logger.warning(f"Redis clear failed: {e}") |
|
|
| |
| to_delete = [k for k in self._memory_cache if k.startswith(self._make_key(prefix))] |
| for k in to_delete: |
| del self._memory_cache[k] |
| count += 1 |
|
|
| return count |
|
|
| def _cleanup_memory_cache(self): |
| """Remove expired entries from memory cache.""" |
| import time |
| now = time.time() |
| expired = [ |
| k for k, v in self._memory_cache.items() |
| if v.get("expires_at", 0) < now |
| ] |
| for k in expired: |
| del self._memory_cache[k] |
|
|
| |
| if len(self._memory_cache) > 10000: |
| sorted_keys = sorted( |
| self._memory_cache.keys(), |
| key=lambda k: self._memory_cache[k].get("expires_at", 0) |
| ) |
| for k in sorted_keys[:len(sorted_keys) // 2]: |
| del self._memory_cache[k] |
|
|
|
|
| class QueryCache(CacheManager): |
| """ |
| Specialized cache for RAG queries. |
| """ |
|
|
| def __init__(self, ttl: int = 3600): |
| super().__init__(prefix="sparknet:query", default_ttl=ttl) |
|
|
| def get_query_key(self, query: str, doc_ids: Optional[List[str]] = None) -> str: |
| """Generate cache key for a query.""" |
| doc_str = ",".join(sorted(doc_ids)) if doc_ids else "all" |
| content = f"{query.lower().strip()}:{doc_str}" |
| return hashlib.md5(content.encode()).hexdigest() |
|
|
| def get_query_response(self, query: str, doc_ids: Optional[List[str]] = None) -> Optional[Dict]: |
| """Get cached query response.""" |
| key = self.get_query_key(query, doc_ids) |
| return self.get(key) |
|
|
| def cache_query_response( |
| self, |
| query: str, |
| response: Dict, |
| doc_ids: Optional[List[str]] = None, |
| ttl: Optional[int] = None |
| ) -> bool: |
| """Cache a query response.""" |
| key = self.get_query_key(query, doc_ids) |
| return self.set(key, response, ttl) |
|
|
|
|
| class EmbeddingCache(CacheManager): |
| """ |
| Specialized cache for embeddings. |
| """ |
|
|
| def __init__(self, ttl: int = 86400): |
| super().__init__(prefix="sparknet:embed", default_ttl=ttl) |
|
|
| def get_embedding_key(self, text: str, model: str = "default") -> str: |
| """Generate cache key for embedding.""" |
| content = f"{model}:{text}" |
| return hashlib.md5(content.encode()).hexdigest() |
|
|
| def get_embedding(self, text: str, model: str = "default") -> Optional[List[float]]: |
| """Get cached embedding.""" |
| key = self.get_embedding_key(text, model) |
| return self.get(key) |
|
|
| def cache_embedding( |
| self, |
| text: str, |
| embedding: List[float], |
| model: str = "default" |
| ) -> bool: |
| """Cache an embedding.""" |
| key = self.get_embedding_key(text, model) |
| return self.set(key, embedding) |
|
|
|
|
| |
| _query_cache: Optional[QueryCache] = None |
| _embedding_cache: Optional[EmbeddingCache] = None |
|
|
|
|
| def get_query_cache() -> QueryCache: |
| """Get or create query cache instance.""" |
| global _query_cache |
| if _query_cache is None: |
| _query_cache = QueryCache() |
| return _query_cache |
|
|
|
|
| def get_embedding_cache() -> EmbeddingCache: |
| """Get or create embedding cache instance.""" |
| global _embedding_cache |
| if _embedding_cache is None: |
| _embedding_cache = EmbeddingCache() |
| return _embedding_cache |
|
|
|
|
| |
| def cached(prefix: str = "func", ttl: int = 3600): |
| """ |
| Decorator to cache function results. |
| |
| Usage: |
| @cached(prefix="my_func", ttl=600) |
| def expensive_function(arg1, arg2): |
| ... |
| """ |
| def decorator(func): |
| cache = CacheManager(prefix=f"sparknet:{prefix}", default_ttl=ttl) |
|
|
| def wrapper(*args, **kwargs): |
| |
| key = f"{func.__name__}:{cache._hash_key(*args, **kwargs)}" |
|
|
| |
| result = cache.get(key) |
| if result is not None: |
| return result |
|
|
| |
| result = func(*args, **kwargs) |
| cache.set(key, result) |
| return result |
|
|
| return wrapper |
| return decorator |
|
|