|
|
import copy
|
|
|
import math
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
import utils
|
|
|
from accelerate import Accelerator
|
|
|
from diffusers import StableDiffusionPipeline
|
|
|
from diffusers.image_processor import PipelineImageInput
|
|
|
from losses import *
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
class ADPipeline(StableDiffusionPipeline):
|
|
|
def freeze(self):
|
|
|
self.vae.requires_grad_(False)
|
|
|
self.unet.requires_grad_(False)
|
|
|
self.text_encoder.requires_grad_(False)
|
|
|
self.classifier.requires_grad_(False)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def image2latent(self, image):
|
|
|
dtype = next(self.vae.parameters()).dtype
|
|
|
device = self._execution_device
|
|
|
image = image.to(device=device, dtype=dtype) * 2.0 - 1.0
|
|
|
latent = self.vae.encode(image)["latent_dist"].mean
|
|
|
latent = latent * self.vae.config.scaling_factor
|
|
|
return latent
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def latent2image(self, latent):
|
|
|
dtype = next(self.vae.parameters()).dtype
|
|
|
device = self._execution_device
|
|
|
latent = latent.to(device=device, dtype=dtype)
|
|
|
latent = latent / self.vae.config.scaling_factor
|
|
|
image = self.vae.decode(latent)[0]
|
|
|
return (image * 0.5 + 0.5).clamp(0, 1)
|
|
|
|
|
|
def init(self, enable_gradient_checkpoint):
|
|
|
self.freeze()
|
|
|
weight_dtype = torch.float32
|
|
|
if self.accelerator.mixed_precision == "fp16":
|
|
|
weight_dtype = torch.float16
|
|
|
elif self.accelerator.mixed_precision == "bf16":
|
|
|
weight_dtype = torch.bfloat16
|
|
|
|
|
|
|
|
|
self.unet.to(self.accelerator.device, dtype=weight_dtype)
|
|
|
self.vae.to(self.accelerator.device, dtype=weight_dtype)
|
|
|
self.text_encoder.to(self.accelerator.device, dtype=weight_dtype)
|
|
|
self.classifier.to(self.accelerator.device, dtype=weight_dtype)
|
|
|
self.classifier = self.accelerator.prepare(self.classifier)
|
|
|
if enable_gradient_checkpoint:
|
|
|
self.classifier.enable_gradient_checkpointing()
|
|
|
|
|
|
def sample(
|
|
|
self,
|
|
|
lr=0.05,
|
|
|
iters=1,
|
|
|
attn_scale=1,
|
|
|
adain=False,
|
|
|
weight=0.25,
|
|
|
controller=None,
|
|
|
style_image=None,
|
|
|
content_image=None,
|
|
|
mixed_precision="no",
|
|
|
start_time=999,
|
|
|
enable_gradient_checkpoint=False,
|
|
|
prompt: Union[str, List[str]] = None,
|
|
|
height: Optional[int] = None,
|
|
|
width: Optional[int] = None,
|
|
|
num_inference_steps: int = 50,
|
|
|
guidance_scale: float = 7.5,
|
|
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
|
num_images_per_prompt: Optional[int] = 1,
|
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
|
latents: Optional[torch.Tensor] = None,
|
|
|
prompt_embeds: Optional[torch.Tensor] = None,
|
|
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
|
|
ip_adapter_image: Optional[PipelineImageInput] = None,
|
|
|
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
|
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
guidance_rescale: float = 0.0,
|
|
|
clip_skip: Optional[int] = None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
|
|
|
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
|
|
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
|
|
self._guidance_scale = guidance_scale
|
|
|
self._guidance_rescale = guidance_rescale
|
|
|
self._clip_skip = clip_skip
|
|
|
self._cross_attention_kwargs = cross_attention_kwargs
|
|
|
self._interrupt = False
|
|
|
|
|
|
self.accelerator = Accelerator(
|
|
|
mixed_precision=mixed_precision, gradient_accumulation_steps=1
|
|
|
)
|
|
|
self.init(enable_gradient_checkpoint)
|
|
|
|
|
|
|
|
|
if prompt is not None and isinstance(prompt, str):
|
|
|
batch_size = 1
|
|
|
elif prompt is not None and isinstance(prompt, list):
|
|
|
batch_size = len(prompt)
|
|
|
else:
|
|
|
batch_size = prompt_embeds.shape[0]
|
|
|
|
|
|
device = self._execution_device
|
|
|
|
|
|
|
|
|
lora_scale = (
|
|
|
self.cross_attention_kwargs.get("scale", None)
|
|
|
if self.cross_attention_kwargs is not None
|
|
|
else None
|
|
|
)
|
|
|
do_cfg = guidance_scale > 1.0
|
|
|
|
|
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
|
|
prompt,
|
|
|
device,
|
|
|
num_images_per_prompt,
|
|
|
do_cfg,
|
|
|
negative_prompt,
|
|
|
prompt_embeds=prompt_embeds,
|
|
|
negative_prompt_embeds=negative_prompt_embeds,
|
|
|
lora_scale=lora_scale,
|
|
|
clip_skip=self.clip_skip,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if do_cfg:
|
|
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
|
|
|
|
|
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
|
|
image_embeds = self.prepare_ip_adapter_image_embeds(
|
|
|
ip_adapter_image,
|
|
|
ip_adapter_image_embeds,
|
|
|
device,
|
|
|
batch_size * num_images_per_prompt,
|
|
|
do_cfg,
|
|
|
)
|
|
|
|
|
|
|
|
|
num_channels_latents = self.unet.config.in_channels
|
|
|
latents = self.prepare_latents(
|
|
|
batch_size * num_images_per_prompt,
|
|
|
num_channels_latents,
|
|
|
height,
|
|
|
width,
|
|
|
prompt_embeds.dtype,
|
|
|
device,
|
|
|
generator,
|
|
|
latents,
|
|
|
)
|
|
|
|
|
|
|
|
|
added_cond_kwargs = (
|
|
|
{"image_embeds": image_embeds}
|
|
|
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
|
|
|
else None
|
|
|
)
|
|
|
|
|
|
|
|
|
timestep_cond = None
|
|
|
if self.unet.config.time_cond_proj_dim is not None:
|
|
|
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
|
|
|
batch_size * num_images_per_prompt
|
|
|
)
|
|
|
timestep_cond = self.get_guidance_scale_embedding(
|
|
|
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
|
|
).to(device=device, dtype=latents.dtype)
|
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps)
|
|
|
timesteps = self.scheduler.timesteps
|
|
|
self.style_latent = self.image2latent(style_image)
|
|
|
if content_image is not None:
|
|
|
self.content_latent = self.image2latent(content_image)
|
|
|
else:
|
|
|
self.content_latent = None
|
|
|
null_embeds = self.encode_prompt("", device, 1, False)[0]
|
|
|
self.null_embeds = null_embeds
|
|
|
self.null_embeds_for_latents = torch.cat([null_embeds] * latents.shape[0])
|
|
|
self.null_embeds_for_style = torch.cat(
|
|
|
[null_embeds] * self.style_latent.shape[0]
|
|
|
)
|
|
|
|
|
|
self.adain = adain
|
|
|
self.attn_scale = attn_scale
|
|
|
self.cache = utils.DataCache()
|
|
|
self.controller = controller
|
|
|
utils.register_attn_control(
|
|
|
self.classifier, controller=self.controller, cache=self.cache
|
|
|
)
|
|
|
print("Total self attention layers of Unet: ", controller.num_self_layers)
|
|
|
print("Self attention layers for AD: ", controller.self_layers)
|
|
|
|
|
|
pbar = tqdm(timesteps, desc="Sample")
|
|
|
for i, t in enumerate(pbar):
|
|
|
with torch.no_grad():
|
|
|
|
|
|
latent_model_input = torch.cat([latents] * 2) if do_cfg else latents
|
|
|
latent_model_input = self.scheduler.scale_model_input(
|
|
|
latent_model_input, t
|
|
|
)
|
|
|
|
|
|
noise_pred = self.unet(
|
|
|
latent_model_input,
|
|
|
t,
|
|
|
encoder_hidden_states=prompt_embeds,
|
|
|
timestep_cond=timestep_cond,
|
|
|
cross_attention_kwargs=self.cross_attention_kwargs,
|
|
|
added_cond_kwargs=added_cond_kwargs,
|
|
|
return_dict=False,
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
if do_cfg:
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
|
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
|
|
noise_pred_text - noise_pred_uncond
|
|
|
)
|
|
|
latents = self.scheduler.step(
|
|
|
noise_pred, t, latents, return_dict=False
|
|
|
)[0]
|
|
|
if iters > 0 and t < start_time:
|
|
|
latents = self.AD(latents, t, lr, iters, pbar, weight)
|
|
|
|
|
|
images = self.latent2image(latents)
|
|
|
|
|
|
self.maybe_free_model_hooks()
|
|
|
return images
|
|
|
|
|
|
def optimize(
|
|
|
self,
|
|
|
latents=None,
|
|
|
attn_scale=1.0,
|
|
|
lr=0.05,
|
|
|
iters=1,
|
|
|
weight=0,
|
|
|
width=512,
|
|
|
height=512,
|
|
|
batch_size=1,
|
|
|
controller=None,
|
|
|
style_image=None,
|
|
|
content_image=None,
|
|
|
mixed_precision="no",
|
|
|
num_inference_steps=50,
|
|
|
enable_gradient_checkpoint=False,
|
|
|
source_mask=None,
|
|
|
target_mask=None,
|
|
|
):
|
|
|
height = height // self.vae_scale_factor
|
|
|
width = width // self.vae_scale_factor
|
|
|
|
|
|
self.accelerator = Accelerator(
|
|
|
mixed_precision=mixed_precision, gradient_accumulation_steps=1
|
|
|
)
|
|
|
self.init(enable_gradient_checkpoint)
|
|
|
|
|
|
style_latent = self.image2latent(style_image)
|
|
|
latents = torch.randn((batch_size, 4, height, width), device=self.device)
|
|
|
null_embeds = self.encode_prompt("", self.device, 1, False)[0]
|
|
|
null_embeds_for_latents = null_embeds.repeat(latents.shape[0], 1, 1)
|
|
|
null_embeds_for_style = null_embeds.repeat(style_latent.shape[0], 1, 1)
|
|
|
|
|
|
if content_image is not None:
|
|
|
content_latent = self.image2latent(content_image)
|
|
|
latents = torch.cat([content_latent.clone()] * batch_size)
|
|
|
null_embeds_for_content = null_embeds.repeat(content_latent.shape[0], 1, 1)
|
|
|
|
|
|
self.cache = utils.DataCache()
|
|
|
self.controller = controller
|
|
|
utils.register_attn_control(
|
|
|
self.classifier, controller=self.controller, cache=self.cache
|
|
|
)
|
|
|
print("Total self attention layers of Unet: ", controller.num_self_layers)
|
|
|
print("Self attention layers for AD: ", controller.self_layers)
|
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps)
|
|
|
timesteps = self.scheduler.timesteps
|
|
|
latents = latents.detach().float()
|
|
|
optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
|
|
|
optimizer = self.accelerator.prepare(optimizer)
|
|
|
pbar = tqdm(timesteps, desc="Optimize")
|
|
|
for i, t in enumerate(pbar):
|
|
|
|
|
|
with torch.no_grad():
|
|
|
qs_list, ks_list, vs_list, s_out_list = self.extract_feature(
|
|
|
style_latent,
|
|
|
t,
|
|
|
null_embeds_for_style,
|
|
|
)
|
|
|
if content_image is not None:
|
|
|
qc_list, kc_list, vc_list, c_out_list = self.extract_feature(
|
|
|
content_latent,
|
|
|
t,
|
|
|
null_embeds_for_content,
|
|
|
)
|
|
|
for j in range(iters):
|
|
|
style_loss = 0
|
|
|
content_loss = 0
|
|
|
optimizer.zero_grad()
|
|
|
q_list, k_list, v_list, self_out_list = self.extract_feature(
|
|
|
latents,
|
|
|
t,
|
|
|
null_embeds_for_latents,
|
|
|
)
|
|
|
style_loss = ad_loss(q_list, ks_list, vs_list, self_out_list, scale=attn_scale, source_mask=source_mask, target_mask=target_mask)
|
|
|
if content_image is not None:
|
|
|
content_loss = q_loss(q_list, qc_list)
|
|
|
|
|
|
|
|
|
loss = style_loss + content_loss * weight
|
|
|
self.accelerator.backward(loss)
|
|
|
optimizer.step()
|
|
|
pbar.set_postfix(loss=loss.item(), time=t.item(), iter=j)
|
|
|
images = self.latent2image(latents)
|
|
|
|
|
|
self.maybe_free_model_hooks()
|
|
|
return images
|
|
|
|
|
|
def panorama(
|
|
|
self,
|
|
|
lr=0.05,
|
|
|
iters=1,
|
|
|
attn_scale=1,
|
|
|
adain=False,
|
|
|
controller=None,
|
|
|
style_image=None,
|
|
|
mixed_precision="no",
|
|
|
enable_gradient_checkpoint=False,
|
|
|
prompt: Union[str, List[str]] = None,
|
|
|
height: Optional[int] = None,
|
|
|
width: Optional[int] = None,
|
|
|
num_inference_steps: int = 50,
|
|
|
guidance_scale: float = 1,
|
|
|
stride=8,
|
|
|
view_batch_size: int = 16,
|
|
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
|
num_images_per_prompt: Optional[int] = 1,
|
|
|
eta: float = 0.0,
|
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
|
latents: Optional[torch.Tensor] = None,
|
|
|
prompt_embeds: Optional[torch.Tensor] = None,
|
|
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
|
|
ip_adapter_image: Optional[PipelineImageInput] = None,
|
|
|
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
|
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
guidance_rescale: float = 0.0,
|
|
|
clip_skip: Optional[int] = None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
|
|
|
|
|
|
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
|
|
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
|
|
|
|
|
self._guidance_scale = guidance_scale
|
|
|
self._guidance_rescale = guidance_rescale
|
|
|
self._clip_skip = clip_skip
|
|
|
self._cross_attention_kwargs = cross_attention_kwargs
|
|
|
self._interrupt = False
|
|
|
|
|
|
self.accelerator = Accelerator(
|
|
|
mixed_precision=mixed_precision, gradient_accumulation_steps=1
|
|
|
)
|
|
|
self.init(enable_gradient_checkpoint)
|
|
|
|
|
|
|
|
|
if prompt is not None and isinstance(prompt, str):
|
|
|
batch_size = 1
|
|
|
elif prompt is not None and isinstance(prompt, list):
|
|
|
batch_size = len(prompt)
|
|
|
else:
|
|
|
batch_size = prompt_embeds.shape[0]
|
|
|
|
|
|
device = self._execution_device
|
|
|
|
|
|
|
|
|
|
|
|
do_cfg = guidance_scale > 1.0
|
|
|
|
|
|
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
|
|
image_embeds = self.prepare_ip_adapter_image_embeds(
|
|
|
ip_adapter_image,
|
|
|
ip_adapter_image_embeds,
|
|
|
device,
|
|
|
batch_size * num_images_per_prompt,
|
|
|
self.do_classifier_free_guidance,
|
|
|
)
|
|
|
|
|
|
|
|
|
text_encoder_lora_scale = (
|
|
|
cross_attention_kwargs.get("scale", None)
|
|
|
if cross_attention_kwargs is not None
|
|
|
else None
|
|
|
)
|
|
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
|
|
prompt,
|
|
|
device,
|
|
|
num_images_per_prompt,
|
|
|
do_cfg,
|
|
|
negative_prompt,
|
|
|
prompt_embeds=prompt_embeds,
|
|
|
negative_prompt_embeds=negative_prompt_embeds,
|
|
|
lora_scale=text_encoder_lora_scale,
|
|
|
clip_skip=clip_skip,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if do_cfg:
|
|
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
|
|
|
|
|
|
|
|
num_channels_latents = self.unet.config.in_channels
|
|
|
latents = self.prepare_latents(
|
|
|
batch_size * num_images_per_prompt,
|
|
|
num_channels_latents,
|
|
|
height,
|
|
|
width,
|
|
|
prompt_embeds.dtype,
|
|
|
device,
|
|
|
generator,
|
|
|
latents,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
views = self.get_views_(height, width, window_size=64, stride=stride)
|
|
|
views_batch = [
|
|
|
views[i : i + view_batch_size]
|
|
|
for i in range(0, len(views), view_batch_size)
|
|
|
]
|
|
|
print(len(views), len(views_batch), views_batch)
|
|
|
self.scheduler.set_timesteps(num_inference_steps)
|
|
|
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(
|
|
|
views_batch
|
|
|
)
|
|
|
count = torch.zeros_like(latents)
|
|
|
value = torch.zeros_like(latents)
|
|
|
|
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
|
|
|
|
|
|
|
|
added_cond_kwargs = (
|
|
|
{"image_embeds": image_embeds}
|
|
|
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
|
|
else None
|
|
|
)
|
|
|
|
|
|
|
|
|
timestep_cond = None
|
|
|
if self.unet.config.time_cond_proj_dim is not None:
|
|
|
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
|
|
|
batch_size * num_images_per_prompt
|
|
|
)
|
|
|
timestep_cond = self.get_guidance_scale_embedding(
|
|
|
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
|
|
).to(device=device, dtype=latents.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
timesteps = self.scheduler.timesteps
|
|
|
self.style_latent = self.image2latent(style_image)
|
|
|
self.content_latent = None
|
|
|
null_embeds = self.encode_prompt("", device, 1, False)[0]
|
|
|
self.null_embeds = null_embeds
|
|
|
self.null_embeds_for_latents = torch.cat([null_embeds] * latents.shape[0])
|
|
|
self.null_embeds_for_style = torch.cat(
|
|
|
[null_embeds] * self.style_latent.shape[0]
|
|
|
)
|
|
|
self.adain = adain
|
|
|
self.attn_scale = attn_scale
|
|
|
self.cache = utils.DataCache()
|
|
|
self.controller = controller
|
|
|
utils.register_attn_control(
|
|
|
self.classifier, controller=self.controller, cache=self.cache
|
|
|
)
|
|
|
print("Total self attention layers of Unet: ", controller.num_self_layers)
|
|
|
print("Self attention layers for AD: ", controller.self_layers)
|
|
|
|
|
|
pbar = tqdm(timesteps, desc="Sample")
|
|
|
for i, t in enumerate(pbar):
|
|
|
count.zero_()
|
|
|
value.zero_()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for j, batch_view in enumerate(views_batch):
|
|
|
vb_size = len(batch_view)
|
|
|
|
|
|
latents_for_view = torch.cat(
|
|
|
[
|
|
|
latents[:, :, h_start:h_end, w_start:w_end]
|
|
|
for h_start, h_end, w_start, w_end in batch_view
|
|
|
]
|
|
|
)
|
|
|
|
|
|
self.scheduler.__dict__.update(views_scheduler_status[j])
|
|
|
|
|
|
|
|
|
latent_model_input = (
|
|
|
latents_for_view.repeat_interleave(2, dim=0)
|
|
|
if do_cfg
|
|
|
else latents_for_view
|
|
|
)
|
|
|
|
|
|
latent_model_input = self.scheduler.scale_model_input(
|
|
|
latent_model_input, t
|
|
|
)
|
|
|
|
|
|
|
|
|
prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
noise_pred = self.unet(
|
|
|
latent_model_input,
|
|
|
t,
|
|
|
encoder_hidden_states=prompt_embeds_input,
|
|
|
timestep_cond=timestep_cond,
|
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
|
added_cond_kwargs=added_cond_kwargs,
|
|
|
).sample
|
|
|
|
|
|
|
|
|
if do_cfg:
|
|
|
noise_pred_uncond, noise_pred_text = (
|
|
|
noise_pred[::2],
|
|
|
noise_pred[1::2],
|
|
|
)
|
|
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
|
|
noise_pred_text - noise_pred_uncond
|
|
|
)
|
|
|
|
|
|
|
|
|
latents_denoised_batch = self.scheduler.step(
|
|
|
noise_pred, t, latents_for_view, **extra_step_kwargs
|
|
|
).prev_sample
|
|
|
if iters > 0:
|
|
|
self.null_embeds_for_latents = torch.cat(
|
|
|
[self.null_embeds] * noise_pred.shape[0]
|
|
|
)
|
|
|
latents_denoised_batch = self.AD(
|
|
|
latents_denoised_batch, t, lr, iters, pbar
|
|
|
)
|
|
|
|
|
|
views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__)
|
|
|
|
|
|
|
|
|
for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
|
|
|
latents_denoised_batch.chunk(vb_size), batch_view
|
|
|
):
|
|
|
|
|
|
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
|
|
count[:, :, h_start:h_end, w_start:w_end] += 1
|
|
|
|
|
|
|
|
|
latents = torch.where(count > 0, value / count, value)
|
|
|
|
|
|
images = self.latent2image(latents)
|
|
|
|
|
|
self.maybe_free_model_hooks()
|
|
|
return images
|
|
|
|
|
|
def AD(self, latents, t, lr, iters, pbar, weight=0):
|
|
|
t = max(
|
|
|
t
|
|
|
- self.scheduler.config.num_train_timesteps
|
|
|
// self.scheduler.num_inference_steps,
|
|
|
torch.tensor([0], device=self.device),
|
|
|
)
|
|
|
if self.adain:
|
|
|
noise = torch.randn_like(self.style_latent)
|
|
|
style_latent = self.scheduler.add_noise(self.style_latent, noise, t)
|
|
|
latents = utils.adain(latents, style_latent)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
qs_list, ks_list, vs_list, s_out_list = self.extract_feature(
|
|
|
self.style_latent,
|
|
|
t,
|
|
|
self.null_embeds_for_style,
|
|
|
add_noise=True,
|
|
|
)
|
|
|
if self.content_latent is not None:
|
|
|
qc_list, kc_list, vc_list, c_out_list = self.extract_feature(
|
|
|
self.content_latent,
|
|
|
t,
|
|
|
self.null_embeds,
|
|
|
add_noise=True,
|
|
|
)
|
|
|
|
|
|
latents = latents.detach()
|
|
|
optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
|
|
|
optimizer = self.accelerator.prepare(optimizer)
|
|
|
|
|
|
for j in range(iters):
|
|
|
style_loss = 0
|
|
|
content_loss = 0
|
|
|
optimizer.zero_grad()
|
|
|
q_list, k_list, v_list, self_out_list = self.extract_feature(
|
|
|
latents,
|
|
|
t,
|
|
|
self.null_embeds_for_latents,
|
|
|
add_noise=False,
|
|
|
)
|
|
|
style_loss = ad_loss(q_list, ks_list, vs_list, self_out_list, scale=self.attn_scale)
|
|
|
if self.content_latent is not None:
|
|
|
content_loss = q_loss(q_list, qc_list)
|
|
|
|
|
|
|
|
|
loss = style_loss + content_loss * weight
|
|
|
self.accelerator.backward(loss)
|
|
|
optimizer.step()
|
|
|
|
|
|
pbar.set_postfix(loss=loss.item(), time=t.item(), iter=j)
|
|
|
latents = latents.detach()
|
|
|
return latents
|
|
|
|
|
|
def extract_feature(
|
|
|
self,
|
|
|
latent,
|
|
|
t,
|
|
|
embeds,
|
|
|
add_noise=False,
|
|
|
):
|
|
|
self.cache.clear()
|
|
|
self.controller.step()
|
|
|
if add_noise:
|
|
|
noise = torch.randn_like(latent)
|
|
|
latent_ = self.scheduler.add_noise(latent, noise, t)
|
|
|
else:
|
|
|
latent_ = latent
|
|
|
_ = self.classifier(latent_, t, embeds)[0]
|
|
|
return self.cache.get()
|
|
|
|
|
|
def get_views_(
|
|
|
self,
|
|
|
panorama_height: int,
|
|
|
panorama_width: int,
|
|
|
window_size: int = 64,
|
|
|
stride: int = 8,
|
|
|
) -> List[Tuple[int, int, int, int]]:
|
|
|
panorama_height //= 8
|
|
|
panorama_width //= 8
|
|
|
|
|
|
num_blocks_height = (
|
|
|
math.ceil((panorama_height - window_size) / stride) + 1
|
|
|
if panorama_height > window_size
|
|
|
else 1
|
|
|
)
|
|
|
num_blocks_width = (
|
|
|
math.ceil((panorama_width - window_size) / stride) + 1
|
|
|
if panorama_width > window_size
|
|
|
else 1
|
|
|
)
|
|
|
|
|
|
views = []
|
|
|
for i in range(int(num_blocks_height)):
|
|
|
for j in range(int(num_blocks_width)):
|
|
|
h_start = int(min(i * stride, panorama_height - window_size))
|
|
|
w_start = int(min(j * stride, panorama_width - window_size))
|
|
|
|
|
|
h_end = h_start + window_size
|
|
|
w_end = w_start + window_size
|
|
|
|
|
|
views.append((h_start, h_end, w_start, w_end))
|
|
|
|
|
|
return views
|
|
|
|