broadwell commited on
Commit
8aa2d09
·
verified ·
1 Parent(s): d4b83d9

Updates to ViT CAM viz, add ResNet CAM viz

Browse files
CLIP_Explainability/rn_cam.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+ import cv2
6
+ import re
7
+
8
+ from .image_utils import show_cam_on_image, show_overlapped_cam
9
+
10
+
11
+ def rn_relevance(
12
+ image,
13
+ target_features,
14
+ img_encoder,
15
+ method,
16
+ device,
17
+ neg_saliency=False,
18
+ img_dim=224,
19
+ ):
20
+ target_layers = [img_encoder.layer4[-1]]
21
+
22
+ cam = method(
23
+ model=img_encoder,
24
+ target_layers=target_layers,
25
+ use_cuda=torch.cuda.is_available(),
26
+ )
27
+
28
+ if neg_saliency:
29
+ target_encoding = -target_features
30
+ else:
31
+ target_encoding = target_features
32
+
33
+ image_relevance = cam(input_tensor=image, target_encoding=target_encoding)[
34
+ 0
35
+ ].squeeze()
36
+ image_relevance = torch.FloatTensor(image_relevance)
37
+
38
+ resize_dim = int(list(image_relevance.shape)[0])
39
+
40
+ image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim)
41
+
42
+ # image_relevance = image_relevance.reshape(1, 1, 7, 7)
43
+ image_relevance = torch.nn.functional.interpolate(
44
+ image_relevance, size=img_dim, mode="bilinear"
45
+ )
46
+ image_relevance = image_relevance.reshape(img_dim, img_dim).data.cpu().numpy()
47
+ image_relevance = (image_relevance - image_relevance.min()) / (
48
+ 1e-7 + image_relevance.max() - image_relevance.min()
49
+ )
50
+ image = image[0].permute(1, 2, 0).data.cpu().numpy()
51
+ image = (image - image.min()) / (image.max() - image.min())
52
+
53
+ return image_relevance, image
54
+
55
+
56
+ def interpret_rn(
57
+ image,
58
+ target_features,
59
+ img_encoder,
60
+ method,
61
+ device,
62
+ neg_saliency=False,
63
+ img_dim=224,
64
+ ):
65
+ image_relevance, image = rn_relevance(
66
+ image,
67
+ target_features,
68
+ img_encoder,
69
+ method,
70
+ device,
71
+ neg_saliency=neg_saliency,
72
+ img_dim=img_dim,
73
+ )
74
+ vis = show_cam_on_image(image, image_relevance, neg_saliency=neg_saliency)
75
+ vis = np.uint8(255 * vis)
76
+ vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
77
+
78
+ return vis
79
+ # plt.imshow(vis)
80
+
81
+
82
+ def interpret_rn_overlapped(
83
+ image, target_features, img_encoder, method, device, img_dim=224
84
+ ):
85
+ pos_image_relevance, _ = rn_relevance(
86
+ image,
87
+ target_features,
88
+ img_encoder,
89
+ method,
90
+ device,
91
+ neg_saliency=False,
92
+ img_dim=img_dim,
93
+ )
94
+ neg_image_relevance, image = rn_relevance(
95
+ image,
96
+ target_features,
97
+ img_encoder,
98
+ method,
99
+ device,
100
+ neg_saliency=True,
101
+ img_dim=img_dim,
102
+ )
103
+
104
+ vis = show_overlapped_cam(image, neg_image_relevance, pos_image_relevance)
105
+ vis = np.uint8(255 * vis)
106
+ vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
107
+
108
+ return vis
109
+ # plt.imshow(vis)
110
+
111
+
112
+ def rn_perword_relevance(
113
+ image,
114
+ text,
115
+ clip_model,
116
+ clip_tokenizer,
117
+ method,
118
+ device,
119
+ masked_word="",
120
+ data_only=False,
121
+ img_dim=224,
122
+ ):
123
+ clip_model.eval()
124
+
125
+ main_text = clip_tokenizer(text).to(device)
126
+ # remove the word for which you want to visualize the saliency
127
+ masked_text = re.sub(masked_word, "", text)
128
+ masked_text = clip_tokenizer(masked_text).to(device)
129
+
130
+ image_features = clip_model.encode_image(image)
131
+ main_text_features = clip_model.encode_text(main_text)
132
+ masked_text_features = clip_model.encode_text(masked_text)
133
+
134
+ image_features_norm = image_features.norm(dim=-1, keepdim=True)
135
+ image_features_new = image_features / image_features_norm
136
+ main_text_features_norm = main_text_features.norm(dim=-1, keepdim=True)
137
+ main_text_features_new = main_text_features / main_text_features_norm
138
+
139
+ masked_text_features_norm = masked_text_features.norm(dim=-1, keepdim=True)
140
+ masked_text_features_new = masked_text_features / masked_text_features_norm
141
+
142
+ target_encoding = main_text_features_new - masked_text_features_new
143
+
144
+ target_layers = [clip_model.visual.layer4[-1]]
145
+
146
+ cam = method(
147
+ model=clip_model.visual,
148
+ target_layers=target_layers,
149
+ use_cuda=torch.cuda.is_available(),
150
+ )
151
+
152
+ image_features = clip_model.visual(image)
153
+
154
+ image_relevance = cam(input_tensor=image, target_encoding=target_encoding)[
155
+ 0
156
+ ].squeeze()
157
+ image_relevance = torch.FloatTensor(image_relevance)
158
+
159
+ resize_dim = int(list(image_relevance.shape)[0])
160
+
161
+ image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim)
162
+
163
+ # image_relevance = image_relevance.reshape(1, 1, 7, 7)
164
+ image_relevance = torch.nn.functional.interpolate(
165
+ image_relevance, size=img_dim, mode="bilinear"
166
+ )
167
+ image_relevance = image_relevance.reshape(img_dim, img_dim).data.cpu().numpy()
168
+ image_relevance = (image_relevance - image_relevance.min()) / (
169
+ 1e-7 + image_relevance.max() - image_relevance.min()
170
+ )
171
+
172
+ if data_only:
173
+ return image_relevance
174
+
175
+ image = image[0].permute(1, 2, 0).data.cpu().numpy()
176
+ image = (image - image.min()) / (image.max() - image.min())
177
+
178
+ return image_relevance, image
179
+
180
+
181
+ def interpret_perword_rn(
182
+ image,
183
+ text,
184
+ clip_model,
185
+ clip_tokenizer,
186
+ method,
187
+ device,
188
+ masked_word="",
189
+ data_only=False,
190
+ img_dim=224,
191
+ ):
192
+ image_relevance, image = rn_perword_relevance(
193
+ image,
194
+ text,
195
+ clip_model,
196
+ clip_tokenizer,
197
+ method,
198
+ device,
199
+ masked_word,
200
+ data_only=data_only,
201
+ img_dim=img_dim,
202
+ )
203
+ vis = show_cam_on_image(image, image_relevance)
204
+ vis = np.uint8(255 * vis)
205
+ vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
206
+
207
+ return vis
208
+ # plt.imshow(vis)
CLIP_Explainability/vit_cam.py CHANGED
@@ -210,7 +210,8 @@ def interpret_vit_overlapped(
210
  vis = np.uint8(255 * vis)
211
  vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
212
 
213
- plt.imshow(vis)
 
214
 
215
 
216
  def vit_perword_relevance(
@@ -322,4 +323,5 @@ def interpret_perword_vit(
322
  vis = np.uint8(255 * vis)
323
  vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
324
 
325
- plt.imshow(vis)
 
 
210
  vis = np.uint8(255 * vis)
211
  vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
212
 
213
+ return vis
214
+ # plt.imshow(vis)
215
 
216
 
217
  def vit_perword_relevance(
 
323
  vis = np.uint8(255 * vis)
324
  vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
325
 
326
+ return vis
327
+ # plt.imshow(vis)