Spaces:
Running
on
Zero
Running
on
Zero
update models
Browse files- app.py +1 -1
- backbone.py +102 -21
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# Author: Huzheng Yang
|
| 2 |
# %%
|
| 3 |
-
USE_SPACES =
|
| 4 |
|
| 5 |
if USE_SPACES:
|
| 6 |
import spaces
|
|
|
|
| 1 |
# Author: Huzheng Yang
|
| 2 |
# %%
|
| 3 |
+
USE_SPACES = True
|
| 4 |
|
| 5 |
if USE_SPACES:
|
| 6 |
import spaces
|
backbone.py
CHANGED
|
@@ -2,6 +2,7 @@ from typing import Optional, Tuple
|
|
| 2 |
from einops import rearrange
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
|
|
|
| 5 |
from PIL import Image
|
| 6 |
from torch import nn
|
| 7 |
import numpy as np
|
|
@@ -13,18 +14,16 @@ import gradio as gr
|
|
| 13 |
MODEL_DICT = {}
|
| 14 |
|
| 15 |
|
| 16 |
-
def
|
| 17 |
-
|
| 18 |
# Convert to torch tensor
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
# Normalize
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
return images
|
| 27 |
-
|
| 28 |
|
| 29 |
class MobileSAM(nn.Module):
|
| 30 |
def __init__(self, **kwargs):
|
|
@@ -283,7 +282,6 @@ class DiNOv2(torch.nn.Module):
|
|
| 283 |
|
| 284 |
MODEL_DICT["DiNO(dinov2_vitb14_reg)"] = DiNOv2()
|
| 285 |
|
| 286 |
-
|
| 287 |
class CLIP(torch.nn.Module):
|
| 288 |
def __init__(self):
|
| 289 |
super().__init__()
|
|
@@ -291,6 +289,18 @@ class CLIP(torch.nn.Module):
|
|
| 291 |
from transformers import CLIPProcessor, CLIPModel
|
| 292 |
|
| 293 |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
|
| 295 |
self.model = model.eval()
|
| 296 |
|
|
@@ -360,26 +370,90 @@ class CLIP(torch.nn.Module):
|
|
| 360 |
MODEL_DICT["CLIP(openai/clip-vit-base-patch16)"] = CLIP()
|
| 361 |
|
| 362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
def extract_features(images, model_name, node_type, layer):
|
|
|
|
|
|
|
|
|
|
| 364 |
resolution_dict = {
|
| 365 |
-
"
|
| 366 |
-
"SAM(sam_vit_b)": (1024, 1024),
|
| 367 |
-
"DiNO(dinov2_vitb14_reg)": (448, 448),
|
| 368 |
-
"CLIP(openai/clip-vit-base-patch16)": (224, 224),
|
| 369 |
}
|
| 370 |
-
|
|
|
|
| 371 |
|
| 372 |
model = MODEL_DICT[model_name]
|
| 373 |
|
| 374 |
-
use_cuda = torch.cuda.is_available()
|
| 375 |
if use_cuda:
|
| 376 |
model = model.cuda()
|
| 377 |
|
| 378 |
outputs = []
|
| 379 |
-
for i in range(images
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
inp = inp.cuda()
|
| 383 |
attn_output, mlp_output, block_output = model(inp)
|
| 384 |
out_dict = {
|
| 385 |
"attn": attn_output,
|
|
@@ -392,3 +466,10 @@ def extract_features(images, model_name, node_type, layer):
|
|
| 392 |
outputs = torch.cat(outputs, dim=0)
|
| 393 |
|
| 394 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from einops import rearrange
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
| 5 |
+
import timm
|
| 6 |
from PIL import Image
|
| 7 |
from torch import nn
|
| 8 |
import numpy as np
|
|
|
|
| 14 |
MODEL_DICT = {}
|
| 15 |
|
| 16 |
|
| 17 |
+
def transform_image(image, resolution=(1024, 1024), use_cuda=False):
|
| 18 |
+
image = image.convert('RGB').resize(resolution, Image.Resampling.NEAREST)
|
| 19 |
# Convert to torch tensor
|
| 20 |
+
image = torch.tensor(np.array(image).transpose(2, 0, 1)).float()
|
| 21 |
+
if use_cuda:
|
| 22 |
+
image = image.cuda()
|
| 23 |
+
image = image / 255
|
| 24 |
# Normalize
|
| 25 |
+
image = (image - 0.5) / 0.5
|
| 26 |
+
return image
|
|
|
|
|
|
|
| 27 |
|
| 28 |
class MobileSAM(nn.Module):
|
| 29 |
def __init__(self, **kwargs):
|
|
|
|
| 282 |
|
| 283 |
MODEL_DICT["DiNO(dinov2_vitb14_reg)"] = DiNOv2()
|
| 284 |
|
|
|
|
| 285 |
class CLIP(torch.nn.Module):
|
| 286 |
def __init__(self):
|
| 287 |
super().__init__()
|
|
|
|
| 289 |
from transformers import CLIPProcessor, CLIPModel
|
| 290 |
|
| 291 |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
|
| 292 |
+
|
| 293 |
+
# resample the patch embeddings to 64x64, take 1024x1024 input
|
| 294 |
+
embeddings = model.vision_model.embeddings.position_embedding.weight
|
| 295 |
+
cls_embeddings = embeddings[0]
|
| 296 |
+
patch_embeddings = embeddings[1:] # [14*14, 768]
|
| 297 |
+
patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=14)
|
| 298 |
+
patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(64, 64), mode="bilinear", align_corners=False).squeeze(0)
|
| 299 |
+
patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
|
| 300 |
+
embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
|
| 301 |
+
model.vision_model.embeddings.position_embedding.weight = nn.Parameter(embeddings)
|
| 302 |
+
model.vision_model.embeddings.position_ids = torch.arange(0, 1+64*64)
|
| 303 |
+
|
| 304 |
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
|
| 305 |
self.model = model.eval()
|
| 306 |
|
|
|
|
| 370 |
MODEL_DICT["CLIP(openai/clip-vit-base-patch16)"] = CLIP()
|
| 371 |
|
| 372 |
|
| 373 |
+
class MAE(timm.models.vision_transformer.VisionTransformer):
|
| 374 |
+
def __init__(self, **kwargs):
|
| 375 |
+
super(MAE, self).__init__(**kwargs)
|
| 376 |
+
|
| 377 |
+
sd = torch.hub.load_state_dict_from_url(
|
| 378 |
+
"https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth"
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
checkpoint_model = sd["model"]
|
| 382 |
+
state_dict = self.state_dict()
|
| 383 |
+
for k in ["head.weight", "head.bias"]:
|
| 384 |
+
if (
|
| 385 |
+
k in checkpoint_model
|
| 386 |
+
and checkpoint_model[k].shape != state_dict[k].shape
|
| 387 |
+
):
|
| 388 |
+
print(f"Removing key {k} from pretrained checkpoint")
|
| 389 |
+
del checkpoint_model[k]
|
| 390 |
+
|
| 391 |
+
# load pre-trained model
|
| 392 |
+
msg = self.load_state_dict(checkpoint_model, strict=False)
|
| 393 |
+
print(msg)
|
| 394 |
+
|
| 395 |
+
# resample the patch embeddings to 64x64, take 1024x1024 input
|
| 396 |
+
pos_embed = self.pos_embed[0]
|
| 397 |
+
cls_embeddings = pos_embed[0]
|
| 398 |
+
patch_embeddings = pos_embed[1:] # [14*14, 768]
|
| 399 |
+
patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=14)
|
| 400 |
+
patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(64, 64), mode="bilinear", align_corners=False).squeeze(0)
|
| 401 |
+
patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
|
| 402 |
+
pos_embed = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
|
| 403 |
+
self.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
|
| 404 |
+
self.img_size = (1024, 1024)
|
| 405 |
+
self.patch_embed.img_size = (1024, 1024)
|
| 406 |
+
|
| 407 |
+
self.requires_grad_(False)
|
| 408 |
+
self.eval()
|
| 409 |
+
|
| 410 |
+
def forward(self, x):
|
| 411 |
+
self.saved_attn_node = self.ls1(self.attn(self.norm1(x)))
|
| 412 |
+
x = x + self.saved_attn_node.clone()
|
| 413 |
+
self.saved_mlp_node = self.ls2(self.mlp(self.norm2(x)))
|
| 414 |
+
x = x + self.saved_mlp_node.clone()
|
| 415 |
+
self.saved_block_output = x.clone()
|
| 416 |
+
return x
|
| 417 |
+
|
| 418 |
+
setattr(self.blocks[0].__class__, "forward", forward)
|
| 419 |
+
|
| 420 |
+
def forward(self, x):
|
| 421 |
+
out = super().forward(x)
|
| 422 |
+
def remove_cls_and_reshape(x):
|
| 423 |
+
x = x.clone()
|
| 424 |
+
x = x[:, 1:]
|
| 425 |
+
hw = np.sqrt(x.shape[1]).astype(int)
|
| 426 |
+
x = rearrange(x, "b (h w) c -> b h w c", h=hw)
|
| 427 |
+
return x
|
| 428 |
+
|
| 429 |
+
attn_nodes = [remove_cls_and_reshape(block.saved_attn_node) for block in self.blocks]
|
| 430 |
+
mlp_nodes = [remove_cls_and_reshape(block.saved_mlp_node) for block in self.blocks]
|
| 431 |
+
block_outputs = [remove_cls_and_reshape(block.saved_block_output) for block in self.blocks]
|
| 432 |
+
return attn_nodes, mlp_nodes, block_outputs
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
MODEL_DICT["MAE(vit_base)"] = MAE()
|
| 436 |
+
|
| 437 |
+
|
| 438 |
def extract_features(images, model_name, node_type, layer):
|
| 439 |
+
use_cuda = torch.cuda.is_available()
|
| 440 |
+
|
| 441 |
+
resolution = (1024, 1024)
|
| 442 |
resolution_dict = {
|
| 443 |
+
"DiNO(dinov2_vitb14_reg)": (896, 896),
|
|
|
|
|
|
|
|
|
|
| 444 |
}
|
| 445 |
+
if model_name in resolution_dict:
|
| 446 |
+
resolution = resolution_dict[model_name]
|
| 447 |
|
| 448 |
model = MODEL_DICT[model_name]
|
| 449 |
|
|
|
|
| 450 |
if use_cuda:
|
| 451 |
model = model.cuda()
|
| 452 |
|
| 453 |
outputs = []
|
| 454 |
+
for i in range(len(images)):
|
| 455 |
+
image = transform_image(images[i], resolution=resolution, use_cuda=use_cuda)
|
| 456 |
+
inp = image.unsqueeze(0)
|
|
|
|
| 457 |
attn_output, mlp_output, block_output = model(inp)
|
| 458 |
out_dict = {
|
| 459 |
"attn": attn_output,
|
|
|
|
| 466 |
outputs = torch.cat(outputs, dim=0)
|
| 467 |
|
| 468 |
return outputs
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
if __name__ == '__main__':
|
| 472 |
+
inp = torch.rand(1, 3, 1024, 1024)
|
| 473 |
+
model = MAE()
|
| 474 |
+
out = model(inp)
|
| 475 |
+
print(out[0][0].shape, out[0][1].shape, out[0][2].shape)
|