Spaces:
Running
on
Zero
Running
on
Zero
update cluster fg bg
Browse files
app.py
CHANGED
|
@@ -308,7 +308,79 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
|
|
| 308 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
| 309 |
return blended.astype(np.uint8)
|
| 310 |
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
progress = gr.Progress()
|
| 313 |
progress(progess_start, desc="Finding Clusters by FPS")
|
| 314 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
@@ -318,10 +390,13 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
| 318 |
|
| 319 |
# gr.Info("Finding Clusters by FPS, no magnitude filtering")
|
| 320 |
top_p_idx = torch.arange(eigvecs.shape[0])
|
|
|
|
|
|
|
| 321 |
# gr.Info("Finding Clusters by FPS, with magnitude filtering")
|
| 322 |
# p = 0.8
|
| 323 |
# top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
|
| 324 |
|
|
|
|
| 325 |
ret_magnitude = magnitude.reshape(-1, h, w)
|
| 326 |
|
| 327 |
|
|
@@ -338,7 +413,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
| 338 |
right = F.normalize(right, dim=-1)
|
| 339 |
heatmap = left @ right.T
|
| 340 |
heatmap = F.normalize(heatmap, dim=-1)
|
| 341 |
-
num_samples =
|
| 342 |
if num_samples > fps_idx.shape[0]:
|
| 343 |
num_samples = fps_idx.shape[0]
|
| 344 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
|
@@ -398,10 +473,10 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
| 398 |
|
| 399 |
fig_images = []
|
| 400 |
i_cluster = 0
|
| 401 |
-
num_plots =
|
| 402 |
plot_step_float = (1.0 - progess_start) / num_plots
|
| 403 |
for i_fig in range(num_plots):
|
| 404 |
-
progress(progess_start + i_fig * plot_step_float, desc="Plotting
|
| 405 |
if not advanced:
|
| 406 |
fig, axs = plt.subplots(3, 5, figsize=(15, 9))
|
| 407 |
if advanced:
|
|
@@ -421,7 +496,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
| 421 |
_heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
|
| 422 |
axs[i, j].imshow(_heatmap)
|
| 423 |
if i == 0:
|
| 424 |
-
axs[i, j].set_title(f"
|
| 425 |
i_cluster += 1
|
| 426 |
plt.tight_layout(h_pad=0.5, w_pad=0.3)
|
| 427 |
|
|
@@ -440,6 +515,39 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
| 440 |
|
| 441 |
return fig_images, ret_magnitude
|
| 442 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
|
| 444 |
def ncut_run(
|
| 445 |
model,
|
|
@@ -601,7 +709,7 @@ def ncut_run(
|
|
| 601 |
if torch.cuda.is_available():
|
| 602 |
images = images.cuda()
|
| 603 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
| 604 |
-
cluster_images, eig_magnitude =
|
| 605 |
logging_str += f"Recursion #{i+1} plot time: {time.time() - start:.2f}s\n"
|
| 606 |
|
| 607 |
norm_images = []
|
|
@@ -716,7 +824,10 @@ def ncut_run(
|
|
| 716 |
images = images.cuda()
|
| 717 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
| 718 |
advanced = kwargs.get("advanced", False)
|
| 719 |
-
|
|
|
|
|
|
|
|
|
|
| 720 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
| 721 |
|
| 722 |
norm_images = None
|
|
@@ -736,33 +847,33 @@ def ncut_run(
|
|
| 736 |
logging_str += "Eigenvector Magnitude\n"
|
| 737 |
logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
|
| 738 |
gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=10)
|
| 739 |
-
|
| 740 |
return to_pil_images(rgb), cluster_images, norm_images, logging_str
|
| 741 |
|
| 742 |
|
| 743 |
|
| 744 |
def _ncut_run(*args, **kwargs):
|
| 745 |
n_ret = kwargs.pop("n_ret", 1)
|
| 746 |
-
try:
|
| 747 |
-
|
| 748 |
-
|
| 749 |
|
| 750 |
-
|
| 751 |
|
| 752 |
-
|
| 753 |
-
|
| 754 |
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
except Exception as e:
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
|
| 767 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 768 |
@spaces.GPU(duration=30)
|
|
@@ -1186,7 +1297,7 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
|
|
| 1186 |
images += [Image.open(new_image) for new_image in new_images]
|
| 1187 |
if isinstance(new_images, str):
|
| 1188 |
images.append(Image.open(new_images))
|
| 1189 |
-
|
| 1190 |
return images
|
| 1191 |
upload_button.upload(convert_to_pil_and_append, inputs=[input_gallery, upload_button], outputs=[input_gallery])
|
| 1192 |
|
|
@@ -1402,6 +1513,7 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
|
|
| 1402 |
if existing_images is None:
|
| 1403 |
existing_images = []
|
| 1404 |
existing_images += new_images
|
|
|
|
| 1405 |
return existing_images
|
| 1406 |
|
| 1407 |
load_images_button.click(load_and_append,
|
|
@@ -1416,165 +1528,6 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
|
|
| 1416 |
|
| 1417 |
|
| 1418 |
|
| 1419 |
-
# def make_input_images_section(rows=1, cols=3, height="auto"):
|
| 1420 |
-
# gr.Markdown('### Input Images')
|
| 1421 |
-
# input_gallery = gr.Gallery(value=None, label="Select images", show_label=True, elem_id="images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False)
|
| 1422 |
-
# submit_button = gr.Button("🔴 RUN", elem_id="submit_button", variant='primary')
|
| 1423 |
-
# clear_images_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop')
|
| 1424 |
-
# return input_gallery, submit_button, clear_images_button
|
| 1425 |
-
|
| 1426 |
-
|
| 1427 |
-
# def make_dataset_images_section(advanced=False, is_random=False):
|
| 1428 |
-
|
| 1429 |
-
# gr.Markdown('### Load Datasets')
|
| 1430 |
-
# load_images_button = gr.Button("🔴 Load Images", elem_id="load-images-button", variant='primary')
|
| 1431 |
-
# advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets", value="Advanced" if advanced else "Basic", elem_id="advanced-radio", show_label=True)
|
| 1432 |
-
# with gr.Column() as basic_block:
|
| 1433 |
-
# example_gallery = gr.Gallery(value=example_items, label="Example Images", show_label=True, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
|
| 1434 |
-
# with gr.Column() as advanced_block:
|
| 1435 |
-
# # dataset_names = DATASET_NAMES
|
| 1436 |
-
# # dataset_classes = DATASET_CLASSES
|
| 1437 |
-
# dataset_categories = list(DATASETS.keys())
|
| 1438 |
-
# defualt_cat = dataset_categories[0]
|
| 1439 |
-
# def get_choices(cat):
|
| 1440 |
-
# return [tup[0] for tup in DATASETS[cat]]
|
| 1441 |
-
# defualt_choices = get_choices(defualt_cat)
|
| 1442 |
-
# with gr.Row():
|
| 1443 |
-
# dataset_radio = gr.Radio(dataset_categories, label="Dataset Category", value=defualt_cat, elem_id="dataset-radio", show_label=True, min_width=600)
|
| 1444 |
-
# # dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="mrm8488/ImageNet1K-val", elem_id="dataset", min_width=300)
|
| 1445 |
-
# dataset_dropdown = gr.Dropdown(defualt_choices, label="Dataset name", value=defualt_choices[0], elem_id="dataset", min_width=400)
|
| 1446 |
-
# dataset_radio.change(fn=lambda x: gr.update(choices=get_choices(x), value=get_choices(x)[0]), inputs=dataset_radio, outputs=dataset_dropdown)
|
| 1447 |
-
# # num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images")
|
| 1448 |
-
# num_images_slider = gr.Slider(1, 1000, step=1, label="Number of images", value=10, elem_id="num_images", min_width=200)
|
| 1449 |
-
# if not is_random:
|
| 1450 |
-
# filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox")
|
| 1451 |
-
# filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=True)
|
| 1452 |
-
# # is_random_checkbox = gr.Checkbox(label="Random shuffle", value=False, elem_id="random_seed_checkbox")
|
| 1453 |
-
# # random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=False)
|
| 1454 |
-
# is_random_checkbox = gr.Checkbox(label="Random shuffle", value=True, elem_id="random_seed_checkbox")
|
| 1455 |
-
# random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=True)
|
| 1456 |
-
# if is_random:
|
| 1457 |
-
# filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox")
|
| 1458 |
-
# filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=False)
|
| 1459 |
-
# is_random_checkbox = gr.Checkbox(label="Random shuffle", value=True, elem_id="random_seed_checkbox")
|
| 1460 |
-
# random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=42, elem_id="random_seed", visible=True)
|
| 1461 |
-
|
| 1462 |
-
|
| 1463 |
-
# if advanced:
|
| 1464 |
-
# advanced_block.visible = True
|
| 1465 |
-
# basic_block.visible = False
|
| 1466 |
-
# else:
|
| 1467 |
-
# advanced_block.visible = False
|
| 1468 |
-
# basic_block.visible = True
|
| 1469 |
-
|
| 1470 |
-
# # change visibility
|
| 1471 |
-
# advanced_radio.change(fn=lambda x: gr.update(visible=x=="Advanced"), inputs=advanced_radio, outputs=[advanced_block])
|
| 1472 |
-
# advanced_radio.change(fn=lambda x: gr.update(visible=x=="Basic"), inputs=advanced_radio, outputs=[basic_block])
|
| 1473 |
-
|
| 1474 |
-
# def find_num_classes(dataset_name):
|
| 1475 |
-
# num_classes = None
|
| 1476 |
-
# for cat, datasets in DATASETS.items():
|
| 1477 |
-
# datasets = [tup[0] for tup in datasets]
|
| 1478 |
-
# if dataset_name in datasets:
|
| 1479 |
-
# num_classes = DATASETS[cat][datasets.index(dataset_name)][1]
|
| 1480 |
-
# break
|
| 1481 |
-
# return num_classes
|
| 1482 |
-
|
| 1483 |
-
# def change_filter_options(dataset_name):
|
| 1484 |
-
# num_classes = find_num_classes(dataset_name)
|
| 1485 |
-
# if num_classes is None:
|
| 1486 |
-
# return (gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox", visible=False),
|
| 1487 |
-
# gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info="e.g. `0,1,2`. This dataset has no class label", visible=False))
|
| 1488 |
-
# return (gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox", visible=True),
|
| 1489 |
-
# gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=True))
|
| 1490 |
-
# dataset_dropdown.change(fn=change_filter_options, inputs=dataset_dropdown, outputs=[filter_by_class_checkbox, filter_by_class_text])
|
| 1491 |
-
|
| 1492 |
-
# def change_filter_by_class(is_filter, dataset_name):
|
| 1493 |
-
# num_classes = find_num_classes(dataset_name)
|
| 1494 |
-
# return gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=is_filter)
|
| 1495 |
-
# filter_by_class_checkbox.change(fn=change_filter_by_class, inputs=[filter_by_class_checkbox, dataset_dropdown], outputs=filter_by_class_text)
|
| 1496 |
-
|
| 1497 |
-
# def change_random_seed(is_random):
|
| 1498 |
-
# return gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=is_random)
|
| 1499 |
-
# is_random_checkbox.change(fn=change_random_seed, inputs=is_random_checkbox, outputs=random_seed_slider)
|
| 1500 |
-
|
| 1501 |
-
|
| 1502 |
-
# def load_dataset_images(is_advanced, dataset_name, num_images=10,
|
| 1503 |
-
# is_filter=True, filter_by_class_text="0,1,2",
|
| 1504 |
-
# is_random=False, seed=1):
|
| 1505 |
-
# progress = gr.Progress()
|
| 1506 |
-
# progress(0, desc="Loading Images")
|
| 1507 |
-
# if is_advanced == "Basic":
|
| 1508 |
-
# gr.Info("Loaded images from Ego-Exo4D")
|
| 1509 |
-
# return default_images
|
| 1510 |
-
# try:
|
| 1511 |
-
# progress(0.5, desc="Downloading Dataset")
|
| 1512 |
-
# dataset = load_dataset(dataset_name, trust_remote_code=True)
|
| 1513 |
-
# key = list(dataset.keys())[0]
|
| 1514 |
-
# dataset = dataset[key]
|
| 1515 |
-
# except Exception as e:
|
| 1516 |
-
# gr.Error(f"Error loading dataset {dataset_name}: {e}")
|
| 1517 |
-
# return None
|
| 1518 |
-
# if num_images > len(dataset):
|
| 1519 |
-
# num_images = len(dataset)
|
| 1520 |
-
|
| 1521 |
-
# if is_filter:
|
| 1522 |
-
# progress(0.8, desc="Filtering Images")
|
| 1523 |
-
# classes = [int(i) for i in filter_by_class_text.split(",")]
|
| 1524 |
-
# labels = np.array(dataset['label'])
|
| 1525 |
-
# unique_labels = np.unique(labels)
|
| 1526 |
-
# valid_classes = [i for i in classes if i in unique_labels]
|
| 1527 |
-
# invalid_classes = [i for i in classes if i not in unique_labels]
|
| 1528 |
-
# if len(invalid_classes) > 0:
|
| 1529 |
-
# gr.Warning(f"Classes {invalid_classes} not found in the dataset.")
|
| 1530 |
-
# if len(valid_classes) == 0:
|
| 1531 |
-
# gr.Error(f"Classes {classes} not found in the dataset.")
|
| 1532 |
-
# return None
|
| 1533 |
-
# # shuffle each class
|
| 1534 |
-
# chunk_size = num_images // len(valid_classes)
|
| 1535 |
-
# image_idx = []
|
| 1536 |
-
# for i in valid_classes:
|
| 1537 |
-
# idx = np.where(labels == i)[0]
|
| 1538 |
-
# if is_random:
|
| 1539 |
-
# idx = np.random.RandomState(seed).choice(idx, chunk_size, replace=False)
|
| 1540 |
-
# else:
|
| 1541 |
-
# idx = idx[:chunk_size]
|
| 1542 |
-
# image_idx.extend(idx.tolist())
|
| 1543 |
-
# if not is_filter:
|
| 1544 |
-
# if is_random:
|
| 1545 |
-
# image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
|
| 1546 |
-
# else:
|
| 1547 |
-
# image_idx = list(range(num_images))
|
| 1548 |
-
# key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0]
|
| 1549 |
-
# images = [dataset[i][key] for i in image_idx]
|
| 1550 |
-
# gr.Info(f"Loaded {len(images)} images from {dataset_name}")
|
| 1551 |
-
# del dataset
|
| 1552 |
-
|
| 1553 |
-
# if dataset_name in CENTER_CROP_DATASETS:
|
| 1554 |
-
# def center_crop_image(img):
|
| 1555 |
-
# # image: PIL image
|
| 1556 |
-
# w, h = img.size
|
| 1557 |
-
# min_hw = min(h, w)
|
| 1558 |
-
# # center crop
|
| 1559 |
-
# left = (w - min_hw) // 2
|
| 1560 |
-
# top = (h - min_hw) // 2
|
| 1561 |
-
# right = left + min_hw
|
| 1562 |
-
# bottom = top + min_hw
|
| 1563 |
-
# img = img.crop((left, top, right, bottom))
|
| 1564 |
-
# return img
|
| 1565 |
-
# images = [center_crop_image(image) for image in images]
|
| 1566 |
-
|
| 1567 |
-
# return images
|
| 1568 |
-
|
| 1569 |
-
# load_images_button.click(load_dataset_images,
|
| 1570 |
-
# inputs=[advanced_radio, dataset_dropdown, num_images_slider,
|
| 1571 |
-
# filter_by_class_checkbox, filter_by_class_text,
|
| 1572 |
-
# is_random_checkbox, random_seed_slider],
|
| 1573 |
-
# outputs=[input_gallery])
|
| 1574 |
-
|
| 1575 |
-
# return dataset_dropdown, num_images_slider, random_seed_slider, load_images_button
|
| 1576 |
-
|
| 1577 |
-
|
| 1578 |
# def random_rotate_rgb_gallery(images):
|
| 1579 |
# if images is None or len(images) == 0:
|
| 1580 |
# gr.Warning("No images selected.")
|
|
@@ -1969,19 +1922,19 @@ with demo:
|
|
| 1969 |
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 1970 |
add_output_images_buttons(l1_gallery)
|
| 1971 |
l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 1972 |
-
l1_cluster_gallery = gr.Gallery(value=[], label="Recursion #1 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=
|
| 1973 |
with gr.Column(scale=5, min_width=200):
|
| 1974 |
gr.Markdown('### Output (Recursion #2)')
|
| 1975 |
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 1976 |
add_output_images_buttons(l2_gallery)
|
| 1977 |
l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 1978 |
-
l2_cluster_gallery = gr.Gallery(value=[], label="Recursion #2 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=
|
| 1979 |
with gr.Column(scale=5, min_width=200):
|
| 1980 |
gr.Markdown('### Output (Recursion #3)')
|
| 1981 |
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 1982 |
add_output_images_buttons(l3_gallery)
|
| 1983 |
l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 1984 |
-
l3_cluster_gallery = gr.Gallery(value=[], label="Recursion #3 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=
|
| 1985 |
|
| 1986 |
with gr.Row():
|
| 1987 |
with gr.Column(scale=5, min_width=200):
|
|
@@ -2352,7 +2305,7 @@ with demo:
|
|
| 2352 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
| 2353 |
add_output_images_buttons(output_gallery)
|
| 2354 |
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 2355 |
-
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height=
|
| 2356 |
[
|
| 2357 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 2358 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
|
|
|
| 308 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
| 309 |
return blended.astype(np.uint8)
|
| 310 |
|
| 311 |
+
|
| 312 |
+
def segment_fg_bg(images):
|
| 313 |
+
|
| 314 |
+
images = F.interpolate(images, (224, 224), mode="bilinear")
|
| 315 |
+
|
| 316 |
+
# model = load_alignedthreemodel()
|
| 317 |
+
model = load_model("CLIP(ViT-B-16/openai)")
|
| 318 |
+
from ncut_pytorch.backbone import resample_position_embeddings
|
| 319 |
+
pos_embed = model.model.visual.positional_embedding
|
| 320 |
+
pos_embed = resample_position_embeddings(pos_embed, 14, 14)
|
| 321 |
+
model.model.visual.positional_embedding = torch.nn.Parameter(pos_embed)
|
| 322 |
+
|
| 323 |
+
batch_size = 4
|
| 324 |
+
chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
|
| 325 |
+
|
| 326 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 327 |
+
model.to(device)
|
| 328 |
+
means = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
| 329 |
+
stds = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
| 330 |
+
|
| 331 |
+
fg_acts, bg_acts = [], []
|
| 332 |
+
for chunk_idx in chunk_idxs:
|
| 333 |
+
with torch.no_grad():
|
| 334 |
+
input_images = images[chunk_idx].to(device)
|
| 335 |
+
# transform the input images
|
| 336 |
+
input_images = (input_images - means) / stds
|
| 337 |
+
# output = model(input_images)[:, 5]
|
| 338 |
+
output = model(input_images)['attn'][6]
|
| 339 |
+
fg_act = output[:, 6, 6].mean(0)
|
| 340 |
+
bg_act = output[:, 0, 0].mean(0)
|
| 341 |
+
fg_acts.append(fg_act)
|
| 342 |
+
bg_acts.append(bg_act)
|
| 343 |
+
fg_act = torch.stack(fg_acts, dim=0).mean(0)
|
| 344 |
+
bg_act = torch.stack(bg_acts, dim=0).mean(0)
|
| 345 |
+
fg_act = F.normalize(fg_act, dim=-1)
|
| 346 |
+
bg_act = F.normalize(bg_act, dim=-1)
|
| 347 |
+
|
| 348 |
+
# ref_image = default_images[0]
|
| 349 |
+
# image = Image.open(ref_image).convert("RGB").resize((224, 224), Image.Resampling.BILINEAR)
|
| 350 |
+
# image = torch.tensor(np.array(image)).permute(2, 0, 1).float().to(device)
|
| 351 |
+
# image = (image / 255.0 - means) / stds
|
| 352 |
+
# output = model(image)['attn'][6][0]
|
| 353 |
+
# # print(output.shape)
|
| 354 |
+
# # bg on the center
|
| 355 |
+
# fg_act = output[5, 5]
|
| 356 |
+
# # bg on the bottom left
|
| 357 |
+
# bg_act = output[0, 0]
|
| 358 |
+
# fg_act = F.normalize(fg_act, dim=-1)
|
| 359 |
+
# bg_act = F.normalize(bg_act, dim=-1)
|
| 360 |
+
|
| 361 |
+
# print(images.mean(), images.std())
|
| 362 |
+
|
| 363 |
+
fg_act, bg_act = fg_act.to(device), bg_act.to(device)
|
| 364 |
+
chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
|
| 365 |
+
heatmap_fgs, heatmap_bgs = [], []
|
| 366 |
+
for chunk_idx in chunk_idxs:
|
| 367 |
+
with torch.no_grad():
|
| 368 |
+
input_images = images[chunk_idx].to(device)
|
| 369 |
+
# transform the input images
|
| 370 |
+
input_images = (input_images - means) / stds
|
| 371 |
+
# output = model(input_images)[:, 5]
|
| 372 |
+
output = model(input_images)['attn'][6]
|
| 373 |
+
output = F.normalize(output, dim=-1)
|
| 374 |
+
heatmap_fg = output @ fg_act[:, None]
|
| 375 |
+
heatmap_bg = output @ bg_act[:, None]
|
| 376 |
+
heatmap_fgs.append(heatmap_fg.cpu())
|
| 377 |
+
heatmap_bgs.append(heatmap_bg.cpu())
|
| 378 |
+
heatmap_fg = torch.cat(heatmap_fgs, dim=0)
|
| 379 |
+
heatmap_bg = torch.cat(heatmap_bgs, dim=0)
|
| 380 |
+
return heatmap_fg, heatmap_bg
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=False, clusters=50, eig_idx=None, title='cluster'):
|
| 384 |
progress = gr.Progress()
|
| 385 |
progress(progess_start, desc="Finding Clusters by FPS")
|
| 386 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
| 390 |
|
| 391 |
# gr.Info("Finding Clusters by FPS, no magnitude filtering")
|
| 392 |
top_p_idx = torch.arange(eigvecs.shape[0])
|
| 393 |
+
if eig_idx is not None:
|
| 394 |
+
top_p_idx = eig_idx
|
| 395 |
# gr.Info("Finding Clusters by FPS, with magnitude filtering")
|
| 396 |
# p = 0.8
|
| 397 |
# top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
|
| 398 |
|
| 399 |
+
|
| 400 |
ret_magnitude = magnitude.reshape(-1, h, w)
|
| 401 |
|
| 402 |
|
|
|
|
| 413 |
right = F.normalize(right, dim=-1)
|
| 414 |
heatmap = left @ right.T
|
| 415 |
heatmap = F.normalize(heatmap, dim=-1)
|
| 416 |
+
num_samples = clusters + 20
|
| 417 |
if num_samples > fps_idx.shape[0]:
|
| 418 |
num_samples = fps_idx.shape[0]
|
| 419 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
|
|
|
| 473 |
|
| 474 |
fig_images = []
|
| 475 |
i_cluster = 0
|
| 476 |
+
num_plots = clusters // 5
|
| 477 |
plot_step_float = (1.0 - progess_start) / num_plots
|
| 478 |
for i_fig in range(num_plots):
|
| 479 |
+
progress(progess_start + i_fig * plot_step_float, desc=f"Plotting {title}")
|
| 480 |
if not advanced:
|
| 481 |
fig, axs = plt.subplots(3, 5, figsize=(15, 9))
|
| 482 |
if advanced:
|
|
|
|
| 496 |
_heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
|
| 497 |
axs[i, j].imshow(_heatmap)
|
| 498 |
if i == 0:
|
| 499 |
+
axs[i, j].set_title(f"{title} {i_cluster+1}", fontsize=24)
|
| 500 |
i_cluster += 1
|
| 501 |
plt.tight_layout(h_pad=0.5, w_pad=0.3)
|
| 502 |
|
|
|
|
| 515 |
|
| 516 |
return fig_images, ret_magnitude
|
| 517 |
|
| 518 |
+
def make_cluster_plot_advanced(eigvecs, images, h=64, w=64):
|
| 519 |
+
heatmap_fg, heatmap_bg = segment_fg_bg(images.clone())
|
| 520 |
+
heatmap_bg = rearrange(heatmap_bg, 'b h w c -> b c h w')
|
| 521 |
+
heatmap_fg = rearrange(heatmap_fg, 'b h w c -> b c h w')
|
| 522 |
+
heatmap_fg = F.interpolate(heatmap_fg, (h, w), mode="bilinear")
|
| 523 |
+
heatmap_bg = F.interpolate(heatmap_bg, (h, w), mode="bilinear")
|
| 524 |
+
heatmap_fg = heatmap_fg.flatten()
|
| 525 |
+
heatmap_bg = heatmap_bg.flatten()
|
| 526 |
+
|
| 527 |
+
fg_minus_bg = heatmap_fg - heatmap_bg
|
| 528 |
+
fg_mask = fg_minus_bg > fg_minus_bg.quantile(0.8)
|
| 529 |
+
bg_mask = fg_minus_bg < fg_minus_bg.quantile(0.2)
|
| 530 |
+
|
| 531 |
+
# fg_mask = heatmap_fg > heatmap_fg.quantile(0.8)
|
| 532 |
+
# bg_mask = heatmap_bg > heatmap_bg.quantile(0.8)
|
| 533 |
+
other_mask = ~(fg_mask | bg_mask)
|
| 534 |
+
|
| 535 |
+
fg_idx = torch.arange(heatmap_fg.shape[0])[fg_mask]
|
| 536 |
+
bg_idx = torch.arange(heatmap_bg.shape[0])[bg_mask]
|
| 537 |
+
other_idx = torch.arange(heatmap_fg.shape[0])[other_mask]
|
| 538 |
+
|
| 539 |
+
fg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=100, eig_idx=fg_idx, title="fg")
|
| 540 |
+
bg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=100, eig_idx=bg_idx, title="bg")
|
| 541 |
+
other_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=100, eig_idx=other_idx, title="other")
|
| 542 |
+
|
| 543 |
+
cluster_images = fg_images + bg_images + other_images
|
| 544 |
+
|
| 545 |
+
magitude = torch.norm(eigvecs, dim=-1)
|
| 546 |
+
magitude = magitude.reshape(-1, h, w)
|
| 547 |
+
|
| 548 |
+
# magitude = fg_minus_bg.reshape(-1, h, w) #TODO
|
| 549 |
+
|
| 550 |
+
return cluster_images, magitude
|
| 551 |
|
| 552 |
def ncut_run(
|
| 553 |
model,
|
|
|
|
| 709 |
if torch.cuda.is_available():
|
| 710 |
images = images.cuda()
|
| 711 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
| 712 |
+
cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w)
|
| 713 |
logging_str += f"Recursion #{i+1} plot time: {time.time() - start:.2f}s\n"
|
| 714 |
|
| 715 |
norm_images = []
|
|
|
|
| 824 |
images = images.cuda()
|
| 825 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
| 826 |
advanced = kwargs.get("advanced", False)
|
| 827 |
+
if advanced:
|
| 828 |
+
cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w)
|
| 829 |
+
else:
|
| 830 |
+
cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start, advanced=False)
|
| 831 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
| 832 |
|
| 833 |
norm_images = None
|
|
|
|
| 847 |
logging_str += "Eigenvector Magnitude\n"
|
| 848 |
logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
|
| 849 |
gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=10)
|
| 850 |
+
|
| 851 |
return to_pil_images(rgb), cluster_images, norm_images, logging_str
|
| 852 |
|
| 853 |
|
| 854 |
|
| 855 |
def _ncut_run(*args, **kwargs):
|
| 856 |
n_ret = kwargs.pop("n_ret", 1)
|
| 857 |
+
# try:
|
| 858 |
+
# if torch.cuda.is_available():
|
| 859 |
+
# torch.cuda.empty_cache()
|
| 860 |
|
| 861 |
+
# ret = ncut_run(*args, **kwargs)
|
| 862 |
|
| 863 |
+
# if torch.cuda.is_available():
|
| 864 |
+
# torch.cuda.empty_cache()
|
| 865 |
|
| 866 |
+
# ret = list(ret)[:n_ret] + [ret[-1]]
|
| 867 |
+
# return ret
|
| 868 |
+
# except Exception as e:
|
| 869 |
+
# gr.Error(str(e))
|
| 870 |
+
# if torch.cuda.is_available():
|
| 871 |
+
# torch.cuda.empty_cache()
|
| 872 |
+
# return *(None for _ in range(n_ret)), "Error: " + str(e)
|
| 873 |
+
|
| 874 |
+
ret = ncut_run(*args, **kwargs)
|
| 875 |
+
ret = list(ret)[:n_ret] + [ret[-1]]
|
| 876 |
+
return ret
|
| 877 |
|
| 878 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 879 |
@spaces.GPU(duration=30)
|
|
|
|
| 1297 |
images += [Image.open(new_image) for new_image in new_images]
|
| 1298 |
if isinstance(new_images, str):
|
| 1299 |
images.append(Image.open(new_images))
|
| 1300 |
+
gr.Info(f"Total images: {len(images)}")
|
| 1301 |
return images
|
| 1302 |
upload_button.upload(convert_to_pil_and_append, inputs=[input_gallery, upload_button], outputs=[input_gallery])
|
| 1303 |
|
|
|
|
| 1513 |
if existing_images is None:
|
| 1514 |
existing_images = []
|
| 1515 |
existing_images += new_images
|
| 1516 |
+
gr.Info(f"Total images: {len(existing_images)}")
|
| 1517 |
return existing_images
|
| 1518 |
|
| 1519 |
load_images_button.click(load_and_append,
|
|
|
|
| 1528 |
|
| 1529 |
|
| 1530 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1531 |
# def random_rotate_rgb_gallery(images):
|
| 1532 |
# if images is None or len(images) == 0:
|
| 1533 |
# gr.Warning("No images selected.")
|
|
|
|
| 1922 |
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 1923 |
add_output_images_buttons(l1_gallery)
|
| 1924 |
l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 1925 |
+
l1_cluster_gallery = gr.Gallery(value=[], label="Recursion #1 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=500, show_share_button=True, preview=True, interactive=False)
|
| 1926 |
with gr.Column(scale=5, min_width=200):
|
| 1927 |
gr.Markdown('### Output (Recursion #2)')
|
| 1928 |
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 1929 |
add_output_images_buttons(l2_gallery)
|
| 1930 |
l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 1931 |
+
l2_cluster_gallery = gr.Gallery(value=[], label="Recursion #2 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=500, show_share_button=True, preview=True, interactive=False)
|
| 1932 |
with gr.Column(scale=5, min_width=200):
|
| 1933 |
gr.Markdown('### Output (Recursion #3)')
|
| 1934 |
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 1935 |
add_output_images_buttons(l3_gallery)
|
| 1936 |
l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 1937 |
+
l3_cluster_gallery = gr.Gallery(value=[], label="Recursion #3 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=500, show_share_button=True, preview=True, interactive=False)
|
| 1938 |
|
| 1939 |
with gr.Row():
|
| 1940 |
with gr.Column(scale=5, min_width=200):
|
|
|
|
| 2305 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
| 2306 |
add_output_images_buttons(output_gallery)
|
| 2307 |
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 2308 |
+
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height=500, show_share_button=True, preview=True, interactive=False)
|
| 2309 |
[
|
| 2310 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 2311 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|