π¬ ViT Base Patch16 384 β GI Endoscopy Classifier
State-of-the-art Vision Transformer for 23-class Gastrointestinal Endoscopy Image Classification
π Overview
This repository contains a fine-tuned ViT Base Patch16 384 model for classifying gastrointestinal endoscopy images into 23 anatomical/pathological categories. Trained on the Hyper-Kvasir dataset with advanced augmentation techniques including MixUp, Focal Loss, and Test-Time Augmentation (TTA).
β¨ Key Features
| Feature | Description |
|---|---|
| π― High Accuracy | 93.25% test accuracy with TTA |
| π₯ Modern Architecture | ViT Base Patch16 @ 384Γ384 resolution |
| π Robust Training | MixUp, Focal Loss, Label Smoothing, CoarseDropout |
| β‘ Production Ready | TorchScript traced weights for fast inference |
| π§ͺ TTA Support | Test-Time Augmentation for improved predictions |
π Performance Metrics
Final Results
| Metric | Validation (Best) | Test (with TTA) |
|---|---|---|
| Accuracy | 92.18% | 93.25% |
| Precision | β | 92.19% |
| Recall | β | 93.25% |
| F1-Score | β | 92.59% |
Training Progression
| Epoch | Train Acc | Val Acc | Learning Rate | Checkpoint |
|---|---|---|---|---|
| 1 | 50.58% | 81.93% | 4.00e-06 | β |
| 2 | 67.99% | 86.68% | 6.00e-06 | β |
| 3 | 74.18% | 87.87% | 8.00e-06 | β |
| 4 | 74.81% | 88.81% | 1.00e-05 | β |
| 5 | 77.37% | 89.12% | 1.00e-05 | β |
| 6 | 77.56% | 89.49% | 9.94e-06 | β |
| 8 | 80.09% | 90.56% | 9.46e-06 | β |
| 9 | 80.08% | 90.68% | 9.05e-06 | β |
| 10 | 80.44% | 90.81% | 8.54e-06 | β |
| 12 | 82.21% | 91.62% | 7.27e-06 | β |
| 16 | 85.41% | 91.74% | 4.22e-06 | β |
| 18 | 84.59% | 92.06% | 2.73e-06 | β |
| 20 | 86.29% | 92.12% | 1.46e-06 | β |
| 21 | 85.86% | 92.18% | 9.55e-07 | β Best |
| 25 | 86.17% | 92.12% | 0.00e+00 | β |
ποΈ Model Architecture
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β ViT Base Patch16 384 β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β Input: 384 Γ 384 Γ 3 (RGB) β
β Patch Size: 16 Γ 16 β
β Patches: (384/16)Β² = 576 patches β
β Hidden Dim: 768 β
β Layers: 12 Transformer blocks β
β Heads: 12 attention heads β
β Parameters: 86,108,183 (~86.1M) β
β Output: 23 classes (softmax) β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
ποΈ Dataset: Hyper-Kvasir
| Split | Images | Classes |
|---|---|---|
| Train | 7,463 | 23 |
| Validation | 1,599 | 23 |
| Test | 1,600 | 23 |
| Total | 10,662 | 23 |
23 GI Classes
Anatomical landmarks and pathological findings from upper and lower GI tract endoscopy.
βοΈ Training Configuration
Environment
PyTorch: 2.x (CUDA 11.8)
GPU: NVIDIA GPU with ~16GB VRAM
Python: 3.12
Platform: Google Colab
Dependencies
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install timm "albumentations>=1.0.0" opencv-python Pillow numpy scikit-learn matplotlib seaborn tqdm
Hyperparameters
| Parameter | Value |
|---|---|
| Model | vit_base_patch16_384 |
| Image Size | 384 Γ 384 |
| Batch Size | 2 |
| Effective Batch Size | 16 (8Γ gradient accumulation) |
| Epochs | 25 |
| Base Learning Rate | 1e-5 |
| Optimizer | AdamW (weight_decay=0.01) |
| Scheduler | Cosine Annealing + 5-epoch Warmup |
| Loss | Focal Loss (Ξ³=2.0) + Label Smoothing (0.1) |
| Mixed Precision | β FP16 (GradScaler) |
| MixUp | β (Ξ±=0.2, p=0.5) |
Data Augmentation (Albumentations)
Training:
A.Compose([
A.Resize(384, 384),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.3),
A.RandomRotate90(p=0.5),
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
A.GaussNoise(p=0.3),
A.CoarseDropout(max_holes=1, max_height=32, max_width=32, p=0.3),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
Validation/Test:
A.Compose([
A.Resize(384, 384),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
π Quick Start
Installation
pip install torch torchvision timm albumentations
Inference (TorchScript)
import torch
from PIL import Image
from torchvision import transforms
# Load traced model
model = torch.jit.load("vit_best_traced.pt")
model.eval()
# Preprocessing (must match training)
preprocess = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load and classify image
img = Image.open("endoscopy_image.jpg").convert("RGB")
tensor = preprocess(img).unsqueeze(0)
with torch.no_grad():
logits = model(tensor)
probs = logits.softmax(dim=1)
confidence, pred_class = probs.max(dim=1)
print(f"Predicted class: {pred_class.item()}")
print(f"Confidence: {confidence.item():.2%}")
Inference with Test-Time Augmentation (TTA)
import torch
from PIL import Image
from torchvision import transforms
model = torch.jit.load("vit_best_traced.pt")
model.eval()
preprocess = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def predict_with_tta(model, tensor):
"""Test-Time Augmentation: average predictions across flips"""
with torch.no_grad():
# Original
pred1 = model(tensor).softmax(dim=1)
# Horizontal flip
pred2 = model(torch.flip(tensor, [3])).softmax(dim=1)
# Vertical flip
pred3 = model(torch.flip(tensor, [2])).softmax(dim=1)
# Average
return (pred1 + pred2 + pred3) / 3.0
img = Image.open("endoscopy_image.jpg").convert("RGB")
tensor = preprocess(img).unsqueeze(0)
probs = predict_with_tta(model, tensor)
confidence, pred_class = probs.max(dim=1)
print(f"Predicted class (TTA): {pred_class.item()}")
print(f"Confidence: {confidence.item():.2%}")
Batch Inference
import torch
from PIL import Image
from torchvision import transforms
from pathlib import Path
model = torch.jit.load("vit_best_traced.pt")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
preprocess = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def classify_batch(image_paths, batch_size=8):
results = []
for i in range(0, len(image_paths), batch_size):
batch_paths = image_paths[i:i+batch_size]
tensors = []
for path in batch_paths:
img = Image.open(path).convert("RGB")
tensors.append(preprocess(img))
batch = torch.stack(tensors).to(device)
with torch.no_grad():
probs = model(batch).softmax(dim=1)
confidences, preds = probs.max(dim=1)
for path, pred, conf in zip(batch_paths, preds, confidences):
results.append({
"file": str(path),
"class": pred.item(),
"confidence": conf.item()
})
return results
# Example usage
image_folder = Path("./test_images")
image_paths = list(image_folder.glob("*.jpg"))
results = classify_batch(image_paths)
π Repository Structure
.
βββ vit_best_traced.pt # TorchScript traced weights (best checkpoint)
βββ README.md # This file
βββ class_mapping.json # (Optional) Class index to name mapping
π§ Advanced: Custom Training
Focal Loss Implementation
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, reduction='mean'):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
return focal_loss.sum() if self.reduction == 'sum' else focal_loss
MixUp Implementation
def mixup_data(x, y, alpha=0.2):
lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
batch_size = x.size(0)
index = torch.randperm(batch_size).to(x.device)
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def mixup_criterion(criterion, pred, y_a, y_b, lam):
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
β οΈ Limitations & Responsible Use
βοΈ Medical Disclaimer
This model is a research artifact and is NOT a regulated medical device. It should NOT be used for clinical diagnosis without proper validation and regulatory approval.
Known Limitations
- Trained on Hyper-Kvasir dataset; may not generalize to other endoscopy equipment or populations
- Best performance requires 384Γ384 input resolution
- TTA improves accuracy but increases inference time 3Γ
Recommended Use
- β Research and educational purposes
- β Preliminary screening with human oversight
- β Benchmark for GI image classification
- β Standalone clinical diagnosis
- β Life-critical medical decisions
π Citation
If you use this model in your research, please cite:
@misc{vit_gi_endoscopy_2025,
author = {Ayan Ahmed Khan},
title = {ViT Base Patch16 384 for GI Endoscopy Classification},
year = {2025},
publisher = {Hugging Face},
url = {https://huggingface.co/ayanahmedkhan/VIT-gi-endoscopy-classifier}
}
Related Work
π Changelog
| Date | Version | Changes |
|---|---|---|
| 2025-12-29 | 1.0.0 | Initial release with traced weights and full documentation |
π¬ Contact
- Author: Ayan Ahmed Khan
- Hugging Face: ayanahmedkhan
Made with β€οΈ for Medical AI Research
- Downloads last month
- -