File size: 7,673 Bytes
8bd45de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from slam3r.pipeline.recon_online_pipeline import scene_recon_pipeline_online, FrameReader
from slam3r.pipeline.recon_offline_pipeline import scene_recon_pipeline_offline
import argparse
from slam3r.utils.recon_utils import * 
from slam3r.datasets.wild_seq import Seq_Data
from slam3r.models import Image2PointsModel, Local2WorldModel, inf
from slam3r.utils.device import to_numpy
import os


def load_model(model_name, weights, device='cuda'):
    print('Loading model: {:s}'.format(model_name))
    model = eval(model_name)
    model.to(device)
    print('Loading pretrained: ', weights)
    ckpt = torch.load(weights, map_location=device)
    print(model.load_state_dict(ckpt['model'], strict=False))
    del ckpt  # in case it occupies memory
    return model 

parser = argparse.ArgumentParser(description="Inference on a scene")
parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
parser.add_argument('--i2p_model', type=str, default="Image2PointsModel(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), \
enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, \
mv_dec1='MultiviewDecoderBlock_max',mv_dec2='MultiviewDecoderBlock_max', enc_minibatch = 11)")
parser.add_argument("--l2w_model", type=str, default="Local2WorldModel(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), \
enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, \
mv_dec1='MultiviewDecoderBlock_max',mv_dec2='MultiviewDecoderBlock_max', enc_minibatch = 11, need_encoder=False)")
parser.add_argument('--i2p_weights', type=str, help='path to the weights of i2p model')
parser.add_argument("--l2w_weights", type=str, help="path to the weights of l2w model")
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument("--dataset", type=str, help="a string indicating the dataset")
input_group.add_argument("--img_dir", type=str, help="directory of the input images")
parser.add_argument("--save_dir", type=str, default="results", help="directory to save the results") 
parser.add_argument("--test_name", type=str, required=True, help="name of the test")
parser.add_argument('--save_all_views', action='store_true', help='whether to save all views respectively')


# args for the whole scene reconstruction
parser.add_argument("--keyframe_stride", type=int, default=3, 
                    help="the stride of sampling keyframes, -1 for auto adaptation")
parser.add_argument("--initial_winsize", type=int, default=5, 
                    help="the number of initial frames to be used for scene initialization")
parser.add_argument("--win_r", type=int, default=3, 
                    help="the radius of the input window for I2P model")
parser.add_argument("--conf_thres_i2p", type=float, default=1.5, 
                    help="confidence threshold for the i2p model")
parser.add_argument("--num_scene_frame", type=int, default=10, 
                    help="the number of scene frames to be selected from \
                        buffering set when registering new keyframes")
parser.add_argument("--max_num_register", type=int, default=10, 
                    help="maximal number of frames to be registered in one go")
parser.add_argument("--conf_thres_l2w", type=float, default=12, 
                    help="confidence threshold for the l2w model(when saving final results)")
parser.add_argument("--num_points_save", type=int, default=2000000, 
                    help="number of points to be saved in the final reconstruction")
parser.add_argument("--norm_input", action="store_true", 
                    help="whether to normalize the input pointmaps for l2w model")
parser.add_argument("--save_frequency", type=int,default=3,
                    help="per xxx frame to save")
parser.add_argument("--save_each_frame",action='store_true',default=True,
                    help="whether to save each frame to .ply")
parser.add_argument("--video_path",type = str)
parser.add_argument("--retrieve_freq",type = int,default=1, 
                    help="(online mode only) frequency of retrieving reference frames")
parser.add_argument("--update_buffer_intv", type=int, default=1, 
                    help="the interval of updating the buffering set")
parser.add_argument('--buffer_size', type=int, default=100, 
                    help='maximal size of the buffering set, -1 if infinite')
parser.add_argument("--buffer_strategy", type=str, choices=['reservoir', 'fifo'], default='reservoir', 
                    help='strategy for maintaining the buffering set: reservoir-sampling or first-in-first-out')
parser.add_argument("--save_online", action='store_true', 
                    help="whether to save the construct result online.")

#params for auto adaptation of keyframe frequency
parser.add_argument("--keyframe_adapt_min", type=int, default=1, 
                    help="minimal stride of sampling keyframes when auto adaptation")
parser.add_argument("--keyframe_adapt_max", type=int, default=20, 
                    help="maximal stride of sampling keyframes when auto adaptation")
parser.add_argument("--keyframe_adapt_stride", type=int, default=1, 
                    help="stride for trying different keyframe stride")
parser.add_argument("--perframe", type=int, default=1)

parser.add_argument("--seed", type=int, default=42, help="seed for python random")
parser.add_argument('--gpu_id', type=int, default=-1, help='gpu id, -1 for auto select')
parser.add_argument('--save_preds', action='store_true', help='whether to save all per-frame preds')    
parser.add_argument('--save_for_eval', action='store_true', help='whether to save partial per-frame preds for evaluation')   
parser.add_argument("--online", action="store_true", help="whether to implement online reconstruction")

if __name__ == "__main__":
    args = parser.parse_args()
    if args.gpu_id == -1:
        args.gpu_id = get_free_gpu()
    print("using gpu: ", args.gpu_id)
    torch.cuda.set_device(f"cuda:{args.gpu_id}")
    # print(args)
    np.random.seed(args.seed)
    
    
    #----------Load model and ckpt-----------
    if args.i2p_weights is not None:
        i2p_model = load_model(args.i2p_model, args.i2p_weights, args.device)
    else:
        i2p_model = Image2PointsModel.from_pretrained('siyan824/slam3r_i2p')
        i2p_model.to(args.device)
    if args.l2w_weights is not None:
        l2w_model = load_model(args.l2w_model, args.l2w_weights, args.device)
    else:
        l2w_model = Local2WorldModel.from_pretrained('siyan824/slam3r_l2w')
        l2w_model.to(args.device)
    i2p_model.eval()
    l2w_model.eval()
    
    save_dir = os.path.join(args.save_dir, args.test_name)
    os.makedirs(save_dir, exist_ok=True)
    
    if args.online:
        picture_capture = FrameReader(args.dataset)
        scene_recon_pipeline_online(i2p_model, l2w_model, picture_capture, args, save_dir)
    else:
        if args.dataset:
            print("Loading dataset: ", args.dataset)
            dataset = Seq_Data(img_dir=args.dataset, \
                img_size=224, silent=False, sample_freq=1, \
                start_idx=0, num_views=-1, start_freq=1, to_tensor=True)
        elif args.img_dir:
            dataset = Seq_Data(img_dir=args.img_dir, img_size=224, to_tensor=True)
        if hasattr(dataset, "set_epoch"):
            dataset.set_epoch(0)

        scene_recon_pipeline_offline(i2p_model, l2w_model, dataset, args, save_dir)