File size: 8,587 Bytes
8bd45de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Dataloader for preprocessed project-aria dataset
# --------------------------------------------------------
import os.path as osp
import os
import cv2
import numpy as np
import math

SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
import sys # noqa: E402
sys.path.insert(0, SLAM3R_DIR) # noqa: E402
from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
from slam3r.utils.image import imread_cv2    


class Aria_Seq(BaseStereoViewDataset):
    def __init__(self,  
                 ROOT='data/projectaria/ase_processed', 
                 num_views=2, 
                 scene_name=None,   # specify scene name(s) to load
                 sample_freq=1,  # stride of the frmaes inside the sliding window
                 start_freq=1,  # start frequency for the sliding window
                 filter=False,  # filter out the windows with abnormally large stride
                 rand_sel=False,  # randomly select views from a window
                 winsize=0,  # window size to randomly select views
                 sel_num=0,  # number of combinations to randomly select from a window
                 *args,**kwargs):
        super().__init__(*args, **kwargs)
        self.ROOT = ROOT
        self.sample_freq = sample_freq
        self.start_freq = start_freq
        self.num_views = num_views
        
        self.rand_sel = rand_sel 
        if rand_sel:
            assert winsize > 0 and sel_num > 0
            comb_num = math.comb(winsize-1, num_views-2)
            assert comb_num >= sel_num
            self.winsize = winsize
            self.sel_num = sel_num
        else:
            self.winsize = sample_freq*(num_views-1)
        
        self.scene_names = os.listdir(self.ROOT)
        self.scene_names = [int(scene_name) for scene_name in self.scene_names if scene_name.isdigit()]
        self.scene_names = sorted(self.scene_names)
        self.scene_names = [str(scene_name) for scene_name in self.scene_names]
        total_scene_num = len(self.scene_names)
        
        if self.split == 'train':
            # choose 90% of the data as training set
            self.scene_names = self.scene_names[:int(total_scene_num*0.9)]
        elif self.split=='test':
            self.scene_names = self.scene_names[int(total_scene_num*0.9):]
        if scene_name is not None:
            assert self.split is None
            if isinstance(scene_name, list): 
                self.scene_names = scene_name
            else:
                if isinstance(scene_name, int):
                    scene_name = str(scene_name)
                assert isinstance(scene_name, str)
                self.scene_names = [scene_name]
                
        self._load_data(filter=filter)
        print(self)

    def filter_windows(self, sid, eid, image_names):
        return False

    def _load_data(self, filter=False):
        self.sceneids = []  
        self.images = []
        self.intrinsics = []   #scene_num*(3,3)
        self.win_bid = []

        num_count = 0
        for id, scene_name in enumerate(self.scene_names):
            scene_dir = os.path.join(self.ROOT, scene_name)
            # print(id, scene_name)
            image_names = os.listdir(os.path.join(scene_dir, 'color'))
            image_names = sorted(image_names)
            intrinsic = np.loadtxt(os.path.join(scene_dir, 'intrinsic', 'intrinsic_color.txt'))[:3,:3]
            image_num = len(image_names)
            # precompute the window indices
            for i in range(0, image_num, self.start_freq):
                last_id = i+self.winsize
                if last_id >= image_num:
                    break
                if filter and self.filter_windows(i, last_id, image_names):
                    continue
                self.win_bid.append((num_count+i, num_count+last_id))
            
            self.intrinsics.append(intrinsic)
            self.images +=  image_names
            self.sceneids += [id,] * image_num
            num_count += image_num
        # print(self.sceneids, self.scene_names)    
        self.intrinsics = np.stack(self.intrinsics, axis=0)
        print(self.intrinsics.shape)
        assert len(self.sceneids)==len(self.images), f"{len(self.sceneids)}, {len(self.images)}"   
        
    def __len__(self):
        if self.rand_sel:
            return self.sel_num*len(self.win_bid)
        return len(self.win_bid)    
    
    def get_img_idxes(self, idx, rng):
        if self.rand_sel:
            sid, eid = self.win_bid[idx//self.sel_num]
            if idx % self.sel_num == 0:
                return np.linspace(sid, eid, self.num_views, endpoint=True, dtype=int)
                
            if self.num_views == 2:
                return [sid, eid]
            sel_ids = rng.choice(range(sid+1, eid), self.num_views-2, replace=False)
            sel_ids.sort()
            return [sid] + list(sel_ids) + [eid]
        else:
            sid, eid = self.win_bid[idx]
            return [sid + i*self.sample_freq for i in range(self.num_views)]
            
    
    def _get_views(self, idx, resolution, rng):

        image_idxes = self.get_img_idxes(idx, rng)
        # print(image_idxes)
        views = []
        for view_idx in image_idxes:
            scene_id = self.sceneids[view_idx]
            scene_dir = osp.join(self.ROOT, self.scene_names[scene_id])

            intrinsics = self.intrinsics[scene_id]
            basename = self.images[view_idx]
            camera_pose = np.loadtxt(osp.join(scene_dir, 'pose', basename.replace('.jpg', '.txt')))
            # Load RGB image
            rgb_image = imread_cv2(osp.join(scene_dir, 'color', basename))
            # Load depthmap
            depthmap = imread_cv2(osp.join(scene_dir, 'depth', basename.replace('.jpg', '.png')), cv2.IMREAD_UNCHANGED)
            depthmap[~np.isfinite(depthmap)] = 0  # invalid
            depthmap = depthmap.astype(np.float32) / 1000
            depthmap[depthmap > 20] = 0  # invalid
            
            rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
                rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx)

            views.append(dict(
                img=rgb_image,
                depthmap=depthmap.astype(np.float32),
                camera_pose=camera_pose.astype(np.float32),
                camera_intrinsics=intrinsics.astype(np.float32),
                dataset='Aria',
                label=self.scene_names[scene_id] + '_' + basename,
                instance=f'{str(idx)}_{str(view_idx)}',
            ))
        # print([view['label'] for view in views])
        return views  

if __name__ == "__main__":
    import trimesh

    num_views = 4
    # dataset = Aria_Seq(resolution=(224,224), 
    #                              num_views=num_views,
    #                              start_freq=1, sample_freq=2)
    dataset = Aria_Seq(split='train', resolution=(224,224), 
                                 num_views=num_views,
                                 start_freq=1, rand_sel=True, winsize=6, sel_num=3)
    save_dir = "visualization/aria_seq_views"
    os.makedirs(save_dir, exist_ok=True)

    for idx in np.random.permutation(len(dataset))[:10]:
        os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True)
        views = dataset[(idx,0)]
        assert len(views) == num_views
        all_pts = []
        all_color=[]
        for i, view in enumerate(views):
            img = np.array(view['img']).transpose(1, 2, 0)
            # save_path = osp.join(save_dir, str(idx), f"{'_'.join(view_name(view).split('/')[1:])}.jpg")
            save_path = osp.join(save_dir, str(idx), f"{i}_{view['label']}")
            # img=cv2.COLOR_RGB2BGR(img)
            img=img[...,::-1]
            img = (img+1)/2
            cv2.imwrite(save_path, img*255)
            print(f"save to {save_path}")
            img = img[...,::-1]
            pts3d = np.array(view['pts3d']).reshape(-1,3)
            pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3))
            pct.export(save_path.replace('.jpg','.ply'))
            all_pts.append(pts3d)
            all_color.append(img.reshape(-1, 3))
        all_pts = np.concatenate(all_pts, axis=0)
        all_color = np.concatenate(all_color, axis=0)
        pct = trimesh.PointCloud(all_pts, all_color)
        pct.export(osp.join(save_dir, str(idx), f"all.ply"))