siyan824's picture
init
8bd45de
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# --------------------------------------------------------
# Dataloader for preprocessed Replica dataset provided by NICER-SLAM
# --------------------------------------------------------
import os.path as osp
import os
import cv2
import numpy as np
from glob import glob
import json
import trimesh
SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
import sys # noqa: E402
sys.path.insert(0, SLAM3R_DIR) # noqa: E402
from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
from slam3r.utils.image import imread_cv2
class Replica(BaseStereoViewDataset):
def __init__(self,
ROOT='data/Replica',
num_views=2,
num_fetch_views=None,
sel_view=None,
scene_name=None,
sample_freq=20,
start_freq=1,
sample_dis=1,
cycle=False,
ref_id=-1,
print_mess=False,
*args,**kwargs):
super().__init__(*args, **kwargs)
self.ROOT = ROOT
self.print_mess = print_mess
self.sample_freq = sample_freq
self.start_freq = start_freq
self.sample_dis = sample_dis
self.cycle=cycle
self.num_fetch_views = num_fetch_views if num_fetch_views is not None else num_views
self.sel_view = np.arange(num_views) if sel_view is None else np.array(sel_view)
self.num_views = num_views
assert ref_id < num_views
self.ref_id = ref_id if ref_id >= 0 else (num_views-1) // 2
self.scene_names = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"]
if self.split == 'train':
self.scene_names = ["room0", "room1", "room2", "office0", "office1", "office2"]
elif self.split=='val':
self.scene_names = ["office3", "office4"]
if scene_name is not None:
assert self.split is None
if isinstance(scene_name, list):
self.scene_names = scene_name
else:
assert isinstance(scene_name, str)
self.scene_names = [scene_name]
self._load_data()
print(self)
def _load_data(self):
self.sceneids = []
self.image_paths = []
self.trajectories = [] #c2w
self.pairs = []
with open(os.path.join(self.ROOT,"cam_params.json"),'r') as f:
self.intrinsic = json.load(f)['camera']
K = np.eye(3)
K[0, 0] = self.intrinsic['fx']
K[1, 1] = self.intrinsic['fy']
K[0, 2] = self.intrinsic['cx']
K[1, 2] = self.intrinsic['cy']
self.intri_mat = K
num_count = 0
for id, scene_name in enumerate(self.scene_names):
scene_dir = os.path.join(self.ROOT, scene_name)
image_paths = sorted(glob(os.path.join(scene_dir,"results","frame*.jpg")))
image_paths = image_paths[::self.sample_freq]
image_num = len(image_paths)
if not self.cycle:
for i in range(0, image_num, self.start_freq):
last_id = i+self.sample_dis*(self.num_fetch_views-1)
if last_id >= image_num:
break
self.pairs.append([j+num_count for j in range(i,last_id+1,self.sample_dis)])
else:
for i in range(0, image_num, self.start_freq):
pair = []
for j in range(0, self.num_fetch_views):
pair.append((i+(j-self.ref_id)*self.sample_dis+image_num)%image_num + num_count)
self.pairs.append(pair)
self.trajectories.append(np.loadtxt(os.path.join(scene_dir,"traj.txt")).reshape(-1,4,4)[::self.sample_freq])
self.image_paths += image_paths
self.sceneids += [id,] * image_num
num_count += image_num
# print(self.sceneids, self.scene_names)
self.trajectories = np.concatenate(self.trajectories,axis=0)
assert len(self.trajectories) == len(self.sceneids) and len(self.sceneids)==len(self.image_paths), f"{len(self.trajectories)}, {len(self.sceneids)}, {len(self.image_paths)}"
def __len__(self):
return len(self.pairs)
def _get_views(self, idx, resolution, rng):
image_idxes = self.pairs[idx]
assert len(image_idxes) == self.num_fetch_views
image_idxes = [image_idxes[i] for i in self.sel_view]
views = []
for view_idx in image_idxes:
scene_id = self.sceneids[view_idx]
camera_pose = self.trajectories[view_idx]
image_path = self.image_paths[view_idx]
image_name = os.path.basename(image_path)
depth_path = image_path.replace(".jpg",".png").replace("frame","depth")
# Load RGB image
rgb_image = imread_cv2(image_path)
# Load depthmap
depthmap = imread_cv2(depth_path, cv2.IMREAD_UNCHANGED)
depthmap = depthmap.astype(np.float32)
depthmap[~np.isfinite(depthmap)] = 0 # TODO:invalid
depthmap /= self.intrinsic['scale']
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, self.intri_mat, resolution, rng=rng, info=view_idx)
# print(intrinsics)
views.append(dict(
img=rgb_image,
depthmap=depthmap.astype(np.float32),
camera_pose=camera_pose.astype(np.float32),
camera_intrinsics=intrinsics.astype(np.float32),
dataset='Replica',
label=self.scene_names[scene_id] + '_' + image_name,
instance=f'{str(idx)}_{str(view_idx)}',
))
if self.print_mess:
print(f"loading {[view['label'] for view in views]}")
return views
if __name__ == "__main__":
num_views = 5
dataset= Replica(ref_id=1, print_mess=True, cycle=True, resolution=224, num_views=num_views, sample_freq=100, seed=777, start_freq=1, sample_dis=1)
save_dir = "visualization/replica_views"
# combine the pointmaps from different views with c2ws
# to check the correctness of the dataloader
for idx in np.random.permutation(len(dataset))[:10]:
# for idx in range(10):
views = dataset[(idx,0)]
os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True)
assert len(views) == num_views
all_pts = []
all_color = []
for i, view in enumerate(views):
img = np.array(view['img']).transpose(1, 2, 0)
save_path = osp.join(save_dir, str(idx), f"{i}_{view['label']}")
print(save_path)
img=img[...,::-1]
img = (img+1)/2
cv2.imwrite(save_path, img*255)
print(f"save to {save_path}")
img = img[...,::-1]
pts3d = np.array(view['pts3d']).reshape(-1,3)
pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3))
pct.export(save_path.replace('.jpg','.ply'))
all_pts.append(pts3d)
all_color.append(img.reshape(-1, 3))
all_pts = np.concatenate(all_pts, axis=0)
all_color = np.concatenate(all_color, axis=0)
pct = trimesh.PointCloud(all_pts, all_color)
pct.export(osp.join(save_dir, str(idx), f"all.ply"))