Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from diffusers import UNet2DConditionModel | |
| from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput | |
| from transformers import CLIPVisionModel | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple | |
| from transformers.utils import ModelOutput | |
| import numpy as np | |
| import argparse | |
| from ConsistentID.lib.pipeline_ConsistentID import ConsistentIDPipeline | |
| from diffusers import ( | |
| UNet2DConditionModel, | |
| DDIMScheduler, | |
| ) | |
| def str2bool(v): | |
| if isinstance(v, bool): | |
| return v | |
| if v.lower() in ("yes", "true", "t", "y", "1"): | |
| return True | |
| elif v.lower() in ("no", "false", "f", "n", "0"): | |
| return False | |
| else: | |
| raise argparse.ArgumentTypeError("Boolean value expected.") | |
| # perturb_tensor() adds a fixed amount of noise to the tensor. | |
| def perturb_tensor(ts, perturb_std, perturb_std_is_relative=True, keep_norm=False, | |
| std_dim=-1, norm_dim=-1, verbose=True): | |
| orig_ts = ts | |
| if perturb_std_is_relative: | |
| ts_std_mean = ts.std(dim=std_dim).mean().detach() | |
| perturb_std *= ts_std_mean | |
| # ts_std_mean: 50~80 for unnormalized images, perturb_std: 2.5-4 for 0.05 noise. | |
| if verbose: | |
| print(f"ts_std_mean: {ts_std_mean:.03f}, perturb_std: {perturb_std:.03f}") | |
| noise = torch.randn_like(ts) * perturb_std | |
| if keep_norm: | |
| orig_norm = ts.norm(dim=norm_dim, keepdim=True) | |
| ts = ts + noise | |
| new_norm = ts.norm(dim=norm_dim, keepdim=True).detach() | |
| ts = ts * orig_norm / (new_norm + 1e-8) | |
| else: | |
| ts = ts + noise | |
| if verbose: | |
| print(f"Correlations between new and original tensors: {F.cosine_similarity(ts.flatten(), orig_ts.flatten(), dim=0).detach().item():.03f}") | |
| return ts | |
| def perturb_np_array(np_array, perturb_std, perturb_std_is_relative=True, std_dim=-1): | |
| ts = torch.from_numpy(np_array).to(dtype=torch.float32) | |
| ts = perturb_tensor(ts, perturb_std, perturb_std_is_relative, std_dim=std_dim) | |
| return ts.numpy().astype(np_array.dtype) | |
| def calc_stats(emb_name, embeddings, mean_dim=-1): | |
| print("%s:" %emb_name) | |
| repeat_count = [1] * embeddings.ndim | |
| repeat_count[mean_dim] = embeddings.shape[mean_dim] | |
| # Average across the mean_dim dim. | |
| # Make emb_mean the same size as embeddings, as required by F.l1_loss. | |
| emb_mean = embeddings.mean(mean_dim, keepdim=True).repeat(repeat_count) | |
| l1_loss = F.l1_loss(embeddings, emb_mean) | |
| # F.l2_loss doesn't take sqrt. So the loss is very small. | |
| # Compute it manually. | |
| l2_loss = ((embeddings - emb_mean) ** 2).mean().sqrt() | |
| norms = torch.norm(embeddings, dim=1).detach().cpu().numpy() | |
| print("L1: %.4f, L2: %.4f" %(l1_loss.detach().item(), l2_loss.detach().item())) | |
| print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std())) | |
| # new_token_embeddings: [new_num_tokens, 768]. | |
| def extend_nn_embedding(old_nn_embedding, new_token_embeddings): | |
| emb_dim = old_nn_embedding.embedding_dim | |
| num_old_tokens = old_nn_embedding.num_embeddings | |
| num_new_tokens = new_token_embeddings.shape[0] | |
| num_tokens2 = num_old_tokens + num_new_tokens | |
| new_nn_embedding = nn.Embedding(num_tokens2, emb_dim, | |
| device=old_nn_embedding.weight.device, | |
| dtype=old_nn_embedding.weight.dtype) | |
| old_num_tokens = old_nn_embedding.weight.shape[0] | |
| # Copy the first old_num_tokens embeddings from old_nn_embedding to new_nn_embedding. | |
| new_nn_embedding.weight.data[:old_num_tokens] = old_nn_embedding.weight.data | |
| # Copy the new embeddings to new_nn_embedding. | |
| new_nn_embedding.weight.data[old_num_tokens:] = new_token_embeddings | |
| print(f"Extended nn.Embedding from {num_old_tokens} to {num_tokens2} tokens.") | |
| return new_nn_embedding | |
| # Revised from RevGrad, by removing the grad negation. | |
| class ScaleGrad(torch.autograd.Function): | |
| def forward(ctx, input_, alpha_, debug=False): | |
| ctx.save_for_backward(alpha_, debug) | |
| output = input_ | |
| if debug: | |
| print(f"input: {input_.abs().mean().detach().item()}") | |
| return output | |
| def backward(ctx, grad_output): # pragma: no cover | |
| # saved_tensors returns a tuple of tensors. | |
| alpha_, debug = ctx.saved_tensors | |
| if ctx.needs_input_grad[0]: | |
| grad_output2 = grad_output * alpha_ | |
| if debug: | |
| print(f"grad_output2: {grad_output2.abs().mean().detach().item()}") | |
| else: | |
| grad_output2 = None | |
| return grad_output2, None, None | |
| class GradientScaler(nn.Module): | |
| def __init__(self, alpha=1., debug=False, *args, **kwargs): | |
| """ | |
| A gradient scaling layer. | |
| This layer has no parameters, and simply scales the gradient in the backward pass. | |
| """ | |
| super().__init__(*args, **kwargs) | |
| self._alpha = torch.tensor(alpha, requires_grad=False) | |
| self._debug = torch.tensor(debug, requires_grad=False) | |
| def forward(self, input_): | |
| _debug = self._debug if hasattr(self, '_debug') else False | |
| return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug) | |
| def gen_gradient_scaler(alpha, debug=False): | |
| if alpha == 1: | |
| return nn.Identity() | |
| if alpha > 0: | |
| return GradientScaler(alpha, debug=debug) | |
| else: | |
| assert alpha == 0 | |
| # Don't use lambda function here, otherwise the object can't be pickled. | |
| return torch.detach | |
| def pad_image_obj_to_square(image_obj, new_size=-1): | |
| # Remove alpha channel if it exists. | |
| if image_obj.mode == 'RGBA': | |
| image_obj = image_obj.convert('RGB') | |
| # Pad input to be width == height | |
| width, height = orig_size = image_obj.size | |
| new_width, new_height = max(width, height), max(width, height) | |
| if width != height: | |
| if width > height: | |
| pads = (0, (width - height) // 2) | |
| elif height > width: | |
| pads = ((height - width) // 2, 0) | |
| square_image_obj = Image.new("RGB", (new_width, new_height)) | |
| # pads indicates the upper left corner to paste the input. | |
| square_image_obj.paste(image_obj, pads) | |
| #square_image_obj = square_image_obj.resize((512, 512)) | |
| print(f"{width}x{height} -> {new_width}x{new_height} -> {square_image_obj.size}") | |
| long_short_ratio = max(width, height) / min(width, height) | |
| else: | |
| square_image_obj = image_obj | |
| pads = (0, 0) | |
| long_short_ratio = 1 | |
| if new_size > 0: | |
| # Resize the shorter edge to 512. | |
| square_image_obj = square_image_obj.resize([int(new_size * long_short_ratio), int(new_size * long_short_ratio)]) | |
| return square_image_obj, pads, orig_size | |
| class UNetEnsemble(nn.Module): | |
| # The first unet is the unet already loaded in a pipeline. | |
| def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights_in_ensemble=None, device='cuda', torch_dtype=torch.float16): | |
| super().__init__() | |
| if unets is not None: | |
| unets = [ unet.to(device) for unet in unets ] | |
| else: | |
| unets = [] | |
| if unet_types is not None: | |
| for unet_type in unet_types: | |
| if unet_type == "arc2face": | |
| from adaface.arc2face_models import create_arc2face_pipeline | |
| unet = create_arc2face_pipeline(unet_only=True) | |
| elif unet_type == "consistentID": | |
| unet = create_consistentid_pipeline(unet_only=True) | |
| else: | |
| breakpoint() | |
| unets.append(unet.to(device=device)) | |
| if extra_unet_dirpaths is not None: | |
| for unet_path in extra_unet_dirpaths: | |
| unet = UNet2DConditionModel.from_pretrained(unet_path, torch_dtype=torch_dtype) | |
| unets.append(unet.to(device=device)) | |
| if unet_weights_in_ensemble is None: | |
| unet_weights_in_ensemble = [1.] * len(unets) | |
| elif len(unets) < len(unet_weights_in_ensemble): | |
| unet_weights_in_ensemble = unet_weights_in_ensemble[:len(unets)] | |
| elif len(unets) > len(unet_weights_in_ensemble): | |
| breakpoint() | |
| unet_weights_in_ensemble = torch.tensor(unet_weights_in_ensemble, dtype=torch_dtype) | |
| unet_weights_in_ensemble = unet_weights_in_ensemble / unet_weights_in_ensemble.sum() | |
| self.unets = nn.ModuleList(unets) | |
| # Put the weights in a Parameter so that they will be moved to the same device as the model. | |
| self.unet_weights_in_ensemble = nn.Parameter(unet_weights_in_ensemble, requires_grad=False) | |
| print(f"UNetEnsemble: {len(self.unets)} UNets loaded with weights: {self.unet_weights_in_ensemble.data.cpu().numpy()}") | |
| # Set these fields to be compatible with diffusers. | |
| self.dtype = self.unets[0].dtype | |
| self.device = self.unets[0].device | |
| self.config = self.unets[0].config | |
| def forward(self, *args, **kwargs): | |
| return_dict = kwargs.get('return_dict', True) | |
| teacher_contexts = kwargs.pop('encoder_hidden_states', None) | |
| # Only one teacher_context is provided. That means all unets will use the same teacher_context. | |
| # We repeat the teacher_contexts to match the number of unets. | |
| if not isinstance(teacher_contexts, (list, tuple)): | |
| teacher_contexts = [teacher_contexts] | |
| if len(teacher_contexts) == 1 and len(self.unets) > 1: | |
| teacher_contexts = teacher_contexts * len(self.unets) | |
| samples = [] | |
| for unet, teacher_context in zip(self.unets, teacher_contexts): | |
| sample = unet(encoder_hidden_states=teacher_context, *args, **kwargs) | |
| if not return_dict: | |
| sample = sample[0] | |
| else: | |
| sample = sample.sample | |
| samples.append(sample) | |
| samples = torch.stack(samples, dim=0) | |
| unet_weights_in_ensemble = self.unet_weights_in_ensemble.reshape(-1, *([1] * (samples.ndim - 1))) | |
| sample = (samples * unet_weights_in_ensemble).sum(dim=0) | |
| if not return_dict: | |
| return (sample,) | |
| else: | |
| return UNet2DConditionOutput(sample=sample) | |
| def create_consistentid_pipeline(base_model_path="models/sd15-dste8-vae.safetensors", | |
| dtype=torch.float16, unet_only=False): | |
| pipe = ConsistentIDPipeline.from_single_file(base_model_path) | |
| # consistentID specific modules are still in fp32. Will be converted to fp16 | |
| # later with .to(device, torch_dtype) by the caller. | |
| pipe.load_ConsistentID_model( | |
| consistentID_weight_path="models/ConsistentID/ConsistentID-v1.bin", | |
| bise_net_weight_path="models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth", | |
| ) | |
| # Avoid passing dtype to ConsistentIDPipeline.from_single_file(), | |
| # because we've overloaded .to() to convert consistentID specific modules as well, | |
| # but diffusers will call .to(dtype) in .from_single_file(), | |
| # and at that moment, the consistentID specific modules are not loaded yet. | |
| pipe.to(dtype=dtype) | |
| # We load the pipeline first, then use the unet in the pipeline. | |
| # Since the pipeline initialization will load LoRA into the unet, | |
| # now we have the unet with LoRA loaded. | |
| if unet_only: | |
| # We release text_encoder and VAE to save memory. | |
| pipe.release_components(["text_encoder", "vae"]) | |
| return pipe.unet | |
| noise_scheduler = DDIMScheduler( | |
| num_train_timesteps=1000, | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| clip_sample=False, | |
| set_alpha_to_one=False, | |
| steps_offset=1, | |
| ) | |
| pipe.scheduler = noise_scheduler | |
| return pipe | |
| class BaseModelOutputWithPooling2(ModelOutput): | |
| """ | |
| Base class for model's outputs that also contains a pooling of the last hidden states. | |
| Args: | |
| last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): | |
| Sequence of hidden-states at the output of the last layer of the model. | |
| pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): | |
| Last layer hidden-state of the first token of the sequence (classification token) after further processing | |
| through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns | |
| the classification token after processing through a linear layer and a tanh activation function. The linear | |
| layer weights are trained from the next sentence prediction (classification) objective during pretraining. | |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + | |
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. | |
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. | |
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | |
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | |
| sequence_length)`. | |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | |
| heads. | |
| """ | |
| last_hidden_state: torch.FloatTensor = None | |
| pooler_output: torch.FloatTensor = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| attn_mask: Optional[torch.FloatTensor] = None | |
| # Revised from CLIPVisionTransformer to support attention mask. | |
| # self: a CLIPVisionTransformer instance. | |
| # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L821 | |
| # pixel_values: preprocessed B*C*H*W images. [BS, 3, 224, 224] | |
| # attn_mask: B*H*W attention mask. | |
| def CLIPVisionTransformer_forward_with_mask(self, pixel_values = None, attn_mask=None, | |
| output_attentions = None, | |
| output_hidden_states = None, return_dict = None): | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if pixel_values is None: | |
| raise ValueError("You have to specify pixel_values") | |
| # Visual tokens are flattended in embeddings(). | |
| # self.embeddings: CLIPVisionEmbeddings. | |
| # hidden_states: [BS, 257, 1280]. 257: 16*16 (patch_embeds) + 1 (class_embeds). | |
| # 16*16 is output from Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), bias=False). | |
| hidden_states = self.embeddings(pixel_values) | |
| hidden_states = self.pre_layrnorm(hidden_states) | |
| if attn_mask is not None: | |
| # feat_edge_size: 16. | |
| feat_edge_size = np.sqrt(hidden_states.shape[1] - 1).astype(int) | |
| # attn_mask: [BS, 512, 512] -> [BS, 1, 16, 16]. | |
| attn_mask = F.interpolate(attn_mask.unsqueeze(1), size=(feat_edge_size, feat_edge_size), mode='nearest') | |
| # Flatten the mask: [BS, 1, 16, 16] => [BS, 1, 256]. | |
| attn_mask = attn_mask.flatten(2) | |
| # Prepend 1 to the mask: [BS, 1, 256] => [BS, 1, 257]. | |
| # This 1 corresponds to class_embeds, which is always attended to. | |
| attn_mask = torch.cat([torch.ones_like(attn_mask[:, :, :1]), attn_mask], dim=-1) | |
| attn_mask_pairs = torch.matmul(attn_mask.transpose(-1, -2), attn_mask).unsqueeze(1) | |
| else: | |
| attn_mask_pairs = None | |
| # encoder: CLIPEncoder. | |
| encoder_outputs = self.encoder( | |
| inputs_embeds=hidden_states, | |
| # New feature: (***The official documentation is wrong***) | |
| # attention_mask (`torch.Tensor` of shape `(batch_size, 1, sequence_length, sequence_length)`, *optional*): | |
| # Mask to avoid performing attention on pairs of token. Mask values selected in `[0, 1]`: | |
| # - 1 for pairs that are **not masked**, | |
| # - 0 for pairs that are **masked**. | |
| # attention_mask is eventually used by CLIPEncoderLayer: | |
| # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L370 | |
| attention_mask=attn_mask_pairs, | |
| output_attentions=output_attentions, # False | |
| output_hidden_states=output_hidden_states, # True | |
| return_dict=return_dict, # True | |
| ) | |
| # last_hidden_state: [BS, 257, 1280] | |
| last_hidden_state = encoder_outputs[0] | |
| pooled_output = last_hidden_state[:, 0, :] | |
| pooled_output = self.post_layernorm(pooled_output) | |
| # return_dict is True. | |
| if not return_dict: | |
| return (last_hidden_state, pooled_output) + encoder_outputs[1:] | |
| return BaseModelOutputWithPooling2( | |
| last_hidden_state=last_hidden_state, | |
| pooler_output=pooled_output, | |
| hidden_states=encoder_outputs.hidden_states, | |
| attentions=encoder_outputs.attentions, | |
| # Newly added: return resized flattened attention mask. | |
| # [BS, 1, 257] -> [BS, 257, 1] | |
| attn_mask=attn_mask.permute(0, 2, 1) if attn_mask is not None else None | |
| ) | |
| def CLIPVisionModel_forward_with_mask(self, pixel_values = None, attn_mask = None, output_attentions = None, | |
| output_hidden_states = None, return_dict = None): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| return self.vision_model( | |
| pixel_values=pixel_values, | |
| attn_mask=attn_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| # patch_clip_image_encoder_with_mask() is applicable to both CLIPVisionModel and CLIPVisionModelWithProjection. | |
| def patch_clip_image_encoder_with_mask(clip_image_encoder): | |
| clip_image_encoder.vision_model.forward = CLIPVisionTransformer_forward_with_mask.__get__(clip_image_encoder.vision_model) | |
| clip_image_encoder.forward = CLIPVisionModel_forward_with_mask.__get__(clip_image_encoder) | |
| return clip_image_encoder | |
| class CLIPVisionModelWithMask(CLIPVisionModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| # Replace vision_model.forward() with the new one that supports mask. | |
| patch_clip_image_encoder_with_mask(self) | |