Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from datetime import datetime | |
| import os | |
| import time | |
| import torch | |
| from typing import Dict, List, Optional, Union | |
| from omegaconf import OmegaConf, DictConfig | |
| import hydra | |
| from hydra.utils import instantiate, get_original_cwd | |
| import time | |
| from functools import partial | |
| import matplotlib.pyplot as plt | |
| import shutil | |
| from util.utils import seed_all_random_engines | |
| from util.load_img_folder import load_and_preprocess_images | |
| from util.geometry_guided_sampling import geometry_guided_sampling | |
| from pytorch3d.vis.plotly_vis import get_camera_wireframe | |
| import subprocess | |
| import tempfile | |
| import gradio as gr | |
| def plot_cameras(ax, cameras, color: str = "blue"): | |
| """ | |
| Plots a set of `cameras` objects into the maplotlib axis `ax` with | |
| color `color`. | |
| """ | |
| cam_wires_canonical = get_camera_wireframe()[None] | |
| cam_trans = cameras.get_world_to_view_transform().inverse() | |
| cam_wires_trans = cam_trans.transform_points(cam_wires_canonical) | |
| plot_handles = [] | |
| for wire in cam_wires_trans: | |
| # the Z and Y axes are flipped intentionally here! | |
| x_, z_, y_ = wire.detach().numpy().T.astype(float) | |
| (h,) = ax.plot(x_, y_, z_, color=color, linewidth=0.3) | |
| plot_handles.append(h) | |
| return plot_handles | |
| def create_matplotlib_figure(pred_cameras): | |
| fig = plt.figure() | |
| ax = fig.add_subplot(projection="3d") | |
| ax.clear() | |
| handle_cam = plot_cameras(ax, pred_cameras, color="#FF7D1E") | |
| plot_radius = 3 | |
| ax.set_xlim3d([-plot_radius, plot_radius]) | |
| ax.set_ylim3d([3 - plot_radius, 3 + plot_radius]) | |
| ax.set_zlim3d([-plot_radius, plot_radius]) | |
| ax.set_xlabel("x") | |
| ax.set_ylabel("z") | |
| ax.set_zlabel("y") | |
| labels_handles = { | |
| "Estimated cameras": handle_cam[0], | |
| } | |
| ax.legend( | |
| labels_handles.values(), | |
| labels_handles.keys(), | |
| loc="upper center", | |
| bbox_to_anchor=(0.5, 0), | |
| ) | |
| return plt | |
| import os | |
| import json | |
| import tempfile | |
| from PIL import Image | |
| def convert_extrinsics_pytorch3d_to_opengl(extrinsics: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Convert extrinsics from PyTorch3D coordinate system to OpenGL coordinate system. | |
| Args: | |
| extrinsics (torch.Tensor): a 4x4 extrinsic matrix in PyTorch3D coordinate system. | |
| Returns: | |
| torch.Tensor: a 4x4 extrinsic matrix in OpenGL coordinate system. | |
| """ | |
| # Create a transformation matrix that flips the Z-axis | |
| flip_z = torch.eye(4) | |
| flip_z[2, 2] = -1 | |
| flip_z[0, 0] = -1 | |
| # Multiply the extrinsic matrix by the transformation matrix | |
| extrinsics_opengl = torch.mm(extrinsics, flip_z) | |
| return extrinsics_opengl | |
| import json | |
| from typing import List, Dict, Any | |
| def create_camera_json(extrinsics: Any, focal_length_world: float, principle_points: List[float], image_size: int) -> str: | |
| # Initialize the dictionary | |
| camera_dict = { | |
| "w": image_size, | |
| "h": image_size, | |
| "fl_x": float(focal_length_world[0]), | |
| "fl_y": float(focal_length_world[1]), | |
| "cx": float(principle_points[0]), | |
| "cy": float(principle_points[1]), | |
| "k1": 0.0, # Assuming these values are not provided | |
| "k2": 0.0, # Assuming these values are not provided | |
| "p1": 0.0, # Assuming these values are not provided | |
| "p2": 0.0, # Assuming these values are not provided | |
| "camera_model": "OPENCV", | |
| "frames": [] | |
| } | |
| # Add frames to the dictionary | |
| for i, extrinsic in enumerate(extrinsics): | |
| frame = { | |
| "file_path": f"images/frame_{str(i).zfill(5)}.jpg", | |
| "transform_matrix": extrinsic.tolist(), | |
| "colmap_im_id": i | |
| } | |
| # Convert numpy float32 to Python's native float | |
| frame["transform_matrix"] = [[float(element) for element in row] for row in frame["transform_matrix"]] | |
| camera_dict["frames"].append(frame) | |
| return camera_dict | |
| def archieve_images_and_transforms(images, pred_cameras, image_size): | |
| images_array = images.permute(0, 2, 3, 1).cpu().numpy() * 255 | |
| images_pil = [Image.fromarray(image.astype('uint8')) for image in images_array] | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| images_dir = os.path.join(temp_dir, 'images') | |
| os.makedirs(images_dir, exist_ok=True) | |
| images_path = [] | |
| for i, image in enumerate(images_pil): | |
| image_path = os.path.join(images_dir, 'frame_{:05d}.jpg'.format(i)) | |
| image.save(image_path) | |
| images_path.append(image_path) | |
| cam_trans = pred_cameras.get_world_to_view_transform() | |
| extrinsics = cam_trans.inverse().get_matrix().cpu() | |
| extrinsics = [convert_extrinsics_pytorch3d_to_opengl(extrinsic.T) for extrinsic in extrinsics] | |
| focal_length_ndc = pred_cameras.focal_length.mean(dim=0).cpu().numpy() | |
| focal_length_world = focal_length_ndc * image_size / 2 | |
| principle_points = [image_size / 2, image_size / 2] | |
| camera_dict = create_camera_json(extrinsics, focal_length_world, principle_points, image_size) | |
| json_path = os.path.join(temp_dir, 'transforms.json') | |
| with open(json_path, 'w') as f: | |
| json.dump(camera_dict, f, indent=4) | |
| project_name = datetime.now().strftime("%Y%m%d-%H%M%S") | |
| shutil.make_archive(f'/tmp/{project_name}', 'zip', temp_dir) | |
| return f'/tmp/{project_name}.zip' | |
| def estimate_images_pose(image_folder, mode) -> None: | |
| print("Slected mode:", mode) | |
| with hydra.initialize(config_path="./cfgs/"): | |
| cfg = hydra.compose(config_name=mode) | |
| OmegaConf.set_struct(cfg, False) | |
| print("Model Config:") | |
| print(OmegaConf.to_yaml(cfg)) | |
| # Check for GPU availability and set the device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Instantiate the model | |
| model = instantiate(cfg.MODEL, _recursive_=False) | |
| # Load and preprocess images | |
| images, image_info = load_and_preprocess_images(image_folder, cfg.image_size) | |
| # Load checkpoint | |
| ckpt_path = os.path.join(cfg.ckpt) | |
| if os.path.isfile(ckpt_path): | |
| checkpoint = torch.load(ckpt_path, map_location=device) | |
| model.load_state_dict(checkpoint, strict=True) | |
| print(f"Loaded checkpoint from: {ckpt_path}") | |
| else: | |
| raise ValueError(f"No checkpoint found at: {ckpt_path}") | |
| # Move model and images to the GPU | |
| model = model.to(device) | |
| images = images.to(device) | |
| # Evaluation Mode | |
| model.eval() | |
| # Seed random engines | |
| seed_all_random_engines(cfg.seed) | |
| # Start the timer | |
| start_time = time.time() | |
| # Perform match extraction | |
| cond_fn = None | |
| print("[92m=====> Sampling without GGS <=====[0m") | |
| # Forward | |
| with torch.no_grad(): | |
| # Obtain predicted camera parameters | |
| # pred_cameras is a PerspectiveCameras object with attributes | |
| # pred_cameras.R, pred_cameras.T, pred_cameras.focal_length | |
| # The poses and focal length are defined as | |
| # NDC coordinate system in | |
| # https://github.com/facebookresearch/pytorch3d/blob/main/docs/notes/cameras.md | |
| pred_cameras = model( | |
| image=images, cond_fn=cond_fn, cond_start_step=cfg.GGS.start_step | |
| ) | |
| # Stop the timer and calculate elapsed time | |
| end_time = time.time() | |
| elapsed_time = end_time - start_time | |
| print("Time taken: {:.4f} seconds".format(elapsed_time)) | |
| zip_path = archieve_images_and_transforms(images, pred_cameras, cfg.image_size) | |
| return create_matplotlib_figure(pred_cameras), zip_path | |
| def extract_frames_from_video(video_path: str) -> str: | |
| """ | |
| Extracts frames from a video file and saves them in a temporary directory. | |
| Returns the path to the directory containing the frames. | |
| """ | |
| temp_dir = tempfile.mkdtemp() | |
| output_path = os.path.join(temp_dir, "%03d.jpg") | |
| command = [ | |
| "ffmpeg", | |
| "-i", video_path, | |
| "-vf", "fps=1", | |
| output_path | |
| ] | |
| subprocess.run(command, check=True) | |
| return temp_dir | |
| def estimate_video_pose(video_path: str, mode: str) -> plt.Figure: | |
| """ | |
| Estimates the pose of objects in a video. | |
| """ | |
| # Extract frames from the video | |
| image_folder = extract_frames_from_video(video_path) | |
| # Estimate the pose for each frame | |
| fig = estimate_images_pose(image_folder, mode) | |
| return fig | |
| if __name__ == "__main__": | |
| examples = [["examples/" + img, 'fast'] for img in os.listdir("examples/")] | |
| # Create a Gradio interface | |
| iface = gr.Interface( | |
| fn=estimate_video_pose, | |
| inputs=[gr.inputs.Video(label='video', type='mp4'), | |
| gr.inputs.Radio(choices=['fast', 'precise'], default='fast', | |
| label='Estimation Model, fast is quick, usually within 1 seconds; precise has higher accuracy, but usually take several minutes')], | |
| outputs=['plot', 'file'], | |
| title="PoseDiffusion Demo: Solving Pose Estimation via Diffusion-aided Bundle Adjustment", | |
| description="Upload a video for object pose estimation. The object should be centrally located within the frame.", | |
| examples=examples, | |
| cache_examples=True | |
| ) | |
| iface.launch() |