ankz22 commited on
Commit
d542ab3
·
1 Parent(s): c6200dd
Files changed (1) hide show
  1. app.py +6 -23
app.py CHANGED
@@ -5,18 +5,13 @@ import torch
5
  from torchvision import transforms
6
  import pandas as pd
7
 
8
- # -------------------------------------
9
- # 1. DATA AUGMENTATION POUR ROBUSTESSE
10
- # -------------------------------------
11
  augment = transforms.Compose([
12
  transforms.RandomHorizontalFlip(p=0.5),
13
  transforms.RandomRotation(10),
14
  transforms.ColorJitter(brightness=0.2, contrast=0.2),
15
  ])
16
 
17
- # -------------------------------------
18
- # 2. CHARGEMENT DU MODELE FINE-TUNED TRASHNET
19
- # -------------------------------------
20
  MODEL_ID = "tribber93/my-trash-classification"
21
  trash_classifier = pipeline(
22
  "image-classification",
@@ -26,9 +21,7 @@ trash_classifier = pipeline(
26
  )
27
 
28
 
29
- # -------------------------------------
30
- # 3. DICTIONNAIRE DE MAPPING
31
- # -------------------------------------
32
  POUBELLES = {
33
  "cardboard": "papier/carton",
34
  "glass": "verre",
@@ -38,17 +31,11 @@ POUBELLES = {
38
  "trash": "ordures ménagères",
39
  }
40
 
41
- # -------------------------------------
42
- # 4. FONCTION DE CLASSIFICATION
43
- # -------------------------------------
44
  def classify_image(image: Image.Image):
45
- # Appliquer data augmentation pour robustesse
46
  image_aug = augment(image)
47
-
48
- # Prédiction via pipeline
49
  results = trash_classifier(image_aug)
50
 
51
- # Préparer les DataFrame rows
52
  rows = []
53
  for r in results:
54
  label = r["label"]
@@ -59,13 +46,9 @@ def classify_image(image: Image.Image):
59
  "Poubelle": poubelle,
60
  "Confiance (%)": round(score * 100, 2)
61
  })
62
-
63
- # Renvoie un tableau pandas pour Gradio
64
  return pd.DataFrame(rows)
65
 
66
- # -------------------------------------
67
- # 5. LANCEMENT DE L'INTERFACE GRADIO
68
- # -------------------------------------
69
  interface = gr.Interface(
70
  fn=classify_image,
71
  inputs=gr.Image(type="pil"),
@@ -73,9 +56,9 @@ interface = gr.Interface(
73
  headers=["Objet", "Poubelle", "Confiance (%)"],
74
  row_count=(1, 10)
75
  ),
76
- title="🗑️ Classifieur de Déchets Amélioré",
77
  description=(
78
- "Dépose une image de déchet pour savoir dans quelle poubelle la trier. "
79
  "Le modèle est fine-tuné sur TrashNet et bénéficie de data augmentation pour une meilleure robustesse."
80
  ),
81
  examples=None,
 
5
  from torchvision import transforms
6
  import pandas as pd
7
 
8
+ # DATA AUGMENTATION
 
 
9
  augment = transforms.Compose([
10
  transforms.RandomHorizontalFlip(p=0.5),
11
  transforms.RandomRotation(10),
12
  transforms.ColorJitter(brightness=0.2, contrast=0.2),
13
  ])
14
 
 
 
 
15
  MODEL_ID = "tribber93/my-trash-classification"
16
  trash_classifier = pipeline(
17
  "image-classification",
 
21
  )
22
 
23
 
24
+ # MAPPING
 
 
25
  POUBELLES = {
26
  "cardboard": "papier/carton",
27
  "glass": "verre",
 
31
  "trash": "ordures ménagères",
32
  }
33
 
34
+ #CLASSIFICATION
 
 
35
  def classify_image(image: Image.Image):
 
36
  image_aug = augment(image)
 
 
37
  results = trash_classifier(image_aug)
38
 
 
39
  rows = []
40
  for r in results:
41
  label = r["label"]
 
46
  "Poubelle": poubelle,
47
  "Confiance (%)": round(score * 100, 2)
48
  })
 
 
49
  return pd.DataFrame(rows)
50
 
51
+ #GRADIO
 
 
52
  interface = gr.Interface(
53
  fn=classify_image,
54
  inputs=gr.Image(type="pil"),
 
56
  headers=["Objet", "Poubelle", "Confiance (%)"],
57
  row_count=(1, 10)
58
  ),
59
+ title="🗑️ Classifieur de Déchets ",
60
  description=(
61
+ "Dépose une image de déchet pour savoir dans quelle poubelle la trier !! "
62
  "Le modèle est fine-tuné sur TrashNet et bénéficie de data augmentation pour une meilleure robustesse."
63
  ),
64
  examples=None,