tox21_leaderboard / frontend /leaderboard.py
Tschoui's picture
:lipstick: Improve frontend
3ba0b02
"""
Leaderboard-specific business logic.
Handles data processing, backend communication, and state management.
"""
import pandas as pd
from typing import Optional
from datetime import datetime
from config.leaderboard import MAX_DECIMALS, COLUMN_NAMES
def parse_parameter_count(value):
"""Parse parameter count from various formats to raw numeric value.
Accepts:
- Raw numbers: 120000000, "120000000"
- Human-readable: "120M", "0.12B", "154K"
- Empty/None values
Args:
value: Parameter count in any supported format
Returns:
int: Raw parameter count, or None for empty/invalid values
"""
if pd.isna(value) or value == "" or value is None:
return None
# If already a number, return it
if isinstance(value, (int, float)):
return int(value)
# Convert string to number
value_str = str(value).strip().upper()
if not value_str:
return None
# Extract numeric part and suffix
import re
match = re.match(r'^([0-9.]+)\s*([KMBT]?)$', value_str)
if not match:
return None
num_part = float(match.group(1))
suffix = match.group(2)
# Apply multiplier based on suffix
multipliers = {'K': 1e3, 'M': 1e6, 'B': 1e9, 'T': 1e12, '': 1}
return int(num_part * multipliers[suffix])
def format_parameter_count(value):
"""Format parameter count to human-readable string (B, M, K).
Args:
value: Raw parameter count (int/float) or np.nan
Returns:
Formatted string like '40.1M', '1.9M', '154K' or empty string for NaN
"""
if pd.isna(value) or value == "":
return ""
value = float(value)
if value >= 1e9:
# Format as integer B if rounded to 1 decimal equals the integer, otherwise 1 decimal
formatted = value / 1e9
rounded = round(formatted, 1)
return f"{int(rounded)}B" if rounded == int(rounded) else f"{rounded}B"
elif value >= 1e6:
# Format as integer M if rounded to 1 decimal equals the integer, otherwise 1 decimal
formatted = value / 1e6
rounded = round(formatted, 1)
return f"{int(rounded)}M" if rounded == int(rounded) else f"{rounded}M"
elif value >= 1e3:
# Format as integer K if rounded to 1 decimal equals the integer, otherwise 1 decimal
formatted = value / 1e3
rounded = round(formatted, 1)
return f"{int(rounded)}K" if rounded == int(rounded) else f"{rounded}K"
else:
return str(int(value))
def refresh_leaderboard() -> pd.DataFrame:
"""
Refresh leaderboard data by fetching from backend.
Currently returns sample data - will connect to backend later.
"""
print("= Refreshing leaderboard data...")
# Load data from backend
from backend.data_loader import load_leaderboard_data
results_data = load_leaderboard_data()
results_data = format_leaderboard_data(results_data)
# assert all(
# [c in COLUMN_NAMES for c in results_data.columns]
# ), "Some required columns not found in dataset!"
return results_data
def format_leaderboard_data(raw_data: dict) -> pd.DataFrame:
"""
Format raw leaderboard data for display.
Args:
raw_data: Raw data from backend/datasets
Returns:
Formatted DataFrame for Gradio display
"""
# TODO: Implement data formatting logic
# This will process raw evaluation results into the display format
# Convert to DataFrame (new schema only)
rows = []
for entry in raw_data:
config = entry["config"]
results = entry["results"]
# Only include approved entries
if not config.get("approved", False):
continue
# Determine model type based on flags
pretrained = config.get("pretrained", "")=="Yes"
zero_shot = config.get("zero_shot", "")=="Yes"
few_shot = config.get("few_shot", "")=="Yes"
# Model type emoji logic
if zero_shot:
model_type = "0️⃣" # Zero-shot
elif few_shot:
model_type = "1️⃣" # Few-shot
elif pretrained:
model_type = "⤵️" # Pre-training
else:
model_type = "🔼" # Standard (trained on Tox21 only)
# Create a row with all the data
# Column order: Type will be added as 2nd column after Rank
row = {
("", "Type"): model_type,
("", "Model"): config["model_name"],
("", "HF_Space_Tag"): config.get("hf_space_tag", ""), # Hidden column for links
("", "Organization"): config.get("organization", ""),
("", "Publication"): config.get("publication_title", ""),
("", "Publication Link"): config.get("publication_link", ""), # Hidden column for links
("", "Model Description"): config["model_description"],
("", "Avg. AUC"): results["overall_score"]["roc_auc"],
("", "Avg. ΔAUC-PR"): results["overall_score"].get("delta_auprc"),
("", "# Parameters"): config.get("model_size", ""), # Moved here after Avg. ΔAUC-PR
}
print(results["overall_score"])
# === Insert task columns immediately after # Parameters ===
for task_key, task_result in results.items():
if task_key != "overall_score":
row[("ROC-AUC", task_key)] = task_result.get("roc_auc", "")
for task_key, task_result in results.items():
if task_key != "overall_score":
row[("ΔAUC-PR", task_key)] = task_result.get("delta_auprc", "")
# === Then continue with the rest of the metadata columns ===
row.update({
("", "Pretrained"): pretrained,
("", "Pretraining Data"): config.get("pretraining_data", ""),
("", "Zero-shot"): zero_shot,
("", "Few-shot"): few_shot,
("", "N-shot"): config.get("n_shot", ""),
})
date_raw = config.get("date_approved", config.get("date_submitted", ""))
try:
# Parse if ISO-like (e.g. "2025-09-11T12:51:33.227003")
date_obj = datetime.fromisoformat(
str(date_raw).replace("Z", "")
) # remove 'Z' if present
date_str = date_obj.strftime("%Y-%m-%d") # ✅ just date
except Exception:
# fallback if parsing fails
date_str = str(date_raw).split("T")[0].split()[0]
row.update({
# ...
("", "Date Added"): date_str,
})
rows.append(row)
df = pd.DataFrame(rows)
df.columns = pd.MultiIndex.from_tuples(df.columns)
# Handle empty dataset case
if df.empty:
print(
"No approved submissions found. Creating empty DataFrame with proper columns."
)
# Create empty DataFrame with expected columns
df = pd.DataFrame(columns=COLUMN_NAMES)
else:
# rank according to overall score
df = df.sort_values(by=("", "Avg. AUC"), ascending=False).reset_index(
drop=True
)
# set different precision
print(f"Created DataFrame with shape: {df.shape}")
df = df.round(decimals=MAX_DECIMALS)
return df
def calculate_average_score(task_scores: dict) -> float:
"""
Calculate average ROC-AUC score across all tasks.
Args:
task_scores: Dictionary of task_name -> score
Returns:
Average score across all tasks
"""
if not task_scores:
return 0.0
valid_scores = [
score for score in task_scores.values() if score is not None
]
if not valid_scores:
return 0.0
return sum(valid_scores) / len(valid_scores)
def sort_by_performance(leaderboard_data: pd.DataFrame) -> pd.DataFrame:
"""
Sort leaderboard by average performance score.
Args:
leaderboard_data: DataFrame with leaderboard data
Returns:
Sorted DataFrame with rank column updated
"""
# Sort by average score (descending)
sorted_data = leaderboard_data.sort_values(by="Average", ascending=False)
# Update rank column
sorted_data["Rank"] = range(1, len(sorted_data) + 1)
return sorted_data
def filter_leaderboard(
data: pd.DataFrame,
min_score: Optional[float] = None,
model_type: Optional[str] = None,
date_range: Optional[tuple] = None,
) -> pd.DataFrame:
"""
Filter leaderboard data based on criteria.
Args:
data: Original leaderboard data
min_score: Minimum average score threshold
model_type: Filter by model type
date_range: Filter by submission date range
Returns:
Filtered DataFrame
"""
filtered_data = data.copy()
if min_score is not None:
filtered_data = filtered_data[filtered_data["Average"] >= min_score]
# TODO: Add more filtering logic as needed
return filtered_data