Spaces:
Paused
Paused
| import os | |
| import sys | |
| import traceback | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import io | |
| import os | |
| sys.path.insert( | |
| 0, os.path.abspath("../..") | |
| ) # Adds the parent directory to the system path | |
| import json | |
| import pytest | |
| import litellm | |
| from litellm import completion | |
| from litellm.llms.cohere.completion.transformation import CohereTextConfig | |
| def test_cohere_generate_api_completion(): | |
| try: | |
| from litellm.llms.custom_httpx.http_handler import HTTPHandler | |
| from unittest.mock import patch, MagicMock | |
| client = HTTPHandler() | |
| litellm.set_verbose = True | |
| messages = [ | |
| {"role": "system", "content": "You're a good bot"}, | |
| { | |
| "role": "user", | |
| "content": "Hey", | |
| }, | |
| ] | |
| with patch.object(client, "post") as mock_client: | |
| try: | |
| completion( | |
| model="cohere/command", | |
| messages=messages, | |
| max_tokens=10, | |
| client=client, | |
| ) | |
| except Exception as e: | |
| print(e) | |
| mock_client.assert_called_once() | |
| print("mock_client.call_args.kwargs", mock_client.call_args.kwargs) | |
| assert ( | |
| mock_client.call_args.kwargs["url"] | |
| == "https://api.cohere.ai/v1/generate" | |
| ) | |
| json_data = json.loads(mock_client.call_args.kwargs["data"]) | |
| assert json_data["model"] == "command" | |
| assert json_data["prompt"] == "You're a good bot Hey" | |
| assert json_data["max_tokens"] == 10 | |
| except Exception as e: | |
| pytest.fail(f"Error occurred: {e}") | |
| async def test_cohere_generate_api_stream(): | |
| try: | |
| litellm.set_verbose = True | |
| messages = [ | |
| {"role": "system", "content": "You're a good bot"}, | |
| { | |
| "role": "user", | |
| "content": "Hey", | |
| }, | |
| ] | |
| response = await litellm.acompletion( | |
| model="cohere/command", | |
| messages=messages, | |
| max_tokens=10, | |
| stream=True, | |
| ) | |
| print("async cohere stream response", response) | |
| async for chunk in response: | |
| print(chunk) | |
| except Exception as e: | |
| pytest.fail(f"Error occurred: {e}") | |
| def test_completion_cohere_stream_bad_key(): | |
| try: | |
| api_key = "bad-key" | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| { | |
| "role": "user", | |
| "content": "how does a court case get to the Supreme Court?", | |
| }, | |
| ] | |
| completion( | |
| model="command", | |
| messages=messages, | |
| stream=True, | |
| max_tokens=50, | |
| api_key=api_key, | |
| ) | |
| except litellm.AuthenticationError as e: | |
| pass | |
| except Exception as e: | |
| pytest.fail(f"Error occurred: {e}") | |
| def test_cohere_transform_request(): | |
| try: | |
| config = CohereTextConfig() | |
| messages = [ | |
| {"role": "system", "content": "You're a helpful bot"}, | |
| {"role": "user", "content": "Hello"}, | |
| ] | |
| optional_params = {"max_tokens": 10, "temperature": 0.7} | |
| headers = {} | |
| transformed_request = config.transform_request( | |
| model="command", | |
| messages=messages, | |
| optional_params=optional_params, | |
| litellm_params={}, | |
| headers=headers, | |
| ) | |
| print("transformed_request", json.dumps(transformed_request, indent=4)) | |
| assert transformed_request["model"] == "command" | |
| assert transformed_request["prompt"] == "You're a helpful bot Hello" | |
| assert transformed_request["max_tokens"] == 10 | |
| assert transformed_request["temperature"] == 0.7 | |
| except Exception as e: | |
| pytest.fail(f"Error occurred: {e}") | |
| def test_cohere_transform_request_with_tools(): | |
| try: | |
| config = CohereTextConfig() | |
| messages = [{"role": "user", "content": "What's the weather?"}] | |
| tools = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "get_weather", | |
| "description": "Get weather information", | |
| "parameters": { | |
| "type": "object", | |
| "properties": {"location": {"type": "string"}}, | |
| }, | |
| }, | |
| } | |
| ] | |
| optional_params = {"tools": tools} | |
| transformed_request = config.transform_request( | |
| model="command", | |
| messages=messages, | |
| optional_params=optional_params, | |
| litellm_params={}, | |
| headers={}, | |
| ) | |
| print("transformed_request", json.dumps(transformed_request, indent=4)) | |
| assert "tools" in transformed_request | |
| assert transformed_request["tools"] == {"tools": tools} | |
| except Exception as e: | |
| pytest.fail(f"Error occurred: {e}") | |
| def test_cohere_map_openai_params(): | |
| try: | |
| config = CohereTextConfig() | |
| openai_params = { | |
| "temperature": 0.7, | |
| "max_tokens": 100, | |
| "n": 2, | |
| "top_p": 0.9, | |
| "frequency_penalty": 0.5, | |
| "presence_penalty": 0.5, | |
| "stop": ["END"], | |
| "stream": True, | |
| } | |
| mapped_params = config.map_openai_params( | |
| non_default_params=openai_params, | |
| optional_params={}, | |
| model="command", | |
| drop_params=False, | |
| ) | |
| assert mapped_params["temperature"] == 0.7 | |
| assert mapped_params["max_tokens"] == 100 | |
| assert mapped_params["num_generations"] == 2 | |
| assert mapped_params["p"] == 0.9 | |
| assert mapped_params["frequency_penalty"] == 0.5 | |
| assert mapped_params["presence_penalty"] == 0.5 | |
| assert mapped_params["stop_sequences"] == ["END"] | |
| assert mapped_params["stream"] == True | |
| except Exception as e: | |
| pytest.fail(f"Error occurred: {e}") | |