Spaces:
Running
on
Zero
Running
on
Zero
add playground
Browse files
app.py
CHANGED
|
@@ -993,26 +993,28 @@ def ncut_run(
|
|
| 993 |
|
| 994 |
def _ncut_run(*args, **kwargs):
|
| 995 |
n_ret = kwargs.get("n_ret", 1)
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
|
|
|
|
| 999 |
|
| 1000 |
-
|
| 1001 |
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
|
| 1005 |
-
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
-
|
| 1011 |
-
|
|
|
|
| 1012 |
|
| 1013 |
-
ret = ncut_run(*args, **kwargs)
|
| 1014 |
-
ret = list(ret)[:n_ret] + [ret[-1]]
|
| 1015 |
-
return ret
|
| 1016 |
|
| 1017 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 1018 |
@spaces.GPU(duration=30)
|
|
@@ -3557,8 +3559,8 @@ with demo:
|
|
| 3557 |
else:
|
| 3558 |
right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
|
| 3559 |
right = right[:n_eig]
|
| 3560 |
-
left = F.normalize(left, p=2, dim
|
| 3561 |
-
_right = F.normalize(right, p=2, dim
|
| 3562 |
heatmap = left @ _right.unsqueeze(-1)
|
| 3563 |
heatmap = heatmap.squeeze(-1)
|
| 3564 |
heatmap = 1 - heatmap
|
|
@@ -3707,7 +3709,192 @@ with demo:
|
|
| 3707 |
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, child_idx, overlay_image_checkbox],
|
| 3708 |
outputs=[n_eig, current_idx, parent_plot, current_plot, *child_plots, child_idx],
|
| 3709 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3710 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3711 |
with gr.Tab('📄About'):
|
| 3712 |
with gr.Column():
|
| 3713 |
gr.Markdown("**This demo is for Python package `ncut-pytorch`, please visit the [Documentation](https://ncut-pytorch.readthedocs.io/)**")
|
|
|
|
| 993 |
|
| 994 |
def _ncut_run(*args, **kwargs):
|
| 995 |
n_ret = kwargs.get("n_ret", 1)
|
| 996 |
+
try:
|
| 997 |
+
gr.Info("NCUT Run Started", 2)
|
| 998 |
+
if torch.cuda.is_available():
|
| 999 |
+
torch.cuda.empty_cache()
|
| 1000 |
|
| 1001 |
+
ret = ncut_run(*args, **kwargs)
|
| 1002 |
|
| 1003 |
+
if torch.cuda.is_available():
|
| 1004 |
+
torch.cuda.empty_cache()
|
| 1005 |
|
| 1006 |
+
ret = list(ret)[:n_ret] + [ret[-1]]
|
| 1007 |
+
gr.Info("NCUT Run Finished", 2)
|
| 1008 |
+
return ret
|
| 1009 |
+
except Exception as e:
|
| 1010 |
+
gr.Error(str(e))
|
| 1011 |
+
if torch.cuda.is_available():
|
| 1012 |
+
torch.cuda.empty_cache()
|
| 1013 |
+
return *(None for _ in range(n_ret)), "Error: " + str(e)
|
| 1014 |
|
| 1015 |
+
# ret = ncut_run(*args, **kwargs)
|
| 1016 |
+
# ret = list(ret)[:n_ret] + [ret[-1]]
|
| 1017 |
+
# return ret
|
| 1018 |
|
| 1019 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 1020 |
@spaces.GPU(duration=30)
|
|
|
|
| 3559 |
else:
|
| 3560 |
right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
|
| 3561 |
right = right[:n_eig]
|
| 3562 |
+
left = F.normalize(left, p=2, dim=-1)
|
| 3563 |
+
_right = F.normalize(right, p=2, dim=-1)
|
| 3564 |
heatmap = left @ _right.unsqueeze(-1)
|
| 3565 |
heatmap = heatmap.squeeze(-1)
|
| 3566 |
heatmap = 1 - heatmap
|
|
|
|
| 3709 |
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, child_idx, overlay_image_checkbox],
|
| 3710 |
outputs=[n_eig, current_idx, parent_plot, current_plot, *child_plots, child_idx],
|
| 3711 |
)
|
| 3712 |
+
|
| 3713 |
+
with gr.Tab('PlayGround', visible=True) as test_playground_tab2:
|
| 3714 |
+
eigvecs = gr.State(torch.tensor([]))
|
| 3715 |
+
with gr.Row():
|
| 3716 |
+
with gr.Column(scale=5, min_width=200):
|
| 3717 |
+
gr.Markdown("### Step 1: Load Images")
|
| 3718 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=10)
|
| 3719 |
+
submit_button.visible = False
|
| 3720 |
+
num_images_slider.value = 30
|
| 3721 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
| 3722 |
+
|
| 3723 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
| 3724 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 3725 |
+
|
| 3726 |
+
|
| 3727 |
+
with gr.Column(scale=5, min_width=200):
|
| 3728 |
+
gr.Markdown("### Step 2a: Run Backbone and NCUT")
|
| 3729 |
+
with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
|
| 3730 |
+
[
|
| 3731 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 3732 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 3733 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 3734 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
| 3735 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
| 3736 |
+
] = make_parameters_section(parameter_dropdown=False)
|
| 3737 |
+
num_eig_slider.value = 1024
|
| 3738 |
+
num_eig_slider.visible = False
|
| 3739 |
+
submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary')
|
| 3740 |
+
submit_button.click(
|
| 3741 |
+
partial(run_fn, n_ret=1, only_eigvecs=True),
|
| 3742 |
+
inputs=[
|
| 3743 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
| 3744 |
+
positive_prompt, negative_prompt,
|
| 3745 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
| 3746 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 3747 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 3748 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
|
| 3749 |
+
],
|
| 3750 |
+
outputs=[eigvecs, logging_text],
|
| 3751 |
+
)
|
| 3752 |
+
gr.Markdown("### Step 2b: Pick an Image")
|
| 3753 |
+
from gradio_image_prompter import ImagePrompter
|
| 3754 |
+
with gr.Row():
|
| 3755 |
+
image1_slider = gr.Slider(0, 100, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True)
|
| 3756 |
+
load_one_image_button = gr.Button("🔴 Load Image", elem_id="load_one_image_button", variant='primary')
|
| 3757 |
+
gr.Markdown("### Step 2c: Draw a Point")
|
| 3758 |
+
gr.Markdown("""
|
| 3759 |
+
<h5>
|
| 3760 |
+
🖱️ Left Click: Foreground </br>
|
| 3761 |
+
</h5>
|
| 3762 |
+
""")
|
| 3763 |
+
prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
|
| 3764 |
+
def update_prompt_image(original_images, index):
|
| 3765 |
+
images = original_images
|
| 3766 |
+
if images is None:
|
| 3767 |
+
return
|
| 3768 |
+
total_len = len(images)
|
| 3769 |
+
if total_len == 0:
|
| 3770 |
+
return
|
| 3771 |
+
if index >= total_len:
|
| 3772 |
+
index = total_len - 1
|
| 3773 |
+
|
| 3774 |
+
return ImagePrompter(value={'image': images[index][0], 'points': []}, interactive=True)
|
| 3775 |
+
# return gr.Image(value=images[index][0], elem_id=f"prompt_image{randint}", interactive=True)
|
| 3776 |
+
load_one_image_button.click(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
|
| 3777 |
+
|
| 3778 |
+
child_idx = gr.State([])
|
| 3779 |
+
current_idx = gr.State(None)
|
| 3780 |
+
n_eig = gr.State(64)
|
| 3781 |
+
with gr.Column(scale=5, min_width=200):
|
| 3782 |
+
gr.Markdown("### Step 3: Check groupping")
|
| 3783 |
+
child_distance_slider = gr.Slider(0, 0.5, step=0.001, label="Child Distance", value=0.1, elem_id="child_distance_slider", interactive=True)
|
| 3784 |
+
child_distance_slider.visible = False
|
| 3785 |
+
overlay_image_checkbox = gr.Checkbox(label="Overlay Image", value=True, elem_id="overlay_image_checkbox", interactive=True)
|
| 3786 |
+
n_eig_slider = gr.Slider(0, 1024, step=1, label="Number of Eigenvectors", value=256, elem_id="n_eig_slider", interactive=True)
|
| 3787 |
+
run_button = gr.Button("🔴 RUN", elem_id="run_groupping", variant='primary')
|
| 3788 |
+
current_plot = gr.Gallery(value=None, label="Current", show_label=True, elem_id="current_plot", interactive=False, rows=[1], columns=[2])
|
| 3789 |
+
with gr.Row():
|
| 3790 |
+
doublue_eigs_button = gr.Button("⬇️ +eigvecs", elem_id="doublue_eigs_button", variant='secondary')
|
| 3791 |
+
half_eigs_button = gr.Button("⬆️ -eigvecs", elem_id="half_eigs_button", variant='secondary')
|
| 3792 |
|
| 3793 |
+
def relative_xy(prompts):
|
| 3794 |
+
image = prompts['image']
|
| 3795 |
+
points = np.asarray(prompts['points'])
|
| 3796 |
+
if points.shape[0] == 0:
|
| 3797 |
+
return [], []
|
| 3798 |
+
is_point = points[:, 5] == 4.0
|
| 3799 |
+
points = points[is_point]
|
| 3800 |
+
is_positive = points[:, 2] == 1.0
|
| 3801 |
+
is_negative = points[:, 2] == 0.0
|
| 3802 |
+
xy = points[:, :2].tolist()
|
| 3803 |
+
if isinstance(image, str):
|
| 3804 |
+
image = Image.open(image)
|
| 3805 |
+
image = np.array(image)
|
| 3806 |
+
h, w = image.shape[:2]
|
| 3807 |
+
new_xy = [(x/w, y/h) for x, y in xy]
|
| 3808 |
+
# print(new_xy)
|
| 3809 |
+
return new_xy, is_positive
|
| 3810 |
+
|
| 3811 |
+
def xy_eigvec(prompts, image_idx, eigvecs):
|
| 3812 |
+
eigvec = eigvecs[image_idx]
|
| 3813 |
+
xy, is_positive = relative_xy(prompts)
|
| 3814 |
+
for i, (x, y) in enumerate(xy):
|
| 3815 |
+
if not is_positive[i]:
|
| 3816 |
+
continue
|
| 3817 |
+
x = int(x * eigvec.shape[1])
|
| 3818 |
+
y = int(y * eigvec.shape[0])
|
| 3819 |
+
return eigvec[y, x], (y, x)
|
| 3820 |
+
|
| 3821 |
+
from ncut_pytorch.ncut_pytorch import _transform_heatmap
|
| 3822 |
+
def _run_heatmap_fn(images, eigvecs, prompt_image_idx, prompt_points, n_eig, flat_idx=None, raw_heatmap=False, overlay_image=True):
|
| 3823 |
+
left = eigvecs[..., :n_eig]
|
| 3824 |
+
if flat_idx is not None:
|
| 3825 |
+
right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx]
|
| 3826 |
+
y, x = None, None
|
| 3827 |
+
else:
|
| 3828 |
+
right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
|
| 3829 |
+
right = right[:n_eig]
|
| 3830 |
+
left = F.normalize(left, p=2, dim=-1)
|
| 3831 |
+
_right = F.normalize(right, p=2, dim=-1)
|
| 3832 |
+
heatmap = left @ _right.unsqueeze(-1)
|
| 3833 |
+
heatmap = heatmap.squeeze(-1)
|
| 3834 |
+
# heatmap = 1 - heatmap
|
| 3835 |
+
# heatmap = _transform_heatmap(heatmap)
|
| 3836 |
+
if raw_heatmap:
|
| 3837 |
+
return heatmap
|
| 3838 |
+
# apply hot colormap and covert to PIL image 256x256
|
| 3839 |
+
# gr.Info(f"heatmap vmin: {heatmap.min()}, vmax: {heatmap.max()}, mean: {heatmap.mean()}")
|
| 3840 |
+
heatmap = heatmap.cpu().numpy()
|
| 3841 |
+
hot_map = matplotlib.cm.get_cmap('hot')
|
| 3842 |
+
heatmap = hot_map(heatmap)
|
| 3843 |
+
pil_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
|
| 3844 |
+
if overlay_image:
|
| 3845 |
+
overlaied_images = []
|
| 3846 |
+
for i_image in range(len(images)):
|
| 3847 |
+
rgb_image = images[i_image].resize((256, 256))
|
| 3848 |
+
rgb_image = np.array(rgb_image)
|
| 3849 |
+
heatmap_image = np.array(pil_images[i_image])[..., :3]
|
| 3850 |
+
blend_image = 0.5 * rgb_image + 0.5 * heatmap_image
|
| 3851 |
+
blend_image = Image.fromarray(blend_image.astype(np.uint8))
|
| 3852 |
+
overlaied_images.append(blend_image)
|
| 3853 |
+
pil_images = overlaied_images
|
| 3854 |
+
return pil_images, (y, x)
|
| 3855 |
+
|
| 3856 |
+
@torch.no_grad()
|
| 3857 |
+
def run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
|
| 3858 |
+
gr.Info(f"current number of eigenvectors: {n_eig}", 2)
|
| 3859 |
+
images = [image[0] for image in images]
|
| 3860 |
+
if isinstance(images[0], str):
|
| 3861 |
+
images = [Image.open(image[0]).convert("RGB").resize((256, 256)) for image in images]
|
| 3862 |
+
|
| 3863 |
+
current_heatmap, (y, x) = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, n_eig, flat_idx, overlay_image=overlay_image)
|
| 3864 |
+
|
| 3865 |
+
return current_heatmap
|
| 3866 |
+
|
| 3867 |
+
def doublue_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
|
| 3868 |
+
n_eig = int(n_eig*2)
|
| 3869 |
+
n_eig = min(n_eig, eigvecs.shape[-1])
|
| 3870 |
+
n_eig = max(n_eig, 1)
|
| 3871 |
+
return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx, overlay_image=overlay_image)
|
| 3872 |
+
|
| 3873 |
+
def half_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx=None, overlay_image=True):
|
| 3874 |
+
n_eig = int(n_eig/2)
|
| 3875 |
+
n_eig = min(n_eig, eigvecs.shape[-1])
|
| 3876 |
+
n_eig = max(n_eig, 1)
|
| 3877 |
+
return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx, overlay_image=overlay_image)
|
| 3878 |
+
|
| 3879 |
+
none_placeholder = gr.State(None)
|
| 3880 |
+
run_button.click(
|
| 3881 |
+
run_heatmap,
|
| 3882 |
+
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
|
| 3883 |
+
outputs=[current_plot],
|
| 3884 |
+
)
|
| 3885 |
+
|
| 3886 |
+
doublue_eigs_button.click(
|
| 3887 |
+
doublue_eigs_wrapper,
|
| 3888 |
+
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox],
|
| 3889 |
+
outputs=[n_eig_slider, current_plot],
|
| 3890 |
+
)
|
| 3891 |
+
|
| 3892 |
+
half_eigs_button.click(
|
| 3893 |
+
half_eigs_wrapper,
|
| 3894 |
+
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, current_idx, overlay_image_checkbox],
|
| 3895 |
+
outputs=[n_eig_slider, current_plot],
|
| 3896 |
+
)
|
| 3897 |
+
|
| 3898 |
with gr.Tab('📄About'):
|
| 3899 |
with gr.Column():
|
| 3900 |
gr.Markdown("**This demo is for Python package `ncut-pytorch`, please visit the [Documentation](https://ncut-pytorch.readthedocs.io/)**")
|