Spaces:
Paused
Paused
| import asyncio | |
| import importlib | |
| import json | |
| import os | |
| import socket | |
| import subprocess | |
| import sys | |
| from unittest import mock | |
| from unittest.mock import AsyncMock, MagicMock, mock_open, patch | |
| import click | |
| import httpx | |
| import pytest | |
| import yaml | |
| from fastapi import FastAPI | |
| from fastapi.testclient import TestClient | |
| sys.path.insert( | |
| 0, os.path.abspath("../../..") | |
| ) # Adds the parent directory to the system-path | |
| import litellm | |
| from litellm.proxy.proxy_server import app, initialize | |
| example_embedding_result = { | |
| "object": "list", | |
| "data": [ | |
| { | |
| "object": "embedding", | |
| "index": 0, | |
| "embedding": [ | |
| -0.006929283495992422, | |
| -0.005336422007530928, | |
| -4.547132266452536e-05, | |
| -0.024047505110502243, | |
| -0.006929283495992422, | |
| -0.005336422007530928, | |
| -4.547132266452536e-05, | |
| -0.024047505110502243, | |
| -0.006929283495992422, | |
| -0.005336422007530928, | |
| -4.547132266452536e-05, | |
| -0.024047505110502243, | |
| ], | |
| } | |
| ], | |
| "model": "text-embedding-3-small", | |
| "usage": {"prompt_tokens": 5, "total_tokens": 5}, | |
| } | |
| def mock_patch_aembedding(): | |
| return mock.patch( | |
| "litellm.proxy.proxy_server.llm_router.aembedding", | |
| return_value=example_embedding_result, | |
| ) | |
| def client_no_auth(): | |
| # Assuming litellm.proxy.proxy_server is an object | |
| from litellm.proxy.proxy_server import cleanup_router_config_variables | |
| cleanup_router_config_variables() | |
| filepath = os.path.dirname(os.path.abspath(__file__)) | |
| config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml" | |
| # initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables | |
| asyncio.run(initialize(config=config_fp, debug=True)) | |
| return TestClient(app) | |
| async def test_initialize_scheduled_jobs_credentials(monkeypatch): | |
| """ | |
| Test that get_credentials is only called when store_model_in_db is True | |
| """ | |
| monkeypatch.delenv("DISABLE_PRISMA_SCHEMA_UPDATE", raising=False) | |
| monkeypatch.delenv("STORE_MODEL_IN_DB", raising=False) | |
| from litellm.proxy.proxy_server import ProxyStartupEvent | |
| from litellm.proxy.utils import ProxyLogging | |
| # Mock dependencies | |
| mock_prisma_client = MagicMock() | |
| mock_proxy_logging = MagicMock(spec=ProxyLogging) | |
| mock_proxy_logging.slack_alerting_instance = MagicMock() | |
| mock_proxy_config = AsyncMock() | |
| with patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config), patch( | |
| "litellm.proxy.proxy_server.store_model_in_db", False | |
| ): # set store_model_in_db to False | |
| # Test when store_model_in_db is False | |
| await ProxyStartupEvent.initialize_scheduled_background_jobs( | |
| general_settings={}, | |
| prisma_client=mock_prisma_client, | |
| proxy_budget_rescheduler_min_time=1, | |
| proxy_budget_rescheduler_max_time=2, | |
| proxy_batch_write_at=5, | |
| proxy_logging_obj=mock_proxy_logging, | |
| ) | |
| # Verify get_credentials was not called | |
| mock_proxy_config.get_credentials.assert_not_called() | |
| # Now test with store_model_in_db = True | |
| with patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config), patch( | |
| "litellm.proxy.proxy_server.store_model_in_db", True | |
| ), patch("litellm.proxy.proxy_server.get_secret_bool", return_value=True): | |
| await ProxyStartupEvent.initialize_scheduled_background_jobs( | |
| general_settings={}, | |
| prisma_client=mock_prisma_client, | |
| proxy_budget_rescheduler_min_time=1, | |
| proxy_budget_rescheduler_max_time=2, | |
| proxy_batch_write_at=5, | |
| proxy_logging_obj=mock_proxy_logging, | |
| ) | |
| # Verify get_credentials was called both directly and scheduled | |
| assert mock_proxy_config.get_credentials.call_count == 1 # Direct call | |
| # Verify a scheduled job was added for get_credentials | |
| mock_scheduler_calls = [ | |
| call[0] for call in mock_proxy_config.get_credentials.mock_calls | |
| ] | |
| assert len(mock_scheduler_calls) > 0 | |
| # Mock Prisma | |
| class MockPrisma: | |
| def __init__(self, database_url=None, proxy_logging_obj=None, http_client=None): | |
| self.database_url = database_url | |
| self.proxy_logging_obj = proxy_logging_obj | |
| self.http_client = http_client | |
| async def connect(self): | |
| pass | |
| async def disconnect(self): | |
| pass | |
| mock_prisma = MockPrisma() | |
| async def test_aaaproxy_startup_master_key(mock_prisma, monkeypatch, tmp_path): | |
| """ | |
| Test that master_key is correctly loaded from either config.yaml or environment variables | |
| """ | |
| import yaml | |
| from fastapi import FastAPI | |
| # Import happens here - this is when the module probably reads the config path | |
| from litellm.proxy.proxy_server import proxy_startup_event | |
| # Mock the Prisma import | |
| monkeypatch.setattr("litellm.proxy.proxy_server.PrismaClient", MockPrisma) | |
| # Create test app | |
| app = FastAPI() | |
| # Test Case 1: Master key from config.yaml | |
| test_master_key = "sk-12345" | |
| test_config = {"general_settings": {"master_key": test_master_key}} | |
| # Create a temporary config file | |
| config_path = tmp_path / "config.yaml" | |
| with open(config_path, "w") as f: | |
| yaml.dump(test_config, f) | |
| print(f"SET ENV VARIABLE - CONFIG_FILE_PATH, str(config_path): {str(config_path)}") | |
| # Second setting of CONFIG_FILE_PATH to a different value | |
| monkeypatch.setenv("CONFIG_FILE_PATH", str(config_path)) | |
| print(f"config_path: {config_path}") | |
| print(f"os.getenv('CONFIG_FILE_PATH'): {os.getenv('CONFIG_FILE_PATH')}") | |
| async with proxy_startup_event(app): | |
| from litellm.proxy.proxy_server import master_key | |
| assert master_key == test_master_key | |
| # Test Case 2: Master key from environment variable | |
| test_env_master_key = "sk-67890" | |
| # Create empty config | |
| empty_config = {"general_settings": {}} | |
| with open(config_path, "w") as f: | |
| yaml.dump(empty_config, f) | |
| monkeypatch.setenv("LITELLM_MASTER_KEY", test_env_master_key) | |
| print("test_env_master_key: {}".format(test_env_master_key)) | |
| async with proxy_startup_event(app): | |
| from litellm.proxy.proxy_server import master_key | |
| assert master_key == test_env_master_key | |
| # Test Case 3: Master key with os.environ prefix | |
| test_resolved_key = "sk-resolved-key" | |
| test_config_with_prefix = { | |
| "general_settings": {"master_key": "os.environ/CUSTOM_MASTER_KEY"} | |
| } | |
| # Create config with os.environ prefix | |
| with open(config_path, "w") as f: | |
| yaml.dump(test_config_with_prefix, f) | |
| monkeypatch.setenv("CUSTOM_MASTER_KEY", test_resolved_key) | |
| async with proxy_startup_event(app): | |
| from litellm.proxy.proxy_server import master_key | |
| assert master_key == test_resolved_key | |
| def test_team_info_masking(): | |
| """ | |
| Test that sensitive team information is properly masked | |
| Ref: https://huntr.com/bounties/661b388a-44d8-4ad5-862b-4dc5b80be30a | |
| """ | |
| from litellm.proxy.proxy_server import ProxyConfig | |
| proxy_config = ProxyConfig() | |
| # Test team object with sensitive data | |
| team1_info = { | |
| "success_callback": "['langfuse', 's3']", | |
| "langfuse_secret": "secret-test-key", | |
| "langfuse_public_key": "public-test-key", | |
| } | |
| with pytest.raises(Exception) as exc_info: | |
| proxy_config._get_team_config( | |
| team_id="test_dev", | |
| all_teams_config=[team1_info], | |
| ) | |
| print("Got exception: {}".format(exc_info.value)) | |
| assert "secret-test-key" not in str(exc_info.value) | |
| assert "public-test-key" not in str(exc_info.value) | |
| def test_embedding_input_array_of_tokens(mock_aembedding, client_no_auth): | |
| """ | |
| Test to bypass decoding input as array of tokens for selected providers | |
| Ref: https://github.com/BerriAI/litellm/issues/10113 | |
| """ | |
| try: | |
| test_data = { | |
| "model": "vllm_embed_model", | |
| "input": [[2046, 13269, 158208]], | |
| } | |
| response = client_no_auth.post("/v1/embeddings", json=test_data) | |
| mock_aembedding.assert_called_once_with( | |
| model="vllm_embed_model", | |
| input=[[2046, 13269, 158208]], | |
| metadata=mock.ANY, | |
| proxy_server_request=mock.ANY, | |
| ) | |
| assert response.status_code == 200 | |
| result = response.json() | |
| print(len(result["data"][0]["embedding"])) | |
| assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so | |
| except Exception as e: | |
| pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") | |