arubaDev commited on
Commit
ae1a461
·
verified ·
1 Parent(s): c0c9f2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -10,7 +10,7 @@ from datasets import load_dataset
10
  # ---------------------------
11
  MODELS = {
12
  "Meta LLaMA 3.1 (8B Instruct)": "meta-llama/Llama-3.1-8B-Instruct",
13
- "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3",
14
  }
15
 
16
  DATASETS = ["The Stack", "CodeXGLUE"]
@@ -123,18 +123,18 @@ def build_api_messages(session_id: int, system_message: str):
123
  msgs.extend(get_messages(session_id))
124
  return msgs
125
 
126
- # def get_client(model_choice: str):
127
- # return InferenceClient(MODELS.get(model_choice, list(MODELS.values())[0]), token=HF_TOKEN)
128
- def get_client(model_choice):
129
- # Normalize model_choice -> must be a string
130
- if isinstance(model_choice, list):
131
- model_choice = model_choice[0] if model_choice else None
132
 
133
- if not model_choice:
134
- model_choice = list(MODELS.keys())[0]
135
 
136
- model_id = MODELS.get(model_choice, list(MODELS.values())[0])
137
- return InferenceClient(model_id, token=HF_TOKEN)
138
 
139
 
140
  def load_dataset_by_name(name: str):
 
10
  # ---------------------------
11
  MODELS = {
12
  "Meta LLaMA 3.1 (8B Instruct)": "meta-llama/Llama-3.1-8B-Instruct",
13
+ "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.2",
14
  }
15
 
16
  DATASETS = ["The Stack", "CodeXGLUE"]
 
123
  msgs.extend(get_messages(session_id))
124
  return msgs
125
 
126
+ def get_client(model_choice: str):
127
+ return InferenceClient(MODELS.get(model_choice, list(MODELS.values())[0]), token=HF_TOKEN)
128
+ # def get_client(model_choice):
129
+ # # Normalize model_choice -> must be a string
130
+ # if isinstance(model_choice, list):
131
+ # model_choice = model_choice[0] if model_choice else None
132
 
133
+ # if not model_choice:
134
+ # model_choice = list(MODELS.keys())[0]
135
 
136
+ # model_id = MODELS.get(model_choice, list(MODELS.values())[0])
137
+ # return InferenceClient(model_id, token=HF_TOKEN)
138
 
139
 
140
  def load_dataset_by_name(name: str):