Blanca commited on
Commit
6553d2e
·
verified ·
1 Parent(s): 0a759e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -8
app.py CHANGED
@@ -17,8 +17,6 @@ from sentence_transformers import SentenceTransformer
17
  from scorer import question_scorer
18
  from content import format_error, format_warning, format_log, TITLE, INTRODUCTION_TEXT, SUBMISSION_TEXT, CITATION_BUTTON_LABEL, CITATION_BUTTON_TEXT, model_hyperlink
19
 
20
- similarity_model = SentenceTransformer("stsb-mpnet-base-v2")
21
-
22
  TOKEN = os.environ.get("TOKEN", None)
23
 
24
  OWNER="Blanca"
@@ -29,9 +27,16 @@ SUBMISSION_DATASET_PUBLIC = f"{OWNER}/submissions_public"
29
  #CONTACT_DATASET = f"{OWNER}/contact_info"
30
  RESULTS_DATASET = f"{OWNER}/results_public"
31
  LEADERBOARD_PATH = f"HiTZ/Critical_Questions_Leaderboard"
32
- METRIC = 'similarity'
33
  api = HfApi()
34
 
 
 
 
 
 
 
 
35
  YEAR_VERSION = "2025"
36
  ref_scores_len = {"test": 34}
37
  #ref_level_len = {"validation": {1: 53, 2: 86, 3: 26}, "test": {1: 93, 2: 159, 3: 49}}
@@ -82,6 +87,61 @@ def restart_space():
82
 
83
  TYPES = ["markdown", "number", "number", "number", "number", "str", "str", "str"]
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def add_new_eval(
86
  model: str,
87
  model_family: str,
@@ -165,7 +225,6 @@ def add_new_eval(
165
  reference_set = [row['cq'] for row in references[indx]]
166
  #print(reference_set, flush=True)
167
  for cq in line['cqs']:
168
- # TODO: compare to each reference and get a value
169
  cq_text = cq['cq']
170
  #print(cq_text, flush=True)
171
 
@@ -180,10 +239,26 @@ def add_new_eval(
180
  # make sure the similarity of the winning reference sentence is at least 0.65
181
  if sims[winner] > 0.65:
182
  label = references[indx][winner]['label']
183
- if label == 'Useful':
184
- score += 1/3
185
- #else:
186
- # label = 'not_able_to_evaluate'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  print(indx, score, flush=True)
188
  #return format_error(score)
189
 
 
17
  from scorer import question_scorer
18
  from content import format_error, format_warning, format_log, TITLE, INTRODUCTION_TEXT, SUBMISSION_TEXT, CITATION_BUTTON_LABEL, CITATION_BUTTON_TEXT, model_hyperlink
19
 
 
 
20
  TOKEN = os.environ.get("TOKEN", None)
21
 
22
  OWNER="Blanca"
 
27
  #CONTACT_DATASET = f"{OWNER}/contact_info"
28
  RESULTS_DATASET = f"{OWNER}/results_public"
29
  LEADERBOARD_PATH = f"HiTZ/Critical_Questions_Leaderboard"
30
+ METRIC = 'similarity' # 'gemma'
31
  api = HfApi()
32
 
33
+ if METRIC == 'similarity':
34
+ similarity_model = SentenceTransformer("stsb-mpnet-base-v2")
35
+
36
+ if METRIC == 'gemma':
37
+ model = AutoModelForCausalLM.from_pretrained('google/gemma-2-9b-it', device_map="auto", attn_implementation='eager')
38
+ tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-9b-it')
39
+
40
  YEAR_VERSION = "2025"
41
  ref_scores_len = {"test": 34}
42
  #ref_level_len = {"validation": {1: 53, 2: 86, 3: 26}, "test": {1: 93, 2: 159, 3: 49}}
 
87
 
88
  TYPES = ["markdown", "number", "number", "number", "number", "str", "str", "str"]
89
 
90
+
91
+ def run_model(model, tokenizer, prompt):
92
+ chat = [{"role": "user", "content": prompt}]
93
+ chat_formated = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
94
+ #print(chat_formated, flush=True)
95
+ inputs = tokenizer(chat_formated, return_tensors="pt")
96
+
97
+ inputs = inputs.to('cuda')
98
+
99
+ generated_ids = model.generate(**inputs, max_new_tokens=512)
100
+
101
+ #generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)] # this does not work for Gemma
102
+
103
+ out = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
104
+
105
+ #print(out, flush=True)
106
+
107
+ try:
108
+ output = out.split('model\n')[1].replace('\n', '')
109
+ except IndexError:
110
+ print('EVAL ERROR: '+output, flush=True)
111
+
112
+ #import pdb; pdb.set_trace()
113
+ output = output.strip()
114
+
115
+ return output
116
+
117
+ def get_prompts(cq, references):
118
+ return {
119
+ 'compare': f"""You will be given a set of reference questions, each with an identifying ID, and a newly generated question. Your task is to determine if any of the reference questions are asking for the same information as the new question.
120
+
121
+ Here is the set of reference questions with their IDs:
122
+
123
+ <reference_questions>
124
+ {references}
125
+ </reference_questions>
126
+
127
+ Here is the newly generated question:
128
+
129
+ <new_question>
130
+ {cq}
131
+ </new_question>
132
+
133
+ Compare the new question to each of the reference questions. Look for questions that are asking for the same information, even if they are worded differently. Consider the core meaning and intent of each question, not just the exact wording.
134
+
135
+ If you find a reference question that is asking for the same information as the new question, output only the ID of that reference question.
136
+
137
+ If none of the reference questions are asking for the same information as the new question, output exactly 'Similar reference not found.' (without quotes).
138
+
139
+ Your final output should consist of only one of the following:
140
+ 1. The ID of the most similar reference question
141
+ 2. The exact phrase 'Similar reference not found.'
142
+
143
+ Do not include any explanation, reasoning, or additional text in your output."""}
144
+
145
  def add_new_eval(
146
  model: str,
147
  model_family: str,
 
225
  reference_set = [row['cq'] for row in references[indx]]
226
  #print(reference_set, flush=True)
227
  for cq in line['cqs']:
 
228
  cq_text = cq['cq']
229
  #print(cq_text, flush=True)
230
 
 
239
  # make sure the similarity of the winning reference sentence is at least 0.65
240
  if sims[winner] > 0.65:
241
  label = references[indx][winner]['label']
242
+ else:
243
+ label = 'not_able_to_evaluate'
244
+
245
+ if METRIC == 'gemma':
246
+ prompts = get_prompts(cq_text, '\n'.join(reference_set))
247
+ winner = run_model(model, tokenizer, prompts['compare'])
248
+ try: # here make sure the output is the id of a reference cq
249
+ if winner.strip() != 'Similar reference not found.':
250
+ label = references[index][int(winner)]['label']
251
+ else:
252
+ label = 'not_able_to_evaluate'
253
+ except IndexError:
254
+ label = 'evaluation_issue'
255
+ except ValueError:
256
+ label = 'evaluation_issue'
257
+
258
+ print(label, flush=True)
259
+ if label == 'Useful':
260
+ score += 1/3
261
+
262
  print(indx, score, flush=True)
263
  #return format_error(score)
264