geitta commited on
Commit
71a443b
·
1 Parent(s): 2832406

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+
5
+ from model import create_vit_model
6
+ from timeit import default_timer as timer
7
+ from typing import Tuple, Dict
8
+
9
+ # setup class names
10
+ class_names = ['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']
11
+ ### model and transforms preparation
12
+ #create vit model
13
+ vit, vit_transforms = create_vit_model(num_classes=5):
14
+ #load saved weights
15
+ vit.load_state_dict(torch.load(f="pretrained_vit_festure_extractor_flower_classification.pth",
16
+ map_location=torch.device('cpu')))
17
+ ### predict function
18
+ # create predict function
19
+ def predict(img) -> Tuple[Dict, float]:
20
+ """transforms and perfroms a prediction on img and returns prediction and time taken"""
21
+ #start the time
22
+ start_time = time.time()
23
+ #transform the target image and add a batch dim
24
+ img = vit_transforms(img).unsqueeze(0)
25
+ #put the model in eval mode and turn on inference
26
+ vit.eval()
27
+ with torch.inference_mode():
28
+ # pass the transformed imag thru the model and turn the prediction logits into prediction probabilities
29
+ pred_probs = torch.softmax(vit(img), dim =1)
30
+ # create a prediction label and prediction probability
31
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
32
+ # calculate prediction time
33
+ end_time = time.time()
34
+ pred_time = round(end_time - start_time, 5)
35
+ # return the prediction dictionary and prediction time
36
+ return pred_labels_and_probs, pred_time
37
+ ### gradio app
38
+ # create title, description and article strings
39
+ title = "flower classification"
40
+ description = "a vit_16 feature extractor computer vision model to classify images of flowers as: daisy, dandelion, rose, sunflower, and tulip"
41
+ article = "created at [flower classification](githhublink to repository)"
42
+
43
+ # create examples list from examples dir
44
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
45
+
46
+ # create gradio demo
47
+ demo = gr.Interface(fn=predict, #mapping function from input to output
48
+ inputs=gr.Image(type="pil"), # what are the inputs
49
+ outputs=[gr.Label(num_top_classes=5, label="predictions"), #what are the outputs
50
+ gr.Number(label="prediction time (s)")], #our fn has two outputs, therefore we have 2 outputs
51
+ examples = example_list,
52
+ title=title,
53
+ description=description,
54
+ article=article
55
+ )
56
+ #launch the demo
57
+ demo.launch()
examples/Image_194.jpg ADDED
examples/Image_300.jpg ADDED
examples/Image_728.jpg ADDED
model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+
4
+ from torch import nn
5
+ def create_vit_model(num_classes: int = 5):
6
+ # create vit pretrained weights, transforms and model
7
+ weights = torchvision.models.ViT_B_16_Weights.DEFAULT
8
+ transforms = weights.transforms()
9
+ model = torchvision.models.vit_b_16(weights=weights)
10
+ # freeze all layers in base model
11
+ for param in model.parameters():
12
+ param.requires_grad = False
13
+ # change the classifier head to suit our problem
14
+ vit.heads = nn.Sequential(nn.Linear(in_features=768,
15
+ out_features=5,
16
+ bias=True))
17
+ return model, transforms
pretrained_vit_festure_extractor_flower_classification.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d284db415f480342b18b33cb2f7f16c8b54ce588ef0e6b783e7eb47340b0dad0
3
+ size 343276305
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio