File size: 2,341 Bytes
2c04e7e
 
 
 
 
 
 
 
 
 
 
c59c979
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633d29a
c59c979
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
---
license: apache-2.0
datasets:
- marmal88/skin_cancer
base_model:
- google/vit-base-patch16-224-in21k
pipeline_tag: image-classification
tags:
- medical
---

## Installation

First, clone the repository:

```bash
git clone https://github.com/ethicalabs-ai/SkinCancerViT.git
cd SkinCancerViT
```

Then, install the package in editable mode using uv (or pip):

```bash
uv sync   # Recommended if you use uv
# Or, if using pip:
# pip install -e .
```

## Quick Start / Usage

This package allows you to load and use a pre-trained SkinCancerViT model for prediction.

```python
import torch
from skincancer_vit.model import SkinCancerViTModel
from PIL import Image
from datasets import load_dataset   # To get a random sample

# Load the model from Hugging Face Hub
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SkinCancerViTModel.from_pretrained("ethicalabs/SkinCancerViT")
model.to(device)   # Move model to the desired device
model.eval()   # Set model to evaluation mode

# Example Prediction from a Specific Image File
image_file_path = "images/patient-001.jpg"   # Specify your image file path here
specific_image = Image.open(image_file_path).convert("RGB")

# Example tabular data for this prediction
specific_age = 42
specific_localization = "face"   # Ensure this matches one of your trained localization categories

predicted_dx, confidence = model.full_predict(
    raw_image=specific_image,
    raw_age=specific_age,
    raw_localization=specific_localization,
    device=device
)

print(f"Predicted Diagnosis: {predicted_dx}")
print(f"Confidence: {confidence:.4f}")

# Example Prediction from a Random Test Sample from the Dataset
dataset = load_dataset("marmal88/skin_cancer", split="test")
random_sample = dataset.shuffle(seed=42).select(range(1))[0] # Get the first shuffled sample

sample_image = random_sample["image"]
sample_age = random_sample["age"]
sample_localization = random_sample["localization"]
sample_true_dx = random_sample["dx"]

predicted_dx_sample, confidence_sample = model.full_predict(
    raw_image=sample_image,
    raw_age=sample_age,
    raw_localization=sample_localization,
    device=device
)

print(f"Predicted Diagnosis: {predicted_dx_sample}")
print(f"Confidence: {confidence_sample:.4f}")
print(f"Correct Prediction: {predicted_dx_sample == sample_true_dx}")
```