Spaces:
Paused
Paused
| import asyncio | |
| import httpx | |
| import json | |
| import pytest | |
| import sys | |
| from typing import Any, Dict, List | |
| from unittest.mock import MagicMock, Mock, patch | |
| import os | |
| sys.path.insert( | |
| 0, os.path.abspath("../..") | |
| ) # Adds the parent directory to the system path | |
| import litellm | |
| from litellm import embedding | |
| from litellm.exceptions import BadRequestError | |
| from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler | |
| from litellm.utils import ( | |
| CustomStreamWrapper, | |
| get_supported_openai_params, | |
| get_optional_params, | |
| get_optional_params_embeddings, | |
| ) | |
| import requests | |
| import base64 | |
| # test_example.py | |
| from abc import ABC, abstractmethod | |
| url = "https://dummyimage.com/100/100/fff&text=Test+image" | |
| response = requests.get(url) | |
| file_data = response.content | |
| encoded_file = base64.b64encode(file_data).decode("utf-8") | |
| base64_image = f"data:image/png;base64,{encoded_file}" | |
| class BaseLLMEmbeddingTest(ABC): | |
| """ | |
| Abstract base test class that enforces a common test across all test classes. | |
| """ | |
| def get_base_embedding_call_args(self) -> dict: | |
| """Must return the base embedding call args""" | |
| pass | |
| def get_custom_llm_provider(self) -> litellm.LlmProviders: | |
| """Must return the custom llm provider""" | |
| pass | |
| async def test_basic_embedding(self, sync_mode): | |
| litellm.set_verbose = True | |
| embedding_call_args = self.get_base_embedding_call_args() | |
| if sync_mode is True: | |
| response = litellm.embedding( | |
| **embedding_call_args, | |
| input=["hello", "world"], | |
| ) | |
| print("embedding response: ", response) | |
| else: | |
| response = await litellm.aembedding( | |
| **embedding_call_args, | |
| input=["hello", "world"], | |
| ) | |
| print("async embedding response: ", response) | |
| from openai.types.create_embedding_response import CreateEmbeddingResponse | |
| CreateEmbeddingResponse.model_validate(response.model_dump()) | |
| def test_embedding_optional_params_max_retries(self): | |
| embedding_call_args = self.get_base_embedding_call_args() | |
| optional_params = get_optional_params_embeddings( | |
| **embedding_call_args, max_retries=20 | |
| ) | |
| assert optional_params["max_retries"] == 20 | |
| def test_image_embedding(self): | |
| litellm.set_verbose = True | |
| from litellm.utils import supports_embedding_image_input | |
| os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" | |
| litellm.model_cost = litellm.get_model_cost_map(url="") | |
| base_embedding_call_args = self.get_base_embedding_call_args() | |
| if not supports_embedding_image_input(base_embedding_call_args["model"], None): | |
| print("Model does not support embedding image input") | |
| pytest.skip("Model does not support embedding image input") | |
| embedding(**base_embedding_call_args, input=[base64_image]) | |