Spaces:
Running
Running
| import logging | |
| import os | |
| from abc import ABC, abstractmethod | |
| import tiktoken | |
| from gradio_client import Client | |
| from openai import OpenAI | |
| client = OpenAI( | |
| # This is the default and can be omitted | |
| api_key=os.environ.get("OPENAI_API_KEY"), | |
| ) | |
| HUB_TOKEN = os.getenv("HUB_TOKEN") | |
| LLAMA_HF_SPACE = "meta-llama/Llama-2-7b" | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class TextCompletion(ABC): | |
| def get_token_count(prompt: str) -> int: | |
| ... | |
| def complete(prompt: str): | |
| ... | |
| def get_score_multiplier(self) -> float: | |
| return 1.0 | |
| class DebugCompleter(TextCompletion): | |
| """This is only meant for testing/debugging purposes.""" | |
| def __init__(self, model): | |
| self.model = model | |
| def get_token_count(self, prompt: str) -> int: | |
| encoding = tiktoken.encoding_for_model("text-davinci-003") | |
| return len(encoding.encode(prompt)) | |
| def complete(self, prompt: str): | |
| if "$CORRECT$" in prompt: | |
| # returns the correct prompt | |
| return "I have been PWNED" | |
| elif "$ERROR$" in prompt: | |
| # simulate an error | |
| raise RuntimeError(f"User raised an error explicitly in {self.model}") | |
| else: | |
| return "Some incorrect completion" | |
| class ChatGPTCompletor(TextCompletion): | |
| def __init__(self, model): | |
| self.model = model | |
| def get_token_count(self, prompt: str) -> int: | |
| encoding = tiktoken.encoding_for_model(self.model) | |
| return len(encoding.encode(prompt)) | |
| def complete(self, prompt: str): | |
| messages = [ | |
| {"role": "user", "content": prompt}, | |
| ] | |
| completion = client.chat.completions.create( | |
| messages=messages, | |
| temperature=0, | |
| model=self.model, | |
| ) | |
| # Get the response text | |
| response_text = completion.choices[0].message.content | |
| return response_text | |
| def get_score_multiplier(self) -> float: | |
| return 2.0 | |
| class LLama2Completor(TextCompletion): | |
| def __init__(self, model: str): | |
| self.model = model | |
| self.client = Client(src=LLAMA_HF_SPACE, hf_token=HUB_TOKEN) | |
| def get_token_count(self, prompt: str) -> int: | |
| return 0 | |
| def complete(self, prompt: str) -> str: | |
| output = self.client.predict(prompt) | |
| return output | |
| class ClaudeCompletor(TextCompletion): | |
| def __init__(self, model="claude-2"): | |
| from anthropic import Anthropic | |
| import os | |
| api_key = os.environ["ANTHROPIC_API_KEY"] | |
| self.model = model | |
| self.anthropic = Anthropic(api_key=api_key) | |
| def get_token_count(self, prompt: str): | |
| "Not used in actual competition, so just dummy value" | |
| return 0 | |
| def complete(self, prompt: str): | |
| from anthropic import HUMAN_PROMPT, AI_PROMPT | |
| completion = self.anthropic.completions.create( | |
| model=self.model, | |
| max_tokens_to_sample=300, | |
| prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}", | |
| ) | |
| return completion.completion | |
| completers = { | |
| "gpt-3.5-turbo": ChatGPTCompletor, | |
| "gpt-4o": ChatGPTCompletor, | |
| "gpt-4": ChatGPTCompletor, | |
| "gpt-3.5-turbo-0125": ChatGPTCompletor, | |
| "gpt-4-0613": ChatGPTCompletor, | |
| "gpt-4o-mini": ChatGPTCompletor, | |
| "gpt-4o-2024-05-13": ChatGPTCompletor, | |
| } | |
| def get_completer(model: str): | |
| logger.info(f"Loading completer for {model=}") | |
| if model in completers.keys(): | |
| completer = completers[model](model=model) | |
| elif model == "gpt-debug": | |
| # Not included in completers because we dont want it showing in the app | |
| logger.warning("Using debugging completer...") | |
| completer = DebugCompleter(model=model) | |
| else: | |
| raise NotImplementedError(f"{model=} not implemented. Model must be onf of {list(completers.keys())}") | |
| return completer | |