import torch import cv2 import numpy as np from os.path import join from tqdm import tqdm import matplotlib.pyplot as plt import trimesh from slam3r.utils.device import to_numpy, collate_with_cat, to_cpu from slam3r.inference import loss_of_one_batch_multiview, \ inv, get_multiview_scale from slam3r.utils.geometry import xy_grid try: import poselib # noqa HAS_POSELIB = True except Exception as e: HAS_POSELIB = False def save_traj(views, pred_frame_num, save_dir, scene_id, args, intrinsics = None, traj_name = 'traj'): save_name = f"{scene_id}_{traj_name}.txt" c2ws = [] H, W, _ = views[0]['pts3d_world'][0].shape for i in tqdm(range(pred_frame_num)): pts = to_numpy(views[i]['pts3d_world'][0]) u, v = np.meshgrid(np.arange(W), np.arange(H)) points_2d = np.stack((u, v), axis=-1) dist_coeffs = np.zeros(4).astype(np.float32) success, rotation_vector, translation_vector, inliers = cv2.solvePnPRansac( pts.reshape(-1, 3).astype(np.float32), points_2d.reshape(-1, 2).astype(np.float32), intrinsics[i].astype(np.float32), dist_coeffs) rotation_matrix, _ = cv2.Rodrigues(rotation_vector) # Extrinsic parameters (4x4 matrix) extrinsic_matrix = np.hstack((rotation_matrix, translation_vector.reshape(-1, 1))) extrinsic_matrix = np.vstack((extrinsic_matrix, [0, 0, 0, 1])) c2w = inv(extrinsic_matrix) c2ws.append(c2w) c2ws = np.stack(c2ws, axis=0) translations = c2ws[:,:3,3] # draw the trajectory in horizontal plane fig = plt.figure() ax = fig.add_subplot(111) plot_traj(ax, [i for i in range(len(translations))], translations, '-', "black", "estimate trajectory") ax.set_xlabel('x [m]') ax.set_ylabel('y [m]') plt.savefig(join(save_dir, save_name.replace('.txt', '.png')), dpi=90) np.savetxt(join(save_dir, save_name), c2ws.reshape(-1,16)) def plot_traj(ax, stamps, traj, style, color, label): """ Plot a trajectory using matplotlib. Input: ax -- the plot stamps -- time stamps (1xn) traj -- trajectory (3xn) style -- line style color -- line color label -- plot legend """ stamps.sort() interval = np.median([s-t for s, t in zip(stamps[1:], stamps[:-1])]) x = [] y = [] last = stamps[0] for i in range(len(stamps)): if stamps[i]-last < 2*interval: x.append(traj[i][0]) y.append(traj[i][1]) elif len(x) > 0: ax.plot(x, y, style, color=color, label=label) label = "" x = [] y = [] last = stamps[i] if len(x) > 0: ax.plot(x, y, style, color=color, label=label) def estimate_camera_pose(pts3d, intrinsic): H, W, _ = pts3d.shape pts = to_numpy(pts3d) u, v = np.meshgrid(np.arange(W), np.arange(H)) points_2d = np.stack((u, v), axis=-1) dist_coeffs = np.zeros(4).astype(np.float32) success, rotation_vector, translation_vector, inliers = cv2.solvePnPRansac( pts.reshape(-1, 3).astype(np.float32), points_2d.reshape(-1, 2).astype(np.float32), intrinsic.astype(np.float32), dist_coeffs) if not success: return np.eye(4), False rotation_matrix, _ = cv2.Rodrigues(rotation_vector) # Extrinsic parameters (4x4 matrix) extrinsic_matrix = np.hstack((rotation_matrix, translation_vector.reshape(-1, 1))) extrinsic_matrix = np.vstack((extrinsic_matrix, [0, 0, 0, 1])) c2w = inv(extrinsic_matrix) return c2w, True def estimate_intrinsics(pts3d_local): ##### estimate focal length B, H, W, _ = pts3d_local.shape pp = torch.tensor((W/2, H/2)) focal = estimate_focal_knowing_depth(pts3d_local.cpu(), pp, focal_mode='weiszfeld') # print(f'Estimated focal of first camera: {focal.item()} (224x224)') intrinsic = np.eye(3) intrinsic[0, 0] = focal intrinsic[1, 1] = focal intrinsic[:2, 2] = pp return intrinsic def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0., max_focal=np.inf): """ Reprojection method, for when the absolute depth is known: 1) estimate the camera focal using a robust estimator 2) reproject points onto true rays, minimizing a certain error """ B, H, W, THREE = pts3d.shape assert THREE == 3 # centered pixel grid pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2 pts3d = pts3d.flatten(1, 2) # (B, HW, 3) if focal_mode == 'median': with torch.no_grad(): # direct estimation of focal u, v = pixels.unbind(dim=-1) x, y, z = pts3d.unbind(dim=-1) fx_votes = (u * z) / x fy_votes = (v * z) / y # assume square pixels, hence same focal for X and Y f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) focal = torch.nanmedian(f_votes, dim=-1).values elif focal_mode == 'weiszfeld': # init focal with l2 closed form # we try to find focal = argmin Sum | pixel - focal * (x,y)/z| xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1) dot_xy_px = (xy_over_z * pixels).sum(dim=-1) dot_xy_xy = xy_over_z.square().sum(dim=-1) focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1) # iterative re-weighted least-squares for iter in range(10): # re-weighting by inverse of distance dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) # print(dis.nanmean(-1)) w = dis.clip(min=1e-8).reciprocal() # update the scaling with the new weights focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1) else: raise ValueError(f'bad {focal_mode=}') focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515 focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base) # print(focal) return focal def unsqueeze_view(view): """Uunsqueeze view to batch size 1, similar to collate_fn """ if len(view['img'].shape) > 3: return view res = dict(img=view['img'][None], true_shape=view['true_shape'][None], idx=view['idx'], instance=view['instance'], pts3d_cam=torch.tensor(view['pts3d_cam'][None]), valid_mask=torch.tensor(view['valid_mask'][None]), camera_pose=torch.tensor(view['camera_pose']), pts3d=torch.tensor(view['pts3d'][None]) ) if 'pointmap_img' in view: res['pointmap_img'] = view['pointmap_img'][None] return res def transform_img(view): #transform to numpy, BGR, 0-255, HWC img = view['img'][0] # print(img.shape) img = img.permute(1, 2, 0).cpu().numpy() img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) img = (img/2.+0.5)*255. return img def save_ply(points:np.array, save_path, colors:np.array=None, metadata:dict=None): #color:0-1 if np.max(colors) > 1: colors = colors/255. pcd = trimesh.points.PointCloud(points, colors=colors) if metadata is not None: for key in metadata: pcd.metadata[key] = metadata[key] pcd.export(save_path) print(">> save_to", save_path) def save_vis(points, dis, vis_path): cmap = plt.get_cmap('Reds') color = cmap(dis/0.05) save_ply(points=points, save_path=vis_path, colors=color) def uni_upsample(img,scale): img = np.array(img) upsampled_img = img[:,None,:,None].repeat(scale,1).repeat(scale,3).reshape(img.shape[0]*scale,-1) return upsampled_img def normalize_views(pts3d:list, valid_masks=None, return_factor=False): """normalize the input point clouds by the average distance of the valid points to the origin Args: pts3d: list of tensors, each tensor has shape (1,224,224,3) valid_masks: list of tensors, each tensor has shape (1,224,224) return_factor: whether to return the normalization factor """ num_views = len(pts3d) # num_views*(1,224,224,3) if valid_masks is None: valid_masks = [torch.ones(p.shape[:-1], dtype=bool, device=pts3d[0].device) for p in pts3d] assert num_views == len(valid_masks) norm_factor = get_multiview_scale([pts3d[id] for id in range(num_views)], [valid_masks[id] for id in range(num_views)], norm_mode='avg_dis') normed_pts3d = [pts3d[id] / norm_factor for id in range(num_views)] if return_factor: return normed_pts3d, norm_factor return normed_pts3d def to_device(view, device='cuda'): """ transfer the input view to the target device """ for name in 'img pts3d_cam pts3d_world true_shape img_tokens'.split(): if name in view: view[name] = view[name].to(device) @torch.no_grad() def i2p_inference_batch(batch_views:list, model, device='cuda', ref_id=0, tocpu=True, unsqueeze=True): """inference on a batch of views with the Image2Points model batch_views: list of list, [[view1, view2, ...], [view1, view2, ...], ...] batch1 batch2 ... """ pairs = [] for views in batch_views: if unsqueeze: pairs.append(tuple(unsqueeze_view(view) for view in views)) else: pairs.append(tuple(views)) input = collate_with_cat(pairs) res = loss_of_one_batch_multiview(input, model, None, device, ref_id=ref_id) result = [to_cpu(res)] if tocpu else [res] output = collate_with_cat(result) #views,preds,loss,view1,..pred1... return output @torch.no_grad() def l2w_inference(raw_views, l2w_model, ref_ids, masks=None, normalize=False, device='cuda'): """Multi-keyframe co-registration with the Local2World model Input: raw_views(should be collated): list of views, each view is a dict containing: img_tokens: the img tokens output from encoder: (B, Patch_H, Patch_W, C) pts3d_cam: the point clouds in the camera coordinate: (B, H, W, 3) ... model: the Local2World model ref_ids: the ids of scene frames masks: the masks of the input pointmap normalize: whether to normalize the input point clouds """ # construct new input to avoid modifying the raw views input_views = [dict(img_tokens=view['img_tokens'], true_shape=view['true_shape'], img_pos=view['img_pos']) for view in raw_views] for view in input_views: to_device(view, device=device) # pts3d_world in input scene frames are normalized together, # while pts3d_cam in input keyframes are normalized separately # Here we calculate the normalized pts3d_world ahead of time if normalize: normed_pts_world, norm_factor_world = \ normalize_views([raw_views[i]['pts3d_world'] for i in ref_ids], None if masks is None else [masks[i] for i in ref_ids], return_factor=True) for id,view in enumerate(raw_views): if id in ref_ids: if normalize: pts_world = normed_pts_world[ref_ids.index(id)] else: pts_world = view['pts3d_world'] if masks is not None: pts_world = pts_world*(masks[id].float()) input_views[id]['pts3d_world'] = pts_world else: if normalize: input_views[id]['pts3d_cam'] = normalize_views([raw_views[id]['pts3d_cam']], None if masks is None else [masks[id]])[0] else: input_views[id]['pts3d_cam'] = raw_views[id]['pts3d_cam'] if masks is not None: input_views[id]['pts3d_cam'] = input_views[id]['pts3d_cam']*(masks[id].float()) with torch.no_grad(): output = l2w_model(input_views, ref_ids=ref_ids) # restore the predicted points to the original scale in raw_views if normalize: for i in range(len(raw_views)): if i in ref_ids: output[i]['pts3d'] = output[i]['pts3d'] * norm_factor_world else: output[i]['pts3d_in_other_view'] = output[i]['pts3d_in_other_view'] * norm_factor_world return output def get_free_gpu(): # initialize PyCUDA try: import pycuda.driver as cuda except ImportError as e: print(f"{e} -- fail to import pycuda, choose GPU 0.") return 0 cuda.init() device_count = cuda.Device.count() most_free_mem = 0 most_free_id = 0 for i in range(device_count): try: device = cuda.Device(i) context = device.make_context() # query the free memory on the device free_memory = cuda.mem_get_info()[0] # if the gpu is totally free, return it total_memory = device.total_memory() if free_memory == total_memory: context.pop() return i if(free_memory > most_free_mem): most_free_mem = free_memory most_free_id = i context.pop() except: pass print("No totally free GPU found! Choose the most free one.") return most_free_id