Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -22,6 +22,14 @@ SUBJECT_TRANS = {
|
|
| 22 |
"组合": "Combinatorics"
|
| 23 |
}
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
MODEL_TRANS = {
|
| 26 |
"acemath-rl-nemotron-7b": "AceMath-RL-Nemotron-7B",
|
| 27 |
"deepseek-r1-distill-qwen-1.5b": "DeepSeek-R1-Distill-Qwen-1.5B",
|
|
@@ -65,6 +73,70 @@ DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"]
|
|
| 65 |
# 全局数据库实例
|
| 66 |
db = None
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
class ModelDatabase:
|
| 69 |
"""Database access class"""
|
| 70 |
|
|
@@ -360,6 +432,82 @@ class ModelDatabase:
|
|
| 360 |
# 清理所有缓存
|
| 361 |
self.clear_cache()
|
| 362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
def format_latex(text):
|
| 364 |
if text is None: return ""
|
| 365 |
# Process the text for proper LaTeX rendering with KaTeX
|
|
@@ -372,12 +520,24 @@ def format_latex(text):
|
|
| 372 |
def format_markdown_with_math(text):
|
| 373 |
if text is None: return ""
|
| 374 |
|
| 375 |
-
#
|
| 376 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
|
| 378 |
# Convert newlines for markdown
|
| 379 |
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
| 380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
# Return the cleaned text for Gradio's markdown component to render
|
| 382 |
return text
|
| 383 |
|
|
@@ -584,16 +744,9 @@ def handle_comparison_problem_update(problem_id, dataset_state):
|
|
| 584 |
# Use format_markdown_with_math for proper rendering
|
| 585 |
problem_content = format_markdown_with_math(problem_dict.get('problem', ''))
|
| 586 |
|
| 587 |
-
#
|
| 588 |
answer_text = problem_dict.get('answer', '')
|
| 589 |
-
|
| 590 |
-
answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL)
|
| 591 |
-
|
| 592 |
-
# 检查答案是否已经包含美元符号,如果没有则添加
|
| 593 |
-
if '$' not in answer_text and answer_text.strip():
|
| 594 |
-
answer_text = f"${answer_text}$"
|
| 595 |
-
|
| 596 |
-
answer_content = format_markdown_with_math(answer_text)
|
| 597 |
|
| 598 |
return problem_content, answer_content
|
| 599 |
except Exception as e:
|
|
@@ -634,16 +787,9 @@ def handle_problem_select(problem_id_from_js, current_model_state, current_datas
|
|
| 634 |
# Process problem and answer text for Markdown rendering
|
| 635 |
problem_content = format_markdown_with_math(problem_dict.get('problem', ''))
|
| 636 |
|
| 637 |
-
#
|
| 638 |
answer_text = problem_dict.get('answer', '')
|
| 639 |
-
|
| 640 |
-
answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL)
|
| 641 |
-
|
| 642 |
-
# 检查答案是否已经包含美元符号,如果没有则添加
|
| 643 |
-
if '$' not in answer_text and answer_text.strip():
|
| 644 |
-
answer_text = f"${answer_text}$"
|
| 645 |
-
|
| 646 |
-
answer_content = format_markdown_with_math(answer_text)
|
| 647 |
|
| 648 |
# For comparison without model, we don't have samples to display
|
| 649 |
return problem_content, answer_content, "", gr.State([])
|
|
@@ -673,16 +819,9 @@ def handle_problem_select(problem_id_from_js, current_model_state, current_datas
|
|
| 673 |
# Process problem and answer text for Markdown rendering
|
| 674 |
problem_content = format_markdown_with_math(problem_dict.get('problem', ''))
|
| 675 |
|
| 676 |
-
#
|
| 677 |
answer_text = problem_dict.get('answer', '')
|
| 678 |
-
|
| 679 |
-
answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL)
|
| 680 |
-
|
| 681 |
-
# 检查答案是否已经包含美元符号,如果没有则添加
|
| 682 |
-
if '$' not in answer_text and answer_text.strip():
|
| 683 |
-
answer_text = f"${answer_text}$"
|
| 684 |
-
|
| 685 |
-
answer_content = format_markdown_with_math(answer_text)
|
| 686 |
|
| 687 |
# Rest of the function remains the same
|
| 688 |
if not responses_data:
|
|
@@ -709,7 +848,7 @@ def handle_problem_select(problem_id_from_js, current_model_state, current_datas
|
|
| 709 |
samples_per_row = 16 if mode == 'comparison' else 32
|
| 710 |
|
| 711 |
# 第一行: 样本 0-samples_per_row
|
| 712 |
-
samples_grid_html = f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;
|
| 713 |
|
| 714 |
for i, resp in enumerate(displayed_samples[:samples_per_row]):
|
| 715 |
correctness = resp.get('correctness', 0)
|
|
@@ -737,7 +876,7 @@ def handle_problem_select(problem_id_from_js, current_model_state, current_datas
|
|
| 737 |
# 如果有更多样本,显示第二行
|
| 738 |
if actual_display_count > samples_per_row:
|
| 739 |
row_samples = displayed_samples[samples_per_row:2*samples_per_row]
|
| 740 |
-
samples_grid_html += f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;
|
| 741 |
|
| 742 |
for i, resp in enumerate(row_samples):
|
| 743 |
actual_idx = i + samples_per_row
|
|
@@ -767,7 +906,7 @@ def handle_problem_select(problem_id_from_js, current_model_state, current_datas
|
|
| 767 |
# 第三行
|
| 768 |
row_samples = displayed_samples[2*samples_per_row:3*samples_per_row]
|
| 769 |
if row_samples:
|
| 770 |
-
samples_grid_html += f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;
|
| 771 |
|
| 772 |
for i, resp in enumerate(row_samples):
|
| 773 |
actual_idx = i + 2*samples_per_row
|
|
@@ -796,7 +935,7 @@ def handle_problem_select(problem_id_from_js, current_model_state, current_datas
|
|
| 796 |
if actual_display_count > 3*samples_per_row:
|
| 797 |
row_samples = displayed_samples[3*samples_per_row:4*samples_per_row]
|
| 798 |
if row_samples:
|
| 799 |
-
samples_grid_html += f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;
|
| 800 |
|
| 801 |
for i, resp in enumerate(row_samples):
|
| 802 |
actual_idx = i + 3*samples_per_row
|
|
@@ -886,6 +1025,54 @@ def create_ui(db_path):
|
|
| 886 |
global db
|
| 887 |
db = ModelDatabase(db_path)
|
| 888 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 889 |
AVAILABLE_DATASETS = db.get_available_datasets()
|
| 890 |
if not AVAILABLE_DATASETS:
|
| 891 |
AVAILABLE_DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"] # Fallback
|
|
@@ -896,9 +1083,9 @@ def create_ui(db_path):
|
|
| 896 |
body, .gradio-container { font-family: sans-serif; font-size: 0.95em; line-height: 1.6; }
|
| 897 |
.sample-btn { transition: all 0.15s ease-in-out; }
|
| 898 |
.sample-btn:hover { transform: translateY(-1px); box-shadow: 0 2px 5px rgba(0,0,0,0.1); }
|
| 899 |
-
.problem-grid-container { overflow
|
| 900 |
-
.math-content { overflow
|
| 901 |
-
.sample-response { overflow
|
| 902 |
h1, h2, h3, h4, h5 { margin-top: 0.8em; margin-bottom: 0.4em; color: var(--color-text); }
|
| 903 |
.gradio-tabs > div[role='tablist'] button { font-size: 0.9em; padding: 8px 12px; }
|
| 904 |
.gr-dropdown select { font-size: 0.9em; }
|
|
@@ -964,6 +1151,68 @@ def create_ui(db_path):
|
|
| 964 |
border: 1px solid #ddd;
|
| 965 |
padding: 4px 8px;
|
| 966 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 967 |
"""
|
| 968 |
|
| 969 |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo:
|
|
@@ -989,6 +1238,64 @@ def create_ui(db_path):
|
|
| 989 |
# 创建占位符State组件替代None
|
| 990 |
dummy_state = gr.State(value=None)
|
| 991 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 992 |
with gr.Tabs():
|
| 993 |
with gr.TabItem("Single Model Analysis"):
|
| 994 |
with gr.Row(variant='compact'):
|
|
@@ -1228,6 +1535,83 @@ def create_ui(db_path):
|
|
| 1228 |
]
|
| 1229 |
)
|
| 1230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1231 |
# --- Event Handlers ---
|
| 1232 |
def update_available_models_for_dropdowns(selected_dataset):
|
| 1233 |
# This function can be used to update model lists if they are dataset-dependent
|
|
@@ -1549,6 +1933,37 @@ def create_ui(db_path):
|
|
| 1549 |
outputs=[sample_number_input]
|
| 1550 |
)
|
| 1551 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1552 |
return demo
|
| 1553 |
|
| 1554 |
def monitor_memory_usage():
|
|
@@ -1575,6 +1990,273 @@ def monitor_memory_usage():
|
|
| 1575 |
except Exception as e:
|
| 1576 |
return "Memory monitor error"
|
| 1577 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1578 |
# 修改主函数以使用优化策略
|
| 1579 |
if __name__ == "__main__":
|
| 1580 |
DB_PATH = "data.db"
|
|
@@ -1582,22 +2264,15 @@ if __name__ == "__main__":
|
|
| 1582 |
# 检查数据库文件是否存在,如果不存在则从 Hugging Face 下载
|
| 1583 |
if not os.path.exists(DB_PATH):
|
| 1584 |
try:
|
| 1585 |
-
# 从环境变量获取 HF_TOKEN
|
| 1586 |
-
hf_token = os.environ.get("HF_TOKEN")
|
| 1587 |
-
if not hf_token:
|
| 1588 |
-
raise ValueError("HF_TOKEN environment variable is not set")
|
| 1589 |
-
|
| 1590 |
-
# 从 Hugging Face 下载数据库文件
|
| 1591 |
DB_PATH = hf_hub_download(
|
| 1592 |
repo_id="CoderBak/OlymMATH-data",
|
| 1593 |
filename="data.db",
|
| 1594 |
-
repo_type="dataset"
|
| 1595 |
-
token=hf_token
|
| 1596 |
)
|
| 1597 |
except Exception as e:
|
| 1598 |
# 创建一个显示错误信息的简单 Gradio 应用
|
| 1599 |
with gr.Blocks() as error_demo:
|
| 1600 |
-
gr.Markdown(f"# Error: Database Download Failed\n{str(e)}
|
| 1601 |
error_demo.launch(server_name="0.0.0.0")
|
| 1602 |
exit(1)
|
| 1603 |
|
|
|
|
| 22 |
"组合": "Combinatorics"
|
| 23 |
}
|
| 24 |
|
| 25 |
+
# 英文到中文的翻译表
|
| 26 |
+
SUBJECT_TRANS_EN_TO_ZH = {
|
| 27 |
+
"Algebra": "代数",
|
| 28 |
+
"Number Theory": "数论",
|
| 29 |
+
"Geometry": "几何",
|
| 30 |
+
"Combinatorics": "组合"
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
MODEL_TRANS = {
|
| 34 |
"acemath-rl-nemotron-7b": "AceMath-RL-Nemotron-7B",
|
| 35 |
"deepseek-r1-distill-qwen-1.5b": "DeepSeek-R1-Distill-Qwen-1.5B",
|
|
|
|
| 73 |
# 全局数据库实例
|
| 74 |
db = None
|
| 75 |
|
| 76 |
+
# 全局缓存for Reference Solutions
|
| 77 |
+
reference_accuracy_cache = {}
|
| 78 |
+
|
| 79 |
+
def precompute_reference_accuracies(db, reference_loader):
|
| 80 |
+
"""Pre-compute all reference problem accuracies for fast loading"""
|
| 81 |
+
global reference_accuracy_cache
|
| 82 |
+
|
| 83 |
+
if not db or not reference_loader:
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
print("Pre-computing reference problem accuracies...")
|
| 87 |
+
start_time = time.time()
|
| 88 |
+
|
| 89 |
+
problem_ids = reference_loader.get_all_problem_ids()
|
| 90 |
+
reference_accuracy_cache = {}
|
| 91 |
+
|
| 92 |
+
# 获取所有模型一次性
|
| 93 |
+
all_models = db.get_available_models()
|
| 94 |
+
print(f"Computing accuracies for {len(problem_ids)} problems across {len(all_models)} models...")
|
| 95 |
+
|
| 96 |
+
for i, pid in enumerate(problem_ids):
|
| 97 |
+
if i % 5 == 0: # 每5个问题打印一次进度
|
| 98 |
+
print(f"Processing problem {i+1}/{len(problem_ids)}: {pid}")
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
en_unique_id = f"OlymMATH-HARD-{pid}-EN"
|
| 102 |
+
zh_unique_id = f"OlymMATH-HARD-{pid}-ZH"
|
| 103 |
+
|
| 104 |
+
en_accuracies = []
|
| 105 |
+
zh_accuracies = []
|
| 106 |
+
|
| 107 |
+
for model in all_models:
|
| 108 |
+
# 英文版本
|
| 109 |
+
try:
|
| 110 |
+
_, responses_en = db.get_problem_data(model, "EN-HARD", en_unique_id)
|
| 111 |
+
if responses_en and len(responses_en) > 0:
|
| 112 |
+
avg_accuracy_en = sum(r['correctness'] for r in responses_en) / len(responses_en)
|
| 113 |
+
en_accuracies.append(avg_accuracy_en)
|
| 114 |
+
except Exception:
|
| 115 |
+
pass
|
| 116 |
+
|
| 117 |
+
# 中文版本
|
| 118 |
+
try:
|
| 119 |
+
_, responses_zh = db.get_problem_data(model, "ZH-HARD", zh_unique_id)
|
| 120 |
+
if responses_zh and len(responses_zh) > 0:
|
| 121 |
+
avg_accuracy_zh = sum(r['correctness'] for r in responses_zh) / len(responses_zh)
|
| 122 |
+
zh_accuracies.append(avg_accuracy_zh)
|
| 123 |
+
except Exception:
|
| 124 |
+
pass
|
| 125 |
+
|
| 126 |
+
# 计算平均值并存储到缓存
|
| 127 |
+
en_avg = sum(en_accuracies) / len(en_accuracies) if en_accuracies else 0.0
|
| 128 |
+
zh_avg = sum(zh_accuracies) / len(zh_accuracies) if zh_accuracies else 0.0
|
| 129 |
+
|
| 130 |
+
reference_accuracy_cache[pid] = {"EN": en_avg, "ZH": zh_avg}
|
| 131 |
+
|
| 132 |
+
except Exception as e:
|
| 133 |
+
print(f"Error computing accuracy for problem {pid}: {e}")
|
| 134 |
+
reference_accuracy_cache[pid] = {"EN": 0.0, "ZH": 0.0}
|
| 135 |
+
|
| 136 |
+
elapsed_time = time.time() - start_time
|
| 137 |
+
print(f"✅ Pre-computation completed in {elapsed_time:.2f} seconds")
|
| 138 |
+
print(f"✅ Cached accuracies for {len(reference_accuracy_cache)} problems")
|
| 139 |
+
|
| 140 |
class ModelDatabase:
|
| 141 |
"""Database access class"""
|
| 142 |
|
|
|
|
| 432 |
# 清理所有缓存
|
| 433 |
self.clear_cache()
|
| 434 |
|
| 435 |
+
class ReferenceDataLoader:
|
| 436 |
+
"""Load and manage reference solutions data"""
|
| 437 |
+
|
| 438 |
+
def __init__(self, jsonl_path):
|
| 439 |
+
self.jsonl_path = jsonl_path
|
| 440 |
+
self.reference_data = {}
|
| 441 |
+
self._load_data()
|
| 442 |
+
|
| 443 |
+
def _load_data(self):
|
| 444 |
+
"""Load data from extra.jsonl"""
|
| 445 |
+
try:
|
| 446 |
+
with open(self.jsonl_path, 'r', encoding='utf-8') as f:
|
| 447 |
+
for line in f:
|
| 448 |
+
data = json.loads(line.strip())
|
| 449 |
+
unique_id = data['unique_id']
|
| 450 |
+
self.reference_data[unique_id] = data
|
| 451 |
+
except Exception as e:
|
| 452 |
+
print(f"Error loading reference data: {e}")
|
| 453 |
+
|
| 454 |
+
def get_problem_data(self, unique_id):
|
| 455 |
+
"""Get reference data for a specific problem ID"""
|
| 456 |
+
return self.reference_data.get(unique_id)
|
| 457 |
+
|
| 458 |
+
def get_all_problem_ids(self):
|
| 459 |
+
"""Get all available problem IDs"""
|
| 460 |
+
return sorted(self.reference_data.keys())
|
| 461 |
+
|
| 462 |
+
def calculate_reference_problem_accuracy(db, unique_id):
|
| 463 |
+
"""Calculate average accuracy for a reference problem across all models for both EN and ZH versions"""
|
| 464 |
+
try:
|
| 465 |
+
# 构建英文和中文版本的unique_id
|
| 466 |
+
en_unique_id = f"OlymMATH-HARD-{unique_id}-EN"
|
| 467 |
+
zh_unique_id = f"OlymMATH-HARD-{unique_id}-ZH"
|
| 468 |
+
|
| 469 |
+
print(f"Calculating accuracy for problem {unique_id}: EN={en_unique_id}, ZH={zh_unique_id}")
|
| 470 |
+
|
| 471 |
+
accuracies = {"EN": [], "ZH": []}
|
| 472 |
+
|
| 473 |
+
# 获取所有模型
|
| 474 |
+
all_models = db.get_available_models()
|
| 475 |
+
print(f"Found {len(all_models)} models in database")
|
| 476 |
+
|
| 477 |
+
for model in all_models:
|
| 478 |
+
# 英文版本
|
| 479 |
+
try:
|
| 480 |
+
_, responses_en = db.get_problem_data(model, "EN-HARD", en_unique_id)
|
| 481 |
+
if responses_en and len(responses_en) > 0:
|
| 482 |
+
avg_accuracy_en = sum(r['correctness'] for r in responses_en) / len(responses_en)
|
| 483 |
+
accuracies["EN"].append(avg_accuracy_en)
|
| 484 |
+
print(f" Model {model} EN: {avg_accuracy_en:.2%} ({len(responses_en)} responses)")
|
| 485 |
+
except Exception as e:
|
| 486 |
+
print(f" Error getting EN data for model {model}: {e}")
|
| 487 |
+
pass
|
| 488 |
+
|
| 489 |
+
# 中文版本
|
| 490 |
+
try:
|
| 491 |
+
_, responses_zh = db.get_problem_data(model, "ZH-HARD", zh_unique_id)
|
| 492 |
+
if responses_zh and len(responses_zh) > 0:
|
| 493 |
+
avg_accuracy_zh = sum(r['correctness'] for r in responses_zh) / len(responses_zh)
|
| 494 |
+
accuracies["ZH"].append(avg_accuracy_zh)
|
| 495 |
+
print(f" Model {model} ZH: {avg_accuracy_zh:.2%} ({len(responses_zh)} responses)")
|
| 496 |
+
except Exception as e:
|
| 497 |
+
print(f" Error getting ZH data for model {model}: {e}")
|
| 498 |
+
pass
|
| 499 |
+
|
| 500 |
+
# 计算平均值
|
| 501 |
+
en_avg = sum(accuracies["EN"]) / len(accuracies["EN"]) if accuracies["EN"] else 0.0
|
| 502 |
+
zh_avg = sum(accuracies["ZH"]) / len(accuracies["ZH"]) if accuracies["ZH"] else 0.0
|
| 503 |
+
|
| 504 |
+
print(f"Final averages for problem {unique_id}: EN={en_avg:.2%} (from {len(accuracies['EN'])} models), ZH={zh_avg:.2%} (from {len(accuracies['ZH'])} models)")
|
| 505 |
+
|
| 506 |
+
return en_avg, zh_avg
|
| 507 |
+
except Exception as e:
|
| 508 |
+
print(f"Error calculating accuracy for problem {unique_id}: {e}")
|
| 509 |
+
return 0.0, 0.0
|
| 510 |
+
|
| 511 |
def format_latex(text):
|
| 512 |
if text is None: return ""
|
| 513 |
# Process the text for proper LaTeX rendering with KaTeX
|
|
|
|
| 520 |
def format_markdown_with_math(text):
|
| 521 |
if text is None: return ""
|
| 522 |
|
| 523 |
+
# Convert LaTeX delimiters first - same logic as format_solution_latex
|
| 524 |
+
# Convert $$xxx$$ to \[xxx\] (display math)
|
| 525 |
+
text = re.sub(r'\$\$(.*?)\$\$', r'\\[\1\\]', text, flags=re.DOTALL)
|
| 526 |
+
|
| 527 |
+
# Convert $xxx$ to \(xxx\) (inline math)
|
| 528 |
+
# Be careful not to match already converted \[...\] content
|
| 529 |
+
text = re.sub(r'(?<!\\)\$([^$\n]+?)\$(?!\])', r'\\(\1\\)', text)
|
| 530 |
|
| 531 |
# Convert newlines for markdown
|
| 532 |
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
| 533 |
|
| 534 |
+
# Clean up excessive newlines
|
| 535 |
+
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
| 536 |
+
|
| 537 |
+
# Debug: Print if aligned environment detected
|
| 538 |
+
if '\\begin{aligned}' in text:
|
| 539 |
+
print(f"LaTeX aligned environment detected in text (first 200 chars): {text[:200]}...")
|
| 540 |
+
|
| 541 |
# Return the cleaned text for Gradio's markdown component to render
|
| 542 |
return text
|
| 543 |
|
|
|
|
| 744 |
# Use format_markdown_with_math for proper rendering
|
| 745 |
problem_content = format_markdown_with_math(problem_dict.get('problem', ''))
|
| 746 |
|
| 747 |
+
# Use special answer formatting
|
| 748 |
answer_text = problem_dict.get('answer', '')
|
| 749 |
+
answer_content = format_answer_with_math(answer_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
|
| 751 |
return problem_content, answer_content
|
| 752 |
except Exception as e:
|
|
|
|
| 787 |
# Process problem and answer text for Markdown rendering
|
| 788 |
problem_content = format_markdown_with_math(problem_dict.get('problem', ''))
|
| 789 |
|
| 790 |
+
# Use special answer formatting
|
| 791 |
answer_text = problem_dict.get('answer', '')
|
| 792 |
+
answer_content = format_answer_with_math(answer_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 793 |
|
| 794 |
# For comparison without model, we don't have samples to display
|
| 795 |
return problem_content, answer_content, "", gr.State([])
|
|
|
|
| 819 |
# Process problem and answer text for Markdown rendering
|
| 820 |
problem_content = format_markdown_with_math(problem_dict.get('problem', ''))
|
| 821 |
|
| 822 |
+
# Use special answer formatting
|
| 823 |
answer_text = problem_dict.get('answer', '')
|
| 824 |
+
answer_content = format_answer_with_math(answer_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 825 |
|
| 826 |
# Rest of the function remains the same
|
| 827 |
if not responses_data:
|
|
|
|
| 848 |
samples_per_row = 16 if mode == 'comparison' else 32
|
| 849 |
|
| 850 |
# 第一行: 样本 0-samples_per_row
|
| 851 |
+
samples_grid_html = f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;">'
|
| 852 |
|
| 853 |
for i, resp in enumerate(displayed_samples[:samples_per_row]):
|
| 854 |
correctness = resp.get('correctness', 0)
|
|
|
|
| 876 |
# 如果有更多样本,显示第二行
|
| 877 |
if actual_display_count > samples_per_row:
|
| 878 |
row_samples = displayed_samples[samples_per_row:2*samples_per_row]
|
| 879 |
+
samples_grid_html += f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;">'
|
| 880 |
|
| 881 |
for i, resp in enumerate(row_samples):
|
| 882 |
actual_idx = i + samples_per_row
|
|
|
|
| 906 |
# 第三行
|
| 907 |
row_samples = displayed_samples[2*samples_per_row:3*samples_per_row]
|
| 908 |
if row_samples:
|
| 909 |
+
samples_grid_html += f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;">'
|
| 910 |
|
| 911 |
for i, resp in enumerate(row_samples):
|
| 912 |
actual_idx = i + 2*samples_per_row
|
|
|
|
| 935 |
if actual_display_count > 3*samples_per_row:
|
| 936 |
row_samples = displayed_samples[3*samples_per_row:4*samples_per_row]
|
| 937 |
if row_samples:
|
| 938 |
+
samples_grid_html += f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;">'
|
| 939 |
|
| 940 |
for i, resp in enumerate(row_samples):
|
| 941 |
actual_idx = i + 3*samples_per_row
|
|
|
|
| 1025 |
global db
|
| 1026 |
db = ModelDatabase(db_path)
|
| 1027 |
|
| 1028 |
+
# Initialize reference data loader with better path handling
|
| 1029 |
+
reference_loader = None
|
| 1030 |
+
# Try multiple possible paths for extra.jsonl
|
| 1031 |
+
possible_paths = [
|
| 1032 |
+
os.path.join(os.path.dirname(db_path), "extra.jsonl"),
|
| 1033 |
+
os.path.join(os.getcwd(), "extra.jsonl"),
|
| 1034 |
+
"extra.jsonl"
|
| 1035 |
+
]
|
| 1036 |
+
|
| 1037 |
+
for extra_jsonl_path in possible_paths:
|
| 1038 |
+
if os.path.exists(extra_jsonl_path):
|
| 1039 |
+
try:
|
| 1040 |
+
reference_loader = ReferenceDataLoader(extra_jsonl_path)
|
| 1041 |
+
print(f"Successfully loaded reference data from: {extra_jsonl_path}")
|
| 1042 |
+
break
|
| 1043 |
+
except Exception as e:
|
| 1044 |
+
print(f"Error loading reference data from {extra_jsonl_path}: {e}")
|
| 1045 |
+
continue
|
| 1046 |
+
|
| 1047 |
+
# If not found locally, try to download from Hugging Face
|
| 1048 |
+
if not reference_loader:
|
| 1049 |
+
try:
|
| 1050 |
+
print("Attempting to download extra.jsonl from Hugging Face...")
|
| 1051 |
+
extra_jsonl_path = hf_hub_download(
|
| 1052 |
+
repo_id="CoderBak/OlymMATH-data",
|
| 1053 |
+
filename="extra.jsonl",
|
| 1054 |
+
repo_type="dataset"
|
| 1055 |
+
)
|
| 1056 |
+
reference_loader = ReferenceDataLoader(extra_jsonl_path)
|
| 1057 |
+
print(f"Successfully downloaded and loaded reference data from: {extra_jsonl_path}")
|
| 1058 |
+
except Exception as e:
|
| 1059 |
+
print(f"Failed to download extra.jsonl from Hugging Face: {e}")
|
| 1060 |
+
|
| 1061 |
+
if not reference_loader:
|
| 1062 |
+
print("Warning: extra.jsonl not found in any of the expected locations:")
|
| 1063 |
+
for path in possible_paths:
|
| 1064 |
+
print(f" - {path}")
|
| 1065 |
+
print("Reference Solutions tab will not be available.")
|
| 1066 |
+
else:
|
| 1067 |
+
# Test the reference data availability
|
| 1068 |
+
test_reference_data_availability(db, reference_loader)
|
| 1069 |
+
|
| 1070 |
+
# Pre-compute reference problem accuracies for fast loading
|
| 1071 |
+
precompute_reference_accuracies(db, reference_loader)
|
| 1072 |
+
|
| 1073 |
+
# Test LaTeX formatting
|
| 1074 |
+
test_latex_formatting()
|
| 1075 |
+
|
| 1076 |
AVAILABLE_DATASETS = db.get_available_datasets()
|
| 1077 |
if not AVAILABLE_DATASETS:
|
| 1078 |
AVAILABLE_DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"] # Fallback
|
|
|
|
| 1083 |
body, .gradio-container { font-family: sans-serif; font-size: 0.95em; line-height: 1.6; }
|
| 1084 |
.sample-btn { transition: all 0.15s ease-in-out; }
|
| 1085 |
.sample-btn:hover { transform: translateY(-1px); box-shadow: 0 2px 5px rgba(0,0,0,0.1); }
|
| 1086 |
+
.problem-grid-container { overflow: visible !important; }
|
| 1087 |
+
.math-content { overflow: visible !important; padding: 5px; }
|
| 1088 |
+
.sample-response { overflow: visible !important; max-height: none !important; height: auto !important; }
|
| 1089 |
h1, h2, h3, h4, h5 { margin-top: 0.8em; margin-bottom: 0.4em; color: var(--color-text); }
|
| 1090 |
.gradio-tabs > div[role='tablist'] button { font-size: 0.9em; padding: 8px 12px; }
|
| 1091 |
.gr-dropdown select { font-size: 0.9em; }
|
|
|
|
| 1151 |
border: 1px solid #ddd;
|
| 1152 |
padding: 4px 8px;
|
| 1153 |
}
|
| 1154 |
+
|
| 1155 |
+
/* 隐藏滚动条但保留功能 */
|
| 1156 |
+
::-webkit-scrollbar {
|
| 1157 |
+
display: none !important;
|
| 1158 |
+
width: 0px !important;
|
| 1159 |
+
height: 0px !important;
|
| 1160 |
+
}
|
| 1161 |
+
|
| 1162 |
+
/* 主容器禁用滚动 */
|
| 1163 |
+
.gradio-container {
|
| 1164 |
+
overflow-x: hidden !important;
|
| 1165 |
+
}
|
| 1166 |
+
|
| 1167 |
+
/* Gradio组件容器 */
|
| 1168 |
+
.gradio-row, .gradio-column {
|
| 1169 |
+
overflow: visible !important;
|
| 1170 |
+
max-height: none !important;
|
| 1171 |
+
}
|
| 1172 |
+
|
| 1173 |
+
/* HTML组件 */
|
| 1174 |
+
.gr-html {
|
| 1175 |
+
overflow: visible !important;
|
| 1176 |
+
max-height: none !important;
|
| 1177 |
+
}
|
| 1178 |
+
|
| 1179 |
+
/* Markdown组件保持可见 */
|
| 1180 |
+
.gr-markdown {
|
| 1181 |
+
overflow: visible !important;
|
| 1182 |
+
max-height: none !important;
|
| 1183 |
+
}
|
| 1184 |
+
|
| 1185 |
+
/* 特定的问题网格容器 */
|
| 1186 |
+
#ref-problem-grid-container, #problem-grid-container, #comp-problem-grid-container-left, #comp-problem-grid-container-right {
|
| 1187 |
+
overflow: visible !important;
|
| 1188 |
+
max-height: none !important;
|
| 1189 |
+
height: auto !important;
|
| 1190 |
+
}
|
| 1191 |
+
|
| 1192 |
+
/* 样本网格 */
|
| 1193 |
+
.sample-grid-btn {
|
| 1194 |
+
overflow: visible !important;
|
| 1195 |
+
}
|
| 1196 |
+
|
| 1197 |
+
/* 确保内容区域不会产生滚动条 */
|
| 1198 |
+
.gr-form, .gr-box {
|
| 1199 |
+
overflow: visible !important;
|
| 1200 |
+
max-height: none !important;
|
| 1201 |
+
}
|
| 1202 |
+
|
| 1203 |
+
/* Reference Solutions - 禁止Solution部分的滚动 */
|
| 1204 |
+
#ref-solution {
|
| 1205 |
+
overflow: hidden !important;
|
| 1206 |
+
max-height: none !important;
|
| 1207 |
+
height: auto !important;
|
| 1208 |
+
}
|
| 1209 |
+
|
| 1210 |
+
/* 确保Solution内容容器也禁止滚动 */
|
| 1211 |
+
#ref-solution .gr-markdown {
|
| 1212 |
+
overflow: hidden !important;
|
| 1213 |
+
max-height: none !important;
|
| 1214 |
+
height: auto !important;
|
| 1215 |
+
}
|
| 1216 |
"""
|
| 1217 |
|
| 1218 |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo:
|
|
|
|
| 1238 |
# 创建占位符State组件替代None
|
| 1239 |
dummy_state = gr.State(value=None)
|
| 1240 |
|
| 1241 |
+
# Add JavaScript for handling problem grid clicks
|
| 1242 |
+
demo.load(lambda: None, js="""
|
| 1243 |
+
() => {
|
| 1244 |
+
// Handle problem button clicks for single model tab
|
| 1245 |
+
function setupProblemGridListeners() {
|
| 1246 |
+
document.addEventListener('click', function(e) {
|
| 1247 |
+
if (e.target.closest('.problem-btn')) {
|
| 1248 |
+
const problemBtn = e.target.closest('.problem-btn');
|
| 1249 |
+
const problemId = problemBtn.getAttribute('data-problem-id');
|
| 1250 |
+
if (problemId) {
|
| 1251 |
+
const problemInput = document.getElementById('problem-state-input');
|
| 1252 |
+
if (problemInput) {
|
| 1253 |
+
problemInput.querySelector('input').value = problemId;
|
| 1254 |
+
problemInput.querySelector('input').dispatchEvent(new Event('input', {bubbles: true}));
|
| 1255 |
+
}
|
| 1256 |
+
}
|
| 1257 |
+
}
|
| 1258 |
+
|
| 1259 |
+
// Handle comparison problem button clicks
|
| 1260 |
+
if (e.target.closest('#comp-problem-grid-container-left .problem-btn') ||
|
| 1261 |
+
e.target.closest('#comp-problem-grid-container-right .problem-btn')) {
|
| 1262 |
+
const problemBtn = e.target.closest('.problem-btn');
|
| 1263 |
+
const problemId = problemBtn.getAttribute('data-problem-id');
|
| 1264 |
+
if (problemId) {
|
| 1265 |
+
const problemInput = document.getElementById('comp-problem-state-input');
|
| 1266 |
+
if (problemInput) {
|
| 1267 |
+
problemInput.querySelector('input').value = problemId;
|
| 1268 |
+
problemInput.querySelector('input').dispatchEvent(new Event('input', {bubbles: true}));
|
| 1269 |
+
}
|
| 1270 |
+
}
|
| 1271 |
+
}
|
| 1272 |
+
|
| 1273 |
+
// Handle reference problem button clicks
|
| 1274 |
+
if (e.target.closest('#ref-problem-grid-container .ref-problem-btn')) {
|
| 1275 |
+
const problemBtn = e.target.closest('.ref-problem-btn');
|
| 1276 |
+
const problemId = problemBtn.getAttribute('data-problem-id');
|
| 1277 |
+
if (problemId) {
|
| 1278 |
+
const problemInput = document.getElementById('ref-problem-state-input');
|
| 1279 |
+
if (problemInput) {
|
| 1280 |
+
problemInput.querySelector('input').value = problemId;
|
| 1281 |
+
problemInput.querySelector('input').dispatchEvent(new Event('input', {bubbles: true}));
|
| 1282 |
+
}
|
| 1283 |
+
}
|
| 1284 |
+
}
|
| 1285 |
+
});
|
| 1286 |
+
}
|
| 1287 |
+
|
| 1288 |
+
// Set up listeners initially and after any DOM changes
|
| 1289 |
+
setupProblemGridListeners();
|
| 1290 |
+
|
| 1291 |
+
// Re-setup listeners whenever the DOM changes (for dynamic content)
|
| 1292 |
+
const observer = new MutationObserver(function(mutations) {
|
| 1293 |
+
setupProblemGridListeners();
|
| 1294 |
+
});
|
| 1295 |
+
observer.observe(document.body, {childList: true, subtree: true});
|
| 1296 |
+
}
|
| 1297 |
+
""")
|
| 1298 |
+
|
| 1299 |
with gr.Tabs():
|
| 1300 |
with gr.TabItem("Single Model Analysis"):
|
| 1301 |
with gr.Row(variant='compact'):
|
|
|
|
| 1535 |
]
|
| 1536 |
)
|
| 1537 |
|
| 1538 |
+
with gr.TabItem("Reference Solutions"):
|
| 1539 |
+
with gr.Row(variant='compact'):
|
| 1540 |
+
with gr.Column(scale=1, min_width=280):
|
| 1541 |
+
ref_problem_state_input = gr.Textbox(
|
| 1542 |
+
value="",
|
| 1543 |
+
elem_id="ref-problem-state-input",
|
| 1544 |
+
visible=True,
|
| 1545 |
+
label="Enter Problem ID",
|
| 1546 |
+
container=True,
|
| 1547 |
+
interactive=True,
|
| 1548 |
+
every=0.5
|
| 1549 |
+
)
|
| 1550 |
+
|
| 1551 |
+
with gr.Column(scale=3, min_width=400):
|
| 1552 |
+
gr.Markdown("#### Problem Grid (OlymMATH-HARD: All models avg. acc. - Top: EN, Bottom: ZH)")
|
| 1553 |
+
ref_problem_grid_html_output = gr.HTML(
|
| 1554 |
+
value="<div>Loading reference data...</div>",
|
| 1555 |
+
elem_id="ref-problem-grid-container"
|
| 1556 |
+
)
|
| 1557 |
+
|
| 1558 |
+
# 问题内容显示区域 - 左右分布
|
| 1559 |
+
with gr.Row(variant='compact'):
|
| 1560 |
+
# 左侧:问题信息
|
| 1561 |
+
with gr.Column(scale=1):
|
| 1562 |
+
gr.Markdown("#### Problem (EN)")
|
| 1563 |
+
ref_problem_en_output = gr.Markdown(
|
| 1564 |
+
"Please select a problem.",
|
| 1565 |
+
latex_delimiters=[
|
| 1566 |
+
{"left": "$", "right": "$", "display": False},
|
| 1567 |
+
{"left": "$$", "right": "$$", "display": True},
|
| 1568 |
+
{"left": "\\(", "right": "\\)", "display": False},
|
| 1569 |
+
{"left": "\\[", "right": "\\]", "display": True}
|
| 1570 |
+
]
|
| 1571 |
+
)
|
| 1572 |
+
|
| 1573 |
+
gr.Markdown("#### Problem (ZH)")
|
| 1574 |
+
ref_problem_zh_output = gr.Markdown(
|
| 1575 |
+
"Please select a problem.",
|
| 1576 |
+
latex_delimiters=[
|
| 1577 |
+
{"left": "$", "right": "$", "display": False},
|
| 1578 |
+
{"left": "$$", "right": "$$", "display": True},
|
| 1579 |
+
{"left": "\\(", "right": "\\)", "display": False},
|
| 1580 |
+
{"left": "\\[", "right": "\\]", "display": True}
|
| 1581 |
+
]
|
| 1582 |
+
)
|
| 1583 |
+
|
| 1584 |
+
gr.Markdown("#### Subject")
|
| 1585 |
+
ref_subject_output = gr.Markdown("Please select a problem.")
|
| 1586 |
+
|
| 1587 |
+
gr.Markdown("#### Answer")
|
| 1588 |
+
ref_answer_output = gr.Markdown(
|
| 1589 |
+
"Please select a problem.",
|
| 1590 |
+
latex_delimiters=[
|
| 1591 |
+
{"left": "$", "right": "$", "display": False},
|
| 1592 |
+
{"left": "$$", "right": "$$", "display": True},
|
| 1593 |
+
{"left": "\\(", "right": "\\)", "display": False},
|
| 1594 |
+
{"left": "\\[", "right": "\\]", "display": True}
|
| 1595 |
+
]
|
| 1596 |
+
)
|
| 1597 |
+
|
| 1598 |
+
# 右侧:解答
|
| 1599 |
+
with gr.Column(scale=1):
|
| 1600 |
+
gr.Markdown("#### Solution")
|
| 1601 |
+
ref_solution_output = gr.Markdown(
|
| 1602 |
+
"Please select a problem.",
|
| 1603 |
+
elem_id="ref-solution",
|
| 1604 |
+
latex_delimiters=[
|
| 1605 |
+
{"left": "$", "right": "$", "display": False},
|
| 1606 |
+
{"left": "$$", "right": "$$", "display": True},
|
| 1607 |
+
{"left": "\\(", "right": "\\)", "display": False},
|
| 1608 |
+
{"left": "\\[", "right": "\\]", "display": True},
|
| 1609 |
+
{"left": "\\begin{align}", "right": "\\end{align}", "display": True},
|
| 1610 |
+
{"left": "\\begin{aligned}", "right": "\\end{aligned}", "display": True},
|
| 1611 |
+
{"left": "\\begin{equation}", "right": "\\end{equation}", "display": True}
|
| 1612 |
+
]
|
| 1613 |
+
)
|
| 1614 |
+
|
| 1615 |
# --- Event Handlers ---
|
| 1616 |
def update_available_models_for_dropdowns(selected_dataset):
|
| 1617 |
# This function can be used to update model lists if they are dataset-dependent
|
|
|
|
| 1933 |
outputs=[sample_number_input]
|
| 1934 |
)
|
| 1935 |
|
| 1936 |
+
# 为引用解决方案标签页添加处理器
|
| 1937 |
+
# 初始化引用问题网格
|
| 1938 |
+
demo.load(
|
| 1939 |
+
fn=lambda: create_reference_problem_grid_html(reference_loader, db),
|
| 1940 |
+
inputs=[],
|
| 1941 |
+
outputs=[ref_problem_grid_html_output]
|
| 1942 |
+
)
|
| 1943 |
+
|
| 1944 |
+
# 引用问题选择事件
|
| 1945 |
+
ref_problem_state_input.change(
|
| 1946 |
+
fn=handle_reference_problem_select,
|
| 1947 |
+
inputs=[ref_problem_state_input, gr.State(reference_loader)],
|
| 1948 |
+
outputs=[ref_problem_en_output, ref_problem_zh_output, ref_subject_output, ref_answer_output, ref_solution_output]
|
| 1949 |
+
)
|
| 1950 |
+
|
| 1951 |
+
# This is the crucial link: problem_state_input is changed by user, triggers this Python callback.
|
| 1952 |
+
problem_state_input.change(
|
| 1953 |
+
fn=handle_problem_select,
|
| 1954 |
+
inputs=[problem_state_input, current_model_state, current_dataset_state],
|
| 1955 |
+
outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state]
|
| 1956 |
+
).then(
|
| 1957 |
+
# 重置Sample Number为0
|
| 1958 |
+
fn=lambda: "0",
|
| 1959 |
+
inputs=[],
|
| 1960 |
+
outputs=[sample_number_input]
|
| 1961 |
+
).then(
|
| 1962 |
+
fn=handle_first_sample,
|
| 1963 |
+
inputs=[current_samples_data_state],
|
| 1964 |
+
outputs=[sample_metadata_output, sample_response_output]
|
| 1965 |
+
)
|
| 1966 |
+
|
| 1967 |
return demo
|
| 1968 |
|
| 1969 |
def monitor_memory_usage():
|
|
|
|
| 1990 |
except Exception as e:
|
| 1991 |
return "Memory monitor error"
|
| 1992 |
|
| 1993 |
+
def create_reference_problem_grid_html(reference_loader, db):
|
| 1994 |
+
"""Create HTML for reference problem grid with average accuracies (using cache)"""
|
| 1995 |
+
global reference_accuracy_cache
|
| 1996 |
+
|
| 1997 |
+
if not db:
|
| 1998 |
+
return "<div>Database not available.</div>"
|
| 1999 |
+
|
| 2000 |
+
if not reference_loader:
|
| 2001 |
+
return "<div><strong>No reference data available.</strong><br>Please ensure <code>extra.jsonl</code> file is in the same directory as the database file or in the current working directory.</div>"
|
| 2002 |
+
|
| 2003 |
+
problem_ids = reference_loader.get_all_problem_ids()
|
| 2004 |
+
if not problem_ids:
|
| 2005 |
+
return "<div>No reference problems found in extra.jsonl file.</div>"
|
| 2006 |
+
|
| 2007 |
+
# 如果缓存为空,返回加载提示
|
| 2008 |
+
if not reference_accuracy_cache:
|
| 2009 |
+
return "<div><strong>Computing problem accuracies...</strong><br>This may take a moment on first load.</div>"
|
| 2010 |
+
|
| 2011 |
+
print(f"Using cached accuracies for {len(problem_ids)} reference problems")
|
| 2012 |
+
|
| 2013 |
+
# 创建两行网格:第一行英文,第二行中文
|
| 2014 |
+
custom_style = "<style>.ref-problem-btn, .ref-problem-btn div { color: white !important; }</style>"
|
| 2015 |
+
|
| 2016 |
+
html_en = ""
|
| 2017 |
+
html_zh = ""
|
| 2018 |
+
|
| 2019 |
+
# 按数字顺序排序
|
| 2020 |
+
sorted_problem_ids = sorted(problem_ids, key=int)
|
| 2021 |
+
|
| 2022 |
+
for pid in sorted_problem_ids:
|
| 2023 |
+
# 从缓存获取准确率
|
| 2024 |
+
accuracy_data = reference_accuracy_cache.get(pid, {"EN": 0.0, "ZH": 0.0})
|
| 2025 |
+
en_acc = accuracy_data["EN"]
|
| 2026 |
+
zh_acc = accuracy_data["ZH"]
|
| 2027 |
+
|
| 2028 |
+
# 英文版本按钮
|
| 2029 |
+
en_bg_color = get_gradient_color(en_acc)
|
| 2030 |
+
en_acc_pct = int(en_acc * 100)
|
| 2031 |
+
html_en += f"""
|
| 2032 |
+
<div
|
| 2033 |
+
data-problem-id="{pid}"
|
| 2034 |
+
class="ref-problem-btn"
|
| 2035 |
+
title="ID: {pid} (EN) - Avg Acc: {en_acc_pct}%"
|
| 2036 |
+
style='background-color: {en_bg_color}; color: white !important;
|
| 2037 |
+
border-radius: 4px; padding: 5px; text-align: center; font-size: 0.7em;
|
| 2038 |
+
min-height: 36px; user-select: none; width: 100%;
|
| 2039 |
+
display: flex; flex-direction: column; justify-content: center;
|
| 2040 |
+
overflow: hidden; text-overflow: ellipsis; white-space: nowrap; cursor: pointer;'>
|
| 2041 |
+
<div style="font-weight: bold; color: white !important;">{pid}</div>
|
| 2042 |
+
<div style="color: white !important;">{en_acc_pct}%</div>
|
| 2043 |
+
</div>
|
| 2044 |
+
"""
|
| 2045 |
+
|
| 2046 |
+
# 中文版本按钮
|
| 2047 |
+
zh_bg_color = get_gradient_color(zh_acc)
|
| 2048 |
+
zh_acc_pct = int(zh_acc * 100)
|
| 2049 |
+
html_zh += f"""
|
| 2050 |
+
<div
|
| 2051 |
+
data-problem-id="{pid}"
|
| 2052 |
+
class="ref-problem-btn"
|
| 2053 |
+
title="ID: {pid} (ZH) - Avg Acc: {zh_acc_pct}%"
|
| 2054 |
+
style='background-color: {zh_bg_color}; color: white !important;
|
| 2055 |
+
border-radius: 4px; padding: 5px; text-align: center; font-size: 0.7em;
|
| 2056 |
+
min-height: 36px; user-select: none; width: 100%;
|
| 2057 |
+
display: flex; flex-direction: column; justify-content: center;
|
| 2058 |
+
overflow: hidden; text-overflow: ellipsis; white-space: nowrap; cursor: pointer;'>
|
| 2059 |
+
<div style="font-weight: bold; color: white !important;">{pid}</div>
|
| 2060 |
+
<div style="color: white !important;">{zh_acc_pct}%</div>
|
| 2061 |
+
</div>
|
| 2062 |
+
"""
|
| 2063 |
+
|
| 2064 |
+
# 计算网格列数(根据问题数量)
|
| 2065 |
+
grid_cols = len(sorted_problem_ids) if len(sorted_problem_ids) <= 30 else 30
|
| 2066 |
+
|
| 2067 |
+
# 组合成完整的HTML
|
| 2068 |
+
grid_html = f"""
|
| 2069 |
+
{custom_style}
|
| 2070 |
+
<div style='margin-bottom: 10px;'>
|
| 2071 |
+
<div style='display: grid; grid-template-columns: repeat({grid_cols}, 1fr); gap: 2px;'>{html_en}</div>
|
| 2072 |
+
</div>
|
| 2073 |
+
<div>
|
| 2074 |
+
<div style='display: grid; grid-template-columns: repeat({grid_cols}, 1fr); gap: 2px;'>{html_zh}</div>
|
| 2075 |
+
</div>
|
| 2076 |
+
"""
|
| 2077 |
+
return grid_html
|
| 2078 |
+
|
| 2079 |
+
def handle_reference_problem_select(problem_id, reference_loader):
|
| 2080 |
+
"""Handle reference problem selection and display all information"""
|
| 2081 |
+
if not problem_id or not reference_loader:
|
| 2082 |
+
return ("Please select a problem.", "Please select a problem.",
|
| 2083 |
+
"Please select a problem.", "Please select a problem.", "Please select a problem.")
|
| 2084 |
+
|
| 2085 |
+
try:
|
| 2086 |
+
problem_id_int = int(problem_id)
|
| 2087 |
+
except ValueError:
|
| 2088 |
+
return ("Please enter a valid problem ID.", "Please enter a valid problem ID.",
|
| 2089 |
+
"Please enter a valid problem ID.", "Please enter a valid problem ID.", "Please enter a valid problem ID.")
|
| 2090 |
+
|
| 2091 |
+
reference_data = reference_loader.get_problem_data(problem_id_int)
|
| 2092 |
+
if not reference_data:
|
| 2093 |
+
error_msg = f"Problem {problem_id_int} not found in reference data."
|
| 2094 |
+
return (error_msg, error_msg, "No subject available.", "No answer available.", "Solution not available.")
|
| 2095 |
+
|
| 2096 |
+
# 格式化各个部分
|
| 2097 |
+
en_problem = format_markdown_with_math(reference_data.get('en_problem', 'Problem (EN) not available.'))
|
| 2098 |
+
zh_problem = format_markdown_with_math(reference_data.get('zh_problem', 'Problem (ZH) not available.'))
|
| 2099 |
+
|
| 2100 |
+
# 处理答案格式 - 使用特殊的答案格式处理
|
| 2101 |
+
answer_text = reference_data.get('answer', 'No answer available.')
|
| 2102 |
+
answer = format_answer_with_math(answer_text)
|
| 2103 |
+
|
| 2104 |
+
# 科目显示
|
| 2105 |
+
subject_en = reference_data.get('subject', 'Unknown')
|
| 2106 |
+
subject_zh = SUBJECT_TRANS_EN_TO_ZH.get(subject_en, subject_en)
|
| 2107 |
+
subject_display = f"**{subject_en}** / **{subject_zh}**"
|
| 2108 |
+
|
| 2109 |
+
# Solution - 使用solution字段,通常是中文解答
|
| 2110 |
+
solution_text = reference_data.get('solution', 'Solution not available.')
|
| 2111 |
+
if solution_text != 'Solution not available.':
|
| 2112 |
+
solution = format_solution_latex(solution_text)
|
| 2113 |
+
else:
|
| 2114 |
+
solution = solution_text
|
| 2115 |
+
|
| 2116 |
+
return (en_problem, zh_problem, subject_display, answer, solution)
|
| 2117 |
+
|
| 2118 |
+
def test_reference_data_availability(db, reference_loader):
|
| 2119 |
+
"""Test function to check if reference data is available"""
|
| 2120 |
+
print("=== Reference Data Availability Test ===")
|
| 2121 |
+
|
| 2122 |
+
# Test database
|
| 2123 |
+
if not db:
|
| 2124 |
+
print("❌ Database is not available")
|
| 2125 |
+
return False
|
| 2126 |
+
|
| 2127 |
+
# Check database schema
|
| 2128 |
+
try:
|
| 2129 |
+
cursor = db.conn.cursor()
|
| 2130 |
+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
| 2131 |
+
tables = [row[0] for row in cursor.fetchall()]
|
| 2132 |
+
print(f"✅ Database tables: {tables}")
|
| 2133 |
+
|
| 2134 |
+
# Check problems table
|
| 2135 |
+
cursor.execute("SELECT COUNT(*) FROM problems")
|
| 2136 |
+
problem_count = cursor.fetchone()[0]
|
| 2137 |
+
print(f"✅ Problems table: {problem_count} problems")
|
| 2138 |
+
|
| 2139 |
+
# Check responses table
|
| 2140 |
+
cursor.execute("SELECT COUNT(*) FROM responses")
|
| 2141 |
+
response_count = cursor.fetchone()[0]
|
| 2142 |
+
print(f"✅ Responses table: {response_count} responses")
|
| 2143 |
+
|
| 2144 |
+
# Check unique datasets
|
| 2145 |
+
cursor.execute("SELECT DISTINCT dataset FROM responses")
|
| 2146 |
+
datasets = [row[0] for row in cursor.fetchall()]
|
| 2147 |
+
print(f"✅ Available datasets: {datasets}")
|
| 2148 |
+
|
| 2149 |
+
# Check some sample unique_ids from problems
|
| 2150 |
+
cursor.execute("SELECT unique_id FROM problems LIMIT 10")
|
| 2151 |
+
sample_ids = [row[0] for row in cursor.fetchall()]
|
| 2152 |
+
print(f"✅ Sample problem unique_ids: {sample_ids}")
|
| 2153 |
+
|
| 2154 |
+
except Exception as e:
|
| 2155 |
+
print(f"❌ Error checking database schema: {e}")
|
| 2156 |
+
|
| 2157 |
+
models = db.get_available_models()
|
| 2158 |
+
print(f"✅ Database connected: {len(models)} models available")
|
| 2159 |
+
|
| 2160 |
+
# Test reference loader
|
| 2161 |
+
if not reference_loader:
|
| 2162 |
+
print("❌ Reference loader is not available (extra.jsonl not found)")
|
| 2163 |
+
return False
|
| 2164 |
+
|
| 2165 |
+
problem_ids = reference_loader.get_all_problem_ids()
|
| 2166 |
+
print(f"✅ Reference loader: {len(problem_ids)} problems available: {problem_ids}")
|
| 2167 |
+
|
| 2168 |
+
# Test a specific problem (simplified test)
|
| 2169 |
+
if problem_ids:
|
| 2170 |
+
test_id = problem_ids[0]
|
| 2171 |
+
en_unique_id = f"OlymMATH-HARD-{test_id}-EN"
|
| 2172 |
+
zh_unique_id = f"OlymMATH-HARD-{test_id}-ZH"
|
| 2173 |
+
|
| 2174 |
+
print(f"Testing with constructed IDs: {en_unique_id}, {zh_unique_id}")
|
| 2175 |
+
|
| 2176 |
+
# Check if problems exist in database
|
| 2177 |
+
problem_en, responses_en = db.get_problem_data(None, "EN-HARD", en_unique_id)
|
| 2178 |
+
problem_zh, responses_zh = db.get_problem_data(None, "ZH-HARD", zh_unique_id)
|
| 2179 |
+
|
| 2180 |
+
print(f"Test problem {test_id}:")
|
| 2181 |
+
print(f" EN problem exists: {problem_en is not None}")
|
| 2182 |
+
print(f" ZH problem exists: {problem_zh is not None}")
|
| 2183 |
+
if responses_en:
|
| 2184 |
+
print(f" EN responses: {len(responses_en)} found")
|
| 2185 |
+
if responses_zh:
|
| 2186 |
+
print(f" ZH responses: {len(responses_zh)} found")
|
| 2187 |
+
|
| 2188 |
+
print("=== End Test ===")
|
| 2189 |
+
return True
|
| 2190 |
+
|
| 2191 |
+
def test_latex_formatting():
|
| 2192 |
+
"""Test function to verify LaTeX environment processing"""
|
| 2193 |
+
test_text = """
|
| 2194 |
+
易知,1, 4, 6, 7, 9 这五个数中的任意两个数之差均不为 4 或 7.
|
| 2195 |
+
|
| 2196 |
+
$$
|
| 2197 |
+
\\begin{aligned}
|
| 2198 |
+
\\sum_{n=1}^{2023}f_{n} &= \\sum_{k=0}^{183}\\sum_{i=0}^{10}f_{11k+i} \\\\
|
| 2199 |
+
&= \\sum_{k=0}^{183}(11 \\times 5k+1+2+3+5 \\times 4+2 \\times 5) \\\\
|
| 2200 |
+
&= 55 \\times \\frac{183 \\times 184}{2}+184 \\times 36 \\\\
|
| 2201 |
+
&= 932604.
|
| 2202 |
+
\\end{aligned}
|
| 2203 |
+
$$
|
| 2204 |
+
|
| 2205 |
+
故答案为:$\\boxed{932604}$.
|
| 2206 |
+
"""
|
| 2207 |
+
|
| 2208 |
+
formatted = format_markdown_with_math(test_text)
|
| 2209 |
+
print("=== LaTeX Formatting Test ===")
|
| 2210 |
+
print("Original text contains \\begin{aligned}:", "\\begin{aligned}" in test_text)
|
| 2211 |
+
print("Formatted text contains \\begin{aligned}:", "\\begin{aligned}" in formatted)
|
| 2212 |
+
print("Formatted text (first 300 chars):", formatted[:300])
|
| 2213 |
+
print("=== End Test ===")
|
| 2214 |
+
return formatted
|
| 2215 |
+
|
| 2216 |
+
def format_solution_latex(text):
|
| 2217 |
+
"""Preprocess solution text by converting LaTeX delimiters from MathJax to KaTeX format"""
|
| 2218 |
+
if text is None:
|
| 2219 |
+
return ""
|
| 2220 |
+
|
| 2221 |
+
# Convert $$xxx$$ to \[xxx\] (display math)
|
| 2222 |
+
# Use non-greedy matching and handle multiple lines
|
| 2223 |
+
text = re.sub(r'\$\$(.*?)\$\$', r'\\[\1\\]', text, flags=re.DOTALL)
|
| 2224 |
+
|
| 2225 |
+
# Convert $xxx$ to \(xxx\) (inline math)
|
| 2226 |
+
# Be careful not to match already converted \[...\] content
|
| 2227 |
+
text = re.sub(r'(?<!\\)\$([^$\n]+?)\$(?!\])', r'\\(\1\\)', text)
|
| 2228 |
+
|
| 2229 |
+
# Convert newlines for markdown
|
| 2230 |
+
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
| 2231 |
+
|
| 2232 |
+
# Clean up excessive newlines
|
| 2233 |
+
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
| 2234 |
+
|
| 2235 |
+
return text
|
| 2236 |
+
|
| 2237 |
+
def format_answer_with_math(text):
|
| 2238 |
+
"""Special formatting for answer fields - manually wrap with \(\) delimiters"""
|
| 2239 |
+
if text is None or text.strip() == "" or text == "No answer available.":
|
| 2240 |
+
return text
|
| 2241 |
+
|
| 2242 |
+
# Convert newlines for markdown
|
| 2243 |
+
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
| 2244 |
+
|
| 2245 |
+
# Convert $$xxx$$ to $xxx$ first (same as before)
|
| 2246 |
+
text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', text, flags=re.DOTALL)
|
| 2247 |
+
|
| 2248 |
+
# Check if answer already contains dollar signs, if not add them
|
| 2249 |
+
if '$' not in text and text.strip():
|
| 2250 |
+
text = f"${text}$"
|
| 2251 |
+
|
| 2252 |
+
# Now convert $xxx$ to \(xxx\) for proper rendering
|
| 2253 |
+
text = re.sub(r'(?<!\\)\$([^$\n]+?)\$', r'\\(\1\\)', text)
|
| 2254 |
+
|
| 2255 |
+
# Clean up excessive newlines
|
| 2256 |
+
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
| 2257 |
+
|
| 2258 |
+
return text
|
| 2259 |
+
|
| 2260 |
# 修改主函数以使用优化策略
|
| 2261 |
if __name__ == "__main__":
|
| 2262 |
DB_PATH = "data.db"
|
|
|
|
| 2264 |
# 检查数据库文件是否存在,如果不存在则从 Hugging Face 下载
|
| 2265 |
if not os.path.exists(DB_PATH):
|
| 2266 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2267 |
DB_PATH = hf_hub_download(
|
| 2268 |
repo_id="CoderBak/OlymMATH-data",
|
| 2269 |
filename="data.db",
|
| 2270 |
+
repo_type="dataset"
|
|
|
|
| 2271 |
)
|
| 2272 |
except Exception as e:
|
| 2273 |
# 创建一个显示错误信息的简单 Gradio 应用
|
| 2274 |
with gr.Blocks() as error_demo:
|
| 2275 |
+
gr.Markdown(f"# Error: Database Download Failed\n{str(e)}")
|
| 2276 |
error_demo.launch(server_name="0.0.0.0")
|
| 2277 |
exit(1)
|
| 2278 |
|