Spaces:
Running
Running
| import argparse | |
| import gradio | |
| import os | |
| import torch | |
| import numpy as np | |
| import tempfile | |
| import functools | |
| import subprocess | |
| from slam3r.pipeline.recon_offline_pipeline import get_img_tokens, initialize_scene, adapt_keyframe_stride, i2p_inference_batch, l2w_inference, normalize_views, scene_frame_retrieve | |
| from slam3r.datasets.wild_seq import Seq_Data | |
| from slam3r.models import Local2WorldModel, Image2PointsModel | |
| from slam3r.utils.device import to_numpy | |
| from slam3r.utils.recon_utils import * | |
| from scipy.spatial.transform import Rotation | |
| import PIL | |
| from pdb import set_trace as bb | |
| # from dust3r | |
| OPENGL = np.array([[1, 0, 0, 0], | |
| [0, -1, 0, 0], | |
| [0, 0, -1, 0], | |
| [0, 0, 0, 1]]) | |
| def geotrf(Trf, pts, ncol=None, norm=False): | |
| """ Apply a geometric transformation to a list of 3-D points. | |
| H: 3x3 or 4x4 projection matrix (typically a Homography) | |
| p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) | |
| ncol: int. number of columns of the result (2 or 3) | |
| norm: float. if != 0, the resut is projected on the z=norm plane. | |
| Returns an array of projected 2d points. | |
| """ | |
| assert Trf.ndim >= 2 | |
| if isinstance(Trf, np.ndarray): | |
| pts = np.asarray(pts) | |
| elif isinstance(Trf, torch.Tensor): | |
| pts = torch.as_tensor(pts, dtype=Trf.dtype) | |
| # adapt shape if necessary | |
| output_reshape = pts.shape[:-1] | |
| ncol = ncol or pts.shape[-1] | |
| # optimized code | |
| if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and | |
| Trf.ndim == 3 and pts.ndim == 4): | |
| d = pts.shape[3] | |
| if Trf.shape[-1] == d: | |
| pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) | |
| elif Trf.shape[-1] == d + 1: | |
| pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] | |
| else: | |
| raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') | |
| else: | |
| if Trf.ndim >= 3: | |
| n = Trf.ndim - 2 | |
| assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' | |
| Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) | |
| if pts.ndim > Trf.ndim: | |
| # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) | |
| pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) | |
| elif pts.ndim == 2: | |
| # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) | |
| pts = pts[:, None, :] | |
| if pts.shape[-1] + 1 == Trf.shape[-1]: | |
| Trf = Trf.swapaxes(-1, -2) # transpose Trf | |
| pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] | |
| elif pts.shape[-1] == Trf.shape[-1]: | |
| Trf = Trf.swapaxes(-1, -2) # transpose Trf | |
| pts = pts @ Trf | |
| else: | |
| pts = Trf @ pts.T | |
| if pts.ndim >= 2: | |
| pts = pts.swapaxes(-1, -2) | |
| if norm: | |
| pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG | |
| if norm != 1: | |
| pts *= norm | |
| res = pts[..., :ncol].reshape(*output_reshape, ncol) | |
| return res | |
| def add_scene_cam(scene, pose_c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.11, marker=None): | |
| if image is not None: | |
| image = np.asarray(image) | |
| H, W, THREE = image.shape | |
| assert THREE == 3 | |
| if image.dtype != np.uint8: | |
| image = np.uint8(255*image) | |
| elif imsize is not None: | |
| W, H = imsize | |
| elif focal is not None: | |
| H = W = focal / 1.1 | |
| else: | |
| H = W = 1 | |
| if isinstance(focal, np.ndarray): | |
| focal = focal[0] | |
| if not focal: | |
| focal = min(H,W) * 1.1 # default value | |
| # create fake camera | |
| height = max( screen_width/10, focal * screen_width / H ) | |
| width = screen_width * 0.5**0.5 | |
| rot45 = np.eye(4) | |
| rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix() | |
| rot45[2, 3] = -height # set the tip of the cone = optical center | |
| aspect_ratio = np.eye(4) | |
| aspect_ratio[0, 0] = W/H | |
| transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45 | |
| cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform) | |
| # this is the image | |
| if image is not None: | |
| vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]]) | |
| faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]]) | |
| img = trimesh.Trimesh(vertices=vertices, faces=faces) | |
| uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]]) | |
| img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image)) | |
| scene.add_geometry(img) | |
| # this is the camera mesh | |
| rot2 = np.eye(4) | |
| rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix() | |
| vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)] | |
| vertices = geotrf(transform, vertices) | |
| faces = [] | |
| for face in cam.faces: | |
| if 0 in face: | |
| continue | |
| a, b, c = face | |
| a2, b2, c2 = face + len(cam.vertices) | |
| a3, b3, c3 = face + 2*len(cam.vertices) | |
| # add 3 pseudo-edges | |
| faces.append((a, b, b2)) | |
| faces.append((a, a2, c)) | |
| faces.append((c2, b, c)) | |
| faces.append((a, b, b3)) | |
| faces.append((a, a3, c)) | |
| faces.append((c3, b, c)) | |
| # no culling | |
| faces += [(c, b, a) for a, b, c in faces] | |
| cam = trimesh.Trimesh(vertices=vertices, faces=faces) | |
| cam.visual.face_colors[:, :3] = edge_color | |
| scene.add_geometry(cam) | |
| if marker == 'o': | |
| marker = trimesh.creation.icosphere(3, radius=screen_width/4) | |
| marker.vertices += pose_c2w[:3,3] | |
| marker.visual.face_colors[:,:3] = edge_color | |
| scene.add_geometry(marker) | |
| def rgb_gradient(n): | |
| assert n > 1 | |
| red = (255, 0, 0) | |
| green = (0, 255, 0) | |
| blue = (0, 0, 255) | |
| if n == 2: | |
| return [red, blue] | |
| if n == 3: | |
| return [red, green, blue] | |
| stage1_count = (n - 1) // 2 | |
| stage2_count = n - 1 - stage1_count | |
| gradient = [] | |
| for i in range(stage1_count + 1): | |
| ratio = i / stage1_count | |
| r = int(red[0] * (1 - ratio) + green[0] * ratio) | |
| g = int(red[1] * (1 - ratio) + green[1] * ratio) | |
| b = int(red[2] * (1 - ratio) + green[2] * ratio) | |
| gradient.append((r, g, b)) | |
| for i in range(1, stage2_count + 1): | |
| ratio = i / stage2_count | |
| r = int(green[0] * (1 - ratio) + blue[0] * ratio) | |
| g = int(green[1] * (1 - ratio) + blue[1] * ratio) | |
| b = int(green[2] * (1 - ratio) + blue[2] * ratio) | |
| gradient.append((r, g, b)) | |
| return gradient | |
| def extract_frames(video_path: str, fps: float) -> str: | |
| temp_dir = tempfile.mkdtemp() | |
| output_path = os.path.join(temp_dir, "%03d.jpg") | |
| command = [ | |
| "ffmpeg", | |
| "-i", video_path, | |
| "-vf", f"fps={fps}", | |
| output_path | |
| ] | |
| subprocess.run(command, check=True) | |
| return temp_dir | |
| def recon_scene(i2p_model:Image2PointsModel, device, | |
| save_dir, img_dir_or_list, | |
| conf_thres_res, num_points_save): | |
| max_num_frames = 7 # Let's take only 7 images since the slow CPU runtime on HF | |
| # max_num_frames = 10 # fixed for this demo | |
| kf_stride = 1 # fixed for this demo | |
| # np.random.seed(4) | |
| # load the imgs or video | |
| if isinstance(img_dir_or_list, str): | |
| img_dir_or_list = extract_frames(img_dir_or_list, fps=5) # fps fixed for this demo | |
| dataset = Seq_Data(img_dir_or_list, to_tensor=True) | |
| data_views = dataset[0][:] | |
| num_views = len(data_views) | |
| # sample frames | |
| assert num_views > 1, print('single image recon not supported') | |
| if num_views > max_num_frames: | |
| sample_indices = np.linspace(0, num_views-1, num=max_num_frames, dtype=int) | |
| data_views = [data_views[i] for i in sample_indices] | |
| num_views = len(data_views) | |
| # Pre-save the RGB images along with their corresponding masks | |
| # in preparation for visualization at last. | |
| rgb_imgs = [] | |
| for i in range(len(data_views)): | |
| if data_views[i]['img'].shape[0] == 1: | |
| data_views[i]['img'] = data_views[i]['img'][0] | |
| rgb_imgs.append(transform_img(dict(img=data_views[i]['img'][None]))[...,::-1]) | |
| # preprocess data for extracting their img tokens with encoder | |
| for view in data_views: | |
| view['img'] = torch.tensor(view['img'][None]) | |
| view['true_shape'] = torch.tensor(view['true_shape'][None]) | |
| for key in ['valid_mask', 'pts3d_cam', 'pts3d']: | |
| if key in view: | |
| del view[key] | |
| to_device(view, device=device) | |
| # pre-extract img tokens by encoder, which can be reused | |
| res_shapes, res_feats, res_poses = get_img_tokens(data_views, i2p_model) # 300+fps | |
| print('finish pre-extracting img tokens') | |
| # re-organize input views for the following inference. | |
| input_views = [] | |
| for i in range(num_views): | |
| input_views.append(dict(label=data_views[i]['label'], | |
| img_tokens=res_feats[i], | |
| true_shape=data_views[i]['true_shape'], | |
| img_pos=res_poses[i])) | |
| # run slam3r i2p | |
| initial_pcds, initial_confs, init_ref_id = initialize_scene(input_views, i2p_model, winsize=num_views, return_ref_id=True) # 5*(1,224,224,3) | |
| print('finish I2P iterations with the best reference') | |
| # format as l2w results | |
| num_init = len(initial_pcds) | |
| per_frame_res = dict(i2p_pcds=[], i2p_confs=[], l2w_pcds=[], l2w_confs=[]) | |
| for key in per_frame_res: | |
| per_frame_res[key] = [None for _ in range(num_init)] | |
| # registered_confs_mean = [_ for _ in range(num_init)] | |
| # set up the world coordinates with the initial window | |
| for i in range(num_init): | |
| per_frame_res['l2w_confs'][i*kf_stride] = initial_confs[i][0].to(device) # 224,224 | |
| # registered_confs_mean[i*kf_stride] = per_frame_res['l2w_confs'][i*kf_stride].mean().cpu() | |
| # set up the world coordinates with frames in the initial window | |
| for i in range(num_init): | |
| input_views[i*kf_stride]['pts3d_world'] = initial_pcds[i] | |
| conf_thres_i2p = 1.5 | |
| initial_valid_masks = [conf > conf_thres_i2p for conf in initial_confs] # 1,224,224 | |
| normed_pts = normalize_views([view['pts3d_world'] for view in input_views[:num_init*kf_stride:kf_stride]], | |
| initial_valid_masks) | |
| for i in range(num_init): | |
| input_views[i*kf_stride]['pts3d_world'] = normed_pts[i] | |
| # filter out points with low confidence | |
| input_views[i*kf_stride]['pts3d_world'][~initial_valid_masks[i]] = 0 | |
| per_frame_res['l2w_pcds'][i*kf_stride] = normed_pts[i] # 224,224,3 | |
| per_frame_res['rgb_imgs'] = rgb_imgs | |
| # estimate camera pose | |
| per_frame_res['cam_pose'] = [] | |
| fx = fy = 224 # fake focal length. TODO: estimate focal length | |
| cx = cy = 112 # center of 224x224 reso | |
| intrin = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) | |
| for i in range(num_init): | |
| pose, _ = estimate_camera_pose(per_frame_res['l2w_pcds'][i].squeeze(), intrin) | |
| per_frame_res['cam_pose'].append(pose) | |
| save_path = get_model_from_scene(per_frame_res=per_frame_res, | |
| save_dir=save_dir, | |
| num_points_save=num_points_save, | |
| conf_thres_res=conf_thres_res) | |
| return save_path, per_frame_res | |
| def get_model_from_scene(per_frame_res, save_dir, | |
| num_points_save=200000, | |
| conf_thres_res=3, | |
| valid_masks=None | |
| ): | |
| # collect the registered point clouds and rgb colors | |
| pcds = [] | |
| rgbs = [] | |
| pred_frame_num = len(per_frame_res['l2w_pcds']) | |
| registered_confs = per_frame_res['l2w_confs'] | |
| registered_pcds = per_frame_res['l2w_pcds'] | |
| rgb_imgs = per_frame_res['rgb_imgs'] | |
| for i in range(pred_frame_num): | |
| registered_pcd = to_numpy(registered_pcds[i]) | |
| if registered_pcd.shape[0] == 3: | |
| registered_pcd = registered_pcd.transpose(1,2,0) | |
| registered_pcd = registered_pcd.reshape(-1,3) | |
| rgb = rgb_imgs[i].reshape(-1,3) | |
| pcds.append(registered_pcd) | |
| rgbs.append(rgb) | |
| res_pcds = np.concatenate(pcds, axis=0) | |
| res_rgbs = np.concatenate(rgbs, axis=0) | |
| pts_count = len(res_pcds) | |
| valid_ids = np.arange(pts_count) | |
| # filter out points with gt valid masks | |
| if valid_masks is not None: | |
| valid_masks = np.stack(valid_masks, axis=0).reshape(-1) | |
| # print('filter out ratio of points by gt valid masks:', 1.-valid_masks.astype(float).mean()) | |
| else: | |
| valid_masks = np.ones(pts_count, dtype=bool) | |
| # filter out points with low confidence | |
| if registered_confs is not None: | |
| conf_masks = [] | |
| for i in range(len(registered_confs)): | |
| conf = registered_confs[i] | |
| conf_mask = (conf > conf_thres_res).reshape(-1).cpu() | |
| conf_masks.append(conf_mask) | |
| conf_masks = np.array(torch.cat(conf_masks)) | |
| valid_ids = valid_ids[conf_masks&valid_masks] | |
| print('ratio of points filered out: {:.2f}%'.format((1.-len(valid_ids)/pts_count)*100)) | |
| # sample from the resulting pcd consisting of all frames | |
| n_samples = min(num_points_save, len(valid_ids)) | |
| print(f"resampling {n_samples} points from {len(valid_ids)} points") | |
| sampled_idx = np.random.choice(valid_ids, n_samples, replace=False) | |
| sampled_pts = res_pcds[sampled_idx] | |
| sampled_rgbs = res_rgbs[sampled_idx] | |
| scene = trimesh.Scene() | |
| # trimesh: scene pts | |
| scene.add_geometry(trimesh.PointCloud(vertices=sampled_pts, colors=sampled_rgbs/255.)) | |
| # trimesh: cam poses | |
| poses = per_frame_res['cam_pose'] | |
| colors = rgb_gradient(len(poses)) | |
| for i, pose_c2w in enumerate(poses): | |
| add_scene_cam(scene, pose_c2w, edge_color=colors[i], image=255-rgb_imgs[i]) | |
| # trimesh: viewpoint for render | |
| rot = np.eye(4) | |
| rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() | |
| scene.apply_transform(np.linalg.inv(poses[0] @ OPENGL @ rot)) | |
| # trimesh: save to file | |
| save_name = "recon.glb" | |
| save_path = join(save_dir, save_name) | |
| scene.export(save_path) | |
| return save_path | |
| def change_inputfile_type(input_type): | |
| if input_type == "2-10 images": | |
| inputfiles = gradio.File(file_count="multiple", file_types=["image"], | |
| scale=1, | |
| label="Click to upload 2-10 images") | |
| elif input_type == "A short video": | |
| inputfiles = gradio.File(file_count="single", file_types=["video"], | |
| scale=1, | |
| label="Click to upload a short video") | |
| return inputfiles | |
| def main_demo(i2p_model, device, tmpdirname, server_name, server_port): | |
| recon_scene_func = functools.partial(recon_scene, i2p_model, device) | |
| with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="SLAM3R I2P") as demo: | |
| # scene state is save so that you can change num_points_save... without rerunning the inference | |
| per_frame_res = gradio.State(None) | |
| tmpdir_name = gradio.State(tmpdirname) | |
| gradio.HTML(''' | |
| <h1 style="text-align: center;">SLAM3R Image-to-Points and camera pose estimation (CPU demo)</h1> | |
| <p style="text-align: center;"> | |
| <a href="https://github.com/PKU-VCL-3DV/SLAM3R">Code</a> | | |
| <a href="https://openaccess.thecvf.com/content/CVPR2025/html/Liu_SLAM3R_Real-Time_Dense_Scene_Reconstruction_from_Monocular_RGB_Videos_CVPR_2025_paper.html">Paper</a> | |
| </p> | |
| <p> | |
| Upload 2–10 images or a short video of a static scene from different viewpoints. SLAM3R’s Image-to-Points module reconstructs scene geometry and can estimate camera poses. | |
| </p> | |
| ''') | |
| with gradio.Column(): | |
| with gradio.Row(): | |
| input_type = gradio.Dropdown(["A short video", "2-10 images"], | |
| scale=1, | |
| value='2-10 images', | |
| label="Select type of input files") | |
| inputfiles = gradio.File(file_count="multiple", file_types=["image"], | |
| scale=1, | |
| label="Click to upload 2-10 images") | |
| run_btn = gradio.Button("Run") | |
| with gradio.Row(): | |
| conf_thres_res = gradio.Slider(value=4, minimum=1., maximum=10, | |
| # visible=False, | |
| interactive=True, | |
| label="Confidence threshold for the result") | |
| num_points_save = gradio.Number(value=1000000, precision=0, minimum=1, | |
| # visible=False, | |
| interactive=True, | |
| label="Number of points sampled from the result") | |
| outmodel = gradio.Model3D(camera_position=(-90, 69, 2.6), | |
| height=500, | |
| clear_color=(0.,0.,0.,0.3)) | |
| # events | |
| input_type.change(change_inputfile_type, | |
| inputs=[input_type], | |
| outputs=[inputfiles]) | |
| run_btn.click(fn=recon_scene_func, | |
| inputs=[tmpdir_name, inputfiles, | |
| conf_thres_res, num_points_save], | |
| outputs=[outmodel, per_frame_res]) | |
| conf_thres_res.release(fn=get_model_from_scene, | |
| inputs=[per_frame_res, tmpdir_name, num_points_save, conf_thres_res], | |
| outputs=outmodel) | |
| num_points_save.change(fn=get_model_from_scene, | |
| inputs=[per_frame_res, tmpdir_name, num_points_save, conf_thres_res], | |
| outputs=outmodel) | |
| demo.launch(share=False, server_name=server_name, server_port=server_port) | |
| def run_i2p(parser: argparse.ArgumentParser): | |
| args = parser.parse_args() | |
| if args.tmp_dir is not None: | |
| tmp_path = args.tmp_dir | |
| os.makedirs(tmp_path, exist_ok=True) | |
| tempfile.tempdir = tmp_path | |
| server_name = '0.0.0.0' # '127.0.0.1' | |
| server_port = 7860 | |
| i2p_model = Image2PointsModel.from_pretrained('siyan824/slam3r_i2p') | |
| i2p_model.to(args.device) | |
| i2p_model.eval() | |
| # this demo will write the 3D model inside tmpdirname | |
| with tempfile.TemporaryDirectory(suffix='slam3r_i2p_gradio_demo') as tmpdirname: | |
| main_demo(i2p_model, args.device, tmpdirname, server_name, server_port) | |
| if __name__ == "__main__": | |
| run_i2p(argparse.ArgumentParser()) | |