Spaces:
Runtime error
Runtime error
feat: Add backend for refinement
Browse files- app.py +62 -15
- backend_utils.py +144 -0
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -15,7 +15,9 @@ from scipy.spatial.transform import Rotation
|
|
| 15 |
from transformers import AutoModelForImageSegmentation
|
| 16 |
from torchvision import transforms
|
| 17 |
from PIL import Image
|
| 18 |
-
import
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# Default values
|
| 21 |
DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
|
|
@@ -143,7 +145,6 @@ def generate_mask(image: np.ndarray):
|
|
| 143 |
mask_np = np.array(mask) / 255.0
|
| 144 |
return mask_np
|
| 145 |
|
| 146 |
-
@spaces.GPU
|
| 147 |
@torch.no_grad()
|
| 148 |
def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_background=False):
|
| 149 |
# Extract frames from video
|
|
@@ -176,7 +177,7 @@ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_b
|
|
| 176 |
if remove_background:
|
| 177 |
mask = generate_mask(image)
|
| 178 |
else:
|
| 179 |
-
mask = np.ones_like(conf)
|
| 180 |
|
| 181 |
images_all.append((image[None, ...] + 1.0)/2.0)
|
| 182 |
pts_all.append(pts[None, ...])
|
|
@@ -192,6 +193,54 @@ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_b
|
|
| 192 |
conf_sig_all = (conf_all-1) / conf_all
|
| 193 |
combined_mask = (conf_sig_all > conf_thresh) & (mask_all > 0.5)
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
scene = trimesh.Scene()
|
| 196 |
|
| 197 |
if as_pointcloud:
|
|
@@ -206,37 +255,35 @@ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_b
|
|
| 206 |
meshes.append(pts3d_to_trimesh(images_all[i], pts_all[i], combined_mask[i]))
|
| 207 |
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
| 208 |
scene.add_geometry(mesh)
|
| 209 |
-
|
| 210 |
rot = np.eye(4)
|
| 211 |
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
| 212 |
scene.apply_transform(np.linalg.inv(OPENGL @ rot))
|
| 213 |
-
|
| 214 |
-
|
| 215 |
if as_pointcloud:
|
| 216 |
-
output_path = tempfile.mktemp(suffix='.ply')
|
| 217 |
else:
|
| 218 |
output_path = tempfile.mktemp(suffix='.obj')
|
| 219 |
scene.export(output_path)
|
| 220 |
-
|
| 221 |
-
# Clean up temporary directory
|
| 222 |
-
os.system(f"rm -rf {demo_path}")
|
| 223 |
-
|
| 224 |
-
return output_path, f"Reconstruction completed. FPS: {fps:.2f}"
|
| 225 |
|
|
|
|
| 226 |
iface = gr.Interface(
|
| 227 |
fn=reconstruct,
|
| 228 |
inputs=[
|
| 229 |
gr.Video(label="Input Video"),
|
| 230 |
-
gr.Slider(0, 1, value=1e-
|
| 231 |
gr.Slider(1, 30, step=1, value=5, label="Keyframe Interval"),
|
| 232 |
gr.Checkbox(label="As Pointcloud", value=False),
|
| 233 |
gr.Checkbox(label="Remove Background", value=False)
|
| 234 |
],
|
| 235 |
outputs=[
|
| 236 |
-
gr.Model3D(label="3D Model", display_mode="solid"),
|
|
|
|
| 237 |
gr.Textbox(label="Status")
|
| 238 |
],
|
| 239 |
-
title="3D Reconstruction with Spatial Memory and
|
| 240 |
)
|
| 241 |
|
| 242 |
if __name__ == "__main__":
|
|
|
|
| 15 |
from transformers import AutoModelForImageSegmentation
|
| 16 |
from torchvision import transforms
|
| 17 |
from PIL import Image
|
| 18 |
+
import open3d as o3d
|
| 19 |
+
from backend_utils import improved_multiway_registration
|
| 20 |
+
|
| 21 |
|
| 22 |
# Default values
|
| 23 |
DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
|
|
|
|
| 145 |
mask_np = np.array(mask) / 255.0
|
| 146 |
return mask_np
|
| 147 |
|
|
|
|
| 148 |
@torch.no_grad()
|
| 149 |
def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_background=False):
|
| 150 |
# Extract frames from video
|
|
|
|
| 177 |
if remove_background:
|
| 178 |
mask = generate_mask(image)
|
| 179 |
else:
|
| 180 |
+
mask = np.ones_like(conf)
|
| 181 |
|
| 182 |
images_all.append((image[None, ...] + 1.0)/2.0)
|
| 183 |
pts_all.append(pts[None, ...])
|
|
|
|
| 193 |
conf_sig_all = (conf_all-1) / conf_all
|
| 194 |
combined_mask = (conf_sig_all > conf_thresh) & (mask_all > 0.5)
|
| 195 |
|
| 196 |
+
# Create coarse result
|
| 197 |
+
coarse_scene = create_scene(pts_all, images_all, combined_mask, as_pointcloud)
|
| 198 |
+
coarse_output_path = save_scene(coarse_scene, as_pointcloud)
|
| 199 |
+
|
| 200 |
+
yield coarse_output_path, None, f"Reconstruction completed. FPS: {fps:.2f}"
|
| 201 |
+
|
| 202 |
+
# Create point clouds for multiway registration
|
| 203 |
+
pcds = []
|
| 204 |
+
for j in range(len(pts_all)):
|
| 205 |
+
pcd = o3d.geometry.PointCloud()
|
| 206 |
+
mask = combined_mask[j]
|
| 207 |
+
pcd.points = o3d.utility.Vector3dVector(pts_all[j][mask])
|
| 208 |
+
pcd.colors = o3d.utility.Vector3dVector(images_all[j][mask])
|
| 209 |
+
pcds.append(pcd)
|
| 210 |
+
|
| 211 |
+
# Perform global optimization
|
| 212 |
+
print("Performing global registration...")
|
| 213 |
+
transformed_pcds, pose_graph = improved_multiway_registration(pcds, voxel_size=0.01)
|
| 214 |
+
|
| 215 |
+
# Apply transformations from pose_graph to original pts_all
|
| 216 |
+
transformed_pts_all = np.zeros_like(pts_all)
|
| 217 |
+
for j in range(len(pts_all)):
|
| 218 |
+
# Get the transformation matrix from the pose graph
|
| 219 |
+
transformation = pose_graph.nodes[j].pose
|
| 220 |
+
|
| 221 |
+
# Reshape pts_all[j] to (H*W, 3)
|
| 222 |
+
H, W, _ = pts_all[j].shape
|
| 223 |
+
pts_reshaped = pts_all[j].reshape(-1, 3)
|
| 224 |
+
|
| 225 |
+
# Apply transformation to all points
|
| 226 |
+
homogeneous_pts = np.hstack((pts_reshaped, np.ones((pts_reshaped.shape[0], 1))))
|
| 227 |
+
transformed_pts = (transformation @ homogeneous_pts.T).T[:, :3]
|
| 228 |
+
|
| 229 |
+
# Reshape back to (H, W, 3) and store
|
| 230 |
+
transformed_pts_all[j] = transformed_pts.reshape(H, W, 3)
|
| 231 |
+
|
| 232 |
+
print(f"Original shape: {pts_all.shape}, Transformed shape: {transformed_pts_all.shape}")
|
| 233 |
+
|
| 234 |
+
# Create refined result
|
| 235 |
+
refined_scene = create_scene(transformed_pts_all, images_all, combined_mask, as_pointcloud)
|
| 236 |
+
refined_output_path = save_scene(refined_scene, as_pointcloud)
|
| 237 |
+
|
| 238 |
+
# Clean up temporary directory
|
| 239 |
+
os.system(f"rm -rf {demo_path}")
|
| 240 |
+
|
| 241 |
+
yield coarse_output_path, refined_output_path, f"Refinement completed. FPS: {fps:.2f}"
|
| 242 |
+
|
| 243 |
+
def create_scene(pts_all, images_all, combined_mask, as_pointcloud):
|
| 244 |
scene = trimesh.Scene()
|
| 245 |
|
| 246 |
if as_pointcloud:
|
|
|
|
| 255 |
meshes.append(pts3d_to_trimesh(images_all[i], pts_all[i], combined_mask[i]))
|
| 256 |
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
| 257 |
scene.add_geometry(mesh)
|
| 258 |
+
|
| 259 |
rot = np.eye(4)
|
| 260 |
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
| 261 |
scene.apply_transform(np.linalg.inv(OPENGL @ rot))
|
| 262 |
+
return scene
|
| 263 |
+
def save_scene(scene, as_pointcloud):
|
| 264 |
if as_pointcloud:
|
| 265 |
+
output_path = tempfile.mktemp(suffix='.ply')
|
| 266 |
else:
|
| 267 |
output_path = tempfile.mktemp(suffix='.obj')
|
| 268 |
scene.export(output_path)
|
| 269 |
+
return output_path
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
+
# Update the Gradio interface
|
| 272 |
iface = gr.Interface(
|
| 273 |
fn=reconstruct,
|
| 274 |
inputs=[
|
| 275 |
gr.Video(label="Input Video"),
|
| 276 |
+
gr.Slider(0, 1, value=1e-6, label="Confidence Threshold"),
|
| 277 |
gr.Slider(1, 30, step=1, value=5, label="Keyframe Interval"),
|
| 278 |
gr.Checkbox(label="As Pointcloud", value=False),
|
| 279 |
gr.Checkbox(label="Remove Background", value=False)
|
| 280 |
],
|
| 281 |
outputs=[
|
| 282 |
+
gr.Model3D(label="Coarse 3D Model", display_mode="solid"),
|
| 283 |
+
gr.Model3D(label="Refined 3D Model", display_mode="solid"),
|
| 284 |
gr.Textbox(label="Status")
|
| 285 |
],
|
| 286 |
+
title="3D Reconstruction with Spatial Memory, Background Removal, and Global Optimization",
|
| 287 |
)
|
| 288 |
|
| 289 |
if __name__ == "__main__":
|
backend_utils.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import open3d as o3d
|
| 3 |
+
|
| 4 |
+
def improved_multiway_registration(pcds, voxel_size=0.05, max_correspondence_distance_coarse=None, max_correspondence_distance_fine=None, overlap=3, quadratic_overlap=True, use_colored_icp=True):
|
| 5 |
+
if max_correspondence_distance_coarse is None:
|
| 6 |
+
max_correspondence_distance_coarse = voxel_size * 15
|
| 7 |
+
if max_correspondence_distance_fine is None:
|
| 8 |
+
max_correspondence_distance_fine = voxel_size * 1.5
|
| 9 |
+
|
| 10 |
+
def preprocess_point_cloud(pcd, voxel_size):
|
| 11 |
+
pcd_down = pcd.voxel_down_sample(voxel_size)
|
| 12 |
+
pcd_down.estimate_normals(
|
| 13 |
+
o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size * 2, max_nn=30))
|
| 14 |
+
# Apply statistical outlier removal
|
| 15 |
+
cl, ind = pcd_down.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
|
| 16 |
+
pcd_down = pcd_down.select_by_index(ind)
|
| 17 |
+
return pcd_down
|
| 18 |
+
|
| 19 |
+
def pairwise_registration(source, target, use_colored_icp, voxel_size, max_correspondence_distance_coarse, max_correspondence_distance_fine):
|
| 20 |
+
current_transformation = np.identity(4) # Start with identity matrix
|
| 21 |
+
|
| 22 |
+
if use_colored_icp:
|
| 23 |
+
print("Apply colored point cloud registration")
|
| 24 |
+
voxel_radius = [5*voxel_size, 3*voxel_size, voxel_size]
|
| 25 |
+
max_iter = [60, 35, 20]
|
| 26 |
+
|
| 27 |
+
for scale in range(3):
|
| 28 |
+
iter = max_iter[scale]
|
| 29 |
+
radius = voxel_radius[scale]
|
| 30 |
+
|
| 31 |
+
source_down = source.voxel_down_sample(radius)
|
| 32 |
+
target_down = target.voxel_down_sample(radius)
|
| 33 |
+
|
| 34 |
+
source_down.estimate_normals(
|
| 35 |
+
o3d.geometry.KDTreeSearchParamHybrid(radius=radius * 2, max_nn=30))
|
| 36 |
+
target_down.estimate_normals(
|
| 37 |
+
o3d.geometry.KDTreeSearchParamHybrid(radius=radius * 2, max_nn=30))
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
result_icp = o3d.pipelines.registration.registration_colored_icp(
|
| 41 |
+
source_down, target_down, radius, current_transformation,
|
| 42 |
+
o3d.pipelines.registration.TransformationEstimationForColoredICP(),
|
| 43 |
+
o3d.pipelines.registration.ICPConvergenceCriteria(relative_fitness=1e-6,
|
| 44 |
+
relative_rmse=1e-6,
|
| 45 |
+
max_iteration=iter))
|
| 46 |
+
current_transformation = result_icp.transformation
|
| 47 |
+
except RuntimeError as e:
|
| 48 |
+
print(f"Colored ICP failed at scale {scale}: {str(e)}")
|
| 49 |
+
print("Keeping the previous transformation")
|
| 50 |
+
# We keep the previous transformation, no need to reassign
|
| 51 |
+
|
| 52 |
+
transformation_icp = current_transformation
|
| 53 |
+
else:
|
| 54 |
+
print("Apply point-to-plane ICP")
|
| 55 |
+
try:
|
| 56 |
+
icp_coarse = o3d.pipelines.registration.registration_icp(
|
| 57 |
+
source, target, max_correspondence_distance_coarse, current_transformation,
|
| 58 |
+
o3d.pipelines.registration.TransformationEstimationPointToPlane())
|
| 59 |
+
current_transformation = icp_coarse.transformation
|
| 60 |
+
|
| 61 |
+
icp_fine = o3d.pipelines.registration.registration_icp(
|
| 62 |
+
source, target, max_correspondence_distance_fine,
|
| 63 |
+
current_transformation,
|
| 64 |
+
o3d.pipelines.registration.TransformationEstimationPointToPlane())
|
| 65 |
+
transformation_icp = icp_fine.transformation
|
| 66 |
+
except RuntimeError as e:
|
| 67 |
+
print(f"Point-to-plane ICP failed: {str(e)}")
|
| 68 |
+
print("Keeping the best available transformation")
|
| 69 |
+
transformation_icp = current_transformation
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
information_icp = o3d.pipelines.registration.get_information_matrix_from_point_clouds(
|
| 73 |
+
source, target, max_correspondence_distance_fine,
|
| 74 |
+
transformation_icp)
|
| 75 |
+
except RuntimeError as e:
|
| 76 |
+
print(f"Failed to compute information matrix: {str(e)}")
|
| 77 |
+
print("Using identity information matrix")
|
| 78 |
+
information_icp = np.identity(6)
|
| 79 |
+
|
| 80 |
+
return transformation_icp, information_icp
|
| 81 |
+
|
| 82 |
+
def full_registration(pcds_down):
|
| 83 |
+
pose_graph = o3d.pipelines.registration.PoseGraph()
|
| 84 |
+
odometry = np.identity(4)
|
| 85 |
+
pose_graph.nodes.append(o3d.pipelines.registration.PoseGraphNode(odometry))
|
| 86 |
+
n_pcds = len(pcds_down)
|
| 87 |
+
|
| 88 |
+
pairs = []
|
| 89 |
+
for i in range(n_pcds - 1):
|
| 90 |
+
for j in range(i + 1, min(i + overlap + 1, n_pcds)):
|
| 91 |
+
pairs.append((i, j))
|
| 92 |
+
if quadratic_overlap:
|
| 93 |
+
q = 2**(j-i)
|
| 94 |
+
if q > overlap and i + q < n_pcds:
|
| 95 |
+
pairs.append((i, i + q))
|
| 96 |
+
|
| 97 |
+
for source_id, target_id in pairs:
|
| 98 |
+
transformation_icp, information_icp = pairwise_registration(
|
| 99 |
+
pcds_down[source_id], pcds_down[target_id], use_colored_icp,
|
| 100 |
+
voxel_size, max_correspondence_distance_coarse, max_correspondence_distance_fine)
|
| 101 |
+
print(f"Build PoseGraph: {source_id} -> {target_id}")
|
| 102 |
+
|
| 103 |
+
if target_id == source_id + 1:
|
| 104 |
+
odometry = np.dot(transformation_icp, odometry)
|
| 105 |
+
pose_graph.nodes.append(
|
| 106 |
+
o3d.pipelines.registration.PoseGraphNode(
|
| 107 |
+
np.linalg.inv(odometry)))
|
| 108 |
+
|
| 109 |
+
pose_graph.edges.append(
|
| 110 |
+
o3d.pipelines.registration.PoseGraphEdge(source_id,
|
| 111 |
+
target_id,
|
| 112 |
+
transformation_icp,
|
| 113 |
+
information_icp,
|
| 114 |
+
uncertain=False))
|
| 115 |
+
return pose_graph
|
| 116 |
+
|
| 117 |
+
# Preprocess point clouds
|
| 118 |
+
print("Preprocessing point clouds...")
|
| 119 |
+
pcds_down = [preprocess_point_cloud(pcd, voxel_size) for pcd in pcds]
|
| 120 |
+
|
| 121 |
+
print("Full registration ...")
|
| 122 |
+
pose_graph = full_registration(pcds_down)
|
| 123 |
+
|
| 124 |
+
print("Optimizing PoseGraph ...")
|
| 125 |
+
option = o3d.pipelines.registration.GlobalOptimizationOption(
|
| 126 |
+
max_correspondence_distance=max_correspondence_distance_fine,
|
| 127 |
+
edge_prune_threshold=0.25,
|
| 128 |
+
reference_node=0)
|
| 129 |
+
|
| 130 |
+
with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
|
| 131 |
+
o3d.pipelines.registration.global_optimization(
|
| 132 |
+
pose_graph,
|
| 133 |
+
o3d.pipelines.registration.GlobalOptimizationLevenbergMarquardt(),
|
| 134 |
+
o3d.pipelines.registration.GlobalOptimizationConvergenceCriteria(),
|
| 135 |
+
option)
|
| 136 |
+
|
| 137 |
+
print("Transform points and combine")
|
| 138 |
+
pcd_combined = o3d.geometry.PointCloud()
|
| 139 |
+
for point_id in range(len(pcds)):
|
| 140 |
+
print(pose_graph.nodes[point_id].pose)
|
| 141 |
+
pcds[point_id].transform(pose_graph.nodes[point_id].pose)
|
| 142 |
+
pcd_combined += pcds[point_id]
|
| 143 |
+
|
| 144 |
+
return pcd_combined, pose_graph
|
requirements.txt
CHANGED
|
@@ -16,4 +16,5 @@ gdown
|
|
| 16 |
imageio[ffmpeg]
|
| 17 |
transformers
|
| 18 |
kornia
|
| 19 |
-
timm
|
|
|
|
|
|
| 16 |
imageio[ffmpeg]
|
| 17 |
transformers
|
| 18 |
kornia
|
| 19 |
+
timm
|
| 20 |
+
open3d
|