| | |
| | |
| | |
| |
|
| | import numpy as np |
| | import os, math, gc |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torchvision as vision |
| | import pytorch_lightning as pl |
| | from pytorch_lightning.utilities import rank_zero_info, rank_zero_only |
| | from pytorch_lightning.strategies import DeepSpeedStrategy |
| | import deepspeed |
| | from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam |
| | |
| |
|
| | def __nop(ob): |
| | return ob |
| | MyModule = torch.jit.ScriptModule |
| | |
| | MyFunction = torch.jit.script_method |
| |
|
| | import clip |
| | from transformers import CLIPModel |
| |
|
| | class L2pooling(nn.Module): |
| | def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0): |
| | super(L2pooling, self).__init__() |
| | self.padding = (filter_size - 2) // 2 |
| | self.stride = stride |
| | self.channels = channels |
| | a = np.hanning(filter_size)[1:-1] |
| | g = torch.Tensor(a[:, None] * a[None, :]) |
| | g = g / torch.sum(g) |
| | self.register_buffer( |
| | "filter", g[None, None, :, :].repeat((self.channels, 1, 1, 1)) |
| | ) |
| |
|
| | def forward(self, input): |
| | input = input**2 |
| | out = F.conv2d( |
| | input, |
| | self.filter, |
| | stride=self.stride, |
| | padding=self.padding, |
| | groups=input.shape[1], |
| | ) |
| | return (out + 1e-12).sqrt() |
| |
|
| |
|
| | class DISTS(torch.nn.Module): |
| | def __init__(self, load_weights=True): |
| | super(DISTS, self).__init__() |
| | vgg_pretrained_features = vision.models.vgg16( |
| | weights="VGG16_Weights.IMAGENET1K_V1" |
| | ).features |
| | self.stage1 = torch.nn.Sequential() |
| | self.stage2 = torch.nn.Sequential() |
| | self.stage3 = torch.nn.Sequential() |
| | self.stage4 = torch.nn.Sequential() |
| | self.stage5 = torch.nn.Sequential() |
| | for x in range(0, 4): |
| | self.stage1.add_module(str(x), vgg_pretrained_features[x]) |
| | self.stage2.add_module(str(4), L2pooling(channels=64)) |
| | for x in range(5, 9): |
| | self.stage2.add_module(str(x), vgg_pretrained_features[x]) |
| | self.stage3.add_module(str(9), L2pooling(channels=128)) |
| | for x in range(10, 16): |
| | self.stage3.add_module(str(x), vgg_pretrained_features[x]) |
| | self.stage4.add_module(str(16), L2pooling(channels=256)) |
| | for x in range(17, 23): |
| | self.stage4.add_module(str(x), vgg_pretrained_features[x]) |
| | self.stage5.add_module(str(23), L2pooling(channels=512)) |
| | for x in range(24, 30): |
| | self.stage5.add_module(str(x), vgg_pretrained_features[x]) |
| |
|
| | self.register_buffer( |
| | "mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1) |
| | ) |
| | self.register_buffer( |
| | "std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1) |
| | ) |
| |
|
| | self.chns = [3, 64, 128, 256, 512, 512] |
| | self.register_buffer( |
| | "alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)) |
| | ) |
| | self.register_buffer("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1))) |
| | self.alpha.data.normal_(0.1, 0.01) |
| | self.beta.data.normal_(0.1, 0.01) |
| | weights = torch.load("test/DISTS_weights.pt") |
| | self.alpha.data = weights["alpha"] |
| | self.beta.data = weights["beta"] |
| |
|
| | for param in self.parameters(): |
| | param.requires_grad = False |
| |
|
| | def forward_once(self, x): |
| | h = (x - self.mean) / self.std |
| | h = self.stage1(h) |
| | h_relu1_2 = h |
| | h = self.stage2(h) |
| | h_relu2_2 = h |
| | h = self.stage3(h) |
| | h_relu3_3 = h |
| | h = self.stage4(h) |
| | h_relu4_3 = h |
| | h = self.stage5(h) |
| | h_relu5_3 = h |
| | return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3] |
| |
|
| | def forward(self, x, y, require_grad=False, batch_average=False): |
| | if require_grad: |
| | feats0 = self.forward_once(x) |
| | feats1 = self.forward_once(y) |
| | else: |
| | with torch.no_grad(): |
| | feats0 = self.forward_once(x) |
| | feats1 = self.forward_once(y) |
| | dist1 = 0 |
| | dist2 = 0 |
| | c1 = 1e-6 |
| | c2 = 1e-6 |
| | w_sum = self.alpha.sum() + self.beta.sum() |
| | alpha = torch.split(self.alpha / w_sum, self.chns, dim=1) |
| | beta = torch.split(self.beta / w_sum, self.chns, dim=1) |
| |
|
| | for k in range(len(self.chns)): |
| | x_mean = feats0[k].mean([2, 3], keepdim=True) |
| | y_mean = feats1[k].mean([2, 3], keepdim=True) |
| | S1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1) |
| | dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True) |
| |
|
| | x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True) |
| | y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True) |
| | xy_cov = (feats0[k] * feats1[k]).mean( |
| | [2, 3], keepdim=True |
| | ) - x_mean * y_mean |
| | S2 = (2 * xy_cov + c2) / (x_var + y_var + c2) |
| | dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True) |
| |
|
| | score = 1 - (dist1 + dist2).squeeze() |
| |
|
| | if batch_average: |
| | return score.mean() |
| | else: |
| | return score |
| |
|
| | class ToBinary(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, x): |
| | |
| | |
| | |
| | |
| | |
| | return torch.floor(x + 0.5) |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | return grad_output.clone() |
| |
|
| | |
| |
|
| | class R_ENCODER(MyModule): |
| | def __init__(self, args): |
| | super().__init__() |
| | self.args = args |
| | dd = 8 |
| | self.Bxx = nn.BatchNorm2d(dd*64) |
| |
|
| | self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1) |
| | self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) |
| | self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) |
| |
|
| | self.B00 = nn.BatchNorm2d(dd*4) |
| | self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) |
| | self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) |
| | self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) |
| | self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) |
| |
|
| | self.B10 = nn.BatchNorm2d(dd*16) |
| | self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) |
| | self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) |
| | self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) |
| | self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) |
| |
|
| | self.B20 = nn.BatchNorm2d(dd*64) |
| | self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) |
| | self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) |
| | self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) |
| | self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) |
| | |
| | |
| | |
| | |
| | |
| |
|
| | self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1) |
| |
|
| | @MyFunction |
| | def forward(self, img): |
| | ACT = F.mish |
| |
|
| | x = self.CIN(img) |
| | xx = self.Bxx(F.pixel_unshuffle(x, 8)) |
| | x = x + self.Cx1(ACT(self.Cx0(x))) |
| |
|
| | x = F.pixel_unshuffle(x, 2) |
| | x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) |
| | x = x + self.C03(ACT(self.C02(x))) |
| |
|
| | x = F.pixel_unshuffle(x, 2) |
| | x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) |
| | x = x + self.C13(ACT(self.C12(x))) |
| |
|
| | x = F.pixel_unshuffle(x, 2) |
| | x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) |
| | x = x + self.C23(ACT(self.C22(x))) |
| | |
| | |
| |
|
| | x = self.COUT(x + xx) |
| | return torch.sigmoid(x) |
| |
|
| | |
| |
|
| | class R_DECODER(MyModule): |
| | def __init__(self, args): |
| | super().__init__() |
| | self.args = args |
| | dd = 8 |
| | self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1) |
| |
|
| | self.B00 = nn.BatchNorm2d(dd*64) |
| | self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) |
| | self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) |
| | self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) |
| | self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) |
| | |
| | |
| | |
| | |
| | |
| |
|
| | self.B10 = nn.BatchNorm2d(dd*16) |
| | self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) |
| | self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) |
| | self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) |
| | self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) |
| |
|
| | self.B20 = nn.BatchNorm2d(dd*4) |
| | self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) |
| | self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) |
| | self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) |
| | self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) |
| |
|
| | self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) |
| | self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) |
| | self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1) |
| |
|
| | @MyFunction |
| | def forward(self, code): |
| | ACT = F.mish |
| | x = self.CIN(code) |
| |
|
| | x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) |
| | x = x + self.C03(ACT(self.C02(x))) |
| | |
| | |
| | x = F.pixel_shuffle(x, 2) |
| |
|
| | x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) |
| | x = x + self.C13(ACT(self.C12(x))) |
| | x = F.pixel_shuffle(x, 2) |
| |
|
| | x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) |
| | x = x + self.C23(ACT(self.C22(x))) |
| | x = F.pixel_shuffle(x, 2) |
| |
|
| | x = x + self.Cx1(ACT(self.Cx0(x))) |
| | x = self.COUT(x) |
| | |
| | return torch.sigmoid(x) |
| |
|
| | |
| |
|
| | def cosine_loss(x, y): |
| | x = F.normalize(x, dim=-1) |
| | y = F.normalize(y, dim=-1) |
| | return 1 - torch.einsum('ij,ij->i',[x,y]) |
| |
|
| | class RWKV_IMG(pl.LightningModule): |
| | def __init__(self, args): |
| | super().__init__() |
| | self.args = args |
| | |
| | self.encoder = R_ENCODER(args) |
| | self.decoder = R_DECODER(args) |
| |
|
| | self.clip_model = None |
| | clip_name = args.my_img_clip |
| | if clip_name == 'B32': |
| | clip_name = 'ViT-B/32' |
| | elif clip_name == 'B16': |
| | clip_name = 'ViT-B/16' |
| | elif clip_name == 'L14': |
| | clip_name = 'ViT-L/14' |
| | elif clip_name == 'OB32': |
| | clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" |
| | self.clip_model = CLIPModel.from_pretrained(clip_name) |
| | self.clip_model.encode_image = self.clip_model.get_image_features |
| | if self.clip_model == None: |
| | self.clip_model, _ = clip.load(clip_name, jit = True) |
| | self.register_buffer( |
| | "clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1) |
| | ) |
| | self.register_buffer( |
| | "clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1) |
| | ) |
| |
|
| | for n, p in self.named_parameters(): |
| | if 'clip_model' in n: |
| | p.requires_grad = False |
| |
|
| | self.loss_dists = DISTS() |
| | |
| |
|
| | def configure_optimizers(self): |
| | args = self.args |
| | optim_groups = [ |
| | {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, |
| | ] |
| | if self.deepspeed_offload: |
| | return DeepSpeedCPUAdam( |
| | optim_groups, |
| | lr=self.args.lr_init, |
| | betas=self.args.betas, |
| | eps=self.args.adam_eps, |
| | bias_correction=True, |
| | adamw_mode=False, |
| | weight_decay=0, |
| | amsgrad=False, |
| | ) |
| | return FusedAdam( |
| | optim_groups, |
| | lr=self.args.lr_init, |
| | betas=self.args.betas, |
| | eps=self.args.adam_eps, |
| | bias_correction=True, |
| | adam_w_mode=False, |
| | weight_decay=0, |
| | amsgrad=False, |
| | ) |
| | |
| |
|
| | @property |
| | def deepspeed_offload(self) -> bool: |
| | strategy = self.trainer.strategy |
| | if isinstance(strategy, DeepSpeedStrategy): |
| | config = strategy.config["zero_optimization"] |
| | return config.get("offload_optimizer") or config.get("offload_param") |
| | return False |
| |
|
| | def forward(self, img): |
| | z = self.encoder(img) |
| | z = ToBinary.apply(z) |
| | out = self.decoder(z) |
| | return out |
| |
|
| | def training_step(self, batch, batch_idx): |
| | args = self.args |
| | img, txt = batch |
| | out = self(img) |
| | if self.trainer.is_global_zero: |
| | if (self.trainer.global_step + 1) % (100 * int(args.devices)) == 0: |
| | img_dir = f"test/image_model/{args.run_name}" |
| | if not os.path.exists(img_dir): |
| | os.makedirs(img_dir) |
| | vision.utils.save_image( |
| | img[:4], f"{img_dir}/{self.trainer.global_step}-src.jpg" |
| | ) |
| | vision.utils.save_image( |
| | out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg" |
| | ) |
| |
|
| | |
| | loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True) |
| |
|
| | iii = self.clip_model.encode_image((img - self.clip_mean) / self.clip_std) |
| | ooo = self.clip_model.encode_image((out - self.clip_mean) / self.clip_std) |
| | loss_clip = torch.mean(cosine_loss(iii, ooo)) |
| |
|
| | if args.my_img_l1_scale > 0: |
| | loss_l1 = F.l1_loss(out, img) |
| | return loss_dists + loss_clip * args.my_img_clip_scale + loss_l1 * args.my_img_l1_scale |
| | else: |
| | return loss_dists + loss_clip * args.my_img_clip_scale |
| |
|
| | def training_step_end(self, batch_parts): |
| | all = self.all_gather(batch_parts) |
| | if self.trainer.is_global_zero: |
| | self.trainer.my_loss_all = all |
| |
|
| | def generate_init_weight(self): |
| | print( |
| | f""" |
| | ############################################################################ |
| | # |
| | # Init model weight (slow for large models)... |
| | # |
| | ############################################################################ |
| | """ |
| | ) |
| | m = {} |
| | for n in self.state_dict(): |
| | scale = 1 |
| | p = self.state_dict()[n] |
| | shape = p.shape |
| | ss = n.split('.') |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | m[n] = p |
| |
|
| | m[n] = m[n].cpu() |
| | if os.environ["RWKV_FLOAT_MODE"] == "fp16": |
| | m[n] = m[n].half() |
| | elif os.environ["RWKV_FLOAT_MODE"] == "bf16": |
| | m[n] = m[n].bfloat16() |
| |
|
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | return m |
| |
|