File size: 9,516 Bytes
8bd45de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Dataloader for preprocessed Co3d_v2
# dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International
# See datasets_preprocess/preprocess_co3d.py
# --------------------------------------------------------
import os.path as osp
import json
import itertools
from collections import deque
import cv2
import numpy as np

SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
import sys # noqa: E402
sys.path.insert(0, SLAM3R_DIR) # noqa: E402
from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
from slam3r.utils.image import imread_cv2

TRAINING_CATEGORIES = [
    "apple","backpack","banana","baseballbat","baseballglove","bench","bicycle",
    "bottle","bowl","broccoli","cake","car","carrot","cellphone","chair","cup","donut","hairdryer","handbag","hydrant","keyboard",
    "laptop","microwave","motorcycle","mouse","orange","parkingmeter","pizza","plant","stopsign","teddybear","toaster","toilet",
    "toybus","toyplane","toytrain","toytruck","tv","umbrella","vase","wineglass",
]
TEST_CATEGORIES = ["ball", "book", "couch", "frisbee", "hotdog", "kite", "remote", "sandwich", "skateboard", "suitcase"]


class Co3d_Seq(BaseStereoViewDataset):
    def __init__(self, 
                 mask_bg=True, 
                 ROOT="data/co3d_processed", 
                 num_views=2,
                 degree=90,  # degree range to select views
                 sel_num=1,  # number of views to select inside a degree range
                 *args, 
                 **kwargs):
        self.ROOT = ROOT
        super().__init__(*args, **kwargs)
        assert mask_bg in (True, False, 'rand')
        self.mask_bg = mask_bg
        self.degree = degree
        self.winsize = int(degree / 360 * 100)
        self.sel_num = sel_num
        self.sel_num_perseq = (101 - self.winsize) * self.sel_num
        self.num_views = num_views

        # load all scenes
        if self.split == 'train':
            self.categories = TRAINING_CATEGORIES
        elif self.split == 'test':
            self.categories = TEST_CATEGORIES
        else:
            raise ValueError(f"Unknown split {self.split}")
        self.scenes = {}
        for cate in TRAINING_CATEGORIES:
            with open(osp.join(self.ROOT, cate, f'selected_seqs_{self.split}.json'), 'r') as f:
                self.scenes[cate] = json.load(f)
        self.scenes = {(k, k2): v2 for k, v in self.scenes.items()
                        for k2, v2 in v.items()}
        self.scene_list = list(self.scenes.keys()) # for each scene, we have about 100 images ==> 360 degrees (so 25 frames ~= 90 degrees)
        self.scene_lens = [len(v) for k,v in self.scenes.items()]
        # print(np.unique(np.array(self.scene_lens)))
        self.invalidate = {scene: {} for scene in self.scene_list}
        
        print(self)

    def __len__(self):
        return len(self.scene_list) * self.sel_num_perseq

    def get_img_idxes(self, idx, rng):
        sid = max(0, idx // self.sel_num - 1) #from 0 to 99-winsize
        eid = sid + self.winsize
        if idx % self.sel_num == 0:
            # generate a uniform sample between sid and eid
            return np.linspace(sid, eid, self.num_views, endpoint=True, dtype=int)
            
        # select the first and last, and randomly select the rest n-2 in between
        if self.num_views == 2:
            return [sid, eid]
        sel_ids = rng.choice(range(sid+1, eid), self.num_views-2, replace=False)
        sel_ids.sort()
        return [sid] + list(sel_ids) + [eid]
    

    def _get_views(self, idx, resolution, rng):
        # choose a scene
        obj, instance = self.scene_list[idx // self.sel_num_perseq]
        image_pool = self.scenes[obj, instance]
        last = len(image_pool)-1
        if last <= self.winsize:
            return self._get_views(rng.integers(0, len(self)-1), resolution, rng)

        imgs_idxs = self.get_img_idxes(idx % self.sel_num_perseq, rng)
        
        for i, idx in enumerate(imgs_idxs):
            if idx > last:
                idx = idx % last
                imgs_idxs[i] = idx 
        # print(imgs_idxs)

        if resolution not in self.invalidate[obj, instance]:  # flag invalid images
            self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))]

        # decide now if we mask the bg
        mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2))

        views = []
        imgs_idxs = deque(imgs_idxs)
        
        while len(imgs_idxs) > 0:  # some images (few) have zero depth
            im_idx = imgs_idxs.popleft()
        
            if self.invalidate[obj, instance][resolution][im_idx]:
                # search for a valid image
                random_direction = 2 * rng.choice(2) - 1
                for offset in range(1, len(image_pool)):
                    tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool)
                    if not self.invalidate[obj, instance][resolution][tentative_im_idx]:
                        im_idx = tentative_im_idx
                        break
                if offset == len(image_pool) - 1:
                    # no valid image found
                    return self._get_views((idx+1)%len(self), resolution, rng)

            view_idx = image_pool[im_idx]

            impath = osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg')

            # load camera params
            input_metadata = np.load(impath.replace('jpg', 'npz'))
            camera_pose = input_metadata['camera_pose'].astype(np.float32)
            intrinsics = input_metadata['camera_intrinsics'].astype(np.float32)

            # load image and depth
            rgb_image = imread_cv2(impath)
            depthmap = imread_cv2(impath.replace('images', 'depths') + '.geometric.png', cv2.IMREAD_UNCHANGED)
            depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth'])
            if mask_bg:
                # load object mask
                maskpath = osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png')
                maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32)
                maskmap = (maskmap / 255.0) > 0.1

                # update the depthmap with mask
                depthmap *= maskmap
                
            rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
                rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)

            # TODO: check if this is resonable
            valid_depth = depthmap[depthmap > 0.0]
            if valid_depth.size > 0:
                median_depth = np.median(valid_depth)
                # print(f"median depth: {median_depth}")
                depthmap[depthmap > median_depth*3] = 0. # filter out floatig points 
            
            num_valid = (depthmap > 0.0).sum()
            if num_valid == 0:
                # problem, invalidate image and retry
                self.invalidate[obj, instance][resolution][im_idx] = True
                imgs_idxs.append(im_idx)
                continue

            views.append(dict(
                img=rgb_image,
                depthmap=depthmap,
                camera_pose=camera_pose,
                camera_intrinsics=intrinsics,
                dataset='Co3d_v2',
                label=f"{obj}_{instance}_frame{view_idx:06n}.jpg",
                instance=osp.split(impath)[1],
            ))
        return views


if __name__ == "__main__":
    from slam3r.datasets.base.base_stereo_view_dataset import view_name
    import os
    import trimesh

    num_views = 11
    dataset = Co3d_Seq(split='train', 
                       mask_bg=False, resolution=224, aug_crop=16,
                       num_views=num_views, degree=90, sel_num=3)

    save_dir = "visualization/co3d_seq_views"
    os.makedirs(save_dir, exist_ok=True)

    # import tqdm
    # for idx in tqdm.tqdm(np.random.permutation(len(dataset))):
    #     views = dataset[(idx,0)]
    #     print([view['instance'] for view in views])

    for idx in np.random.permutation(len(dataset))[:10]:
    # for idx in range(len(dataset))[5:10000:2000]:
        os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True)
        views = dataset[(idx,0)]
        assert len(views) == num_views
        all_pts = []
        all_color=[]
        for i, view in enumerate(views):
            img = np.array(view['img']).transpose(1, 2, 0)
            save_path = osp.join(save_dir, str(idx), f"{i}_{view['label']}")
            # img=cv2.COLOR_RGB2BGR(img)
            img=img[...,::-1]
            img = (img+1)/2
            cv2.imwrite(save_path, img*255)
            print(f"save to {save_path}")
            pts3d = np.array(view['pts3d']).reshape(-1,3)
            img = img[...,::-1]
            pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3))
            pct.export(save_path.replace('.jpg','.ply'))
            all_pts.append(pts3d)
            all_color.append(img.reshape(-1, 3))
        all_pts = np.concatenate(all_pts, axis=0)
        all_color = np.concatenate(all_color, axis=0)
        pct = trimesh.PointCloud(all_pts, all_color)
        pct.export(osp.join(save_dir, str(idx), f"all.ply"))