siyan824's picture
init
8bd45de
import argparse
import gradio
import os
import torch
import numpy as np
import tempfile
import functools
import subprocess
from slam3r.pipeline.recon_offline_pipeline import get_img_tokens, initialize_scene, adapt_keyframe_stride, i2p_inference_batch, l2w_inference, normalize_views, scene_frame_retrieve
from slam3r.datasets.wild_seq import Seq_Data
from slam3r.models import Local2WorldModel, Image2PointsModel
from slam3r.utils.device import to_numpy
from slam3r.utils.recon_utils import *
from scipy.spatial.transform import Rotation
import PIL
from pdb import set_trace as bb
# from dust3r
OPENGL = np.array([[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1]])
def geotrf(Trf, pts, ncol=None, norm=False):
""" Apply a geometric transformation to a list of 3-D points.
H: 3x3 or 4x4 projection matrix (typically a Homography)
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
ncol: int. number of columns of the result (2 or 3)
norm: float. if != 0, the resut is projected on the z=norm plane.
Returns an array of projected 2d points.
"""
assert Trf.ndim >= 2
if isinstance(Trf, np.ndarray):
pts = np.asarray(pts)
elif isinstance(Trf, torch.Tensor):
pts = torch.as_tensor(pts, dtype=Trf.dtype)
# adapt shape if necessary
output_reshape = pts.shape[:-1]
ncol = ncol or pts.shape[-1]
# optimized code
if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
Trf.ndim == 3 and pts.ndim == 4):
d = pts.shape[3]
if Trf.shape[-1] == d:
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
elif Trf.shape[-1] == d + 1:
pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
else:
raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
else:
if Trf.ndim >= 3:
n = Trf.ndim - 2
assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
if pts.ndim > Trf.ndim:
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
elif pts.ndim == 2:
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
pts = pts[:, None, :]
if pts.shape[-1] + 1 == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
elif pts.shape[-1] == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf
else:
pts = Trf @ pts.T
if pts.ndim >= 2:
pts = pts.swapaxes(-1, -2)
if norm:
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
if norm != 1:
pts *= norm
res = pts[..., :ncol].reshape(*output_reshape, ncol)
return res
def add_scene_cam(scene, pose_c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.11, marker=None):
if image is not None:
image = np.asarray(image)
H, W, THREE = image.shape
assert THREE == 3
if image.dtype != np.uint8:
image = np.uint8(255*image)
elif imsize is not None:
W, H = imsize
elif focal is not None:
H = W = focal / 1.1
else:
H = W = 1
if isinstance(focal, np.ndarray):
focal = focal[0]
if not focal:
focal = min(H,W) * 1.1 # default value
# create fake camera
height = max( screen_width/10, focal * screen_width / H )
width = screen_width * 0.5**0.5
rot45 = np.eye(4)
rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix()
rot45[2, 3] = -height # set the tip of the cone = optical center
aspect_ratio = np.eye(4)
aspect_ratio[0, 0] = W/H
transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45
cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform)
# this is the image
if image is not None:
vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]])
faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]])
img = trimesh.Trimesh(vertices=vertices, faces=faces)
uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]])
img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image))
scene.add_geometry(img)
# this is the camera mesh
rot2 = np.eye(4)
rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix()
vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)]
vertices = geotrf(transform, vertices)
faces = []
for face in cam.faces:
if 0 in face:
continue
a, b, c = face
a2, b2, c2 = face + len(cam.vertices)
a3, b3, c3 = face + 2*len(cam.vertices)
# add 3 pseudo-edges
faces.append((a, b, b2))
faces.append((a, a2, c))
faces.append((c2, b, c))
faces.append((a, b, b3))
faces.append((a, a3, c))
faces.append((c3, b, c))
# no culling
faces += [(c, b, a) for a, b, c in faces]
cam = trimesh.Trimesh(vertices=vertices, faces=faces)
cam.visual.face_colors[:, :3] = edge_color
scene.add_geometry(cam)
if marker == 'o':
marker = trimesh.creation.icosphere(3, radius=screen_width/4)
marker.vertices += pose_c2w[:3,3]
marker.visual.face_colors[:,:3] = edge_color
scene.add_geometry(marker)
def rgb_gradient(n):
assert n > 1
red = (255, 0, 0)
green = (0, 255, 0)
blue = (0, 0, 255)
if n == 2:
return [red, blue]
if n == 3:
return [red, green, blue]
stage1_count = (n - 1) // 2
stage2_count = n - 1 - stage1_count
gradient = []
for i in range(stage1_count + 1):
ratio = i / stage1_count
r = int(red[0] * (1 - ratio) + green[0] * ratio)
g = int(red[1] * (1 - ratio) + green[1] * ratio)
b = int(red[2] * (1 - ratio) + green[2] * ratio)
gradient.append((r, g, b))
for i in range(1, stage2_count + 1):
ratio = i / stage2_count
r = int(green[0] * (1 - ratio) + blue[0] * ratio)
g = int(green[1] * (1 - ratio) + blue[1] * ratio)
b = int(green[2] * (1 - ratio) + blue[2] * ratio)
gradient.append((r, g, b))
return gradient
def extract_frames(video_path: str, fps: float) -> str:
temp_dir = tempfile.mkdtemp()
output_path = os.path.join(temp_dir, "%03d.jpg")
command = [
"ffmpeg",
"-i", video_path,
"-vf", f"fps={fps}",
output_path
]
subprocess.run(command, check=True)
return temp_dir
def recon_scene(i2p_model:Image2PointsModel, device,
save_dir, img_dir_or_list,
conf_thres_res, num_points_save):
max_num_frames = 7 # Let's take only 7 images since the slow CPU runtime on HF
# max_num_frames = 10 # fixed for this demo
kf_stride = 1 # fixed for this demo
# np.random.seed(4)
# load the imgs or video
if isinstance(img_dir_or_list, str):
img_dir_or_list = extract_frames(img_dir_or_list, fps=5) # fps fixed for this demo
dataset = Seq_Data(img_dir_or_list, to_tensor=True)
data_views = dataset[0][:]
num_views = len(data_views)
# sample frames
assert num_views > 1, print('single image recon not supported')
if num_views > max_num_frames:
sample_indices = np.linspace(0, num_views-1, num=max_num_frames, dtype=int)
data_views = [data_views[i] for i in sample_indices]
num_views = len(data_views)
# Pre-save the RGB images along with their corresponding masks
# in preparation for visualization at last.
rgb_imgs = []
for i in range(len(data_views)):
if data_views[i]['img'].shape[0] == 1:
data_views[i]['img'] = data_views[i]['img'][0]
rgb_imgs.append(transform_img(dict(img=data_views[i]['img'][None]))[...,::-1])
# preprocess data for extracting their img tokens with encoder
for view in data_views:
view['img'] = torch.tensor(view['img'][None])
view['true_shape'] = torch.tensor(view['true_shape'][None])
for key in ['valid_mask', 'pts3d_cam', 'pts3d']:
if key in view:
del view[key]
to_device(view, device=device)
# pre-extract img tokens by encoder, which can be reused
res_shapes, res_feats, res_poses = get_img_tokens(data_views, i2p_model) # 300+fps
print('finish pre-extracting img tokens')
# re-organize input views for the following inference.
input_views = []
for i in range(num_views):
input_views.append(dict(label=data_views[i]['label'],
img_tokens=res_feats[i],
true_shape=data_views[i]['true_shape'],
img_pos=res_poses[i]))
# run slam3r i2p
initial_pcds, initial_confs, init_ref_id = initialize_scene(input_views, i2p_model, winsize=num_views, return_ref_id=True) # 5*(1,224,224,3)
print('finish I2P iterations with the best reference')
# format as l2w results
num_init = len(initial_pcds)
per_frame_res = dict(i2p_pcds=[], i2p_confs=[], l2w_pcds=[], l2w_confs=[])
for key in per_frame_res:
per_frame_res[key] = [None for _ in range(num_init)]
# registered_confs_mean = [_ for _ in range(num_init)]
# set up the world coordinates with the initial window
for i in range(num_init):
per_frame_res['l2w_confs'][i*kf_stride] = initial_confs[i][0].to(device) # 224,224
# registered_confs_mean[i*kf_stride] = per_frame_res['l2w_confs'][i*kf_stride].mean().cpu()
# set up the world coordinates with frames in the initial window
for i in range(num_init):
input_views[i*kf_stride]['pts3d_world'] = initial_pcds[i]
conf_thres_i2p = 1.5
initial_valid_masks = [conf > conf_thres_i2p for conf in initial_confs] # 1,224,224
normed_pts = normalize_views([view['pts3d_world'] for view in input_views[:num_init*kf_stride:kf_stride]],
initial_valid_masks)
for i in range(num_init):
input_views[i*kf_stride]['pts3d_world'] = normed_pts[i]
# filter out points with low confidence
input_views[i*kf_stride]['pts3d_world'][~initial_valid_masks[i]] = 0
per_frame_res['l2w_pcds'][i*kf_stride] = normed_pts[i] # 224,224,3
per_frame_res['rgb_imgs'] = rgb_imgs
# estimate camera pose
per_frame_res['cam_pose'] = []
fx = fy = 224 # fake focal length. TODO: estimate focal length
cx = cy = 112 # center of 224x224 reso
intrin = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
for i in range(num_init):
pose, _ = estimate_camera_pose(per_frame_res['l2w_pcds'][i].squeeze(), intrin)
per_frame_res['cam_pose'].append(pose)
save_path = get_model_from_scene(per_frame_res=per_frame_res,
save_dir=save_dir,
num_points_save=num_points_save,
conf_thres_res=conf_thres_res)
return save_path, per_frame_res
def get_model_from_scene(per_frame_res, save_dir,
num_points_save=200000,
conf_thres_res=3,
valid_masks=None
):
# collect the registered point clouds and rgb colors
pcds = []
rgbs = []
pred_frame_num = len(per_frame_res['l2w_pcds'])
registered_confs = per_frame_res['l2w_confs']
registered_pcds = per_frame_res['l2w_pcds']
rgb_imgs = per_frame_res['rgb_imgs']
for i in range(pred_frame_num):
registered_pcd = to_numpy(registered_pcds[i])
if registered_pcd.shape[0] == 3:
registered_pcd = registered_pcd.transpose(1,2,0)
registered_pcd = registered_pcd.reshape(-1,3)
rgb = rgb_imgs[i].reshape(-1,3)
pcds.append(registered_pcd)
rgbs.append(rgb)
res_pcds = np.concatenate(pcds, axis=0)
res_rgbs = np.concatenate(rgbs, axis=0)
pts_count = len(res_pcds)
valid_ids = np.arange(pts_count)
# filter out points with gt valid masks
if valid_masks is not None:
valid_masks = np.stack(valid_masks, axis=0).reshape(-1)
# print('filter out ratio of points by gt valid masks:', 1.-valid_masks.astype(float).mean())
else:
valid_masks = np.ones(pts_count, dtype=bool)
# filter out points with low confidence
if registered_confs is not None:
conf_masks = []
for i in range(len(registered_confs)):
conf = registered_confs[i]
conf_mask = (conf > conf_thres_res).reshape(-1).cpu()
conf_masks.append(conf_mask)
conf_masks = np.array(torch.cat(conf_masks))
valid_ids = valid_ids[conf_masks&valid_masks]
print('ratio of points filered out: {:.2f}%'.format((1.-len(valid_ids)/pts_count)*100))
# sample from the resulting pcd consisting of all frames
n_samples = min(num_points_save, len(valid_ids))
print(f"resampling {n_samples} points from {len(valid_ids)} points")
sampled_idx = np.random.choice(valid_ids, n_samples, replace=False)
sampled_pts = res_pcds[sampled_idx]
sampled_rgbs = res_rgbs[sampled_idx]
scene = trimesh.Scene()
# trimesh: scene pts
scene.add_geometry(trimesh.PointCloud(vertices=sampled_pts, colors=sampled_rgbs/255.))
# trimesh: cam poses
poses = per_frame_res['cam_pose']
colors = rgb_gradient(len(poses))
for i, pose_c2w in enumerate(poses):
add_scene_cam(scene, pose_c2w, edge_color=colors[i], image=255-rgb_imgs[i])
# trimesh: viewpoint for render
rot = np.eye(4)
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
scene.apply_transform(np.linalg.inv(poses[0] @ OPENGL @ rot))
# trimesh: save to file
save_name = "recon.glb"
save_path = join(save_dir, save_name)
scene.export(save_path)
return save_path
def change_inputfile_type(input_type):
if input_type == "2-10 images":
inputfiles = gradio.File(file_count="multiple", file_types=["image"],
scale=1,
label="Click to upload 2-10 images")
elif input_type == "A short video":
inputfiles = gradio.File(file_count="single", file_types=["video"],
scale=1,
label="Click to upload a short video")
return inputfiles
def main_demo(i2p_model, device, tmpdirname, server_name, server_port):
recon_scene_func = functools.partial(recon_scene, i2p_model, device)
with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="SLAM3R I2P") as demo:
# scene state is save so that you can change num_points_save... without rerunning the inference
per_frame_res = gradio.State(None)
tmpdir_name = gradio.State(tmpdirname)
gradio.HTML('''
<h1 style="text-align: center;">SLAM3R Image-to-Points and camera pose estimation (CPU demo)</h1>
<p style="text-align: center;">
<a href="https://github.com/PKU-VCL-3DV/SLAM3R">Code</a> |
<a href="https://openaccess.thecvf.com/content/CVPR2025/html/Liu_SLAM3R_Real-Time_Dense_Scene_Reconstruction_from_Monocular_RGB_Videos_CVPR_2025_paper.html">Paper</a>
</p>
<p>
Upload 2–10 images or a short video of a static scene from different viewpoints. SLAM3R’s Image-to-Points module reconstructs scene geometry and can estimate camera poses.
</p>
''')
with gradio.Column():
with gradio.Row():
input_type = gradio.Dropdown(["A short video", "2-10 images"],
scale=1,
value='2-10 images',
label="Select type of input files")
inputfiles = gradio.File(file_count="multiple", file_types=["image"],
scale=1,
label="Click to upload 2-10 images")
run_btn = gradio.Button("Run")
with gradio.Row():
conf_thres_res = gradio.Slider(value=4, minimum=1., maximum=10,
# visible=False,
interactive=True,
label="Confidence threshold for the result")
num_points_save = gradio.Number(value=1000000, precision=0, minimum=1,
# visible=False,
interactive=True,
label="Number of points sampled from the result")
outmodel = gradio.Model3D(camera_position=(-90, 69, 2.6),
height=500,
clear_color=(0.,0.,0.,0.3))
# events
input_type.change(change_inputfile_type,
inputs=[input_type],
outputs=[inputfiles])
run_btn.click(fn=recon_scene_func,
inputs=[tmpdir_name, inputfiles,
conf_thres_res, num_points_save],
outputs=[outmodel, per_frame_res])
conf_thres_res.release(fn=get_model_from_scene,
inputs=[per_frame_res, tmpdir_name, num_points_save, conf_thres_res],
outputs=outmodel)
num_points_save.change(fn=get_model_from_scene,
inputs=[per_frame_res, tmpdir_name, num_points_save, conf_thres_res],
outputs=outmodel)
demo.launch(share=False, server_name=server_name, server_port=server_port)
def run_i2p(parser: argparse.ArgumentParser):
args = parser.parse_args()
if args.tmp_dir is not None:
tmp_path = args.tmp_dir
os.makedirs(tmp_path, exist_ok=True)
tempfile.tempdir = tmp_path
server_name = '0.0.0.0' # '127.0.0.1'
server_port = 7860
i2p_model = Image2PointsModel.from_pretrained('siyan824/slam3r_i2p')
i2p_model.to(args.device)
i2p_model.eval()
# this demo will write the 3D model inside tmpdirname
with tempfile.TemporaryDirectory(suffix='slam3r_i2p_gradio_demo') as tmpdirname:
main_demo(i2p_model, args.device, tmpdirname, server_name, server_port)
if __name__ == "__main__":
run_i2p(argparse.ArgumentParser())