slam3r-i2p_demo / slam3r /datasets /scannetpp_seq.py
siyan824's picture
init
8bd45de
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# --------------------------------------------------------
# Dataloader for preprocessed scannet++
# dataset at https://github.com/scannetpp/scannetpp - non-commercial research and educational purposes
# https://kaldir.vc.in.tum.de/scannetpp/static/scannetpp-terms-of-use.pdf
# See datasets_preprocess/preprocess_scannetpp.py
# --------------------------------------------------------
import os.path as osp
import os
import cv2
import numpy as np
import math
SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
import sys # noqa: E402
sys.path.insert(0, SLAM3R_DIR) # noqa: E402
from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
from slam3r.utils.image import imread_cv2
class ScanNetpp_Seq(BaseStereoViewDataset):
def __init__(self,
ROOT='data/scannetpp_processed',
num_views=2,
scene_name=None, # specify scene name(s) to load
sample_freq=1, # stride of the frmaes inside the sliding window
start_freq=1, # start frequency for the sliding window
img_types=['iphone', 'dslr'],
filter=False, # filter out the windows with abnormally large stride
rand_sel=False, # randomly select views from a window
winsize=0, # window size to randomly select views
sel_num=0, # number of combinations to randomly select from a window
*args,**kwargs):
super().__init__(*args, **kwargs)
self.ROOT = ROOT
self.sample_freq = sample_freq
self.start_freq = start_freq
self.num_views = num_views
self.img_types = img_types
self.rand_sel = rand_sel
if rand_sel:
assert winsize > 0 and sel_num > 0
comb_num = math.comb(winsize-1, num_views-2)
assert comb_num >= sel_num
self.winsize = winsize
self.sel_num = sel_num
else:
self.winsize = sample_freq*(num_views-1)
if self.split == 'train':
with open(os.path.join(ROOT, 'splits', 'nvs_sem_train.txt'), 'r') as f:
self.scene_names = f.read().splitlines()
elif self.split=='test':
with open(os.path.join(ROOT, 'splits', 'nvs_sem_val.txt'), 'r') as f:
self.scene_names = f.read().splitlines()
if scene_name is not None:
assert self.split is None
if isinstance(scene_name, list):
self.scene_names = scene_name
else:
assert isinstance(scene_name, str)
self.scene_names = [scene_name]
self._load_data(filter=filter)
print(self)
def filter_windows(self, img_type, sid, eid, image_names):
if img_type == 'iphone': # frame_000450.jpg
start_id = int(image_names[sid].split('_')[-1].split('.')[0])
end_id = int(image_names[eid].split('_')[-1].split('.')[0])
base_stride = 10*self.winsize
elif img_type == 'dslr': # DSC06967.jpg
start_id = int(image_names[sid].split('.')[0][-5:])
end_id = int(image_names[eid].split('.')[0][-5:])
base_stride = self.winsize
# filiter out the windows with abnormally large stride
if end_id - start_id >= base_stride*3:
return True
return False
def _load_data(self, filter=False):
self.sceneids = []
self.images = []
self.intrinsics = [] #(3,3)
self.trajectories = [] #c2w (4,4)
self.win_bid = []
num_count = 0
for id, scene_name in enumerate(self.scene_names):
scene_dir = os.path.join(self.ROOT, scene_name)
for img_type in self.img_types:
metadata_path = os.path.join(scene_dir, f'scene_{img_type}_metadata.npz')
if not os.path.exists(metadata_path):
continue
metadata = np.load(metadata_path)
image_names = metadata['images'].tolist()
# check if the images are in the same sequence,
# only work for certain scenes in ScanNet++ V2
if img_type == 'dslr':
prefixes = [name[0:3] for name in image_names]
if len(set(prefixes)) > 1:
# dslr images are not in the same sequence
print(f"Warning: {scene_name} {img_type} images are not in the same sequence {set(prefixes)}")
continue
assert image_names == sorted(image_names)
image_names = sorted(image_names)
intrinsics = metadata['intrinsics']
trajectories = metadata['trajectories']
image_num = len(image_names)
# precompute the window indices
for i in range(0, image_num, self.start_freq):
last_id = i+self.winsize
if last_id >= image_num:
break
if filter and self.filter_windows(img_type, i, last_id, image_names):
continue
self.win_bid.append((num_count+i, num_count+last_id))
self.trajectories.append(trajectories)
self.intrinsics.append(intrinsics)
self.images += image_names
self.sceneids += [id,] * image_num
num_count += image_num
# print(self.sceneids, self.scene_names)
self.trajectories = np.concatenate(self.trajectories,axis=0)
self.intrinsics = np.concatenate(self.intrinsics, axis=0)
assert len(self.trajectories) == len(self.sceneids) and len(self.sceneids)==len(self.images), f"{len(self.trajectories)}, {len(self.sceneids)}, {len(self.images)}"
def __len__(self):
if self.rand_sel:
return self.sel_num*len(self.win_bid)
return len(self.win_bid)
def get_img_idxes(self, idx, rng):
if self.rand_sel:
sid, eid = self.win_bid[idx//self.sel_num]
if idx % self.sel_num == 0:
return np.linspace(sid, eid, self.num_views, endpoint=True, dtype=int)
# random select the views, including the start and end view
if self.num_views == 2:
return [sid, eid]
sel_ids = rng.choice(range(sid+1, eid), self.num_views-2, replace=False)
sel_ids.sort()
return [sid] + list(sel_ids) + [eid]
else:
sid, eid = self.win_bid[idx]
return [sid + i*self.sample_freq for i in range(self.num_views)]
def _get_views(self, idx, resolution, rng):
image_idxes = self.get_img_idxes(idx, rng)
# print(image_idxes)
views = []
for view_idx in image_idxes:
scene_id = self.sceneids[view_idx]
scene_dir = osp.join(self.ROOT, self.scene_names[scene_id])
intrinsics = self.intrinsics[view_idx]
camera_pose = self.trajectories[view_idx]
basename = self.images[view_idx]
# Load RGB image
rgb_image = imread_cv2(osp.join(scene_dir, 'images', basename))
# Load depthmap
depthmap = imread_cv2(osp.join(scene_dir, 'depth', basename.replace('.jpg', '.png')), cv2.IMREAD_UNCHANGED)
depthmap = depthmap.astype(np.float32) / 1000
depthmap[~np.isfinite(depthmap)] = 0 # invalid
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx)
views.append(dict(
img=rgb_image,
depthmap=depthmap.astype(np.float32),
camera_pose=camera_pose.astype(np.float32),
camera_intrinsics=intrinsics.astype(np.float32),
dataset='ScanNet++',
label=self.scene_names[scene_id] + '_' + basename,
instance=f'{str(idx)}_{str(view_idx)}',
))
return views
if __name__ == "__main__":
from slam3r.datasets.base.base_stereo_view_dataset import view_name
import trimesh
num_views = 5
dataset = ScanNetpp_Seq(split='train', resolution=(224,224),
num_views=num_views,
start_freq=1, sample_freq=3)
# dataset = ScanNetpp_Seq(split='train', resolution=(224,224),
# num_views=num_views,
# start_freq=1, rand_sel=True, winsize=6, sel_num=3)
save_dir = "visualization/scannetpp_seq_views"
os.makedirs(save_dir, exist_ok=True)
for idx in np.random.permutation(len(dataset))[:10]:
# for idx in range(len(dataset))[5:10000:2000]:
os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True)
views = dataset[(idx,0)]
assert len(views) == num_views
all_pts = []
all_color=[]
for i, view in enumerate(views):
img = np.array(view['img']).transpose(1, 2, 0)
save_path = osp.join(save_dir, str(idx), f"{i}_{view['label']}")
img=img[...,::-1]
img = (img+1)/2
cv2.imwrite(save_path, img*255)
print(f"save to {save_path}")
img = img[...,::-1]
pts3d = np.array(view['pts3d']).reshape(-1,3)
pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3))
pct.export(save_path.replace('.jpg','.ply'))
all_pts.append(pts3d)
all_color.append(img.reshape(-1, 3))
all_pts = np.concatenate(all_pts, axis=0)
all_color = np.concatenate(all_color, axis=0)
pct = trimesh.PointCloud(all_pts, all_color)
pct.export(osp.join(save_dir, str(idx), f"all.ply"))