ankz22 commited on
Commit
9871674
·
1 Parent(s): e7436cd
Files changed (1) hide show
  1. app.py +58 -25
app.py CHANGED
@@ -1,15 +1,33 @@
1
  import gradio as gr
2
- from transformers import AutoImageProcessor, AutoModelForImageClassification
3
- from PIL import Image
4
  import torch
5
- import torch.nn.functional as F
6
  import pandas as pd
7
 
8
- # Chargement du processeur et du modèle
9
- processor = AutoImageProcessor.from_pretrained("tribber93/my-trash-classification")
10
- model = AutoModelForImageClassification.from_pretrained("tribber93/my-trash-classification")
 
 
 
 
 
11
 
12
- # Dictionnaire de correspondance entre les labels et les types de poubelles
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  POUBELLES = {
14
  "cardboard": "papier/carton",
15
  "glass": "verre",
@@ -19,34 +37,49 @@ POUBELLES = {
19
  "trash": "ordures ménagères",
20
  }
21
 
22
- # Fonction de classification de l'image
23
- def classify_image(image):
24
- inputs = processor(images=image, return_tensors="pt")
25
- with torch.no_grad():
26
- logits = model(**inputs).logits
27
- probs = F.softmax(logits, dim=-1)
28
 
29
- top_probs, top_idxs = torch.topk(probs, 3)
30
- top_probs = top_probs.squeeze().tolist()
31
- top_idxs = top_idxs.squeeze().tolist()
32
 
 
33
  rows = []
34
- for idx, prob in zip(top_idxs, top_probs):
35
- label = model.config.id2label[idx]
 
36
  poubelle = POUBELLES.get(label.lower(), "inconnue")
37
  rows.append({
38
  "Objet": label,
39
  "Poubelle": poubelle,
40
- "Confiance (%)": round(prob * 100, 2),
41
  })
42
 
 
43
  return pd.DataFrame(rows)
44
 
45
- # Création de l'interface Gradio
46
- gr.Interface(
 
 
47
  fn=classify_image,
48
  inputs=gr.Image(type="pil"),
49
- outputs=gr.Dataframe(),
50
- title="🗑️ Classifieur de Déchets",
51
- description="Dépose une image de déchet pour savoir dans quelle poubelle le trier."
52
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
3
+ from PIL import Image, ImageDraw
4
  import torch
5
+ from torchvision import transforms
6
  import pandas as pd
7
 
8
+ # -------------------------------------
9
+ # 1. DATA AUGMENTATION POUR ROBUSTESSE
10
+ # -------------------------------------
11
+ augment = transforms.Compose([
12
+ transforms.RandomHorizontalFlip(p=0.5),
13
+ transforms.RandomRotation(10),
14
+ transforms.ColorJitter(brightness=0.2, contrast=0.2),
15
+ ])
16
 
17
+ # -------------------------------------
18
+ # 2. CHARGEMENT DU MODELE FINE-TUNED TRASHNET
19
+ # -------------------------------------
20
+ MODEL_ID = "mrm8488/resnet50-finetuned-trashnet"
21
+ trash_classifier = pipeline(
22
+ "image-classification",
23
+ model=MODEL_ID,
24
+ device=0 if torch.cuda.is_available() else -1,
25
+ top_k=3
26
+ )
27
+
28
+ # -------------------------------------
29
+ # 3. DICTIONNAIRE DE MAPPING
30
+ # -------------------------------------
31
  POUBELLES = {
32
  "cardboard": "papier/carton",
33
  "glass": "verre",
 
37
  "trash": "ordures ménagères",
38
  }
39
 
40
+ # -------------------------------------
41
+ # 4. FONCTION DE CLASSIFICATION
42
+ # -------------------------------------
43
+ def classify_image(image: Image.Image):
44
+ # Appliquer data augmentation pour robustesse
45
+ image_aug = augment(image)
46
 
47
+ # Prédiction via pipeline
48
+ results = trash_classifier(image_aug)
 
49
 
50
+ # Préparer les DataFrame rows
51
  rows = []
52
+ for r in results:
53
+ label = r["label"]
54
+ score = r["score"]
55
  poubelle = POUBELLES.get(label.lower(), "inconnue")
56
  rows.append({
57
  "Objet": label,
58
  "Poubelle": poubelle,
59
+ "Confiance (%)": round(score * 100, 2)
60
  })
61
 
62
+ # Renvoie un tableau pandas pour Gradio
63
  return pd.DataFrame(rows)
64
 
65
+ # -------------------------------------
66
+ # 5. LANCEMENT DE L'INTERFACE GRADIO
67
+ # -------------------------------------
68
+ interface = gr.Interface(
69
  fn=classify_image,
70
  inputs=gr.Image(type="pil"),
71
+ outputs=gr.Dataframe(
72
+ headers=["Objet", "Poubelle", "Confiance (%)"],
73
+ row_count=(1, 10)
74
+ ),
75
+ title="🗑️ Classifieur de Déchets Amélioré",
76
+ description=(
77
+ "Dépose une image de déchet pour savoir dans quelle poubelle la trier. "
78
+ "Le modèle est fine-tuné sur TrashNet et bénéficie de data augmentation pour une meilleure robustesse."
79
+ ),
80
+ examples=None,
81
+ allow_flagging="never"
82
+ )
83
+
84
+ if __name__ == "__main__":
85
+ interface.launch()