| import matplotlib.pyplot as plt |
| import time |
| import torch |
| from torch.utils.data import DataLoader |
| from torchvision import datasets, transforms |
| import numpy as np |
| import tracemalloc |
|
|
| |
| from Andromeda.model import Andromeda |
| from Andromeda.utils.stable_adamw import StableAdamWUnfused |
|
|
| torch.manual_seed(0) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(0) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
| import torch.nn.functional as F |
| from nltk.translate.bleu_score import corpus_bleu |
| from rouge import Rouge |
| from sklearn.metrics import f1_score |
|
|
|
|
| class AccuracyMetrics: |
| def __init__(self): |
| self.rouge = Rouge() |
| |
| def calculate_perplexity(self, model, data_loader): |
| model.eval() |
| total_loss = 0 |
| with torch.no_grad(): |
| for batch in data_loader: |
| input_ids, labels = batch |
| output = model(input_ids) |
| loss = F.cross_entropy(output.view(-1, output.size(-1)), labels.view(-1)) |
| total_loss += loss.item() |
| return torch.exp(torch.tensor(total_loss / len(data_loader))) |
| |
| def calculate_bleu(self, references, hypotheses): |
| return corpus_bleu(references, hypotheses) |
| |
| def calculate_rouge(self, references, hypotheses): |
| scores = self.rouge.get_scores(hypotheses, references, avg=True) |
| return scores |
| |
| def calculate_f1(self, true_labels, pred_labels): |
| return f1_score(true_labels, pred_labels, average="weighted") |
|
|
|
|
|
|
|
|
|
|
| |
| test_dataset = datasets.FakeData(size=1000, transform=transforms.ToTensor()) |
|
|
| |
| model = Andromeda( |
| num_tokens=50304, |
| dim=1024, |
| depth=24, |
| dim_head=128, |
| heads=8, |
| alibi_num_heads=4 |
| ) |
|
|
|
|
|
|
| |
| accuracy_metrics = AccuracyMetrics() |
|
|
| |
| perplexity = accuracy_metrics.calculate_perplexity(model, data_loader) |
| print('Perplexity:', perplexity) |
|
|
| |
| bleu = accuracy_metrics.calculate_bleu(references, hypotheses) |
| print('BLEU Score:', bleu) |
|
|
| |
| rouge_scores = accuracy_metrics.calculate_rouge(references, hypotheses) |
| print('ROUGE Scores:', rouge_scores) |
|
|
| |
| f1 = accuracy_metrics.calculate_f1(true_labels, pred_labels) |
| print('F1 Score:', f1) |
|
|
|
|
|
|
|
|
| |
| if __name__ == "__main__": |
| AccuracyMetrics() |