| | import os
|
| | import yaml
|
| | import torch
|
| | import numpy as np
|
| | import pandas as pd
|
| | from tqdm import tqdm
|
| | from sklearn.metrics import (
|
| | average_precision_score,
|
| | roc_auc_score,
|
| | f1_score,
|
| | precision_score,
|
| | recall_score,
|
| | accuracy_score
|
| | )
|
| |
|
| |
|
| | from models.transmil_q2l import TransMIL_Query2Label_E2E
|
| | from thyroid_dataset import create_dataloaders, TARGET_CLASSES
|
| | '''
|
| | # 18类标签定义 (与训练时保持一致)
|
| | TARGET_CLASSES = [
|
| | "TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
|
| | "TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
|
| | "钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留", "切除术后",
|
| | "弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
|
| | ]
|
| | '''
|
| |
|
| | def get_best_checkpoint_path(save_dir):
|
| | """自动寻找 best checkpoint"""
|
| | best_path = os.path.join(save_dir, 'checkpoint_best.pth')
|
| | if os.path.exists(best_path):
|
| | return best_path
|
| |
|
| | latest_path = os.path.join(save_dir, 'checkpoint_latest.pth')
|
| | if os.path.exists(latest_path):
|
| | print(f"Warning: 'checkpoint_best.pth' not found. Using '{latest_path}' instead.")
|
| | return latest_path
|
| | raise FileNotFoundError(f"No checkpoints found in {save_dir}")
|
| |
|
| | def compute_metrics(y_true, y_pred_probs, threshold=0.5):
|
| | """
|
| | 计算全面的多标签指标
|
| | y_true: [N, num_classes] (0 or 1)
|
| | y_pred_probs: [N, num_classes] (0.0 ~ 1.0)
|
| | """
|
| | metrics = {}
|
| |
|
| |
|
| | y_pred_binary = (y_pred_probs >= threshold).astype(int)
|
| |
|
| |
|
| |
|
| | metrics['mAP'] = average_precision_score(y_true, y_pred_probs, average='macro')
|
| | metrics['weighted_mAP'] = average_precision_score(y_true, y_pred_probs, average='weighted')
|
| |
|
| |
|
| | try:
|
| | metrics['macro_auroc'] = roc_auc_score(y_true, y_pred_probs, average='macro')
|
| | metrics['micro_auroc'] = roc_auc_score(y_true, y_pred_probs, average='micro')
|
| | except ValueError:
|
| | metrics['macro_auroc'] = 0.0
|
| | metrics['micro_auroc'] = 0.0
|
| |
|
| |
|
| | metrics['micro_f1'] = f1_score(y_true, y_pred_binary, average='micro')
|
| | metrics['macro_f1'] = f1_score(y_true, y_pred_binary, average='macro')
|
| |
|
| |
|
| | metrics['subset_accuracy'] = accuracy_score(y_true, y_pred_binary)
|
| |
|
| |
|
| | class_metrics = []
|
| | for i, class_name in enumerate(TARGET_CLASSES):
|
| |
|
| | yt = y_true[:, i]
|
| | yp = y_pred_probs[:, i]
|
| | yb = y_pred_binary[:, i]
|
| |
|
| |
|
| | support = int(yt.sum())
|
| |
|
| |
|
| | if support > 0:
|
| | ap = average_precision_score(yt, yp)
|
| | try:
|
| | auroc = roc_auc_score(yt, yp)
|
| | except ValueError:
|
| | auroc = 0.5
|
| |
|
| | f1 = f1_score(yt, yb)
|
| | rec = recall_score(yt, yb)
|
| | prec = precision_score(yt, yb, zero_division=0)
|
| | else:
|
| | ap, auroc, f1, rec, prec = 0.0, 0.5, 0.0, 0.0, 0.0
|
| |
|
| | class_metrics.append({
|
| | "Class": class_name,
|
| | "Support": support,
|
| | "AP": ap,
|
| | "AUROC": auroc,
|
| | "F1": f1,
|
| | "Precision": prec,
|
| | "Recall": rec
|
| | })
|
| |
|
| | return metrics, pd.DataFrame(class_metrics)
|
| |
|
| | def main():
|
| |
|
| | config_path = 'config.yaml'
|
| | with open(config_path, 'r') as f:
|
| | config = yaml.safe_load(f)
|
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| | print(f"Evaluating on {device}")
|
| |
|
| |
|
| | print("Loading Test Data...")
|
| | _, _, test_loader = create_dataloaders(config)
|
| |
|
| |
|
| | print("Initializing Model...")
|
| | model = TransMIL_Query2Label_E2E(
|
| | num_class=config['model']['num_class'],
|
| | hidden_dim=config['model']['hidden_dim'],
|
| | nheads=config['model']['nheads'],
|
| | num_decoder_layers=config['model']['num_decoder_layers'],
|
| | pretrained_resnet=False,
|
| | use_checkpointing=False,
|
| | use_ppeg=config['model'].get('use_ppeg', False)
|
| | )
|
| |
|
| |
|
| | ckpt_path = get_best_checkpoint_path(config['training']['save_dir'])
|
| | print(f"Loading checkpoint from: {ckpt_path}")
|
| | checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| |
|
| |
|
| | state_dict = checkpoint['model_state_dict']
|
| | new_state_dict = {}
|
| | for k, v in state_dict.items():
|
| | name = k.replace("module.", "")
|
| | new_state_dict[name] = v
|
| | model.load_state_dict(new_state_dict)
|
| |
|
| | model.to(device)
|
| | model.eval()
|
| |
|
| |
|
| | print("Running Inference...")
|
| | all_preds = []
|
| | all_targets = []
|
| |
|
| | with torch.no_grad():
|
| | for batch in tqdm(test_loader):
|
| | images = batch['images'].to(device)
|
| | num_instances = batch['num_instances_per_case']
|
| | labels = batch['labels'].numpy()
|
| |
|
| |
|
| | logits = model(images, num_instances)
|
| | probs = torch.sigmoid(logits).cpu().numpy()
|
| |
|
| | all_preds.append(probs)
|
| | all_targets.append(labels)
|
| |
|
| |
|
| | y_pred_probs = np.concatenate(all_preds, axis=0)
|
| | y_true = np.concatenate(all_targets, axis=0)
|
| |
|
| |
|
| | print("\nComputing Metrics...")
|
| | global_metrics, class_df = compute_metrics(y_true, y_pred_probs)
|
| |
|
| |
|
| | print("\n" + "="*60)
|
| | print(" GLOBAL PERFORMANCE SUMMARY ")
|
| | print("="*60)
|
| | print(f" mAP (Macro) : {global_metrics['mAP']:.4f}")
|
| | print(f" mAP (Weighted): {global_metrics['weighted_mAP']:.4f}")
|
| | print(f" AUROC (Macro) : {global_metrics['macro_auroc']:.4f}")
|
| | print(f" AUROC (Micro) : {global_metrics['micro_auroc']:.4f}")
|
| | print(f" F1 (Micro) : {global_metrics['micro_f1']:.4f}")
|
| | print(f" F1 (Macro) : {global_metrics['macro_f1']:.4f}")
|
| | print(f" Subset Acc : {global_metrics['subset_accuracy']:.4f}")
|
| | print("-" * 60)
|
| |
|
| | print("\n" + "="*100)
|
| | print(" PER-CLASS PERFORMANCE DETAILS (Sorted by Support) ")
|
| | print("="*100)
|
| |
|
| |
|
| | class_df = class_df.sort_values(by='Support', ascending=False)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | headers = ["Class", "Support", "AP", "AUROC", "F1", "Precision", "Recall"]
|
| |
|
| |
|
| |
|
| | head_fmt = "{:<24} {:>8} {:>10} {:>10} {:>10} {:>12} {:>10}"
|
| | print(head_fmt.format(*headers))
|
| | print("-" * 100)
|
| |
|
| |
|
| | row_fmt = "{:<24} {:>8d} {:>10.4f} {:>10.4f} {:>10.4f} {:>12.4f} {:>10.4f}"
|
| |
|
| | for _, row in class_df.iterrows():
|
| | cls_name = row['Class']
|
| |
|
| |
|
| | display_width = len(cls_name.encode('gbk'))
|
| |
|
| |
|
| |
|
| | target_width = 24
|
| | padding = target_width - display_width
|
| |
|
| |
|
| | aligned_name = cls_name + " " * padding
|
| |
|
| | print(f"{aligned_name} {int(row['Support']):>8d} {row['AP']:>10.4f} {row['AUROC']:>10.4f} {row['F1']:>10.4f} {row['Precision']:>12.4f} {row['Recall']:>10.4f}")
|
| |
|
| | print("="*100)
|
| |
|
| |
|
| | result_csv = os.path.join(config['training']['save_dir'], 'evaluation_report.csv')
|
| | class_df.to_csv(result_csv, index=False, encoding='utf-8-sig')
|
| | print(f"\nDetailed report saved to: {result_csv}")
|
| |
|
| | if __name__ == "__main__":
|
| | main() |