Spaces:
Paused
Paused
| import io | |
| import os | |
| import pathlib | |
| import ssl | |
| import sys | |
| from unittest.mock import MagicMock, patch | |
| import httpx | |
| import pytest | |
| from aiohttp import ClientSession, TCPConnector | |
| sys.path.insert( | |
| 0, os.path.abspath("../../../..") | |
| ) # Adds the parent directory to the system path | |
| import litellm | |
| from litellm.llms.custom_httpx.aiohttp_transport import LiteLLMAiohttpTransport | |
| from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler | |
| async def test_ssl_security_level(monkeypatch): | |
| # Set environment variable for SSL security level | |
| monkeypatch.setenv("SSL_SECURITY_LEVEL", "DEFAULT@SECLEVEL=1") | |
| # Create async client with SSL verification disabled to isolate SSL context testing | |
| client = AsyncHTTPHandler(ssl_verify=False) | |
| # Get the transport (should be LiteLLMAiohttpTransport) | |
| transport = client.client._transport | |
| # Get the aiohttp ClientSession | |
| client_session = transport._get_valid_client_session() | |
| # Get the connector from the session | |
| connector = client_session.connector | |
| # Get the SSL context from the connector | |
| ssl_context = connector._ssl | |
| print("ssl_context", ssl_context) | |
| # Verify that the SSL context exists and has the correct cipher string | |
| assert isinstance(ssl_context, ssl.SSLContext) | |
| # Optionally, check the ciphers string if needed | |
| # assert "DEFAULT@SECLEVEL=1" in ssl_context.get_ciphers() | |
| async def test_force_ipv4_transport(): | |
| """Test transport creation with force_ipv4 enabled""" | |
| litellm.force_ipv4 = True | |
| litellm.disable_aiohttp_transport = True | |
| transport = AsyncHTTPHandler._create_async_transport() | |
| # Should get an AsyncHTTPTransport | |
| assert isinstance(transport, httpx.AsyncHTTPTransport) | |
| # Verify IPv4 configuration through a request | |
| client = httpx.AsyncClient(transport=transport) | |
| try: | |
| response = await client.get("http://example.com") | |
| assert response.status_code == 200 | |
| finally: | |
| await client.aclose() | |
| async def test_ssl_context_transport(): | |
| """Test transport creation with SSL context""" | |
| # Create a test SSL context | |
| ssl_context = ssl.create_default_context() | |
| transport = AsyncHTTPHandler._create_async_transport(ssl_context=ssl_context) | |
| assert transport is not None | |
| if isinstance(transport, LiteLLMAiohttpTransport): | |
| # Get the client session and verify SSL context is passed through | |
| client_session = transport._get_valid_client_session() | |
| assert isinstance(client_session, ClientSession) | |
| assert isinstance(client_session.connector, TCPConnector) | |
| # Verify the connector has SSL context set by checking if it's using SSL | |
| assert client_session.connector._ssl is not None | |
| async def test_aiohttp_disabled_transport(): | |
| """Test transport creation with aiohttp disabled""" | |
| litellm.disable_aiohttp_transport = True | |
| litellm.force_ipv4 = False | |
| transport = AsyncHTTPHandler._create_async_transport() | |
| # Should get None when both aiohttp is disabled and force_ipv4 is False | |
| assert transport is None | |
| async def test_ssl_verification_with_aiohttp_transport(): | |
| """ | |
| Test aiohttp respects ssl_verify=False | |
| We validate that the ssl settings for a litellm transport match what a ssl verify=False aiohttp client would have. | |
| """ | |
| import aiohttp | |
| # Create a test SSL context | |
| litellm_async_client = AsyncHTTPHandler(ssl_verify=False) | |
| transport_connector = ( | |
| litellm_async_client.client._transport._get_valid_client_session().connector | |
| ) | |
| print("transport_connector", transport_connector) | |
| print("transport_connector._ssl", transport_connector._ssl) | |
| aiohttp_session = aiohttp.ClientSession( | |
| connector=aiohttp.TCPConnector(verify_ssl=False) | |
| ) | |
| print("aiohttp_session", aiohttp_session) | |
| print("aiohttp_session._ssl", aiohttp_session.connector._ssl) | |
| # assert both litellm transport and aiohttp session have ssl_verify=False | |
| assert transport_connector._ssl == aiohttp_session.connector._ssl | |
| async def test_disable_aiohttp_trust_env_with_env_variable(monkeypatch): | |
| """Test aiohttp transport respects DISABLE_AIOHTTP_TRUST_ENV environment variable""" | |
| # Set environment variable to disable trust env | |
| monkeypatch.setenv("DISABLE_AIOHTTP_TRUST_ENV", "True") | |
| # Ensure aiohttp transport is enabled | |
| litellm.disable_aiohttp_transport = False | |
| # Create async client | |
| client = AsyncHTTPHandler() | |
| # Get the transport (should be LiteLLMAiohttpTransport) | |
| transport = client.client._transport | |
| assert isinstance(transport, LiteLLMAiohttpTransport) | |
| # Get the aiohttp ClientSession | |
| client_session = transport._get_valid_client_session() | |
| # Verify that trust_env is False when DISABLE_AIOHTTP_TRUST_ENV is True | |
| assert client_session._trust_env is False | |
| async def test_disable_aiohttp_trust_env_with_litellm_setting(): | |
| """Test aiohttp transport respects litellm.disable_aiohttp_trust_env setting""" | |
| # Set litellm setting to disable trust env | |
| litellm.disable_aiohttp_trust_env = True | |
| # Ensure aiohttp transport is enabled | |
| litellm.disable_aiohttp_transport = False | |
| # Create async client | |
| client = AsyncHTTPHandler() | |
| # Get the transport (should be LiteLLMAiohttpTransport) | |
| transport = client.client._transport | |
| assert isinstance(transport, LiteLLMAiohttpTransport) | |
| # Get the aiohttp ClientSession | |
| client_session = transport._get_valid_client_session() | |
| # Verify that trust_env is False when litellm.disable_aiohttp_trust_env is True | |
| assert client_session._trust_env is False | |
| async def test_enable_aiohttp_trust_env_default(): | |
| """Test aiohttp transport enables trust_env by default""" | |
| # Ensure both settings are disabled/default | |
| litellm.disable_aiohttp_trust_env = False | |
| litellm.disable_aiohttp_transport = False | |
| # Create async client | |
| client = AsyncHTTPHandler() | |
| # Get the transport (should be LiteLLMAiohttpTransport) | |
| transport = client.client._transport | |
| assert isinstance(transport, LiteLLMAiohttpTransport) | |
| # Get the aiohttp ClientSession | |
| client_session = transport._get_valid_client_session() | |
| # Verify that trust_env is True by default | |
| assert client_session._trust_env is True | |