Update app.py
Browse files
app.py
CHANGED
|
@@ -17,8 +17,6 @@ from sentence_transformers import SentenceTransformer
|
|
| 17 |
from scorer import question_scorer
|
| 18 |
from content import format_error, format_warning, format_log, TITLE, INTRODUCTION_TEXT, SUBMISSION_TEXT, CITATION_BUTTON_LABEL, CITATION_BUTTON_TEXT, model_hyperlink
|
| 19 |
|
| 20 |
-
similarity_model = SentenceTransformer("stsb-mpnet-base-v2")
|
| 21 |
-
|
| 22 |
TOKEN = os.environ.get("TOKEN", None)
|
| 23 |
|
| 24 |
OWNER="Blanca"
|
|
@@ -29,9 +27,16 @@ SUBMISSION_DATASET_PUBLIC = f"{OWNER}/submissions_public"
|
|
| 29 |
#CONTACT_DATASET = f"{OWNER}/contact_info"
|
| 30 |
RESULTS_DATASET = f"{OWNER}/results_public"
|
| 31 |
LEADERBOARD_PATH = f"HiTZ/Critical_Questions_Leaderboard"
|
| 32 |
-
METRIC = 'similarity'
|
| 33 |
api = HfApi()
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
YEAR_VERSION = "2025"
|
| 36 |
ref_scores_len = {"test": 34}
|
| 37 |
#ref_level_len = {"validation": {1: 53, 2: 86, 3: 26}, "test": {1: 93, 2: 159, 3: 49}}
|
|
@@ -82,6 +87,61 @@ def restart_space():
|
|
| 82 |
|
| 83 |
TYPES = ["markdown", "number", "number", "number", "number", "str", "str", "str"]
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
def add_new_eval(
|
| 86 |
model: str,
|
| 87 |
model_family: str,
|
|
@@ -165,7 +225,6 @@ def add_new_eval(
|
|
| 165 |
reference_set = [row['cq'] for row in references[indx]]
|
| 166 |
#print(reference_set, flush=True)
|
| 167 |
for cq in line['cqs']:
|
| 168 |
-
# TODO: compare to each reference and get a value
|
| 169 |
cq_text = cq['cq']
|
| 170 |
#print(cq_text, flush=True)
|
| 171 |
|
|
@@ -180,10 +239,26 @@ def add_new_eval(
|
|
| 180 |
# make sure the similarity of the winning reference sentence is at least 0.65
|
| 181 |
if sims[winner] > 0.65:
|
| 182 |
label = references[indx][winner]['label']
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
print(indx, score, flush=True)
|
| 188 |
#return format_error(score)
|
| 189 |
|
|
|
|
| 17 |
from scorer import question_scorer
|
| 18 |
from content import format_error, format_warning, format_log, TITLE, INTRODUCTION_TEXT, SUBMISSION_TEXT, CITATION_BUTTON_LABEL, CITATION_BUTTON_TEXT, model_hyperlink
|
| 19 |
|
|
|
|
|
|
|
| 20 |
TOKEN = os.environ.get("TOKEN", None)
|
| 21 |
|
| 22 |
OWNER="Blanca"
|
|
|
|
| 27 |
#CONTACT_DATASET = f"{OWNER}/contact_info"
|
| 28 |
RESULTS_DATASET = f"{OWNER}/results_public"
|
| 29 |
LEADERBOARD_PATH = f"HiTZ/Critical_Questions_Leaderboard"
|
| 30 |
+
METRIC = 'similarity' # 'gemma'
|
| 31 |
api = HfApi()
|
| 32 |
|
| 33 |
+
if METRIC == 'similarity':
|
| 34 |
+
similarity_model = SentenceTransformer("stsb-mpnet-base-v2")
|
| 35 |
+
|
| 36 |
+
if METRIC == 'gemma':
|
| 37 |
+
model = AutoModelForCausalLM.from_pretrained('google/gemma-2-9b-it', device_map="auto", attn_implementation='eager')
|
| 38 |
+
tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-9b-it')
|
| 39 |
+
|
| 40 |
YEAR_VERSION = "2025"
|
| 41 |
ref_scores_len = {"test": 34}
|
| 42 |
#ref_level_len = {"validation": {1: 53, 2: 86, 3: 26}, "test": {1: 93, 2: 159, 3: 49}}
|
|
|
|
| 87 |
|
| 88 |
TYPES = ["markdown", "number", "number", "number", "number", "str", "str", "str"]
|
| 89 |
|
| 90 |
+
|
| 91 |
+
def run_model(model, tokenizer, prompt):
|
| 92 |
+
chat = [{"role": "user", "content": prompt}]
|
| 93 |
+
chat_formated = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
| 94 |
+
#print(chat_formated, flush=True)
|
| 95 |
+
inputs = tokenizer(chat_formated, return_tensors="pt")
|
| 96 |
+
|
| 97 |
+
inputs = inputs.to('cuda')
|
| 98 |
+
|
| 99 |
+
generated_ids = model.generate(**inputs, max_new_tokens=512)
|
| 100 |
+
|
| 101 |
+
#generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)] # this does not work for Gemma
|
| 102 |
+
|
| 103 |
+
out = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 104 |
+
|
| 105 |
+
#print(out, flush=True)
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
output = out.split('model\n')[1].replace('\n', '')
|
| 109 |
+
except IndexError:
|
| 110 |
+
print('EVAL ERROR: '+output, flush=True)
|
| 111 |
+
|
| 112 |
+
#import pdb; pdb.set_trace()
|
| 113 |
+
output = output.strip()
|
| 114 |
+
|
| 115 |
+
return output
|
| 116 |
+
|
| 117 |
+
def get_prompts(cq, references):
|
| 118 |
+
return {
|
| 119 |
+
'compare': f"""You will be given a set of reference questions, each with an identifying ID, and a newly generated question. Your task is to determine if any of the reference questions are asking for the same information as the new question.
|
| 120 |
+
|
| 121 |
+
Here is the set of reference questions with their IDs:
|
| 122 |
+
|
| 123 |
+
<reference_questions>
|
| 124 |
+
{references}
|
| 125 |
+
</reference_questions>
|
| 126 |
+
|
| 127 |
+
Here is the newly generated question:
|
| 128 |
+
|
| 129 |
+
<new_question>
|
| 130 |
+
{cq}
|
| 131 |
+
</new_question>
|
| 132 |
+
|
| 133 |
+
Compare the new question to each of the reference questions. Look for questions that are asking for the same information, even if they are worded differently. Consider the core meaning and intent of each question, not just the exact wording.
|
| 134 |
+
|
| 135 |
+
If you find a reference question that is asking for the same information as the new question, output only the ID of that reference question.
|
| 136 |
+
|
| 137 |
+
If none of the reference questions are asking for the same information as the new question, output exactly 'Similar reference not found.' (without quotes).
|
| 138 |
+
|
| 139 |
+
Your final output should consist of only one of the following:
|
| 140 |
+
1. The ID of the most similar reference question
|
| 141 |
+
2. The exact phrase 'Similar reference not found.'
|
| 142 |
+
|
| 143 |
+
Do not include any explanation, reasoning, or additional text in your output."""}
|
| 144 |
+
|
| 145 |
def add_new_eval(
|
| 146 |
model: str,
|
| 147 |
model_family: str,
|
|
|
|
| 225 |
reference_set = [row['cq'] for row in references[indx]]
|
| 226 |
#print(reference_set, flush=True)
|
| 227 |
for cq in line['cqs']:
|
|
|
|
| 228 |
cq_text = cq['cq']
|
| 229 |
#print(cq_text, flush=True)
|
| 230 |
|
|
|
|
| 239 |
# make sure the similarity of the winning reference sentence is at least 0.65
|
| 240 |
if sims[winner] > 0.65:
|
| 241 |
label = references[indx][winner]['label']
|
| 242 |
+
else:
|
| 243 |
+
label = 'not_able_to_evaluate'
|
| 244 |
+
|
| 245 |
+
if METRIC == 'gemma':
|
| 246 |
+
prompts = get_prompts(cq_text, '\n'.join(reference_set))
|
| 247 |
+
winner = run_model(model, tokenizer, prompts['compare'])
|
| 248 |
+
try: # here make sure the output is the id of a reference cq
|
| 249 |
+
if winner.strip() != 'Similar reference not found.':
|
| 250 |
+
label = references[index][int(winner)]['label']
|
| 251 |
+
else:
|
| 252 |
+
label = 'not_able_to_evaluate'
|
| 253 |
+
except IndexError:
|
| 254 |
+
label = 'evaluation_issue'
|
| 255 |
+
except ValueError:
|
| 256 |
+
label = 'evaluation_issue'
|
| 257 |
+
|
| 258 |
+
print(label, flush=True)
|
| 259 |
+
if label == 'Useful':
|
| 260 |
+
score += 1/3
|
| 261 |
+
|
| 262 |
print(indx, score, flush=True)
|
| 263 |
#return format_error(score)
|
| 264 |
|