tox21_leaderboard / backend /evaluator.py
Tschoui's picture
Update backend/evaluator.py
9115aed verified
import asyncio
import httpx
from typing import List, Dict, Any
BATCH_SIZE = 1000
TIMEOUT_S = 36000 # 600 todo: temporallily changed by js, consider reverting or changing
MAX_RETRIES = 3
RETRY_DELAY = 1
def chunks(xs: List[str], n: int):
for i in range(0, len(xs), n):
yield xs[i:i+n]
async def fetch_metadata(client: httpx.AsyncClient, base_url: str) -> Dict[str, Any]:
for attempt in range(MAX_RETRIES):
r = await client.get(f"{base_url}/metadata", timeout=30)
r.raise_for_status()
return r.json()
async def call_predict(client: httpx.AsyncClient, base_url: str, smiles_batch: List[str]) -> Dict[str, Any]:
for attempt in range(MAX_RETRIES):
r = await client.post(
f"{base_url}/predict",
json={"smiles": smiles_batch},
timeout=TIMEOUT_S,
)
r.raise_for_status()
return r.json()
async def evaluate_model(hf_space_tag: str, smiles_list: List[str]) -> Dict[str, Any]:
# Convert username/space-name to username-space-name.hf.space
base_url = f"https://{hf_space_tag.replace('/', '-').replace('_', '-').lower()}.hf.space"
results = []
async with httpx.AsyncClient() as client:
meta = await fetch_metadata(client, base_url)
max_bs = min(meta.get("max_batch_size", BATCH_SIZE), BATCH_SIZE)
for batch in chunks(smiles_list, max_bs):
resp = await call_predict(client, base_url, batch)
predictions_dict = resp["predictions"]
for smiles in batch:
if smiles in predictions_dict:
results.append({"smiles": smiles, "raw_predictions": predictions_dict[smiles]})
else:
results.append({"smiles": smiles, "raw_predictions": {}, "error": "No prediction found"})
return {"results": results, "metadata": meta}