Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,014 Bytes
199f9c2 add1478 cc63be8 add1478 cc63be8 199f9c2 3729b71 199f9c2 cc63be8 d556a8c cc63be8 add1478 cc63be8 add1478 8f87fe4 cc63be8 d556a8c 199f9c2 d556a8c 199f9c2 d556a8c 199f9c2 3729b71 d556a8c 8f87fe4 cc63be8 199f9c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 |
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script to fine-tune Stable Video Diffusion."""
import math
import os
import numpy as np
import torch
import torch.utils.checkpoint
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from tqdm.auto import tqdm
from transformers import CLIPVisionModelWithProjection
from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from diffusers.utils import check_min_version
from simple_pipeline import StableVideoDiffusionPipeline
from PIL import Image
from diffusers.utils import export_to_video
import argparse
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
import numpy as np
import torch
import os
def parse_args():
parser = argparse.ArgumentParser(description="SVD Training Script")
parser.add_argument(
"--config",
type=str,
default="/datasets/sai/focal-burst-learning/svd/training/configs/outside_photos.yaml",
help="Path to the config file.",
)
#seed should be int that default 0 (optional)
parser.add_argument(
"--image_path",
type=str,
required=True,
help="Path to image input or directory containing input images",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="A seed for reproducible training.",
)
parser.add_argument(
"--learn2refocus_hf_repo_path",
type=str,
default="tedlasai/learn2refocus",
help="hf repo containing the weight files",
)
parser.add_argument(
"--pretrained_model_path",
type=str,
default="stabilityai/stable-video-diffusion-img2vid",
help="repo id or path for pretrained StableVideo Diffusion model",
)
parser.add_argument(
"--output_dir",
type=str,
default="outputs/simple_inference",
help="path to output",
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=25,
help="number of DDPM steps",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="inference device",
)
args = parser.parse_args()
return args
def find_scale(height, width):
max_pixels = 500000
# Start with no scaling
scale = 1.0
while True:
# Calculate the scaled dimensions
scaled_height = math.floor((height * scale) / 64) * 64
scaled_width = math.floor((width * scale) / 64) * 64
# Check if the scaled dimensions meet the pixel constraint
if scaled_height * scaled_width <= max_pixels:
return scaled_height, scaled_width
# Reduce the scale slightly
scale -= 0.01
def convert_to_batch(img, input_focal_position, sample_frames=9):
focal_stack_num = input_focal_position
icc_profile = img.info.get("icc_profile")
if icc_profile is None:
icc_profile = "none"
original_pixels = torch.from_numpy(np.array(img)).float().permute(2,0,1)
original_pixels = original_pixels / 255
width, height = img.size
scaled_width, scaled_height = find_scale(width, height)
img_resized = img.resize((scaled_width, scaled_height))
img_tensor = torch.from_numpy(np.array(img_resized)).float()
img_normalized = img_tensor / 127.5 - 1
img_normalized = img_normalized.permute(2, 0, 1)
pixels = torch.zeros((1, sample_frames, 3, scaled_height, scaled_width))
pixels[0, focal_stack_num] = img_normalized
return {"pixel_values": pixels, "focal_stack_num": focal_stack_num, "original_pixel_values": original_pixels, 'icc_profile': icc_profile}
def inference_on_image(args, batch, pipeline, device):
pipeline.set_progress_bar_config(disable=True)
num_frames = 9
pixel_values = batch["pixel_values"].to(device)
focal_stack_num = batch["focal_stack_num"]
svd_output, _ = pipeline(
pixel_values,
height=pixel_values.shape[3],
width=pixel_values.shape[4],
num_frames=num_frames,
decode_chunk_size=8,
motion_bucket_id=0,
min_guidance_scale=1.5,
max_guidance_scale=1.5,
fps=7,
noise_aug_strength=0,
focal_stack_num = focal_stack_num,
num_inference_steps=args.num_inference_steps,
)
video_frames = svd_output.frames[0]
video_frames_normalized = video_frames*0.5 + 0.5
video_frames_normalized = torch.clamp(video_frames_normalized,0,1)
video_frames_normalized = video_frames_normalized.permute(1,0,2,3)
video_frames_normalized = torch.nn.functional.interpolate(video_frames_normalized, ((pixel_values.shape[3]//2)*2, (pixel_values.shape[4]//2)*2), mode='bilinear')
return video_frames_normalized, focal_stack_num
# run inference
def write_output(output_dir, frames, focal_stack_num, icc_profile):
print("Validation images will be saved to ", output_dir)
os.makedirs(output_dir, exist_ok=True)
print("Frames shape: ", frames.shape)
export_to_video(frames.permute(0,2,3,1).cpu().numpy(), os.path.join(output_dir, "stack.mp4"), fps=5)
#save images
for i in range(9):
#use Pillow to save images
img = Image.fromarray((frames[i].permute(1,2,0).cpu().numpy()*255).astype(np.uint8))
if icc_profile != "none":
img.info['icc_profile'] = icc_profile
img.save(os.path.join(output_dir, f"frame_{i}.png"))
def load_model(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# inference-only modules
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
args.pretrained_model_path, subfolder="image_encoder"
)
vae = AutoencoderKLTemporalDecoder.from_pretrained(
args.pretrained_model_path, subfolder="vae", variant="fp16"
)
weight_dtype = torch.float32
image_encoder.requires_grad_(False).to(device, dtype=weight_dtype)
vae.requires_grad_(False).to(device, dtype=weight_dtype)
# ---- load UNet from checkpoint root (this reads unet/config.json + diffusion_pytorch_model.safetensors)
unet = UNetSpatioTemporalConditionModel.from_pretrained(
args.learn2refocus_hf_repo_path, subfolder="checkpoint-200000/unet"
).to(device)
unet.eval(); image_encoder.eval(); vae.eval()
pipeline = StableVideoDiffusionPipeline.from_pretrained(
args.pretrained_model_path,
unet=unet,
image_encoder=image_encoder,
vae=vae,
torch_dtype=weight_dtype,
)
return pipeline, device
def main():
args = parse_args()
if args.seed is not None:
set_seed(args.seed)
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
pipeline, device = load_model(args)
with torch.no_grad():
img = Image.open(args.image_path)
batch = convert_to_batch(img, input_focal_position=6)
output_frames, focal_stack_num = inference_on_image(args, batch, pipeline, device)
name = os.path.splitext(os.path.basename(args.image_path))[0]
val_save_dir = os.path.join(args.output_dir, "validation_images", name)
write_output(val_save_dir, output_frames, focal_stack_num, batch['icc_profile'])
if __name__ == "__main__":
main()
|