BrianIsaac commited on
Commit
84b71cf
·
1 Parent(s): 021b909

fix: implement fixed window rate limiting and resolve analysis history persistence

Browse files

- Replace token bucket with fixed window rate limiter
- Quota now fully resets at midnight UTC (not 72-hour refill)
- Authenticated users: 3 requests per day
- Demo users: 1 request per day
- Redis-backed with in-memory fallback

- Fix analysis history not showing up
- Ensure demo user ID is used for demo mode sessions
- Portfolio creation now validates user_id exists
- History loading uses correct user_id for both authenticated and demo
- Enhanced logging for debugging

- Add comprehensive testing and documentation
- Test scripts for verification
- Migration guide for rate limiting
- Detailed logging for troubleshooting

Fixes: Analysis history persistence issue
Fixes: Rate limit refresh behaviour (72h → 24h)

app.py CHANGED
@@ -52,10 +52,10 @@ from backend.agents.personas import get_available_personas
52
  from backend.tax.interface import create_tax_analysis, format_tax_analysis_output
53
  from backend.config import settings
54
  from backend.rate_limiting import (
55
- TieredRateLimiter,
56
  GradioRateLimitMiddleware,
57
  UserTier,
58
  )
 
59
  from backend.auth import auth, UserSession
60
 
61
  def check_authentication(session_state: Dict) -> bool:
@@ -92,31 +92,75 @@ logging.getLogger('matplotlib.font_manager').setLevel(logging.WARNING)
92
  # Initialize workflow
93
  workflow = PortfolioAnalysisWorkflow(mcp_router)
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  # Initialize rate limiter
96
  rate_limiter = None
97
  rate_limit_middleware = None
98
 
99
  if settings.rate_limit_enabled:
100
  try:
101
- rate_limiter = TieredRateLimiter(
102
  tier_limits={
103
- UserTier.ANONYMOUS: (
104
- settings.rate_limit_anonymous_capacity,
105
- settings.rate_limit_anonymous_refill_rate
106
- ),
107
- UserTier.AUTHENTICATED: (
108
- settings.rate_limit_authenticated_capacity,
109
- settings.rate_limit_authenticated_refill_rate
110
- ),
111
- UserTier.PREMIUM: (
112
- settings.rate_limit_premium_capacity,
113
- settings.rate_limit_premium_refill_rate
114
- ),
115
  },
116
  redis_url=settings.redis_url
117
  )
118
- rate_limit_middleware = GradioRateLimitMiddleware(rate_limiter)
119
- logger.info("Rate limiting enabled with tiered limits")
 
 
 
120
  except Exception as e:
121
  logger.error(f"Failed to initialise rate limiter: {e}")
122
  logger.warning("Continuing without rate limiting")
@@ -744,18 +788,27 @@ async def run_analysis_with_ui_update(
744
 
745
  portfolio_id = f"demo_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
746
 
747
- await db.ensure_demo_user_exists()
 
748
 
749
  try:
750
  session = UserSession.from_dict(session_state)
 
 
 
 
 
 
 
751
  await db.save_portfolio(
752
  portfolio_id=portfolio_id,
753
- user_id=session.user_id,
754
  name=f"Analysis {datetime.now().strftime('%Y-%m-%d %H:%M')}",
755
  risk_tolerance='moderate'
756
  )
 
757
  except Exception as e:
758
- logger.warning(f"Portfolio may already exist: {e}")
759
 
760
  initial_state: AgentState = {
761
  'portfolio_id': portfolio_id,
@@ -799,10 +852,18 @@ async def run_analysis_with_ui_update(
799
  # Save analysis to database (Enhancement #4 - Historical Analysis Storage)
800
  try:
801
  session = UserSession.from_dict(session_state)
802
- await db.save_analysis(portfolio_id, final_state)
803
- logger.info(f"Saved analysis for portfolio {portfolio_id}")
 
 
 
 
 
 
 
 
804
  except Exception as e:
805
- logger.warning(f"Failed to save analysis: {e}")
806
 
807
  progress(0.7, desc=random.choice(LOADING_MESSAGES))
808
  await asyncio.sleep(0.3)
@@ -1866,11 +1927,28 @@ def create_interface() -> gr.Blocks:
1866
  """Load analysis history from database."""
1867
  try:
1868
  session = UserSession.from_dict(session_state)
1869
- history = await db.get_analysis_history(session.user_id, limit=20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1870
 
1871
  if not history:
 
1872
  return [], "No previous analyses found"
1873
 
 
 
1874
  # Format history for dataframe
1875
  rows = []
1876
  for record in history:
@@ -1955,7 +2033,7 @@ Please try again with different parameters.
1955
  # Enforce rate limiting
1956
  if rate_limit_middleware:
1957
  try:
1958
- rate_limit_middleware.enforce(request)
1959
  except Exception as e:
1960
  logger.warning(f"Rate limit exceeded: {e}")
1961
  yield {
 
52
  from backend.tax.interface import create_tax_analysis, format_tax_analysis_output
53
  from backend.config import settings
54
  from backend.rate_limiting import (
 
55
  GradioRateLimitMiddleware,
56
  UserTier,
57
  )
58
+ from backend.rate_limiting.fixed_window import TieredFixedWindowLimiter
59
  from backend.auth import auth, UserSession
60
 
61
  def check_authentication(session_state: Dict) -> bool:
 
92
  # Initialize workflow
93
  workflow = PortfolioAnalysisWorkflow(mcp_router)
94
 
95
+ # Custom get_user_tier function for session-aware rate limiting
96
+ def get_user_tier_from_session(request: Optional[gr.Request], session_state: Optional[dict] = None):
97
+ """Determine user tier from session state.
98
+
99
+ Args:
100
+ request: Gradio request object
101
+ session_state: Session state dict containing user authentication info
102
+
103
+ Returns:
104
+ Tuple of (identifier, tier)
105
+ - For authenticated users: (user_id, AUTHENTICATED)
106
+ - For demo mode: (ip_hash, ANONYMOUS)
107
+ """
108
+ import hashlib
109
+
110
+ # Check if user is authenticated via session state
111
+ if session_state and isinstance(session_state, dict):
112
+ user_id = session_state.get("user_id")
113
+ is_demo = session_state.get("is_demo", False)
114
+
115
+ # Authenticated user: use user_id for rate limiting
116
+ if user_id and not is_demo:
117
+ return str(user_id), UserTier.AUTHENTICATED
118
+
119
+ # Demo mode or unauthenticated: use IP-based rate limiting
120
+ # Extract client IP from request
121
+ client_ip = "unknown"
122
+
123
+ if request:
124
+ try:
125
+ if hasattr(request, "client") and request.client:
126
+ if hasattr(request.client, "host"):
127
+ client_ip = request.client.host
128
+ elif isinstance(request.client, str):
129
+ client_ip = request.client
130
+
131
+ # Check headers for forwarded IPs (behind proxy)
132
+ if hasattr(request, "headers"):
133
+ forwarded = request.headers.get("X-Forwarded-For")
134
+ if forwarded:
135
+ client_ip = forwarded.split(",")[0].strip()
136
+ except Exception as e:
137
+ logger.warning(f"Error extracting client IP: {e}")
138
+
139
+ # Hash IP for privacy
140
+ identifier = hashlib.sha256(client_ip.encode()).hexdigest()[:16]
141
+
142
+ return identifier, UserTier.ANONYMOUS
143
+
144
+
145
  # Initialize rate limiter
146
  rate_limiter = None
147
  rate_limit_middleware = None
148
 
149
  if settings.rate_limit_enabled:
150
  try:
151
+ rate_limiter = TieredFixedWindowLimiter(
152
  tier_limits={
153
+ UserTier.ANONYMOUS: settings.rate_limit_anonymous_capacity,
154
+ UserTier.AUTHENTICATED: settings.rate_limit_authenticated_capacity,
155
+ UserTier.PREMIUM: settings.rate_limit_premium_capacity,
 
 
 
 
 
 
 
 
 
156
  },
157
  redis_url=settings.redis_url
158
  )
159
+ rate_limit_middleware = GradioRateLimitMiddleware(
160
+ rate_limiter,
161
+ get_user_tier=get_user_tier_from_session
162
+ )
163
+ logger.info("Rate limiting enabled with fixed window (daily reset at midnight UTC)")
164
  except Exception as e:
165
  logger.error(f"Failed to initialise rate limiter: {e}")
166
  logger.warning("Continuing without rate limiting")
 
788
 
789
  portfolio_id = f"demo_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
790
 
791
+ # Ensure demo user exists and get the user_id
792
+ demo_user_id = await db.ensure_demo_user_exists()
793
 
794
  try:
795
  session = UserSession.from_dict(session_state)
796
+ # Use demo user ID if user is in demo mode, otherwise use authenticated user ID
797
+ user_id = session.user_id if (session and session.user_id and not session_state.get("is_demo")) else demo_user_id
798
+
799
+ if not user_id:
800
+ logger.error("No valid user_id available for portfolio creation")
801
+ raise ValueError("User ID not available")
802
+
803
  await db.save_portfolio(
804
  portfolio_id=portfolio_id,
805
+ user_id=user_id,
806
  name=f"Analysis {datetime.now().strftime('%Y-%m-%d %H:%M')}",
807
  risk_tolerance='moderate'
808
  )
809
+ logger.info(f"Portfolio {portfolio_id} created for user {user_id}")
810
  except Exception as e:
811
+ logger.warning(f"Failed to save portfolio: {e}")
812
 
813
  initial_state: AgentState = {
814
  'portfolio_id': portfolio_id,
 
852
  # Save analysis to database (Enhancement #4 - Historical Analysis Storage)
853
  try:
854
  session = UserSession.from_dict(session_state)
855
+ is_demo = session_state.get("is_demo", False) if session_state else False
856
+ user_identifier = session.user_id if session and session.user_id else "demo"
857
+
858
+ logger.info(f"Saving analysis for portfolio {portfolio_id}, user: {user_identifier}, is_demo: {is_demo}")
859
+
860
+ save_result = await db.save_analysis(portfolio_id, final_state)
861
+ if save_result:
862
+ logger.info(f"✓ Successfully saved analysis for portfolio {portfolio_id}")
863
+ else:
864
+ logger.warning(f"✗ Failed to save analysis for portfolio {portfolio_id} (returned False)")
865
  except Exception as e:
866
+ logger.error(f" Exception saving analysis for portfolio {portfolio_id}: {e}", exc_info=True)
867
 
868
  progress(0.7, desc=random.choice(LOADING_MESSAGES))
869
  await asyncio.sleep(0.3)
 
1927
  """Load analysis history from database."""
1928
  try:
1929
  session = UserSession.from_dict(session_state)
1930
+ is_demo = session_state.get("is_demo", False) if session_state else False
1931
+
1932
+ # Get user_id - use demo user for demo sessions
1933
+ if is_demo or not session or not session.user_id:
1934
+ demo_user_id = await db.ensure_demo_user_exists()
1935
+ user_id = demo_user_id
1936
+ else:
1937
+ user_id = session.user_id
1938
+
1939
+ if not user_id:
1940
+ logger.error("No valid user_id for loading history")
1941
+ return [], "Unable to load history - user not found"
1942
+
1943
+ logger.info(f"Loading history for user: {user_id}, is_demo: {is_demo}")
1944
+ history = await db.get_analysis_history(user_id, limit=20)
1945
 
1946
  if not history:
1947
+ logger.info(f"No history found for user {user_id}")
1948
  return [], "No previous analyses found"
1949
 
1950
+ logger.info(f"Loaded {len(history)} analyses for user {user_id}")
1951
+
1952
  # Format history for dataframe
1953
  rows = []
1954
  for record in history:
 
2033
  # Enforce rate limiting
2034
  if rate_limit_middleware:
2035
  try:
2036
+ rate_limit_middleware.enforce(request, session_state=session_state)
2037
  except Exception as e:
2038
  logger.warning(f"Rate limit exceeded: {e}")
2039
  yield {
backend/config.py CHANGED
@@ -93,24 +93,26 @@ class Settings(BaseSettings):
93
  validation_alias="RATE_LIMIT_ENABLED"
94
  )
95
 
96
- # Rate limit tiers (capacity, refill_rate)
97
- # Anonymous (Demo Mode): 1 request capacity, refills at 1 request per 24 hours
98
  rate_limit_anonymous_capacity: int = Field(
99
  default=1,
100
- validation_alias="RATE_LIMIT_ANONYMOUS_CAPACITY"
 
101
  )
102
  rate_limit_anonymous_refill_rate: float = Field(
103
- default=1.0 / 86400.0, # 1 token per 24 hours (86400 seconds)
104
  validation_alias="RATE_LIMIT_ANONYMOUS_REFILL_RATE"
105
  )
106
 
107
- # Authenticated: 50 requests capacity, refills at 0.5 requests/second (1 every 2 seconds)
108
  rate_limit_authenticated_capacity: int = Field(
109
- default=50,
110
- validation_alias="RATE_LIMIT_AUTHENTICATED_CAPACITY"
 
111
  )
112
  rate_limit_authenticated_refill_rate: float = Field(
113
- default=0.5,
114
  validation_alias="RATE_LIMIT_AUTHENTICATED_REFILL_RATE"
115
  )
116
 
 
93
  validation_alias="RATE_LIMIT_ENABLED"
94
  )
95
 
96
+ # Rate limit tiers (fixed window - resets daily at midnight UTC)
97
+ # Anonymous (Demo Mode): 1 request per day
98
  rate_limit_anonymous_capacity: int = Field(
99
  default=1,
100
+ validation_alias="RATE_LIMIT_ANONYMOUS_CAPACITY",
101
+ description="Number of requests allowed per day for anonymous users"
102
  )
103
  rate_limit_anonymous_refill_rate: float = Field(
104
+ default=1.0 / 86400.0, # DEPRECATED: Not used with fixed window
105
  validation_alias="RATE_LIMIT_ANONYMOUS_REFILL_RATE"
106
  )
107
 
108
+ # Authenticated: 3 requests per day
109
  rate_limit_authenticated_capacity: int = Field(
110
+ default=3,
111
+ validation_alias="RATE_LIMIT_AUTHENTICATED_CAPACITY",
112
+ description="Number of requests allowed per day for authenticated users"
113
  )
114
  rate_limit_authenticated_refill_rate: float = Field(
115
+ default=1.0 / 86400.0, # DEPRECATED: Not used with fixed window
116
  validation_alias="RATE_LIMIT_AUTHENTICATED_REFILL_RATE"
117
  )
118
 
backend/rate_limiting/__init__.py CHANGED
@@ -42,6 +42,10 @@ from backend.rate_limiting.limiter import (
42
  UserTier,
43
  RateLimitInfo,
44
  )
 
 
 
 
45
 
46
  __all__ = [
47
  "ThreadSafeTokenBucket",
@@ -51,4 +55,6 @@ __all__ = [
51
  "GradioRateLimitMiddleware",
52
  "UserTier",
53
  "RateLimitInfo",
 
 
54
  ]
 
42
  UserTier,
43
  RateLimitInfo,
44
  )
45
+ from backend.rate_limiting.fixed_window import (
46
+ FixedWindowRateLimiter,
47
+ TieredFixedWindowLimiter,
48
+ )
49
 
50
  __all__ = [
51
  "ThreadSafeTokenBucket",
 
55
  "GradioRateLimitMiddleware",
56
  "UserTier",
57
  "RateLimitInfo",
58
+ "FixedWindowRateLimiter",
59
+ "TieredFixedWindowLimiter",
60
  ]
backend/rate_limiting/fixed_window.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fixed Window Rate Limiter Implementation.
3
+
4
+ Provides daily rate limits that reset at a specific time (midnight UTC).
5
+ Users get their full quota back every 24 hours.
6
+ """
7
+
8
+ import hashlib
9
+ import time
10
+ from datetime import datetime, timezone, timedelta
11
+ from typing import Optional, Tuple, Dict
12
+ from threading import Lock
13
+ import logging
14
+
15
+ try:
16
+ import redis
17
+ REDIS_AVAILABLE = True
18
+ except ImportError:
19
+ REDIS_AVAILABLE = False
20
+
21
+ from backend.rate_limiting.limiter import UserTier, RateLimitInfo
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class FixedWindowRateLimiter:
27
+ """Fixed window rate limiter with daily reset.
28
+
29
+ Implements a simple counter-based rate limiter where the quota
30
+ resets at a specific time each day (midnight UTC by default).
31
+
32
+ This is ideal for "X requests per day" use cases where users
33
+ expect their quota to fully refresh every 24 hours.
34
+
35
+ Features:
36
+ - Redis-backed with in-memory fallback
37
+ - Configurable daily limits per tier
38
+ - Resets at midnight UTC (configurable)
39
+ - Thread-safe for concurrent requests
40
+
41
+ Attributes:
42
+ tier_limits: Dictionary mapping UserTier to daily request limit
43
+ redis_client: Optional Redis client for distributed rate limiting
44
+ use_redis: Boolean indicating if Redis is available
45
+ memory_counters: In-memory fallback storage
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ tier_limits: Dict[UserTier, int],
51
+ redis_url: Optional[str] = None,
52
+ key_prefix: str = "ratelimit_daily"
53
+ ):
54
+ """Initialise fixed window rate limiter.
55
+
56
+ Args:
57
+ tier_limits: Dictionary mapping UserTier to daily request limit
58
+ Example: {UserTier.AUTHENTICATED: 3, UserTier.ANONYMOUS: 1}
59
+ redis_url: Optional Redis connection URL
60
+ key_prefix: Prefix for Redis keys (default: "ratelimit_daily")
61
+ """
62
+ self.tier_limits = tier_limits
63
+ self.key_prefix = key_prefix
64
+ self.memory_counters: Dict[str, Tuple[int, str]] = {} # (count, date_key)
65
+ self._memory_lock = Lock()
66
+
67
+ # Initialize Redis if available
68
+ self.redis_client = None
69
+ self.use_redis = False
70
+
71
+ if redis_url and REDIS_AVAILABLE:
72
+ try:
73
+ self.redis_client = redis.from_url(
74
+ redis_url,
75
+ decode_responses=True,
76
+ socket_connect_timeout=2,
77
+ socket_timeout=2
78
+ )
79
+ self.redis_client.ping()
80
+ self.use_redis = True
81
+ logger.info("Fixed window rate limiter: Redis initialised")
82
+ except Exception as e:
83
+ logger.warning(f"Redis unavailable, using in-memory: {e}")
84
+ else:
85
+ logger.info("Fixed window rate limiter: Using in-memory storage")
86
+
87
+ def _get_date_key(self) -> str:
88
+ """Get current date key for rate limiting window.
89
+
90
+ Returns:
91
+ Date string in YYYY-MM-DD format (UTC)
92
+ """
93
+ return datetime.now(timezone.utc).strftime("%Y-%m-%d")
94
+
95
+ def _get_redis_key(self, identifier: str, tier: UserTier, date_key: str) -> str:
96
+ """Construct Redis key for rate limit counter.
97
+
98
+ Args:
99
+ identifier: User identifier (user_id or IP hash)
100
+ tier: User tier
101
+ date_key: Date key (YYYY-MM-DD)
102
+
103
+ Returns:
104
+ Redis key string
105
+ """
106
+ return f"{self.key_prefix}:{tier.value}:{identifier}:{date_key}"
107
+
108
+ def _seconds_until_reset(self) -> int:
109
+ """Calculate seconds until next reset (midnight UTC).
110
+
111
+ Returns:
112
+ Seconds remaining until midnight UTC
113
+ """
114
+ now = datetime.now(timezone.utc)
115
+ tomorrow = (now + timedelta(days=1)).replace(
116
+ hour=0, minute=0, second=0, microsecond=0
117
+ )
118
+ return int((tomorrow - now).total_seconds())
119
+
120
+ def consume(
121
+ self,
122
+ identifier: str,
123
+ tier: UserTier,
124
+ tokens: int = 1
125
+ ) -> RateLimitInfo:
126
+ """Attempt to consume tokens from daily quota.
127
+
128
+ Args:
129
+ identifier: User identifier (user_id or IP hash)
130
+ tier: User tier for quota lookup
131
+ tokens: Number of tokens to consume (default: 1)
132
+
133
+ Returns:
134
+ RateLimitInfo with consumption result
135
+ """
136
+ date_key = self._get_date_key()
137
+ daily_limit = self.tier_limits.get(tier, 1)
138
+
139
+ # Try Redis first
140
+ if self.use_redis and self.redis_client:
141
+ try:
142
+ return self._consume_redis(identifier, tier, date_key, daily_limit, tokens)
143
+ except Exception as e:
144
+ logger.error(f"Redis consume error, falling back to memory: {e}")
145
+
146
+ # Fallback to in-memory
147
+ return self._consume_memory(identifier, tier, date_key, daily_limit, tokens)
148
+
149
+ def _consume_redis(
150
+ self,
151
+ identifier: str,
152
+ tier: UserTier,
153
+ date_key: str,
154
+ daily_limit: int,
155
+ tokens: int
156
+ ) -> RateLimitInfo:
157
+ """Consume tokens using Redis backend.
158
+
159
+ Args:
160
+ identifier: User identifier
161
+ tier: User tier
162
+ date_key: Current date key
163
+ daily_limit: Daily request limit for this tier
164
+ tokens: Tokens to consume
165
+
166
+ Returns:
167
+ RateLimitInfo with result
168
+ """
169
+ key = self._get_redis_key(identifier, tier, date_key)
170
+
171
+ # Increment counter atomically
172
+ pipe = self.redis_client.pipeline()
173
+ pipe.incr(key)
174
+ pipe.ttl(key)
175
+ results = pipe.execute()
176
+
177
+ current_count = results[0]
178
+ current_ttl = results[1]
179
+
180
+ # Set TTL if this is a new key
181
+ if current_ttl == -1:
182
+ seconds_until_reset = self._seconds_until_reset()
183
+ self.redis_client.expire(key, seconds_until_reset)
184
+
185
+ # Calculate reset time
186
+ now = time.time()
187
+ reset_time = now + self._seconds_until_reset()
188
+
189
+ if current_count <= daily_limit:
190
+ # Request allowed
191
+ remaining = max(0, daily_limit - current_count)
192
+ return RateLimitInfo(
193
+ allowed=True,
194
+ remaining=remaining,
195
+ reset_time=reset_time,
196
+ retry_after=None
197
+ )
198
+ else:
199
+ # Rate limit exceeded
200
+ retry_after = self._seconds_until_reset()
201
+ return RateLimitInfo(
202
+ allowed=False,
203
+ remaining=0,
204
+ reset_time=reset_time,
205
+ retry_after=float(retry_after)
206
+ )
207
+
208
+ def _consume_memory(
209
+ self,
210
+ identifier: str,
211
+ tier: UserTier,
212
+ date_key: str,
213
+ daily_limit: int,
214
+ tokens: int
215
+ ) -> RateLimitInfo:
216
+ """Consume tokens using in-memory backend.
217
+
218
+ Args:
219
+ identifier: User identifier
220
+ tier: User tier
221
+ date_key: Current date key
222
+ daily_limit: Daily request limit for this tier
223
+ tokens: Tokens to consume
224
+
225
+ Returns:
226
+ RateLimitInfo with result
227
+ """
228
+ memory_key = f"{tier.value}:{identifier}"
229
+
230
+ with self._memory_lock:
231
+ # Get or initialize counter
232
+ if memory_key in self.memory_counters:
233
+ count, stored_date = self.memory_counters[memory_key]
234
+
235
+ # Reset if new day
236
+ if stored_date != date_key:
237
+ count = 0
238
+ else:
239
+ count = 0
240
+
241
+ # Increment counter
242
+ count += tokens
243
+ self.memory_counters[memory_key] = (count, date_key)
244
+
245
+ # Calculate reset time
246
+ reset_time = time.time() + self._seconds_until_reset()
247
+
248
+ if count <= daily_limit:
249
+ # Request allowed
250
+ remaining = max(0, daily_limit - count)
251
+ return RateLimitInfo(
252
+ allowed=True,
253
+ remaining=remaining,
254
+ reset_time=reset_time,
255
+ retry_after=None
256
+ )
257
+ else:
258
+ # Rate limit exceeded
259
+ retry_after = self._seconds_until_reset()
260
+ return RateLimitInfo(
261
+ allowed=False,
262
+ remaining=0,
263
+ reset_time=reset_time,
264
+ retry_after=float(retry_after)
265
+ )
266
+
267
+ def reset(self, identifier: str, tier: UserTier) -> None:
268
+ """Reset rate limit for identifier (admin/testing).
269
+
270
+ Args:
271
+ identifier: User identifier to reset
272
+ tier: User tier
273
+ """
274
+ date_key = self._get_date_key()
275
+
276
+ # Clear from Redis
277
+ if self.use_redis and self.redis_client:
278
+ try:
279
+ key = self._get_redis_key(identifier, tier, date_key)
280
+ self.redis_client.delete(key)
281
+ except Exception as e:
282
+ logger.error(f"Redis reset error: {e}")
283
+
284
+ # Clear from memory
285
+ memory_key = f"{tier.value}:{identifier}"
286
+ with self._memory_lock:
287
+ if memory_key in self.memory_counters:
288
+ del self.memory_counters[memory_key]
289
+
290
+ def get_current_count(self, identifier: str, tier: UserTier) -> int:
291
+ """Get current request count for identifier (debugging).
292
+
293
+ Args:
294
+ identifier: User identifier
295
+ tier: User tier
296
+
297
+ Returns:
298
+ Current request count for today
299
+ """
300
+ date_key = self._get_date_key()
301
+
302
+ # Try Redis first
303
+ if self.use_redis and self.redis_client:
304
+ try:
305
+ key = self._get_redis_key(identifier, tier, date_key)
306
+ count = self.redis_client.get(key)
307
+ return int(count) if count else 0
308
+ except Exception:
309
+ pass
310
+
311
+ # Fallback to memory
312
+ memory_key = f"{tier.value}:{identifier}"
313
+ with self._memory_lock:
314
+ if memory_key in self.memory_counters:
315
+ count, stored_date = self.memory_counters[memory_key]
316
+ if stored_date == date_key:
317
+ return count
318
+
319
+ return 0
320
+
321
+
322
+ class TieredFixedWindowLimiter:
323
+ """Multi-tier fixed window rate limiter.
324
+
325
+ Manages separate fixed window limiters for each user tier.
326
+ Provides a unified interface similar to TieredRateLimiter.
327
+ """
328
+
329
+ def __init__(
330
+ self,
331
+ tier_limits: Dict[UserTier, int],
332
+ redis_url: Optional[str] = None
333
+ ):
334
+ """Initialise tiered fixed window limiter.
335
+
336
+ Args:
337
+ tier_limits: Dictionary mapping UserTier to daily limit
338
+ redis_url: Optional Redis connection URL
339
+ """
340
+ self.limiters: Dict[UserTier, FixedWindowRateLimiter] = {}
341
+
342
+ # Create separate limiter for each tier
343
+ for tier, daily_limit in tier_limits.items():
344
+ self.limiters[tier] = FixedWindowRateLimiter(
345
+ tier_limits={tier: daily_limit},
346
+ redis_url=redis_url,
347
+ key_prefix=f"ratelimit_daily:{tier.value}"
348
+ )
349
+
350
+ logger.info(f"Tiered fixed window limiter initialised with {len(tier_limits)} tiers")
351
+
352
+ def consume(
353
+ self,
354
+ identifier: str,
355
+ tier: UserTier,
356
+ tokens: int = 1
357
+ ) -> RateLimitInfo:
358
+ """Consume tokens from the appropriate tier limiter.
359
+
360
+ Args:
361
+ identifier: User identifier
362
+ tier: User tier for quota lookup
363
+ tokens: Number of tokens to consume
364
+
365
+ Returns:
366
+ RateLimitInfo with consumption result
367
+ """
368
+ if tier not in self.limiters:
369
+ logger.warning(f"Unknown tier {tier}, using ANONYMOUS")
370
+ tier = UserTier.ANONYMOUS
371
+
372
+ return self.limiters[tier].consume(identifier, tier, tokens)
373
+
374
+ def reset(self, identifier: str, tier: UserTier) -> None:
375
+ """Reset rate limit for identifier.
376
+
377
+ Args:
378
+ identifier: User identifier to reset
379
+ tier: User tier
380
+ """
381
+ if tier in self.limiters:
382
+ self.limiters[tier].reset(identifier, tier)
383
+
384
+ def get_current_count(self, identifier: str, tier: UserTier) -> int:
385
+ """Get current request count for identifier.
386
+
387
+ Args:
388
+ identifier: User identifier
389
+ tier: User tier
390
+
391
+ Returns:
392
+ Current request count for today
393
+ """
394
+ if tier in self.limiters:
395
+ return self.limiters[tier].get_current_count(identifier, tier)
396
+ return 0
backend/rate_limiting/limiter.py CHANGED
@@ -555,25 +555,33 @@ class GradioRateLimitMiddleware:
555
  def check_rate_limit(
556
  self,
557
  request: Optional["gr.Request"] = None,
558
- tokens: int = 1
 
559
  ) -> RateLimitInfo:
560
  """Check rate limit for request.
561
 
562
  Args:
563
  request: Gradio request object
564
  tokens: Number of tokens to consume
 
565
 
566
  Returns:
567
  RateLimitInfo with result
568
  """
569
- identifier, tier = self.get_user_tier(request)
 
 
 
 
 
570
  return self.limiter.consume(identifier, tier, tokens)
571
 
572
  def enforce(
573
  self,
574
  request: Optional["gr.Request"] = None,
575
  tokens: int = 1,
576
- error_message: Optional[str] = None
 
577
  ) -> None:
578
  """Enforce rate limit, raise gr.Error if exceeded.
579
 
@@ -581,13 +589,14 @@ class GradioRateLimitMiddleware:
581
  request: Gradio request object
582
  tokens: Number of tokens to consume
583
  error_message: Custom error message (optional)
 
584
 
585
  Raises:
586
  gr.Error: If rate limit exceeded
587
  """
588
  import gradio as gr
589
 
590
- info = self.check_rate_limit(request, tokens)
591
 
592
  if not info.allowed:
593
  retry_seconds = int(info.retry_after) if info.retry_after else 60
 
555
  def check_rate_limit(
556
  self,
557
  request: Optional["gr.Request"] = None,
558
+ tokens: int = 1,
559
+ session_state: Optional[dict] = None
560
  ) -> RateLimitInfo:
561
  """Check rate limit for request.
562
 
563
  Args:
564
  request: Gradio request object
565
  tokens: Number of tokens to consume
566
+ session_state: Optional session state dict for authenticated users
567
 
568
  Returns:
569
  RateLimitInfo with result
570
  """
571
+ # If custom get_user_tier accepts session_state, pass it
572
+ try:
573
+ identifier, tier = self.get_user_tier(request, session_state)
574
+ except TypeError:
575
+ # Fallback for get_user_tier functions that don't accept session_state
576
+ identifier, tier = self.get_user_tier(request)
577
  return self.limiter.consume(identifier, tier, tokens)
578
 
579
  def enforce(
580
  self,
581
  request: Optional["gr.Request"] = None,
582
  tokens: int = 1,
583
+ error_message: Optional[str] = None,
584
+ session_state: Optional[dict] = None
585
  ) -> None:
586
  """Enforce rate limit, raise gr.Error if exceeded.
587
 
 
589
  request: Gradio request object
590
  tokens: Number of tokens to consume
591
  error_message: Custom error message (optional)
592
+ session_state: Optional session state dict for authenticated users
593
 
594
  Raises:
595
  gr.Error: If rate limit exceeded
596
  """
597
  import gradio as gr
598
 
599
+ info = self.check_rate_limit(request, tokens, session_state)
600
 
601
  if not info.allowed:
602
  retry_seconds = int(info.retry_after) if info.retry_after else 60