Spaces:
Running
on
Zero
Running
on
Zero
remove mask size filter
Browse files
app.py
CHANGED
|
@@ -287,7 +287,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
| 287 |
right = F.normalize(right, dim=-1)
|
| 288 |
heatmap = left @ right.T
|
| 289 |
heatmap = F.normalize(heatmap, dim=-1)
|
| 290 |
-
num_samples =
|
| 291 |
if num_samples > fps_idx.shape[0]:
|
| 292 |
num_samples = fps_idx.shape[0]
|
| 293 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
|
@@ -305,6 +305,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
| 305 |
fps_heatmaps = {}
|
| 306 |
sort_values = []
|
| 307 |
top3_image_idx = {}
|
|
|
|
| 308 |
for _, idx in enumerate(fps_idx):
|
| 309 |
heatmap = F.cosine_similarity(eigvecs, eigvecs[idx][None], dim=-1)
|
| 310 |
|
|
@@ -314,7 +315,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
| 314 |
# tensor = tensor[torch.randperm(tensor.shape[0])[:max_size]]
|
| 315 |
# return tensor.quantile(p)
|
| 316 |
# top_p = top_percentile(heatmap, p=0.5)
|
| 317 |
-
top_p = 0.
|
| 318 |
|
| 319 |
heatmap = heatmap.reshape(-1, h, w)
|
| 320 |
mask = (heatmap > top_p).float()
|
|
@@ -324,8 +325,9 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
| 324 |
mask = mask[mask_sort_idx[:3]]
|
| 325 |
sort_values.append(mask.mean().item())
|
| 326 |
# fps_heatmaps[idx.item()] = heatmap.cpu()
|
| 327 |
-
fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:
|
| 328 |
top3_image_idx[idx.item()] = mask_sort_idx[:3]
|
|
|
|
| 329 |
# do the sorting
|
| 330 |
_sort_idx = torch.tensor(sort_values).argsort(descending=True)
|
| 331 |
fps_idx = fps_idx[_sort_idx]
|
|
@@ -342,13 +344,17 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
| 342 |
# shuffle the fps_idx
|
| 343 |
fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
|
| 344 |
|
|
|
|
| 345 |
fig_images = []
|
| 346 |
i_cluster = 0
|
| 347 |
num_plots = 10 if not advanced else 20
|
| 348 |
plot_step_float = (1.0 - progess_start) / num_plots
|
| 349 |
for i_fig in range(num_plots):
|
| 350 |
progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
| 352 |
for ax in axs.flatten():
|
| 353 |
ax.axis("off")
|
| 354 |
for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
|
|
@@ -358,7 +364,8 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
| 358 |
size = (images.shape[1], images.shape[2])
|
| 359 |
heatmap = apply_reds_colormap(heatmap, size)
|
| 360 |
# for i, image_idx in enumerate(sorted_image_idxs[:3]):
|
| 361 |
-
|
|
|
|
| 362 |
# _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[image_idx])
|
| 363 |
_heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
|
| 364 |
axs[i, j].imshow(_heatmap)
|
|
@@ -378,10 +385,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
| 378 |
|
| 379 |
fig_images.append(img)
|
| 380 |
plt.close()
|
| 381 |
-
|
| 382 |
-
# plt.imshow(img)
|
| 383 |
-
# plt.axis("off")
|
| 384 |
-
# plt.show()
|
| 385 |
return fig_images, ret_magnitude
|
| 386 |
|
| 387 |
|
|
@@ -647,26 +651,26 @@ def ncut_run(
|
|
| 647 |
|
| 648 |
def _ncut_run(*args, **kwargs):
|
| 649 |
n_ret = kwargs.pop("n_ret", 1)
|
| 650 |
-
try:
|
| 651 |
-
|
| 652 |
-
|
| 653 |
|
| 654 |
-
|
| 655 |
|
| 656 |
-
|
| 657 |
-
|
| 658 |
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
except Exception as e:
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
|
| 671 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 672 |
@spaces.GPU(duration=30)
|
|
@@ -1415,7 +1419,7 @@ with demo:
|
|
| 1415 |
with gr.Column(scale=5, min_width=200):
|
| 1416 |
output_gallery = make_output_images_section()
|
| 1417 |
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 1418 |
-
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[
|
| 1419 |
[
|
| 1420 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 1421 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
|
|
|
| 287 |
right = F.normalize(right, dim=-1)
|
| 288 |
heatmap = left @ right.T
|
| 289 |
heatmap = F.normalize(heatmap, dim=-1)
|
| 290 |
+
num_samples = 50 if not advanced else 100
|
| 291 |
if num_samples > fps_idx.shape[0]:
|
| 292 |
num_samples = fps_idx.shape[0]
|
| 293 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
|
|
|
| 305 |
fps_heatmaps = {}
|
| 306 |
sort_values = []
|
| 307 |
top3_image_idx = {}
|
| 308 |
+
top10_image_idx = {}
|
| 309 |
for _, idx in enumerate(fps_idx):
|
| 310 |
heatmap = F.cosine_similarity(eigvecs, eigvecs[idx][None], dim=-1)
|
| 311 |
|
|
|
|
| 315 |
# tensor = tensor[torch.randperm(tensor.shape[0])[:max_size]]
|
| 316 |
# return tensor.quantile(p)
|
| 317 |
# top_p = top_percentile(heatmap, p=0.5)
|
| 318 |
+
top_p = 0.8
|
| 319 |
|
| 320 |
heatmap = heatmap.reshape(-1, h, w)
|
| 321 |
mask = (heatmap > top_p).float()
|
|
|
|
| 325 |
mask = mask[mask_sort_idx[:3]]
|
| 326 |
sort_values.append(mask.mean().item())
|
| 327 |
# fps_heatmaps[idx.item()] = heatmap.cpu()
|
| 328 |
+
fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:10]].cpu()
|
| 329 |
top3_image_idx[idx.item()] = mask_sort_idx[:3]
|
| 330 |
+
top10_image_idx[idx.item()] = mask_sort_idx[:10]
|
| 331 |
# do the sorting
|
| 332 |
_sort_idx = torch.tensor(sort_values).argsort(descending=True)
|
| 333 |
fps_idx = fps_idx[_sort_idx]
|
|
|
|
| 344 |
# shuffle the fps_idx
|
| 345 |
fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
|
| 346 |
|
| 347 |
+
|
| 348 |
fig_images = []
|
| 349 |
i_cluster = 0
|
| 350 |
num_plots = 10 if not advanced else 20
|
| 351 |
plot_step_float = (1.0 - progess_start) / num_plots
|
| 352 |
for i_fig in range(num_plots):
|
| 353 |
progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
|
| 354 |
+
if not advanced:
|
| 355 |
+
fig, axs = plt.subplots(3, 5, figsize=(15, 9))
|
| 356 |
+
if advanced:
|
| 357 |
+
fig, axs = plt.subplots(6, 5, figsize=(15, 18))
|
| 358 |
for ax in axs.flatten():
|
| 359 |
ax.axis("off")
|
| 360 |
for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
|
|
|
|
| 364 |
size = (images.shape[1], images.shape[2])
|
| 365 |
heatmap = apply_reds_colormap(heatmap, size)
|
| 366 |
# for i, image_idx in enumerate(sorted_image_idxs[:3]):
|
| 367 |
+
image_idxs = top3_image_idx[idx.item()] if not advanced else top10_image_idx[idx.item()]
|
| 368 |
+
for i, image_idx in enumerate(image_idxs):
|
| 369 |
# _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[image_idx])
|
| 370 |
_heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
|
| 371 |
axs[i, j].imshow(_heatmap)
|
|
|
|
| 385 |
|
| 386 |
fig_images.append(img)
|
| 387 |
plt.close()
|
| 388 |
+
|
|
|
|
|
|
|
|
|
|
| 389 |
return fig_images, ret_magnitude
|
| 390 |
|
| 391 |
|
|
|
|
| 651 |
|
| 652 |
def _ncut_run(*args, **kwargs):
|
| 653 |
n_ret = kwargs.pop("n_ret", 1)
|
| 654 |
+
# try:
|
| 655 |
+
# if torch.cuda.is_available():
|
| 656 |
+
# torch.cuda.empty_cache()
|
| 657 |
|
| 658 |
+
# ret = ncut_run(*args, **kwargs)
|
| 659 |
|
| 660 |
+
# if torch.cuda.is_available():
|
| 661 |
+
# torch.cuda.empty_cache()
|
| 662 |
|
| 663 |
+
# ret = list(ret)[:n_ret] + [ret[-1]]
|
| 664 |
+
# return ret
|
| 665 |
+
# except Exception as e:
|
| 666 |
+
# gr.Error(str(e))
|
| 667 |
+
# if torch.cuda.is_available():
|
| 668 |
+
# torch.cuda.empty_cache()
|
| 669 |
+
# return *(None for _ in range(n_ret)), "Error: " + str(e)
|
| 670 |
+
|
| 671 |
+
ret = ncut_run(*args, **kwargs)
|
| 672 |
+
ret = list(ret)[:n_ret] + [ret[-1]]
|
| 673 |
+
return ret
|
| 674 |
|
| 675 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 676 |
@spaces.GPU(duration=30)
|
|
|
|
| 1419 |
with gr.Column(scale=5, min_width=200):
|
| 1420 |
output_gallery = make_output_images_section()
|
| 1421 |
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 1422 |
+
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
|
| 1423 |
[
|
| 1424 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 1425 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|