Spaces:
Sleeping
Sleeping
File size: 7,673 Bytes
8bd45de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from slam3r.pipeline.recon_online_pipeline import scene_recon_pipeline_online, FrameReader
from slam3r.pipeline.recon_offline_pipeline import scene_recon_pipeline_offline
import argparse
from slam3r.utils.recon_utils import *
from slam3r.datasets.wild_seq import Seq_Data
from slam3r.models import Image2PointsModel, Local2WorldModel, inf
from slam3r.utils.device import to_numpy
import os
def load_model(model_name, weights, device='cuda'):
print('Loading model: {:s}'.format(model_name))
model = eval(model_name)
model.to(device)
print('Loading pretrained: ', weights)
ckpt = torch.load(weights, map_location=device)
print(model.load_state_dict(ckpt['model'], strict=False))
del ckpt # in case it occupies memory
return model
parser = argparse.ArgumentParser(description="Inference on a scene")
parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
parser.add_argument('--i2p_model', type=str, default="Image2PointsModel(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), \
enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, \
mv_dec1='MultiviewDecoderBlock_max',mv_dec2='MultiviewDecoderBlock_max', enc_minibatch = 11)")
parser.add_argument("--l2w_model", type=str, default="Local2WorldModel(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), \
enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, \
mv_dec1='MultiviewDecoderBlock_max',mv_dec2='MultiviewDecoderBlock_max', enc_minibatch = 11, need_encoder=False)")
parser.add_argument('--i2p_weights', type=str, help='path to the weights of i2p model')
parser.add_argument("--l2w_weights", type=str, help="path to the weights of l2w model")
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument("--dataset", type=str, help="a string indicating the dataset")
input_group.add_argument("--img_dir", type=str, help="directory of the input images")
parser.add_argument("--save_dir", type=str, default="results", help="directory to save the results")
parser.add_argument("--test_name", type=str, required=True, help="name of the test")
parser.add_argument('--save_all_views', action='store_true', help='whether to save all views respectively')
# args for the whole scene reconstruction
parser.add_argument("--keyframe_stride", type=int, default=3,
help="the stride of sampling keyframes, -1 for auto adaptation")
parser.add_argument("--initial_winsize", type=int, default=5,
help="the number of initial frames to be used for scene initialization")
parser.add_argument("--win_r", type=int, default=3,
help="the radius of the input window for I2P model")
parser.add_argument("--conf_thres_i2p", type=float, default=1.5,
help="confidence threshold for the i2p model")
parser.add_argument("--num_scene_frame", type=int, default=10,
help="the number of scene frames to be selected from \
buffering set when registering new keyframes")
parser.add_argument("--max_num_register", type=int, default=10,
help="maximal number of frames to be registered in one go")
parser.add_argument("--conf_thres_l2w", type=float, default=12,
help="confidence threshold for the l2w model(when saving final results)")
parser.add_argument("--num_points_save", type=int, default=2000000,
help="number of points to be saved in the final reconstruction")
parser.add_argument("--norm_input", action="store_true",
help="whether to normalize the input pointmaps for l2w model")
parser.add_argument("--save_frequency", type=int,default=3,
help="per xxx frame to save")
parser.add_argument("--save_each_frame",action='store_true',default=True,
help="whether to save each frame to .ply")
parser.add_argument("--video_path",type = str)
parser.add_argument("--retrieve_freq",type = int,default=1,
help="(online mode only) frequency of retrieving reference frames")
parser.add_argument("--update_buffer_intv", type=int, default=1,
help="the interval of updating the buffering set")
parser.add_argument('--buffer_size', type=int, default=100,
help='maximal size of the buffering set, -1 if infinite')
parser.add_argument("--buffer_strategy", type=str, choices=['reservoir', 'fifo'], default='reservoir',
help='strategy for maintaining the buffering set: reservoir-sampling or first-in-first-out')
parser.add_argument("--save_online", action='store_true',
help="whether to save the construct result online.")
#params for auto adaptation of keyframe frequency
parser.add_argument("--keyframe_adapt_min", type=int, default=1,
help="minimal stride of sampling keyframes when auto adaptation")
parser.add_argument("--keyframe_adapt_max", type=int, default=20,
help="maximal stride of sampling keyframes when auto adaptation")
parser.add_argument("--keyframe_adapt_stride", type=int, default=1,
help="stride for trying different keyframe stride")
parser.add_argument("--perframe", type=int, default=1)
parser.add_argument("--seed", type=int, default=42, help="seed for python random")
parser.add_argument('--gpu_id', type=int, default=-1, help='gpu id, -1 for auto select')
parser.add_argument('--save_preds', action='store_true', help='whether to save all per-frame preds')
parser.add_argument('--save_for_eval', action='store_true', help='whether to save partial per-frame preds for evaluation')
parser.add_argument("--online", action="store_true", help="whether to implement online reconstruction")
if __name__ == "__main__":
args = parser.parse_args()
if args.gpu_id == -1:
args.gpu_id = get_free_gpu()
print("using gpu: ", args.gpu_id)
torch.cuda.set_device(f"cuda:{args.gpu_id}")
# print(args)
np.random.seed(args.seed)
#----------Load model and ckpt-----------
if args.i2p_weights is not None:
i2p_model = load_model(args.i2p_model, args.i2p_weights, args.device)
else:
i2p_model = Image2PointsModel.from_pretrained('siyan824/slam3r_i2p')
i2p_model.to(args.device)
if args.l2w_weights is not None:
l2w_model = load_model(args.l2w_model, args.l2w_weights, args.device)
else:
l2w_model = Local2WorldModel.from_pretrained('siyan824/slam3r_l2w')
l2w_model.to(args.device)
i2p_model.eval()
l2w_model.eval()
save_dir = os.path.join(args.save_dir, args.test_name)
os.makedirs(save_dir, exist_ok=True)
if args.online:
picture_capture = FrameReader(args.dataset)
scene_recon_pipeline_online(i2p_model, l2w_model, picture_capture, args, save_dir)
else:
if args.dataset:
print("Loading dataset: ", args.dataset)
dataset = Seq_Data(img_dir=args.dataset, \
img_size=224, silent=False, sample_freq=1, \
start_idx=0, num_views=-1, start_freq=1, to_tensor=True)
elif args.img_dir:
dataset = Seq_Data(img_dir=args.img_dir, img_size=224, to_tensor=True)
if hasattr(dataset, "set_epoch"):
dataset.set_epoch(0)
scene_recon_pipeline_offline(i2p_model, l2w_model, dataset, args, save_dir)
|