Blanca commited on
Commit
0c85f3a
·
verified ·
1 Parent(s): 38ffc99

Upload scorer.py

Browse files
Files changed (1) hide show
  1. scorer.py +104 -0
scorer.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import string
4
+ import warnings
5
+
6
+ import numpy as np
7
+
8
+
9
+ def normalize_number_str(number_str: str) -> float:
10
+ # we replace these common units and commas to allow
11
+ # conversion to float
12
+ for char in ["$", "%", ","]:
13
+ number_str = number_str.replace(char, "")
14
+ try:
15
+ return float(number_str)
16
+ except ValueError:
17
+ print(f"String {number_str} cannot be normalized to number str.")
18
+ return float("inf")
19
+
20
+
21
+ def split_string(
22
+ s: str,
23
+ char_list: list[str] = [",", ";"],
24
+ ) -> list[str]:
25
+ pattern = f"[{''.join(char_list)}]"
26
+ return re.split(pattern, s)
27
+
28
+
29
+ def question_scorer(
30
+ model_answer: str,
31
+ ground_truth: str,
32
+ ) -> bool:
33
+ def is_float(element: any) -> bool:
34
+ try:
35
+ float(element)
36
+ return True
37
+ except ValueError:
38
+ return False
39
+
40
+ if model_answer is None:
41
+ model_answer = "None"
42
+
43
+ # if gt is a number
44
+ if is_float(ground_truth):
45
+ print(f"Evaluating {model_answer} as a number.")
46
+ normalized_answer = normalize_number_str(model_answer)
47
+ return normalized_answer == float(ground_truth)
48
+
49
+ # if gt is a list
50
+ elif any(char in ground_truth for char in [",", ";"]):
51
+ print(f"Evaluating {model_answer} as a comma separated list.")
52
+ # question with the fish: normalization removes punct
53
+
54
+ gt_elems = split_string(ground_truth)
55
+ ma_elems = split_string(model_answer)
56
+
57
+ # check length is the same
58
+ if len(gt_elems) != len(ma_elems):
59
+ warnings.warn(
60
+ "Answer lists have different lengths, returning False.", UserWarning
61
+ )
62
+ return False
63
+
64
+ # compare each element as float or str
65
+ comparisons = []
66
+ for ma_elem, gt_elem in zip(ma_elems, gt_elems):
67
+ if is_float(gt_elem):
68
+ normalized_ma_elem = normalize_number_str(ma_elem)
69
+ comparisons.append(normalized_ma_elem == float(gt_elem))
70
+ else:
71
+ # we do not remove punct since comparisons can include punct
72
+ comparisons.append(
73
+ normalize_str(ma_elem, remove_punct=False)
74
+ == normalize_str(gt_elem, remove_punct=False)
75
+ )
76
+ return all(comparisons)
77
+
78
+ # if gt is a str
79
+ else:
80
+ print(f"Evaluating {model_answer} as a string.")
81
+ return normalize_str(model_answer) == normalize_str(ground_truth)
82
+
83
+
84
+ def normalize_str(input_str, remove_punct=True) -> str:
85
+ """
86
+ Normalize a string by:
87
+ - Removing all white spaces
88
+ - Optionally removing punctuation (if remove_punct is True)
89
+ - Converting to lowercase
90
+ Parameters:
91
+ - input_str: str, the string to normalize
92
+ - remove_punct: bool, whether to remove punctuation (default: True)
93
+ Returns:
94
+ - str, the normalized string
95
+ """
96
+ # Remove all white spaces. Required e.g for seagull vs. sea gull
97
+ no_spaces = re.sub(r"\s", "", input_str)
98
+
99
+ # Remove punctuation, if specified.
100
+ if remove_punct:
101
+ translator = str.maketrans("", "", string.punctuation)
102
+ return no_spaces.lower().translate(translator)
103
+ else:
104
+ return no_spaces.lower()