File size: 3,180 Bytes
8bd45de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Dataloader for self-captured img sequence
# --------------------------------------------------------
import os.path as osp
import torch

SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
import sys # noqa: E402
sys.path.insert(0, SLAM3R_DIR) # noqa: E402
from slam3r.utils.image import load_images

class Seq_Data():
    def __init__(self, 
                 img_dir,     # the directory of the img sequence
                 img_size=224,  # only img_size=224 is supported now 
                 silent=False,  
                 sample_freq=1, # the frequency of the imgs to be sampled
                 num_views=-1, # only take the first num_views imgs in the img_dir
                 start_freq=1,  
                 postfix=None,   # the postfix of the img in the img_dir(.jpg, .png, ...)
                 to_tensor=False,
                 start_idx=0):
        
        # Note that only img_size=224 is supported now.
        # Imgs will be cropped and resized to 224x224, thus losing the information in the border.
        assert img_size==224, "Sorry, only img_size=224 is supported now."

        # load imgs with sequential number.
        # Imgs in the img_dir should have number in their names to indicate the order,
        # such as frame-0031.color.png, output_414.jpg, ...
        self.imgs = load_images(img_dir, size=img_size, 
                                verbose=not silent, img_freq=sample_freq,
                                postfix=postfix, start_idx=start_idx, img_num=num_views)
        
        self.num_views = num_views if num_views > 0 else len(self.imgs)
        self.stride = start_freq
        self.img_num = len(self.imgs)
        if to_tensor:
            for img in self.imgs:
                img['true_shape'] = torch.tensor(img['true_shape'])
        self.make_groups()
        self.length = len(self.groups)
        
        if isinstance(img_dir, str):
            if img_dir[-1] == '/':
                img_dir = img_dir[:-1]
            self.scene_names = ['_'.join(img_dir.split('/')[-2:])]
        
    def make_groups(self):
        self.groups = []
        for start in range(0,self.img_num, self.stride):
            end = start + self.num_views 
            if end > self.img_num:
                break
            self.groups.append(self.imgs[start:end])
    
    def __len__(self):
        return len(self.groups)
                        
    def __getitem__(self, idx):
        return self.groups[idx]



if __name__ == "__main__":
    from slam3r.datasets.base.base_stereo_view_dataset import view_name
    from slam3r.viz import SceneViz, auto_cam_size
    from slam3r.utils.image import rgb

    dataset = Seq_Data(img_dir="dataset/7Scenes/office-09",
                       img_size=224, silent=False, sample_freq=10, 
                       num_views=5, start_freq=2, postfix="color.png")
    for i in range(len(dataset)):
        data = dataset[i]
        print([img['idx'] for img in data])
        # break