Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Sample new images from a pre-trained DiT. | |
| """ | |
| import os | |
| import sys | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import argparse | |
| import yaml | |
| import json | |
| import numpy as np | |
| from pathlib import Path | |
| import gin | |
| import importlib | |
| import logging | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| logging.basicConfig( | |
| format="[%(asctime)s.%(msecs)03d] [%(module)s] [%(levelname)s] | %(message)s", | |
| datefmt="%H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| import torch | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| from torch.utils.data import DataLoader | |
| from core.diffusion import create_diffusion | |
| from core.models import DiT_models | |
| from core.dataset import ImageParamsDataset | |
| from core.utils.train_utils import load_model | |
| from core.utils.math_utils import unnormalize_params | |
| from scripts.prepare_data import generate | |
| def main(cfg, generator): | |
| # Setup PyTorch: | |
| torch.manual_seed(cfg["seed"]) | |
| torch.set_grad_enabled(False) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load model: | |
| latent_size = cfg["num_params"] | |
| model = DiT_models[cfg["model"]](input_size=latent_size).to(device) | |
| # load a custom DiT checkpoint from train.py: | |
| state_dict = load_model(cfg["ckpt_path"]) | |
| model.load_state_dict(state_dict) | |
| model.eval() # important! | |
| diffusion = create_diffusion(str(cfg["num_sampling_steps"])) | |
| # Load dataset | |
| dataset = ImageParamsDataset(cfg["data_root"], cfg["test_file"], cfg["params_dict_file"]) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=cfg["batch_size"], | |
| shuffle=False, | |
| num_workers=cfg["num_workers"], | |
| pin_memory=True, | |
| drop_last=False | |
| ) | |
| params_dict = json.load(open(cfg["params_dict_file"])) | |
| idx = 0 | |
| total_error = np.zeros(cfg["num_params"]) | |
| for x, img_feat, img in loader: | |
| # sample from random noise, conditioned on image features | |
| img_feat = img_feat.to(device) | |
| model_kwargs = dict(y=img_feat) | |
| z = torch.randn(cfg["batch_size"], 1, latent_size, device=device) | |
| # Sample target params: | |
| samples = diffusion.p_sample_loop( | |
| model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device | |
| ) | |
| samples = samples.reshape(cfg["batch_size"], 1, -1) | |
| samples = samples.squeeze(1).cpu().numpy() | |
| x = x.squeeze(1).cpu().numpy() | |
| img = img.cpu().numpy() | |
| if cfg["run_generate"]: | |
| # save GT & sampled params & images | |
| for x_, params, img_ in zip(x, samples, img): | |
| # generate 3D using sampled params | |
| params_original = unnormalize_params(params, params_dict) | |
| save_dir = os.path.join(cfg["save_dir"], "{:05d}".format(idx)) | |
| os.makedirs(save_dir, exist_ok=True) | |
| save_name = "sampled" | |
| asset, _ = generate(generator, params_original, seed=cfg["seed"], save_dir=save_dir, save_name=save_name, | |
| save_blend=True, save_img=True, save_gif=False, save_mesh=True, | |
| cam_dists=cfg["r_cam_dists"], cam_elevations=cfg["r_cam_elevations"], cam_azimuths=cfg["r_cam_azimuths"], zoff=cfg["r_zoff"], | |
| resolution='256x256', sample=100) | |
| np.save(os.path.join(save_dir, "params.npy"), params_original) | |
| print("Generating model using sampled parameters. Saved in {}".format(save_dir)) | |
| # also save GT image & GT params | |
| x_original = unnormalize_params(x_, params_dict) | |
| np.save(os.path.join(save_dir, "gt_params.npy"), x_original) | |
| cv2.imwrite(os.path.join(save_dir, "gt.png"), img_[:,:,::-1]) | |
| idx += 1 | |
| # calculate metrics for sampled params & GT params | |
| error = np.abs(x - samples) | |
| total_error += error | |
| # print the average error for each parameter | |
| avg_error = total_error / len(dataset) | |
| param_names = params_dict.keys() | |
| for param_name, error in zip(param_names, avg_error): | |
| print(f"{param_name}: {error:.4f}") | |
| # plot the error for each parameter | |
| fig, ax = plt.subplots() | |
| fig.set_size_inches(20, 15) | |
| ax.barh(param_names, avg_error) | |
| ax.set_xlabel("Average Error") | |
| ax.set_ylabel("Parameters") | |
| ax.set_title("Average Error for Each Parameter") | |
| plt.yticks(fontsize=10) | |
| fig.tight_layout() | |
| fig.savefig(os.path.join(cfg["save_dir"], "avg_error.png")) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, required=True) | |
| args = parser.parse_args() | |
| with open(args.config) as f: | |
| cfg = yaml.load(f, Loader=yaml.FullLoader) | |
| # load the Blender procedural generator | |
| OBJECTS_PATH = Path(cfg["generator_root"]) | |
| assert OBJECTS_PATH.exists(), OBJECTS_PATH | |
| generator = None | |
| for subdir in sorted(list(OBJECTS_PATH.iterdir())): | |
| clsname = subdir.name.split(".")[0].strip() | |
| with gin.unlock_config(): | |
| module = importlib.import_module(f"core.assets.{clsname}") | |
| if hasattr(module, cfg["generator"]): | |
| generator = getattr(module, cfg["generator"]) | |
| logger.info("Found {} in {}".format(cfg["generator"], subdir)) | |
| break | |
| logger.debug("{} not found in {}".format(cfg["generator"], subdir)) | |
| if generator is None: | |
| raise ModuleNotFoundError("{} not Found.".format(cfg["generator"])) | |
| gen = generator(cfg["seed"]) | |
| # create visualize dir | |
| os.makedirs(cfg["save_dir"], exist_ok=True) | |
| main(cfg, gen) | |