Spaces:
Running
on
Zero
Running
on
Zero
fix align3model load
Browse files
app.py
CHANGED
|
@@ -678,13 +678,14 @@ def plot_one_image_36_grid(original_image, tsne_rgb_images):
|
|
| 678 |
return img
|
| 679 |
|
| 680 |
def load_alignedthreemodel():
|
| 681 |
-
|
| 682 |
-
os.system("git clone https://huggingface.co/huzey/alignedthreeattn >> /dev/null 2>&1")
|
| 683 |
-
# pull
|
| 684 |
-
os.system("git -C alignedthreeattn pull >> /dev/null 2>&1")
|
| 685 |
-
# add to path
|
| 686 |
import sys
|
| 687 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 688 |
|
| 689 |
|
| 690 |
from alignedthreeattn.alignedthreeattn_model import ThreeAttnNodes
|
|
@@ -692,11 +693,6 @@ def load_alignedthreemodel():
|
|
| 692 |
align_weights = torch.load("alignedthreeattn/align_weights.pth")
|
| 693 |
model = ThreeAttnNodes(align_weights)
|
| 694 |
|
| 695 |
-
# url = 'https://huggingface.co/huzey/aligned_model_test/resolve/main/3attn_nodes.pth'
|
| 696 |
-
# save_path = "alignedthreemodel.pth"
|
| 697 |
-
# if not os.path.exists(save_path):
|
| 698 |
-
# os.system(f"wget {url} -O {save_path} -q")
|
| 699 |
-
# model = torch.load(save_path)
|
| 700 |
return model
|
| 701 |
|
| 702 |
promptable_diffusion_models = ["Diffusion(stabilityai/stable-diffusion-2)", "Diffusion(CompVis/stable-diffusion-v1-4)"]
|
|
@@ -1174,7 +1170,7 @@ with demo:
|
|
| 1174 |
with gr.Column(scale=5, min_width=200):
|
| 1175 |
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
| 1176 |
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section()
|
| 1177 |
-
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
| 1178 |
|
| 1179 |
with gr.Column(scale=5, min_width=200):
|
| 1180 |
output_gallery = make_output_images_section()
|
|
@@ -1490,17 +1486,65 @@ with demo:
|
|
| 1490 |
# logging text box
|
| 1491 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
| 1492 |
|
| 1493 |
-
|
| 1494 |
-
|
| 1495 |
-
|
| 1496 |
-
|
| 1497 |
-
|
| 1498 |
-
|
| 1499 |
-
|
| 1500 |
-
|
| 1501 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1502 |
|
| 1503 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1504 |
clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
|
| 1505 |
|
| 1506 |
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
|
@@ -1520,6 +1564,7 @@ with demo:
|
|
| 1520 |
outputs=[output_gallery, logging_text],
|
| 1521 |
)
|
| 1522 |
|
|
|
|
| 1523 |
with gr.Tab('Compare Models'):
|
| 1524 |
def add_one_model(i_model=1):
|
| 1525 |
with gr.Column(scale=5, min_width=200) as col:
|
|
|
|
| 678 |
return img
|
| 679 |
|
| 680 |
def load_alignedthreemodel():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
import sys
|
| 682 |
+
|
| 683 |
+
if "alignedthreeattn" not in sys.path:
|
| 684 |
+
for _ in range(3):
|
| 685 |
+
os.system("git clone https://huggingface.co/huzey/alignedthreeattn >> /dev/null 2>&1")
|
| 686 |
+
os.system("git -C alignedthreeattn pull >> /dev/null 2>&1")
|
| 687 |
+
# add to path
|
| 688 |
+
sys.path.append("alignedthreeattn")
|
| 689 |
|
| 690 |
|
| 691 |
from alignedthreeattn.alignedthreeattn_model import ThreeAttnNodes
|
|
|
|
| 693 |
align_weights = torch.load("alignedthreeattn/align_weights.pth")
|
| 694 |
model = ThreeAttnNodes(align_weights)
|
| 695 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
return model
|
| 697 |
|
| 698 |
promptable_diffusion_models = ["Diffusion(stabilityai/stable-diffusion-2)", "Diffusion(CompVis/stable-diffusion-v1-4)"]
|
|
|
|
| 1170 |
with gr.Column(scale=5, min_width=200):
|
| 1171 |
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
| 1172 |
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section()
|
| 1173 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
| 1174 |
|
| 1175 |
with gr.Column(scale=5, min_width=200):
|
| 1176 |
output_gallery = make_output_images_section()
|
|
|
|
| 1486 |
# logging text box
|
| 1487 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
| 1488 |
|
| 1489 |
+
clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
|
| 1490 |
+
|
| 1491 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
| 1492 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 1493 |
+
|
| 1494 |
+
submit_button.click(
|
| 1495 |
+
run_fn,
|
| 1496 |
+
inputs=[
|
| 1497 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
| 1498 |
+
positive_prompt, negative_prompt,
|
| 1499 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
| 1500 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
| 1501 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 1502 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
| 1503 |
+
],
|
| 1504 |
+
# outputs=galleries + [logging_text],
|
| 1505 |
+
outputs=[output_gallery, logging_text],
|
| 1506 |
+
)
|
| 1507 |
+
|
| 1508 |
+
with gr.Tab('Model Aligned (+Recursive)'):
|
| 1509 |
+
gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
|
| 1510 |
+
gr.Markdown('---')
|
| 1511 |
+
gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
|
| 1512 |
+
gr.Markdown('NCUT is computed on the concatenated graph of all models, layers, and images. Color is **aligned** across all models and layers.')
|
| 1513 |
+
gr.Markdown('')
|
| 1514 |
+
gr.Markdown("To see a good pattern, you will need to load 100~1000 images. 100 images need 10sec for RTX4090. Running out of HuggingFace GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
|
| 1515 |
+
gr.Markdown('---')
|
| 1516 |
+
with gr.Row():
|
| 1517 |
+
with gr.Column(scale=5, min_width=200):
|
| 1518 |
+
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
| 1519 |
|
| 1520 |
+
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True, is_random=True)
|
| 1521 |
+
num_images_slider.value = 100
|
| 1522 |
+
|
| 1523 |
+
|
| 1524 |
+
with gr.Column(scale=5, min_width=200):
|
| 1525 |
+
output_gallery = make_output_images_section()
|
| 1526 |
+
gr.Markdown('### TIP1: use the `full-screen` button, and use `arrow keys` to navigate')
|
| 1527 |
+
gr.Markdown('---')
|
| 1528 |
+
gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)')
|
| 1529 |
+
gr.Markdown('Layer type: attention output (attn), without sum of residual')
|
| 1530 |
+
gr.Markdown('### TIP2: for large image set, please increase the `num_sample` for t-SNE and NCUT')
|
| 1531 |
+
gr.Markdown('---')
|
| 1532 |
+
[
|
| 1533 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 1534 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
| 1535 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 1536 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
| 1537 |
+
sampling_method_dropdown, positive_prompt, negative_prompt
|
| 1538 |
+
] = make_parameters_section()
|
| 1539 |
+
model_dropdown.value = "AlignedThreeModelAttnNodes"
|
| 1540 |
+
model_dropdown.visible = False
|
| 1541 |
+
layer_slider.visible = False
|
| 1542 |
+
node_type_dropdown.visible = False
|
| 1543 |
+
num_sample_ncut_slider.value = 10000
|
| 1544 |
+
num_sample_tsne_slider.value = 1000
|
| 1545 |
+
# logging text box
|
| 1546 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
| 1547 |
+
|
| 1548 |
clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
|
| 1549 |
|
| 1550 |
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
|
|
|
| 1564 |
outputs=[output_gallery, logging_text],
|
| 1565 |
)
|
| 1566 |
|
| 1567 |
+
|
| 1568 |
with gr.Tab('Compare Models'):
|
| 1569 |
def add_one_model(i_model=1):
|
| 1570 |
with gr.Column(scale=5, min_width=200) as col:
|