slam3r-i2p_demo / slam3r /datasets /project_aria_seq.py
siyan824's picture
init
8bd45de
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Dataloader for preprocessed project-aria dataset
# --------------------------------------------------------
import os.path as osp
import os
import cv2
import numpy as np
import math
SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
import sys # noqa: E402
sys.path.insert(0, SLAM3R_DIR) # noqa: E402
from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
from slam3r.utils.image import imread_cv2
class Aria_Seq(BaseStereoViewDataset):
def __init__(self,
ROOT='data/projectaria/ase_processed',
num_views=2,
scene_name=None, # specify scene name(s) to load
sample_freq=1, # stride of the frmaes inside the sliding window
start_freq=1, # start frequency for the sliding window
filter=False, # filter out the windows with abnormally large stride
rand_sel=False, # randomly select views from a window
winsize=0, # window size to randomly select views
sel_num=0, # number of combinations to randomly select from a window
*args,**kwargs):
super().__init__(*args, **kwargs)
self.ROOT = ROOT
self.sample_freq = sample_freq
self.start_freq = start_freq
self.num_views = num_views
self.rand_sel = rand_sel
if rand_sel:
assert winsize > 0 and sel_num > 0
comb_num = math.comb(winsize-1, num_views-2)
assert comb_num >= sel_num
self.winsize = winsize
self.sel_num = sel_num
else:
self.winsize = sample_freq*(num_views-1)
self.scene_names = os.listdir(self.ROOT)
self.scene_names = [int(scene_name) for scene_name in self.scene_names if scene_name.isdigit()]
self.scene_names = sorted(self.scene_names)
self.scene_names = [str(scene_name) for scene_name in self.scene_names]
total_scene_num = len(self.scene_names)
if self.split == 'train':
# choose 90% of the data as training set
self.scene_names = self.scene_names[:int(total_scene_num*0.9)]
elif self.split=='test':
self.scene_names = self.scene_names[int(total_scene_num*0.9):]
if scene_name is not None:
assert self.split is None
if isinstance(scene_name, list):
self.scene_names = scene_name
else:
if isinstance(scene_name, int):
scene_name = str(scene_name)
assert isinstance(scene_name, str)
self.scene_names = [scene_name]
self._load_data(filter=filter)
print(self)
def filter_windows(self, sid, eid, image_names):
return False
def _load_data(self, filter=False):
self.sceneids = []
self.images = []
self.intrinsics = [] #scene_num*(3,3)
self.win_bid = []
num_count = 0
for id, scene_name in enumerate(self.scene_names):
scene_dir = os.path.join(self.ROOT, scene_name)
# print(id, scene_name)
image_names = os.listdir(os.path.join(scene_dir, 'color'))
image_names = sorted(image_names)
intrinsic = np.loadtxt(os.path.join(scene_dir, 'intrinsic', 'intrinsic_color.txt'))[:3,:3]
image_num = len(image_names)
# precompute the window indices
for i in range(0, image_num, self.start_freq):
last_id = i+self.winsize
if last_id >= image_num:
break
if filter and self.filter_windows(i, last_id, image_names):
continue
self.win_bid.append((num_count+i, num_count+last_id))
self.intrinsics.append(intrinsic)
self.images += image_names
self.sceneids += [id,] * image_num
num_count += image_num
# print(self.sceneids, self.scene_names)
self.intrinsics = np.stack(self.intrinsics, axis=0)
print(self.intrinsics.shape)
assert len(self.sceneids)==len(self.images), f"{len(self.sceneids)}, {len(self.images)}"
def __len__(self):
if self.rand_sel:
return self.sel_num*len(self.win_bid)
return len(self.win_bid)
def get_img_idxes(self, idx, rng):
if self.rand_sel:
sid, eid = self.win_bid[idx//self.sel_num]
if idx % self.sel_num == 0:
return np.linspace(sid, eid, self.num_views, endpoint=True, dtype=int)
if self.num_views == 2:
return [sid, eid]
sel_ids = rng.choice(range(sid+1, eid), self.num_views-2, replace=False)
sel_ids.sort()
return [sid] + list(sel_ids) + [eid]
else:
sid, eid = self.win_bid[idx]
return [sid + i*self.sample_freq for i in range(self.num_views)]
def _get_views(self, idx, resolution, rng):
image_idxes = self.get_img_idxes(idx, rng)
# print(image_idxes)
views = []
for view_idx in image_idxes:
scene_id = self.sceneids[view_idx]
scene_dir = osp.join(self.ROOT, self.scene_names[scene_id])
intrinsics = self.intrinsics[scene_id]
basename = self.images[view_idx]
camera_pose = np.loadtxt(osp.join(scene_dir, 'pose', basename.replace('.jpg', '.txt')))
# Load RGB image
rgb_image = imread_cv2(osp.join(scene_dir, 'color', basename))
# Load depthmap
depthmap = imread_cv2(osp.join(scene_dir, 'depth', basename.replace('.jpg', '.png')), cv2.IMREAD_UNCHANGED)
depthmap[~np.isfinite(depthmap)] = 0 # invalid
depthmap = depthmap.astype(np.float32) / 1000
depthmap[depthmap > 20] = 0 # invalid
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx)
views.append(dict(
img=rgb_image,
depthmap=depthmap.astype(np.float32),
camera_pose=camera_pose.astype(np.float32),
camera_intrinsics=intrinsics.astype(np.float32),
dataset='Aria',
label=self.scene_names[scene_id] + '_' + basename,
instance=f'{str(idx)}_{str(view_idx)}',
))
# print([view['label'] for view in views])
return views
if __name__ == "__main__":
import trimesh
num_views = 4
# dataset = Aria_Seq(resolution=(224,224),
# num_views=num_views,
# start_freq=1, sample_freq=2)
dataset = Aria_Seq(split='train', resolution=(224,224),
num_views=num_views,
start_freq=1, rand_sel=True, winsize=6, sel_num=3)
save_dir = "visualization/aria_seq_views"
os.makedirs(save_dir, exist_ok=True)
for idx in np.random.permutation(len(dataset))[:10]:
os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True)
views = dataset[(idx,0)]
assert len(views) == num_views
all_pts = []
all_color=[]
for i, view in enumerate(views):
img = np.array(view['img']).transpose(1, 2, 0)
# save_path = osp.join(save_dir, str(idx), f"{'_'.join(view_name(view).split('/')[1:])}.jpg")
save_path = osp.join(save_dir, str(idx), f"{i}_{view['label']}")
# img=cv2.COLOR_RGB2BGR(img)
img=img[...,::-1]
img = (img+1)/2
cv2.imwrite(save_path, img*255)
print(f"save to {save_path}")
img = img[...,::-1]
pts3d = np.array(view['pts3d']).reshape(-1,3)
pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3))
pct.export(save_path.replace('.jpg','.ply'))
all_pts.append(pts3d)
all_color.append(img.reshape(-1, 3))
all_pts = np.concatenate(all_pts, axis=0)
all_color = np.concatenate(all_color, axis=0)
pct = trimesh.PointCloud(all_pts, all_color)
pct.export(osp.join(save_dir, str(idx), f"all.ply"))