| | import os |
| | import numpy as np |
| | import torch |
| | from contextlib import nullcontext |
| | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker |
| | from einops import rearrange |
| | from ldm.util import instantiate_from_config |
| | from ldm.models.diffusion.ddim import DDIMSampler |
| | from omegaconf import OmegaConf |
| | from PIL import Image |
| | from rich import print |
| | from transformers import CLIPImageProcessor |
| | from torch import autocast |
| | from torchvision import transforms |
| |
|
| |
|
| | def load_model_from_config(config, ckpt, device, verbose=False): |
| | print(f'Loading model from {ckpt}') |
| | pl_sd = torch.load(ckpt, map_location='cpu') |
| | if 'global_step' in pl_sd: |
| | print(f'Global Step: {pl_sd["global_step"]}') |
| | sd = pl_sd['state_dict'] |
| | model = instantiate_from_config(config.model) |
| | m, u = model.load_state_dict(sd, strict=False) |
| | if len(m) > 0 and verbose: |
| | print('missing keys:') |
| | print(m) |
| | if len(u) > 0 and verbose: |
| | print('unexpected keys:') |
| | print(u) |
| |
|
| | model.to(device) |
| | model.eval() |
| | return model |
| |
|
| |
|
| | def init_model(device, ckpt): |
| | config = os.path.join(os.path.dirname(__file__), 'configs/sd-objaverse-finetune-c_concat-256.yaml') |
| | config = OmegaConf.load(config) |
| |
|
| | |
| | models = dict() |
| | print('Instantiating LatentDiffusion...') |
| | models['turncam'] = torch.compile(load_model_from_config(config, ckpt, device=device)) |
| | print('Instantiating StableDiffusionSafetyChecker...') |
| | models['nsfw'] = StableDiffusionSafetyChecker.from_pretrained( |
| | 'CompVis/stable-diffusion-safety-checker').to(device) |
| | models['clip_fe'] = CLIPImageProcessor.from_pretrained( |
| | "openai/clip-vit-large-patch14") |
| | |
| | models['nsfw'].concept_embeds_weights *= 1.2 |
| | models['nsfw'].special_care_embeds_weights *= 1.2 |
| |
|
| | return models |
| |
|
| | @torch.no_grad() |
| | def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision='autocast', ddim_eta=1.0, ddim_steps=75, scale=3.0, h=256, w=256): |
| | precision_scope = autocast if precision == 'autocast' else nullcontext |
| | with precision_scope("cuda"): |
| | with model.ema_scope(): |
| | c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1) |
| | T = [] |
| | for x, y in zip(xs, ys): |
| | T.append([np.radians(x), np.sin(np.radians(y)), np.cos(np.radians(y)), 0]) |
| | T = torch.tensor(np.array(T))[:, None, :].float().to(c.device) |
| | c = torch.cat([c, T], dim=-1) |
| | c = model.cc_projection(c) |
| | cond = {} |
| | cond['c_crossattn'] = [c] |
| | cond['c_concat'] = [model.encode_first_stage(input_im).mode().detach() |
| | .repeat(n_samples, 1, 1, 1)] |
| | if scale != 1.0: |
| | uc = {} |
| | uc['c_concat'] = [torch.zeros(n_samples, 4, h // 8, w // 8).to(c.device)] |
| | uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)] |
| | else: |
| | uc = None |
| |
|
| | shape = [4, h // 8, w // 8] |
| | samples_ddim, _ = sampler.sample(S=ddim_steps, |
| | conditioning=cond, |
| | batch_size=n_samples, |
| | shape=shape, |
| | verbose=False, |
| | unconditional_guidance_scale=scale, |
| | unconditional_conditioning=uc, |
| | eta=ddim_eta, |
| | x_T=None) |
| | print(samples_ddim.shape) |
| | |
| | x_samples_ddim = model.decode_first_stage(samples_ddim) |
| | ret_imgs = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu() |
| | del cond, c, x_samples_ddim, samples_ddim, uc, input_im |
| | torch.cuda.empty_cache() |
| | return ret_imgs |
| |
|
| | @torch.no_grad() |
| | def predict_stage1_gradio(model, raw_im, save_path = "", adjust_set=[], device="cuda", ddim_steps=75, scale=3.0): |
| | |
| | |
| | input_im_init = np.asarray(raw_im, dtype=np.float32) / 255.0 |
| | input_im = transforms.ToTensor()(input_im_init).unsqueeze(0).to(device) |
| | input_im = input_im * 2 - 1 |
| |
|
| | |
| | delta_x_1_8 = [0] * 4 + [30] * 4 + [-30] * 4 |
| | delta_y_1_8 = [0+90*(i%4) if i < 4 else 30+90*(i%4) for i in range(8)] + [30+90*(i%4) for i in range(4)] |
| |
|
| | ret_imgs = [] |
| | sampler = DDIMSampler(model) |
| | |
| | if adjust_set != []: |
| | x_samples_ddims_8 = sample_model_batch(model, sampler, input_im, |
| | [delta_x_1_8[i] for i in adjust_set], [delta_y_1_8[i] for i in adjust_set], |
| | n_samples=len(adjust_set), ddim_steps=ddim_steps, scale=scale) |
| | else: |
| | x_samples_ddims_8 = sample_model_batch(model, sampler, input_im, delta_x_1_8, delta_y_1_8, n_samples=len(delta_x_1_8), ddim_steps=ddim_steps, scale=scale) |
| | sample_idx = 0 |
| | for stage1_idx in range(len(delta_x_1_8)): |
| | if adjust_set != [] and stage1_idx not in adjust_set: |
| | continue |
| | x_sample = 255.0 * rearrange(x_samples_ddims_8[sample_idx].numpy(), 'c h w -> h w c') |
| | out_image = Image.fromarray(x_sample.astype(np.uint8)) |
| | ret_imgs.append(out_image) |
| | if save_path: |
| | out_image.save(os.path.join(save_path, '%d.png'%(stage1_idx))) |
| | sample_idx += 1 |
| | del x_samples_ddims_8 |
| | del sampler |
| | torch.cuda.empty_cache() |
| | return ret_imgs |
| |
|
| | def infer_stage_2(model, save_path_stage1, save_path_stage2, delta_x_2, delta_y_2, indices, device, ddim_steps=75, scale=3.0): |
| | for stage1_idx in indices: |
| | |
| | |
| | |
| | stage1_image_path = os.path.join(save_path_stage1, '%d.png'%(stage1_idx)) |
| |
|
| | raw_im = Image.open(stage1_image_path) |
| | |
| | input_im_init = np.asarray(raw_im, dtype=np.float32) |
| | input_im_init[input_im_init >= 253.0] = 255.0 |
| | input_im_init = input_im_init / 255.0 |
| | input_im = transforms.ToTensor()(input_im_init).unsqueeze(0).to(device) |
| | input_im = input_im * 2 - 1 |
| | |
| | sampler = DDIMSampler(model) |
| | |
| | |
| | x_samples_ddims_stage2 = sample_model_batch(model, sampler, input_im, delta_x_2, delta_y_2, n_samples=len(delta_x_2), ddim_steps=ddim_steps, scale=scale) |
| | for stage2_idx in range(len(delta_x_2)): |
| | x_sample_stage2 = 255.0 * rearrange(x_samples_ddims_stage2[stage2_idx].numpy(), 'c h w -> h w c') |
| | Image.fromarray(x_sample_stage2.astype(np.uint8)).save(os.path.join(save_path_stage2, '%d_%d.png'%(stage1_idx, stage2_idx))) |
| | del input_im |
| | del x_samples_ddims_stage2 |
| | torch.cuda.empty_cache() |
| |
|
| | def zero123_infer(model, input_dir_path, start_idx=0, end_idx=12, indices=None, device="cuda", ddim_steps=75, scale=3.0): |
| | |
| | save_path_8 = os.path.join(input_dir_path, "stage1_8") |
| | save_path_8_2 = os.path.join(input_dir_path, "stage2_8") |
| | os.makedirs(save_path_8_2, exist_ok=True) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | delta_x_2 = [-10, 10, 0, 0] |
| | delta_y_2 = [0, 0, -10, 10] |
| | |
| | infer_stage_2(model, save_path_8, save_path_8_2, delta_x_2, delta_y_2, indices=indices if indices else list(range(start_idx,end_idx)), device=device, ddim_steps=ddim_steps, scale=scale) |
| |
|