Spaces:
Sleeping
Sleeping
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # utilities needed for the inference | |
| # -------------------------------------------------------- | |
| import torch | |
| import numpy as np | |
| from .utils.misc import invalid_to_zeros | |
| from .utils.geometry import geotrf, inv | |
| def loss_of_one_batch(loss_func, batch, model, criterion, device, | |
| use_amp=False, ret=None, | |
| assist_model=None, train=False, epoch=0, | |
| args=None): | |
| if loss_func == "i2p": | |
| return loss_of_one_batch_multiview(batch, model, criterion, | |
| device, use_amp, ret, | |
| args.ref_id) | |
| elif loss_func == "i2p_corr_score": | |
| return loss_of_one_batch_multiview_corr_score(batch, model, criterion, | |
| device, use_amp, ret, | |
| args.ref_id) | |
| elif loss_func == "l2w": | |
| return loss_of_one_batch_l2w( | |
| batch, model, criterion, | |
| device, use_amp, ret, | |
| ref_ids=args.ref_ids, coord_frame_id=0, | |
| exclude_ident=True, to_zero=True | |
| ) | |
| else: | |
| raise NotImplementedError | |
| def loss_of_one_batch_multiview(batch, model, criterion, device, | |
| use_amp=False, ret=None, ref_id=-1): | |
| """ Function to compute the reconstruction loss of the Image-to-Points model | |
| """ | |
| views = batch | |
| for view in views: | |
| for name in 'img pts3d valid_mask camera_pose'.split(): # pseudo_focal | |
| if name not in view: | |
| continue | |
| view[name] = view[name].to(device, non_blocking=True) | |
| if ref_id == -1: | |
| ref_id = (len(views)-1)//2 | |
| with torch.cuda.amp.autocast(enabled=bool(use_amp)): | |
| preds = model(views, ref_id=ref_id) | |
| assert len(preds) == len(views) | |
| with torch.cuda.amp.autocast(enabled=False): | |
| if criterion is None: | |
| loss = None | |
| else: | |
| loss = criterion(views, preds, ref_id=ref_id) | |
| result = dict(views=views, preds=preds, loss=loss) | |
| for i in range(len(preds)): | |
| result[f'pred{i+1}'] = preds[i] | |
| result[f'view{i+1}'] = views[i] | |
| return result[ret] if ret else result | |
| def loss_of_one_batch_multiview_corr_score(batch, model, criterion, device, | |
| use_amp=False, ret=None, ref_id=-1): | |
| views = batch | |
| for view in views: | |
| for name in 'img pts3d valid_mask camera_pose'.split(): # pseudo_focal | |
| if name not in view: | |
| continue | |
| view[name] = view[name].to(device, non_blocking=True) | |
| if ref_id == -1: | |
| ref_id = (len(views)-1)//2 | |
| all_loss = [0, {}] | |
| with torch.cuda.amp.autocast(enabled=bool(use_amp)): | |
| preds = model(views, ref_id=ref_id, return_corr_score=True) | |
| assert len(preds) == len(views) | |
| for i,pred in enumerate(preds): | |
| if i == ref_id: | |
| continue | |
| patch_pseudo_conf = pred['pseudo_conf'] # (B,S) | |
| true_conf = (pred['conf']-1.).mean(dim=(1,2)) # (B,) mean(exp(x)) | |
| pseudo_conf = torch.exp(patch_pseudo_conf).mean(dim=1) # (B,) mean(exp(batch(x))) | |
| pseudo_conf = pseudo_conf / (1+pseudo_conf) | |
| true_conf = true_conf / (1+true_conf) | |
| dis = torch.abs(pseudo_conf-true_conf) | |
| loss = dis.mean() | |
| # if loss.isinf(): | |
| # print(((patch_pseudo_conf-patch_true_conf)**2).max()) | |
| all_loss[0] += loss | |
| all_loss[1][f'pseudo_conf_loss_{i}'] = loss | |
| result = dict(views=views, preds=preds, loss=all_loss) | |
| for i in range(len(preds)): | |
| result[f'pred{i+1}'] = preds[i] | |
| result[f'view{i+1}'] = views[i] | |
| return result[ret] if ret else result | |
| def get_multiview_scale(pts:list, valid:list, norm_mode='avg_dis'): | |
| # adpat from DUSt3R | |
| for i in range(len(pts)): | |
| assert pts[i].ndim >= 3 and pts[i].shape[-1] == 3 | |
| assert len(pts) == len(valid) | |
| norm_mode, dis_mode = norm_mode.split('_') | |
| if norm_mode == 'avg': | |
| # gather all points together (joint normalization) | |
| all_pts = [] | |
| all_nnz = 0 | |
| for i in range(len(pts)): | |
| nan_pts, nnz = invalid_to_zeros(pts[i], valid[i], ndim=3) | |
| # print(nnz,nan_pts.shape) #(B,) (B,H*W,3) | |
| all_pts.append(nan_pts) | |
| all_nnz += nnz | |
| all_pts = torch.cat(all_pts, dim=1) | |
| # compute distance to origin | |
| all_dis = all_pts.norm(dim=-1) | |
| if dis_mode == 'dis': | |
| pass # do nothing | |
| elif dis_mode == 'log1p': | |
| all_dis = torch.log1p(all_dis) | |
| else: | |
| raise ValueError(f'bad {dis_mode=}') | |
| norm_factor = all_dis.sum(dim=1) / (all_nnz + 1e-8) | |
| else: | |
| raise ValueError(f'bad {norm_mode=}') | |
| norm_factor = norm_factor.clip(min=1e-8) | |
| while norm_factor.ndim < pts[0].ndim: | |
| norm_factor.unsqueeze_(-1) | |
| # print('norm factor:', norm_factor) | |
| return norm_factor | |
| def loss_of_one_batch_l2w(batch, model, criterion, device, | |
| use_amp=False, ret=None, | |
| ref_ids=-1, coord_frame_id=0, | |
| exclude_ident=True, to_zero=True): | |
| """ Function to compute the reconstruction loss of the Local-to-World model | |
| ref_ids: list of indices of the suppporting frames(excluding the coord_frame) | |
| coord_frame_id: all the pointmaps input and output will be in the coord_frame_id's camera coordinate | |
| exclude_ident: whether to exclude the coord_frame to simulate real-life inference scenarios | |
| to_zero: whether to set the invalid points to zero | |
| """ | |
| views = batch | |
| for view in views: | |
| for name in 'img pts3d pts3d_cam valid_mask camera_pose'.split(): # pseudo_focal | |
| if name not in view: | |
| continue | |
| view[name] = view[name].to(device, non_blocking=True) | |
| if coord_frame_id == -1: | |
| # ramdomly select a camera as the target camera | |
| coord_frame_id = np.random.randint(0, len(views)) | |
| # print(coord_frame_id) | |
| c2w = views[coord_frame_id]['camera_pose'] | |
| w2c = inv(c2w) | |
| # exclude the frame that has the identity pose | |
| if exclude_ident: | |
| views.pop(coord_frame_id) | |
| if ref_ids == -1: | |
| ref_ids = [i for i in range(len(views)-1)] # all views except the last one | |
| elif ref_ids == -2: | |
| #select half of the views randomly | |
| ref_ids = np.random.choice(len(views), len(views)//2, replace=False).tolist() | |
| else: | |
| assert isinstance(ref_ids, list) | |
| for id in ref_ids: | |
| views[id]['pts3d_world'] = geotrf(w2c, views[id]['pts3d']) #转移到目标坐标系 | |
| norm_factor_world = get_multiview_scale([views[id]['pts3d_world'] for id in ref_ids], | |
| [views[id]['valid_mask'] for id in ref_ids], | |
| norm_mode='avg_dis') | |
| for id,view in enumerate(views): | |
| if id in ref_ids: | |
| view['pts3d_world'] = view['pts3d_world'].permute(0,3,1,2) / norm_factor_world | |
| else: | |
| norm_factor_src = get_multiview_scale([view['pts3d_cam']], | |
| [view['valid_mask']], | |
| norm_mode='avg_dis') | |
| view['pts3d_cam'] = view['pts3d_cam'].permute(0,3,1,2) / norm_factor_src | |
| if to_zero: | |
| for id,view in enumerate(views): | |
| valid_mask = view['valid_mask'].unsqueeze(1).float() # B,1,H,W | |
| if id in ref_ids: | |
| # print(view['pts3d_world'].shape, valid_mask.shape, (-valid_mask+1).sum()) | |
| view['pts3d_world'] = view['pts3d_world'] * valid_mask | |
| else: | |
| view['pts3d_cam'] = view['pts3d_cam'] * valid_mask | |
| with torch.cuda.amp.autocast(enabled=bool(use_amp)): | |
| preds = model(views, ref_ids=ref_ids) | |
| assert len(preds) == len(views) | |
| with torch.cuda.amp.autocast(enabled=False): | |
| if criterion is None: | |
| loss = None | |
| else: | |
| loss = criterion(views, preds, ref_id=ref_ids, ref_camera=w2c, norm_scale=norm_factor_world) | |
| result = dict(views=views, preds=preds, loss=loss) | |
| for i in range(len(preds)): | |
| result[f'pred{i+1}'] = preds[i] | |
| result[f'view{i+1}'] = views[i] | |
| return result[ret] if ret else result | |