github-actions[bot] commited on
Commit
cdcf094
·
1 Parent(s): 6230849

Sync from GitHub: e4ae36a6b560759a3e49020c0fe8fc46cedcea2d

Browse files
README.md CHANGED
@@ -1,14 +1,315 @@
1
- ---
2
- title: Rgbd Depth
3
- emoji: 🔥
4
- colorFrom: blue
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 6.0.1
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: Camera Depth Models for accurate metric depth estimation fro
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Camera Depth Models (CDM)
2
+
3
+ Optimized Python package for RGB-D depth refinement using Vision Transformer encoders. This implementation is aligned with the [ByteDance CDM reference implementation](https://github.com/bytedance/camera-depth-models) with additional performance optimizations for CUDA, MPS (Apple Silicon), and CPU.
4
+
5
+ [![Tests](https://github.com/Aedelon/camera-depth-models/actions/workflows/test.yml/badge.svg)](https://github.com/Aedelon/camera-depth-models/actions/workflows/test.yml)
6
+ [![PyPI version](https://img.shields.io/pypi/v/rgbd-depth.svg)](https://pypi.org/project/rgbd-depth/)
7
+ [![PyPI downloads](https://img.shields.io/pypi/dm/rgbd-depth.svg)](https://pypi.org/project/rgbd-depth/)
8
+ [![Hugging Face Spaces](https://img.shields.io/badge/🤗%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/Aedelon/rgbd-depth)
9
+ [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE)
10
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
11
+ [![PyTorch 2.0+](https://img.shields.io/badge/PyTorch-2.0+-red.svg)](https://pytorch.org/)
12
+
13
+ ## 🎮 Try it Online
14
+
15
+ [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/Aedelon/rgbd-depth)
16
+
17
+ Try rgbd-depth directly in your browser with our interactive Gradio demo! No installation required.
18
+
19
+ ## Overview
20
+
21
+ Camera Depth Models (CDMs) are sensor-specific depth models trained to produce clean, simulation-like depth maps from noisy real-world depth camera data. By bridging the visual gap between simulation and reality through depth perception, CDMs enable robotic policies trained purely in simulation to transfer directly to real robots.
22
+
23
+ **Original work by ByteDance Research.** This package provides an optimized implementation with:
24
+ - ✅ **Pixel-perfect alignment** with reference implementation (verified: 0 pixel difference)
25
+ - ⚡ **Device-specific optimizations**: xFormers (CUDA), SDPA fallback, torch.compile
26
+ - 🎯 **Mixed precision support**: FP16 (CUDA/MPS), BF16 (CUDA)
27
+ - 🔧 **Better CLI**: Device selection, optimization control, precision modes
28
+ - 📦 **Easy installation**: Single `pip install` command
29
+
30
+ ## Why This Package?
31
+
32
+ This is an **optimized, production-ready** version of ByteDance's Camera Depth Models with several improvements:
33
+
34
+ | Feature | ByteDance Original | This Package |
35
+ |---------|-------------------|--------------|
36
+ | **Installation** | Manual setup | `pip install rgbd-depth` |
37
+ | **CUDA Optimization** | Basic | xFormers (~8% faster) + torch.compile |
38
+ | **Apple Silicon (MPS)** | Not optimized | Native support with fallbacks |
39
+ | **Mixed Precision** | Manual | Automatic FP16/BF16 with `--precision` flag |
40
+ | **CLI** | Basic | Enhanced with device selection, optimization control |
41
+ | **Documentation** | Minimal | Comprehensive guides (README + OPTIMIZATION.md) |
42
+ | **Testing** | None | CI/CD with automated tests |
43
+ | **PyPI Package** | No | ✅ Yes (`rgbd-depth`) |
44
+
45
+ **Choose this package if you want:**
46
+ - 🚀 Faster inference on CUDA (xFormers) or Apple Silicon (MPS)
47
+ - 🎯 Easy mixed precision (FP16/BF16) without code changes
48
+ - 📦 Simple installation via PyPI
49
+ - 🔧 Production-ready CLI with device/precision control
50
+ - ✅ Maintained with CI/CD and tests
51
+
52
+ ### Key Features
53
+
54
+ - **Metric Depth Estimation**: Produces accurate absolute depth measurements in meters
55
+ - **Multi-Camera Support**: Optimized models for various depth sensors (RealSense D405/D435/L515, ZED 2i, Azure Kinect)
56
+ - **Performance Optimizations**: ~8% faster on CUDA with xFormers, automatic backend selection
57
+ - **Mixed Precision**: FP16/BF16 support for faster inference on compatible hardware
58
+ - **Sim-to-Real Ready**: Generates simulation-quality depth from real camera data
59
+
60
+ ## Architecture
61
+
62
+ CDM uses a dual-branch Vision Transformer architecture:
63
+ - **RGB Branch**: Extracts semantic information from RGB images
64
+ - **Depth Branch**: Processes noisy depth sensor data
65
+ - **Cross-Attention Fusion**: Combines RGB semantics with depth scale information
66
+ - **DPT Decoder**: Produces final metric depth estimation
67
+
68
+ Supported ViT encoder sizes:
69
+ - `vits`: Small (64 features, 384 output channels)
70
+ - `vitb`: Base (128 features, 768 output channels)
71
+ - `vitl`: Large (256 features, 1024 output channels)
72
+ - `vitg`: Giant (384 features, 1536 output channels)
73
+
74
+ All pretrained models we provide are based on `vitl`.
75
+
76
+ ## Installation
77
+
78
+ ### From PyPI (recommended)
79
+
80
+ ```bash
81
+ # Basic installation
82
+ pip install rgbd-depth
83
+
84
+ # With CUDA optimizations (xFormers)
85
+ pip install rgbd-depth[xformers]
86
+
87
+ # Development installation
88
+ git clone https://github.com/Aedelon/camera-depth-models.git
89
+ cd camera-depth-models
90
+ pip install -e .
91
+ ```
92
+
93
+ **Requirements:**
94
+ - Python 3.8+
95
+ - PyTorch 2.0+ with appropriate CUDA/MPS support
96
+ - OpenCV, NumPy, Pillow
97
+
98
+ ## Quick Start
99
+
100
+ ```bash
101
+ # CUDA (optimizations auto-enabled, FP16 for best speed)
102
+ python infer.py --input rgb.png --depth depth.png --precision fp16
103
+
104
+ # Apple Silicon (MPS)
105
+ python infer.py --input rgb.png --depth depth.png --device mps
106
+
107
+ # CPU (FP32 only)
108
+ python infer.py --input rgb.png --depth depth.png --device cpu
109
+ ```
110
+
111
+ > Example images are provided in `input_data/`. Pre-trained models can be downloaded from [Hugging Face](https://huggingface.co/collections/depth-anything/camera-depth-models-68b521181dedd223f4b020db).
112
+
113
+ ## Usage
114
+
115
+ ### Command Line Interface
116
+
117
+ **Basic inference:**
118
+ ```bash
119
+ python infer.py \
120
+ --input /path/to/rgb.png \
121
+ --depth /path/to/depth.png \
122
+ --output refined_depth.png
123
+ ```
124
+
125
+ **CUDA with optimizations (default):**
126
+ ```bash
127
+ # FP32 (best accuracy)
128
+ python infer.py --input rgb.png --depth depth.png
129
+
130
+ # FP16 (best speed, ~2× faster)
131
+ python infer.py --input rgb.png --depth depth.png --precision fp16
132
+
133
+ # BF16 (best stability)
134
+ python infer.py --input rgb.png --depth depth.png --precision bf16
135
+
136
+ # Disable optimizations (debugging)
137
+ python infer.py --input rgb.png --depth depth.png --no-optimize
138
+ ```
139
+
140
+ **Apple Silicon (MPS):**
141
+ ```bash
142
+ # FP32 (default)
143
+ python infer.py --input rgb.png --depth depth.png --device mps
144
+
145
+ # FP16 (faster)
146
+ python infer.py --input rgb.png --depth depth.png --device mps --precision fp16
147
+ ```
148
+
149
+ **CPU:**
150
+ ```bash
151
+ # FP32 only (FP16 not recommended on CPU)
152
+ python infer.py --input rgb.png --depth depth.png --device cpu
153
+ ```
154
+
155
+ ### Command Line Arguments
156
+
157
+ **Required:**
158
+ - `--input`: Path to RGB input image (JPG/PNG)
159
+ - `--depth`: Path to depth input image (PNG, 16-bit or 32-bit)
160
+
161
+ **Optional:**
162
+ - `--output`: Output visualization path (default: `output.png`)
163
+ - `--device`: Device to use: `auto`, `cuda`, `mps`, `cpu` (default: `auto`)
164
+ - `--precision`: Precision mode: `fp32`, `fp16`, `bf16` (default: `fp32`)
165
+ - `--no-optimize`: Disable optimizations on CUDA (for debugging)
166
+ - `--encoder`: Model size: `vits`, `vitb`, `vitl`, `vitg` (default: `vitl`)
167
+ - `--input-size`: Input resolution for inference (default: 518)
168
+ - `--depth-scale`: Scale factor for depth values (default: 1000.0)
169
+ - `--max-depth`: Maximum valid depth in meters (default: 6.0)
170
+
171
+ ### Python API
172
+
173
+ ```python
174
+ import torch
175
+ from rgbddepth.dpt import RGBDDepth
176
+ import cv2
177
+ import numpy as np
178
+
179
+ # Load model with optimizations
180
+ model = RGBDDepth(encoder='vitl', features=256, use_xformers=True)
181
+ model.load_state_dict(torch.load('model.pth'))
182
+ model.eval()
183
+ model = model.to('cuda') # or 'mps', 'cpu'
184
+
185
+ # Optional: compile for extra speed on CUDA
186
+ model = torch.compile(model)
187
+
188
+ # Load images
189
+ rgb = cv2.imread('rgb.jpg')[:, :, ::-1] # BGR to RGB
190
+ depth = cv2.imread('depth.png', cv2.IMREAD_UNCHANGED) / 1000.0 # Convert to meters
191
+
192
+ # Create similarity depth (inverse depth)
193
+ simi_depth = np.zeros_like(depth)
194
+ simi_depth[depth > 0] = 1 / depth[depth > 0]
195
+
196
+ # Run inference with mixed precision
197
+ with torch.amp.autocast('cuda', dtype=torch.float16):
198
+ pred_depth = model.infer_image(rgb, simi_depth, input_size=518)
199
+ ```
200
+
201
+ ## Model Training
202
+
203
+ CDMs are trained on synthetic datasets generated using camera-specific noise models:
204
+
205
+ 1. **Noise Model Training**: Learn hole and value noise patterns from real camera data
206
+ 2. **Synthetic Data Generation**: Apply learned noise to clean simulation depth
207
+ 3. **CDM Training**: Train depth estimation model on synthetic noisy data
208
+
209
+ Training datasets: HyperSim, DREDS, HISS, IRS (280,000+ images total)
210
+
211
+ ## Supported Cameras
212
+
213
+ We currently provide pre-trained models available for:
214
+ - Intel RealSense D405/D435/L515
215
+ - Stereolabs ZED 2i (2 modes: Quality, Neural)
216
+ - Microsoft Azure Kinect
217
+
218
+ ## File Structure
219
+
220
+ ```
221
+ cdm/
222
+ ├── infer.py # Main inference script
223
+ ├── setup.py # Package installation
224
+ ├── rgbddepth/ # Core package
225
+ │ ├── __init__.py
226
+ │ ├── dpt.py # Main RGBDDepth model
227
+ │ ├── dinov2.py # DINOv2 encoder
228
+ │ ├── dinov2_layers/ # ViT transformer layers
229
+ │ └── util/ # Utility functions
230
+ │ ├── blocks.py # Neural network blocks
231
+ │ └── transform.py # Image preprocessing
232
+ └── README.md
233
+ ```
234
+
235
+ ## Performance
236
+
237
+ ### Accuracy
238
+
239
+ This implementation achieves **pixel-perfect alignment** with the ByteDance reference:
240
+ - ✅ **0 pixel difference** between vanilla and optimized inference (verified on test images)
241
+ - ✅ **Identical checkpoint loading** (weights are fully compatible)
242
+ - ✅ **Numerical precision preserved** (min=0.2036, max=1.1217, exact match)
243
+
244
+ CDMs achieve state-of-the-art performance on metric depth estimation:
245
+ - Superior accuracy compared to existing prompt-based depth models
246
+ - Zero-shot generalization across different camera types
247
+ - Real-time inference suitable for robot control (lightweight ViT variants)
248
+
249
+ ### Speed Benchmarks
250
+
251
+ | Device | Mode | Precision | Time | vs Baseline | Notes |
252
+ |--------|------|-----------|------|-------------|-------|
253
+ | **CUDA** | Vanilla | FP32 | TBD | - | Reference |
254
+ | **CUDA** | Optimized (xFormers) | FP32 | TBD | ~8% faster | Recommended |
255
+ | **CUDA** | Optimized | FP16 | TBD | ~2× faster | Best speed |
256
+ | **CUDA** | Optimized | BF16 | TBD | ~2× faster | Best stability |
257
+ | **MPS** | Vanilla | FP32 | 1.34s | - | torch.compile: no gain |
258
+ | **MPS** | Vanilla | FP16 | TBD | TBD | To be benchmarked |
259
+ | **CPU** | Vanilla | FP32 | 13.37s | - | Optimizations: -11% slower |
260
+
261
+ **Notes:**
262
+ - **CUDA**: Optimizations auto-enabled by default (use `--no-optimize` to disable)
263
+ - **MPS**: torch.compile provides no gain for Vision Transformers (~0% improvement)
264
+ - **CPU**: torch.compile is counterproductive (compilation overhead > gains)
265
+ - xFormers is CUDA-only (~8% faster than native SDPA)
266
+
267
+ For detailed optimization strategies, see [OPTIMIZATION.md](OPTIMIZATION.md).
268
+
269
+ ## What's Different from Reference?
270
+
271
+ This implementation maintains **100% compatibility** with ByteDance CDM while adding:
272
+
273
+ ### 1. Performance Optimizations
274
+ - **xFormers support**: ~8% faster attention on CUDA (automatic fallback to SDPA)
275
+ - **torch.compile**: JIT compilation (CUDA only, auto-enabled)
276
+ - **Mixed precision**: FP16/BF16 support via `torch.amp.autocast`
277
+ - **Device-specific strategies**: Optimizations only where beneficial
278
+
279
+ ### 2. Better CLI/API
280
+ - `--device` flag: Force specific device (auto/cuda/mps/cpu)
281
+ - `--precision` flag: Choose FP32/FP16/BF16
282
+ - `--no-optimize` flag: Disable optimizations for debugging
283
+ - Automatic device detection and optimization selection
284
+
285
+ ### 3. Improved Architecture
286
+ - `FlexibleCrossAttention`: Inherits from `nn.MultiheadAttention` for checkpoint compatibility
287
+ - Automatic backend selection: xFormers (CUDA) → SDPA (fallback)
288
+ - Device-aware preprocessing: Uses model's device instead of auto-detection
289
+
290
+ ### 4. Code Quality
291
+ - Type hints and better documentation
292
+ - Cleaner argument parsing
293
+ - Validation for precision/device combinations
294
+ - Helpful warnings for incompatible configurations
295
+
296
+ All changes are **backwards compatible** with original checkpoints and produce **identical numerical results**.
297
+
298
+ ## Citation
299
+
300
+ If you use CDM in your research, please cite:
301
+
302
+ ```bibtex
303
+ @article{liu2025manipulation,
304
+ title={Manipulation as in Simulation: Enabling Accurate Geometry Perception in Robots},
305
+ author={Liu, Minghuan and Zhu, Zhengbang and Han, Xiaoshen and Hu, Peng and Lin, Haotong and
306
+ Li, Xinyao and Chen, Jingxiao and Xu, Jiafeng and Yang, Yichu and Lin, Yunfeng and
307
+ Li, Xinghang and Yu, Yong and Zhang, Weinan and Kong, Tao and Kang, Bingyi},
308
+ journal={arXiv preprint},
309
+ year={2025}
310
+ }
311
+ ```
312
+
313
+ ## License
314
+
315
+ This project is licensed under the Apache 2.0 License. See [LICENSE](../LICENSE) for details.
app.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """Gradio demo for rgbd-depth on Hugging Face Spaces."""
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+
12
+ from rgbddepth import RGBDDepth
13
+
14
+ # Global model cache
15
+ MODELS = {}
16
+
17
+
18
+ def load_model(encoder: str, use_xformers: bool = False):
19
+ """Load model with caching."""
20
+ cache_key = f"{encoder}_{use_xformers}"
21
+
22
+ if cache_key not in MODELS:
23
+ # Model configs
24
+ configs = {
25
+ "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
26
+ "vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
27
+ "vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]},
28
+ "vitg": {"encoder": "vitg", "features": 384, "out_channels": [1536, 1536, 1536, 1536]},
29
+ }
30
+
31
+ config = configs[encoder].copy()
32
+ config["use_xformers"] = use_xformers
33
+
34
+ model = RGBDDepth(**config)
35
+
36
+ # Try to load weights if checkpoint exists
37
+ try:
38
+ checkpoint = torch.load(f"checkpoints/{encoder}.pt", map_location="cpu")
39
+ if "model" in checkpoint:
40
+ states = {k[7:]: v for k, v in checkpoint["model"].items()}
41
+ elif "state_dict" in checkpoint:
42
+ states = {k[9:]: v for k, v in checkpoint["state_dict"].items()}
43
+ else:
44
+ states = checkpoint
45
+
46
+ model.load_state_dict(states, strict=False)
47
+ print(f"✓ Loaded checkpoint for {encoder}")
48
+ except FileNotFoundError:
49
+ print(f"⚠ No checkpoint found for {encoder}, using random weights (demo only)")
50
+
51
+ # Move to GPU if available
52
+ device = "cuda" if torch.cuda.is_available() else "cpu"
53
+ model = model.to(device).eval()
54
+
55
+ MODELS[cache_key] = model
56
+
57
+ return MODELS[cache_key]
58
+
59
+
60
+ def process_depth(
61
+ rgb_image: np.ndarray,
62
+ depth_image: np.ndarray,
63
+ encoder: str = "vitl",
64
+ input_size: int = 518,
65
+ depth_scale: float = 1000.0,
66
+ max_depth: float = 25.0,
67
+ use_xformers: bool = False,
68
+ precision: str = "fp32",
69
+ colormap: str = "Spectral",
70
+ ) -> tuple[Image.Image, str]:
71
+ """Process RGB-D depth refinement.
72
+
73
+ Args:
74
+ rgb_image: RGB image as numpy array [H, W, 3]
75
+ depth_image: Depth image as numpy array [H, W] or [H, W, 3]
76
+ encoder: Model encoder type
77
+ input_size: Input size for inference
78
+ depth_scale: Scale factor for depth values
79
+ max_depth: Maximum valid depth value
80
+ use_xformers: Whether to use xFormers (CUDA only)
81
+ precision: Precision mode (fp32/fp16/bf16)
82
+ colormap: Matplotlib colormap for visualization
83
+
84
+ Returns:
85
+ Tuple of (refined depth image, info message)
86
+ """
87
+ try:
88
+ # Validate inputs
89
+ if rgb_image is None:
90
+ return None, "❌ Please upload an RGB image"
91
+ if depth_image is None:
92
+ return None, "❌ Please upload a depth image"
93
+
94
+ # Convert depth to single channel if needed
95
+ if depth_image.ndim == 3:
96
+ depth_image = depth_image[:, :, 0]
97
+
98
+ # Normalize depth
99
+ depth_normalized = depth_image.astype(np.float32) / depth_scale
100
+ depth_normalized[depth_normalized > max_depth] = 0.0
101
+
102
+ # Create inverse depth (similarity depth)
103
+ simi_depth = np.zeros_like(depth_normalized)
104
+ valid_mask = depth_normalized > 0
105
+ simi_depth[valid_mask] = 1.0 / depth_normalized[valid_mask]
106
+
107
+ # Load model
108
+ model = load_model(encoder, use_xformers and torch.cuda.is_available())
109
+ device = next(model.parameters()).device
110
+
111
+ # Determine precision
112
+ if precision == "fp16" and device.type in ["cuda", "mps"]:
113
+ dtype = torch.float16
114
+ elif precision == "bf16" and device.type == "cuda":
115
+ dtype = torch.bfloat16
116
+ else:
117
+ dtype = None # FP32
118
+
119
+ # Run inference
120
+ if dtype is not None:
121
+ device_type = "cuda" if device.type == "cuda" else "cpu"
122
+ with torch.amp.autocast(device_type=device_type, dtype=dtype):
123
+ pred = model.infer_image(rgb_image, simi_depth, input_size=input_size)
124
+ else:
125
+ pred = model.infer_image(rgb_image, simi_depth, input_size=input_size)
126
+
127
+ # Convert from inverse depth to depth
128
+ pred = np.where(pred > 1e-8, 1.0 / pred, 0.0)
129
+
130
+ # Colorize for visualization
131
+ try:
132
+ import matplotlib
133
+ import matplotlib.pyplot as plt
134
+
135
+ # Normalize to [0, 1]
136
+ pred_min, pred_max = pred.min(), pred.max()
137
+ if pred_max - pred_min > 1e-8:
138
+ pred_norm = (pred - pred_min) / (pred_max - pred_min)
139
+ else:
140
+ pred_norm = np.zeros_like(pred)
141
+
142
+ # Apply colormap
143
+ cm_func = matplotlib.colormaps[colormap]
144
+ pred_colored = cm_func(pred_norm, bytes=True)[:, :, :3] # RGB only
145
+
146
+ # Create PIL Image
147
+ output_image = Image.fromarray(pred_colored)
148
+
149
+ except ImportError:
150
+ # Fallback to grayscale if matplotlib not available
151
+ pred_norm = ((pred - pred.min()) / (pred.max() - pred.min() + 1e-8) * 255).astype(np.uint8)
152
+ output_image = Image.fromarray(pred_norm, mode='L').convert('RGB')
153
+
154
+ # Create info message
155
+ info = f"""
156
+ ✅ **Refinement complete!**
157
+
158
+ **Model:** {encoder.upper()}
159
+ **Precision:** {precision.upper()}
160
+ **Device:** {device.type.upper()}
161
+ **Input size:** {input_size}px
162
+ **Depth range:** {pred_min:.3f}m - {pred_max:.3f}m
163
+ **xFormers:** {'✓ Enabled' if use_xformers and torch.cuda.is_available() else '✗ Disabled'}
164
+ """
165
+
166
+ return output_image, info.strip()
167
+
168
+ except Exception as e:
169
+ return None, f"❌ Error: {str(e)}"
170
+
171
+
172
+ # Create Gradio interface
173
+ with gr.Blocks(title="rgbd-depth Demo") as demo:
174
+ gr.Markdown("""
175
+ # 🎨 rgbd-depth: RGB-D Depth Refinement
176
+
177
+ High-quality depth map refinement using Vision Transformers. Based on [ByteDance's camera-depth-models](https://manipulation-as-in-simulation.github.io/).
178
+
179
+ ⚠️ **Note:** This demo uses random weights for demonstration. For real results:
180
+ 1. Download checkpoints from [Hugging Face](https://huggingface.co/collections/depth-anything/camera-depth-models-68b521181dedd223f4b020db)
181
+ 2. Place in `checkpoints/` directory
182
+ 3. Restart the app
183
+ """)
184
+
185
+ with gr.Row():
186
+ with gr.Column():
187
+ gr.Markdown("### 📥 Inputs")
188
+
189
+ rgb_input = gr.Image(
190
+ label="RGB Image",
191
+ type="numpy",
192
+ height=300,
193
+ )
194
+
195
+ depth_input = gr.Image(
196
+ label="Input Depth Map",
197
+ type="numpy",
198
+ height=300,
199
+ )
200
+
201
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
202
+ encoder_choice = gr.Radio(
203
+ choices=["vits", "vitb", "vitl", "vitg"],
204
+ value="vitl",
205
+ label="Encoder Model",
206
+ info="Larger = better quality but slower",
207
+ )
208
+
209
+ input_size = gr.Slider(
210
+ minimum=256,
211
+ maximum=1024,
212
+ value=518,
213
+ step=2,
214
+ label="Input Size",
215
+ info="Resolution for processing (higher = better but slower)",
216
+ )
217
+
218
+ depth_scale = gr.Number(
219
+ value=1000.0,
220
+ label="Depth Scale",
221
+ info="Scale factor to convert depth values to meters",
222
+ )
223
+
224
+ max_depth = gr.Number(
225
+ value=25.0,
226
+ label="Max Depth (m)",
227
+ info="Maximum valid depth value",
228
+ )
229
+
230
+ precision_choice = gr.Radio(
231
+ choices=["fp32", "fp16", "bf16"],
232
+ value="fp32",
233
+ label="Precision",
234
+ info="fp16/bf16 = faster but slightly less accurate (CUDA only)",
235
+ )
236
+
237
+ use_xformers = gr.Checkbox(
238
+ value=False,
239
+ label="Use xFormers (CUDA only)",
240
+ info="~8% faster on CUDA with xFormers installed",
241
+ )
242
+
243
+ colormap_choice = gr.Dropdown(
244
+ choices=["Spectral", "viridis", "plasma", "inferno", "magma", "turbo"],
245
+ value="Spectral",
246
+ label="Colormap",
247
+ info="Visualization colormap",
248
+ )
249
+
250
+ process_btn = gr.Button("🚀 Refine Depth", variant="primary", size="lg")
251
+
252
+ with gr.Column():
253
+ gr.Markdown("### 📤 Output")
254
+
255
+ output_image = gr.Image(
256
+ label="Refined Depth Map",
257
+ type="pil",
258
+ height=600,
259
+ )
260
+
261
+ output_info = gr.Markdown()
262
+
263
+ # Example inputs
264
+ gr.Markdown("### 📸 Examples")
265
+ gr.Examples(
266
+ examples=[
267
+ ["example_data/color_12.png", "example_data/depth_12.png"],
268
+ ],
269
+ inputs=[rgb_input, depth_input],
270
+ label="Try with example images",
271
+ )
272
+
273
+ # Process button click
274
+ process_btn.click(
275
+ fn=process_depth,
276
+ inputs=[
277
+ rgb_input,
278
+ depth_input,
279
+ encoder_choice,
280
+ input_size,
281
+ depth_scale,
282
+ max_depth,
283
+ use_xformers,
284
+ precision_choice,
285
+ colormap_choice,
286
+ ],
287
+ outputs=[output_image, output_info],
288
+ )
289
+
290
+ # Footer
291
+ gr.Markdown("""
292
+ ---
293
+
294
+ ### 🔗 Links
295
+
296
+ - **GitHub:** [Aedelon/camera-depth-models](https://github.com/Aedelon/camera-depth-models)
297
+ - **PyPI:** [rgbd-depth](https://pypi.org/project/rgbd-depth/)
298
+ - **Paper:** [Manipulation-as-in-Simulation](https://manipulation-as-in-simulation.github.io/)
299
+
300
+ ### 📦 Install
301
+
302
+ ```bash
303
+ pip install rgbd-depth
304
+ ```
305
+
306
+ ### 💻 CLI Usage
307
+
308
+ ```bash
309
+ rgbd-depth \\
310
+ --model-path model.pt \\
311
+ --rgb-image input.jpg \\
312
+ --depth-image depth.png \\
313
+ --output refined.png
314
+ ```
315
+
316
+ ---
317
+
318
+ Built with ❤️ by [Aedelon](https://github.com/Aedelon) | Powered by [Gradio](https://gradio.app)
319
+ """)
320
+
321
+ if __name__ == "__main__":
322
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces requirements
2
+ # Generated from pyproject.toml - DO NOT install rgbd-depth itself (causes circular dependency)
3
+
4
+ # Core dependencies (from pyproject.toml)
5
+ torch>=2.0.0
6
+ torchvision>=0.15.0
7
+ opencv-python>=4.5.0
8
+ numpy>=1.20.0
9
+ Pillow>=9.0.0
10
+
11
+ # Gradio demo
12
+ gradio>=4.0.0
13
+ matplotlib>=3.5.0
14
+
15
+ # Model downloads from HuggingFace
16
+ huggingface-hub>=0.16.0
rgbddepth/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RGBD Depth - Optimized RGB-D depth refinement using Vision Transformers.
2
+
3
+ This package provides optimized depth refinement for RGB-D cameras with support
4
+ for CUDA (xFormers), MPS (Apple Silicon), and CPU devices.
5
+ """
6
+
7
+ __version__ = "1.0.2"
8
+
9
+ from .dinov2 import DinoVisionTransformer
10
+ from .dpt import RGBDDepth
11
+
12
+ __all__ = ["RGBDDepth", "DinoVisionTransformer", "__version__"]
rgbddepth/dinov2.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import math
12
+ from functools import partial
13
+ from typing import Callable, Sequence, Tuple, Union
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from .dinov2_layers import MemEffAttention, Mlp
21
+ from .dinov2_layers import NestedTensorBlock as Block
22
+ from .dinov2_layers import PatchEmbed, SwiGLUFFNFused
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ def named_apply(
28
+ fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
29
+ ) -> nn.Module:
30
+ if not depth_first and include_root:
31
+ fn(module=module, name=name)
32
+ for child_name, child_module in module.named_children():
33
+ child_name = ".".join((name, child_name)) if name else child_name
34
+ named_apply(
35
+ fn=fn,
36
+ module=child_module,
37
+ name=child_name,
38
+ depth_first=depth_first,
39
+ include_root=True,
40
+ )
41
+ if depth_first and include_root:
42
+ fn(module=module, name=name)
43
+ return module
44
+
45
+
46
+ class BlockChunk(nn.ModuleList):
47
+ def forward(self, x):
48
+ for b in self:
49
+ x = b(x)
50
+ return x
51
+
52
+
53
+ class DinoVisionTransformer(nn.Module):
54
+ def __init__(
55
+ self,
56
+ img_size=224,
57
+ patch_size=16,
58
+ in_chans=3,
59
+ embed_dim=768,
60
+ depth=12,
61
+ num_heads=12,
62
+ mlp_ratio=4.0,
63
+ qkv_bias=True,
64
+ ffn_bias=True,
65
+ proj_bias=True,
66
+ drop_path_rate=0.0,
67
+ drop_path_uniform=False,
68
+ init_values=None, # for layerscale: None or 0 => no layerscale
69
+ embed_layer=PatchEmbed,
70
+ act_layer=nn.GELU,
71
+ block_fn=Block,
72
+ ffn_layer="mlp",
73
+ block_chunks=1,
74
+ num_register_tokens=0,
75
+ interpolate_antialias=False,
76
+ interpolate_offset=0.1,
77
+ ):
78
+ """
79
+ Args:
80
+ img_size (int, tuple): input image size
81
+ patch_size (int, tuple): patch size
82
+ in_chans (int): number of input channels
83
+ embed_dim (int): embedding dimension
84
+ depth (int): depth of transformer
85
+ num_heads (int): number of attention heads
86
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
87
+ qkv_bias (bool): enable bias for qkv if True
88
+ proj_bias (bool): enable bias for proj in attn if True
89
+ ffn_bias (bool): enable bias for ffn if True
90
+ drop_path_rate (float): stochastic depth rate
91
+ drop_path_uniform (bool): apply uniform drop rate across blocks
92
+ weight_init (str): weight init scheme
93
+ init_values (float): layer-scale init values
94
+ embed_layer (nn.Module): patch embedding layer
95
+ act_layer (nn.Module): MLP activation layer
96
+ block_fn (nn.Module): transformer block class
97
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
98
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
99
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
100
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
101
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
102
+ """
103
+ super().__init__()
104
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
105
+
106
+ self.num_features = self.embed_dim = (
107
+ embed_dim # num_features for consistency with other models
108
+ )
109
+ self.num_tokens = 1
110
+ self.n_blocks = depth
111
+ self.num_heads = num_heads
112
+ self.patch_size = patch_size
113
+ self.num_register_tokens = num_register_tokens
114
+ self.interpolate_antialias = interpolate_antialias
115
+ self.interpolate_offset = interpolate_offset
116
+
117
+ self.patch_embed = embed_layer(
118
+ img_size=img_size,
119
+ patch_size=patch_size,
120
+ in_chans=in_chans,
121
+ embed_dim=embed_dim,
122
+ )
123
+ num_patches = self.patch_embed.num_patches
124
+
125
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
126
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
127
+ assert num_register_tokens >= 0
128
+ self.register_tokens = (
129
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
130
+ if num_register_tokens
131
+ else None
132
+ )
133
+
134
+ if drop_path_uniform is True:
135
+ dpr = [drop_path_rate] * depth
136
+ else:
137
+ dpr = [
138
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
139
+ ] # stochastic depth decay rule
140
+
141
+ if ffn_layer == "mlp":
142
+ logger.info("using MLP layer as FFN")
143
+ ffn_layer = Mlp
144
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
145
+ logger.info("using SwiGLU layer as FFN")
146
+ ffn_layer = SwiGLUFFNFused
147
+ elif ffn_layer == "identity":
148
+ logger.info("using Identity layer as FFN")
149
+
150
+ def f(*args, **kwargs):
151
+ return nn.Identity()
152
+
153
+ ffn_layer = f
154
+ else:
155
+ raise NotImplementedError
156
+
157
+ blocks_list = [
158
+ block_fn(
159
+ dim=embed_dim,
160
+ num_heads=num_heads,
161
+ mlp_ratio=mlp_ratio,
162
+ qkv_bias=qkv_bias,
163
+ proj_bias=proj_bias,
164
+ ffn_bias=ffn_bias,
165
+ drop_path=dpr[i],
166
+ norm_layer=norm_layer,
167
+ act_layer=act_layer,
168
+ ffn_layer=ffn_layer,
169
+ init_values=init_values,
170
+ )
171
+ for i in range(depth)
172
+ ]
173
+ if block_chunks > 0:
174
+ self.chunked_blocks = True
175
+ chunked_blocks = []
176
+ chunksize = depth // block_chunks
177
+ for i in range(0, depth, chunksize):
178
+ # this is to keep the block index consistent if we chunk the block list
179
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
180
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
181
+ else:
182
+ self.chunked_blocks = False
183
+ self.blocks = nn.ModuleList(blocks_list)
184
+
185
+ self.norm = norm_layer(embed_dim)
186
+ self.head = nn.Identity()
187
+
188
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
189
+
190
+ self.init_weights()
191
+
192
+ def init_weights(self):
193
+ trunc_normal_(self.pos_embed, std=0.02)
194
+ nn.init.normal_(self.cls_token, std=1e-6)
195
+ if self.register_tokens is not None:
196
+ nn.init.normal_(self.register_tokens, std=1e-6)
197
+ named_apply(init_weights_vit_timm, self)
198
+
199
+ def interpolate_pos_encoding(self, x, w, h):
200
+ previous_dtype = x.dtype
201
+ npatch = x.shape[1] - 1
202
+ N = self.pos_embed.shape[1] - 1
203
+ if npatch == N and w == h:
204
+ return self.pos_embed
205
+ pos_embed = self.pos_embed.float()
206
+ class_pos_embed = pos_embed[:, 0]
207
+ patch_pos_embed = pos_embed[:, 1:]
208
+ dim = x.shape[-1]
209
+ w0 = w // self.patch_size
210
+ h0 = h // self.patch_size
211
+ # we add a small number to avoid floating point error in the interpolation
212
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
213
+ # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
214
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
215
+ # w0, h0 = w0 + 0.1, h0 + 0.1
216
+
217
+ sqrt_N = math.sqrt(N)
218
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
219
+ patch_pos_embed = nn.functional.interpolate(
220
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
221
+ scale_factor=(sx, sy),
222
+ # (int(w0), int(h0)), # to solve the upsampling shape issue
223
+ mode="bicubic",
224
+ antialias=self.interpolate_antialias,
225
+ )
226
+
227
+ assert int(w0) == patch_pos_embed.shape[-2]
228
+ assert int(h0) == patch_pos_embed.shape[-1]
229
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
230
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
231
+
232
+ def prepare_tokens_with_masks(self, x, masks=None):
233
+ B, nc, w, h = x.shape
234
+ x = self.patch_embed(x)
235
+ if masks is not None:
236
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
237
+
238
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
239
+ x = x + self.interpolate_pos_encoding(x, w, h)
240
+
241
+ if self.register_tokens is not None:
242
+ x = torch.cat(
243
+ (
244
+ x[:, :1],
245
+ self.register_tokens.expand(x.shape[0], -1, -1),
246
+ x[:, 1:],
247
+ ),
248
+ dim=1,
249
+ )
250
+
251
+ return x
252
+
253
+ def forward_features_list(self, x_list, masks_list):
254
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
255
+ for blk in self.blocks:
256
+ x = blk(x)
257
+
258
+ all_x = x
259
+ output = []
260
+ for x, masks in zip(all_x, masks_list):
261
+ x_norm = self.norm(x)
262
+ output.append(
263
+ {
264
+ "x_norm_clstoken": x_norm[:, 0],
265
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
266
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
267
+ "x_prenorm": x,
268
+ "masks": masks,
269
+ }
270
+ )
271
+ return output
272
+
273
+ def forward_features(self, x, masks=None):
274
+ if isinstance(x, list):
275
+ return self.forward_features_list(x, masks)
276
+
277
+ x = self.prepare_tokens_with_masks(x, masks)
278
+
279
+ for blk in self.blocks:
280
+ x = blk(x)
281
+
282
+ x_norm = self.norm(x)
283
+ return {
284
+ "x_norm_clstoken": x_norm[:, 0],
285
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
286
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
287
+ "x_prenorm": x,
288
+ "masks": masks,
289
+ }
290
+
291
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
292
+ x = self.prepare_tokens_with_masks(x)
293
+ # If n is an int, take the n last blocks. If it's a list, take them
294
+ output, total_block_len = [], len(self.blocks)
295
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
296
+ for i, blk in enumerate(self.blocks):
297
+ x = blk(x)
298
+ if i in blocks_to_take:
299
+ output.append(x)
300
+ assert len(output) == len(
301
+ blocks_to_take
302
+ ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
303
+ return output
304
+
305
+ def _get_intermediate_layers_chunked(self, x, n=1):
306
+ x = self.prepare_tokens_with_masks(x)
307
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
308
+ # If n is an int, take the n last blocks. If it's a list, take them
309
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
310
+ for block_chunk in self.blocks:
311
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
312
+ x = blk(x)
313
+ if i in blocks_to_take:
314
+ output.append(x)
315
+ i += 1
316
+ assert len(output) == len(
317
+ blocks_to_take
318
+ ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
319
+ return output
320
+
321
+ def get_intermediate_layers(
322
+ self,
323
+ x: torch.Tensor,
324
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
325
+ reshape: bool = False,
326
+ return_class_token: bool = False,
327
+ norm=True,
328
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
329
+ if self.chunked_blocks:
330
+ outputs = self._get_intermediate_layers_chunked(x, n)
331
+ else:
332
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
333
+ if norm:
334
+ outputs = [self.norm(out) for out in outputs]
335
+ class_tokens = [out[:, 0] for out in outputs]
336
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
337
+ if reshape:
338
+ B, _, w, h = x.shape
339
+ outputs = [
340
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
341
+ .permute(0, 3, 1, 2)
342
+ .contiguous()
343
+ for out in outputs
344
+ ]
345
+ if return_class_token:
346
+ return tuple(zip(outputs, class_tokens))
347
+ return tuple(outputs)
348
+
349
+ def forward(self, *args, is_training=False, **kwargs):
350
+ ret = self.forward_features(*args, **kwargs)
351
+ if is_training:
352
+ return ret
353
+ else:
354
+ return self.head(ret["x_norm_clstoken"])
355
+
356
+
357
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
358
+ """ViT weight initialization, original timm impl (for reproducibility)"""
359
+ if isinstance(module, nn.Linear):
360
+ trunc_normal_(module.weight, std=0.02)
361
+ if module.bias is not None:
362
+ nn.init.zeros_(module.bias)
363
+
364
+
365
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
366
+ model = DinoVisionTransformer(
367
+ patch_size=patch_size,
368
+ embed_dim=384,
369
+ depth=12,
370
+ num_heads=6,
371
+ mlp_ratio=4,
372
+ block_fn=partial(Block, attn_class=MemEffAttention),
373
+ num_register_tokens=num_register_tokens,
374
+ **kwargs,
375
+ )
376
+ return model
377
+
378
+
379
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
380
+ model = DinoVisionTransformer(
381
+ patch_size=patch_size,
382
+ embed_dim=768,
383
+ depth=12,
384
+ num_heads=12,
385
+ mlp_ratio=4,
386
+ block_fn=partial(Block, attn_class=MemEffAttention),
387
+ num_register_tokens=num_register_tokens,
388
+ **kwargs,
389
+ )
390
+ return model
391
+
392
+
393
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
394
+ model = DinoVisionTransformer(
395
+ patch_size=patch_size,
396
+ embed_dim=1024,
397
+ depth=24,
398
+ num_heads=16,
399
+ mlp_ratio=4,
400
+ block_fn=partial(Block, attn_class=MemEffAttention),
401
+ num_register_tokens=num_register_tokens,
402
+ **kwargs,
403
+ )
404
+ return model
405
+
406
+
407
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
408
+ """
409
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
410
+ """
411
+ model = DinoVisionTransformer(
412
+ patch_size=patch_size,
413
+ embed_dim=1536,
414
+ depth=40,
415
+ num_heads=24,
416
+ mlp_ratio=4,
417
+ block_fn=partial(Block, attn_class=MemEffAttention),
418
+ num_register_tokens=num_register_tokens,
419
+ **kwargs,
420
+ )
421
+ return model
422
+
423
+
424
+ def DINOv2(model_name):
425
+ model_zoo = {
426
+ "vits": vit_small,
427
+ "vitb": vit_base,
428
+ "vitl": vit_large,
429
+ "vitg": vit_giant2,
430
+ }
431
+
432
+ return model_zoo[model_name](
433
+ img_size=518,
434
+ patch_size=14,
435
+ init_values=1.0,
436
+ ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
437
+ block_chunks=0,
438
+ num_register_tokens=0,
439
+ interpolate_antialias=False,
440
+ interpolate_offset=0.1,
441
+ )
rgbddepth/dinov2_layers/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .attention import MemEffAttention
8
+ from .block import NestedTensorBlock
9
+ from .mlp import Mlp
10
+ from .patch_embed import PatchEmbed
11
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
12
+
13
+ __all__ = [
14
+ "MemEffAttention",
15
+ "NestedTensorBlock",
16
+ "Mlp",
17
+ "PatchEmbed",
18
+ "SwiGLUFFN",
19
+ "SwiGLUFFNFused",
20
+ ]
rgbddepth/dinov2_layers/attention.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+
13
+ from torch import Tensor, nn
14
+
15
+ logger = logging.getLogger("dinov2")
16
+
17
+
18
+ try:
19
+ from xformers.ops import memory_efficient_attention, unbind
20
+
21
+ XFORMERS_AVAILABLE = True
22
+ except ImportError:
23
+ logger.warning("xFormers not available")
24
+ XFORMERS_AVAILABLE = False
25
+
26
+
27
+ class Attention(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ num_heads: int = 8,
32
+ qkv_bias: bool = False,
33
+ proj_bias: bool = True,
34
+ attn_drop: float = 0.0,
35
+ proj_drop: float = 0.0,
36
+ ) -> None:
37
+ super().__init__()
38
+ self.num_heads = num_heads
39
+ head_dim = dim // num_heads
40
+ self.scale = head_dim**-0.5
41
+
42
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
43
+ self.attn_drop = nn.Dropout(attn_drop)
44
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
45
+ self.proj_drop = nn.Dropout(proj_drop)
46
+
47
+ def forward(self, x: Tensor) -> Tensor:
48
+ B, N, C = x.shape
49
+ qkv = (
50
+ self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
51
+ )
52
+
53
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54
+ attn = q @ k.transpose(-2, -1)
55
+
56
+ attn = attn.softmax(dim=-1)
57
+ attn = self.attn_drop(attn)
58
+
59
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60
+ x = self.proj(x)
61
+ x = self.proj_drop(x)
62
+ return x
63
+
64
+
65
+ class MemEffAttention(Attention):
66
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67
+ if not XFORMERS_AVAILABLE:
68
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
69
+ return super().forward(x)
70
+
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73
+
74
+ q, k, v = unbind(qkv, 2)
75
+
76
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77
+ x = x.reshape([B, N, C])
78
+
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
rgbddepth/dinov2_layers/block.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ import logging
12
+ from typing import Any, Callable, Dict, List, Tuple
13
+
14
+ import torch
15
+ from torch import Tensor, nn
16
+
17
+ from .attention import Attention, MemEffAttention
18
+ from .drop_path import DropPath
19
+ from .layer_scale import LayerScale
20
+ from .mlp import Mlp
21
+
22
+ logger = logging.getLogger("dinov2")
23
+
24
+
25
+ try:
26
+ from xformers.ops import fmha, index_select_cat, scaled_index_add
27
+
28
+ XFORMERS_AVAILABLE = True
29
+ except ImportError:
30
+ logger.warning("xFormers not available")
31
+ XFORMERS_AVAILABLE = False
32
+
33
+
34
+ class Block(nn.Module):
35
+ def __init__(
36
+ self,
37
+ dim: int,
38
+ num_heads: int,
39
+ mlp_ratio: float = 4.0,
40
+ qkv_bias: bool = False,
41
+ proj_bias: bool = True,
42
+ ffn_bias: bool = True,
43
+ drop: float = 0.0,
44
+ attn_drop: float = 0.0,
45
+ init_values=None,
46
+ drop_path: float = 0.0,
47
+ act_layer: Callable[..., nn.Module] = nn.GELU,
48
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
49
+ attn_class: Callable[..., nn.Module] = Attention,
50
+ ffn_layer: Callable[..., nn.Module] = Mlp,
51
+ ) -> None:
52
+ super().__init__()
53
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
54
+ self.norm1 = norm_layer(dim)
55
+ self.attn = attn_class(
56
+ dim,
57
+ num_heads=num_heads,
58
+ qkv_bias=qkv_bias,
59
+ proj_bias=proj_bias,
60
+ attn_drop=attn_drop,
61
+ proj_drop=drop,
62
+ )
63
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
64
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
65
+
66
+ self.norm2 = norm_layer(dim)
67
+ mlp_hidden_dim = int(dim * mlp_ratio)
68
+ self.mlp = ffn_layer(
69
+ in_features=dim,
70
+ hidden_features=mlp_hidden_dim,
71
+ act_layer=act_layer,
72
+ drop=drop,
73
+ bias=ffn_bias,
74
+ )
75
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
76
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
77
+
78
+ self.sample_drop_ratio = drop_path
79
+
80
+ def forward(self, x: Tensor) -> Tensor:
81
+ def attn_residual_func(x: Tensor) -> Tensor:
82
+ return self.ls1(self.attn(self.norm1(x)))
83
+
84
+ def ffn_residual_func(x: Tensor) -> Tensor:
85
+ return self.ls2(self.mlp(self.norm2(x)))
86
+
87
+ if self.training and self.sample_drop_ratio > 0.1:
88
+ # the overhead is compensated only for a drop path rate larger than 0.1
89
+ x = drop_add_residual_stochastic_depth(
90
+ x,
91
+ residual_func=attn_residual_func,
92
+ sample_drop_ratio=self.sample_drop_ratio,
93
+ )
94
+ x = drop_add_residual_stochastic_depth(
95
+ x,
96
+ residual_func=ffn_residual_func,
97
+ sample_drop_ratio=self.sample_drop_ratio,
98
+ )
99
+ elif self.training and self.sample_drop_ratio > 0.0:
100
+ x = x + self.drop_path1(attn_residual_func(x))
101
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
102
+ else:
103
+ x = x + attn_residual_func(x)
104
+ x = x + ffn_residual_func(x)
105
+ return x
106
+
107
+
108
+ def drop_add_residual_stochastic_depth(
109
+ x: Tensor,
110
+ residual_func: Callable[[Tensor], Tensor],
111
+ sample_drop_ratio: float = 0.0,
112
+ ) -> Tensor:
113
+ # 1) extract subset using permutation
114
+ b, n, d = x.shape
115
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
116
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
117
+ x_subset = x[brange]
118
+
119
+ # 2) apply residual_func to get residual
120
+ residual = residual_func(x_subset)
121
+
122
+ x_flat = x.flatten(1)
123
+ residual = residual.flatten(1)
124
+
125
+ residual_scale_factor = b / sample_subset_size
126
+
127
+ # 3) add the residual
128
+ x_plus_residual = torch.index_add(
129
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
130
+ )
131
+ return x_plus_residual.view_as(x)
132
+
133
+
134
+ def get_branges_scales(x, sample_drop_ratio=0.0):
135
+ b, n, d = x.shape
136
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138
+ residual_scale_factor = b / sample_subset_size
139
+ return brange, residual_scale_factor
140
+
141
+
142
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143
+ if scaling_vector is None:
144
+ x_flat = x.flatten(1)
145
+ residual = residual.flatten(1)
146
+ x_plus_residual = torch.index_add(
147
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
148
+ )
149
+ else:
150
+ x_plus_residual = scaled_index_add(
151
+ x,
152
+ brange,
153
+ residual.to(dtype=x.dtype),
154
+ scaling=scaling_vector,
155
+ alpha=residual_scale_factor,
156
+ )
157
+ return x_plus_residual
158
+
159
+
160
+ attn_bias_cache: Dict[Tuple, Any] = {}
161
+
162
+
163
+ def get_attn_bias_and_cat(x_list, branges=None):
164
+ """
165
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
166
+ """
167
+ batch_sizes = (
168
+ [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
169
+ )
170
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
171
+ if all_shapes not in attn_bias_cache.keys():
172
+ seqlens = []
173
+ for b, x in zip(batch_sizes, x_list):
174
+ for _ in range(b):
175
+ seqlens.append(x.shape[1])
176
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
177
+ attn_bias._batch_sizes = batch_sizes
178
+ attn_bias_cache[all_shapes] = attn_bias
179
+
180
+ if branges is not None:
181
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
182
+ 1, -1, x_list[0].shape[-1]
183
+ )
184
+ else:
185
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
186
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
187
+
188
+ return attn_bias_cache[all_shapes], cat_tensors
189
+
190
+
191
+ def drop_add_residual_stochastic_depth_list(
192
+ x_list: List[Tensor],
193
+ residual_func: Callable[[Tensor, Any], Tensor],
194
+ sample_drop_ratio: float = 0.0,
195
+ scaling_vector=None,
196
+ ) -> Tensor:
197
+ # 1) generate random set of indices for dropping samples in the batch
198
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
199
+ branges = [s[0] for s in branges_scales]
200
+ residual_scale_factors = [s[1] for s in branges_scales]
201
+
202
+ # 2) get attention bias and index+concat the tensors
203
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
204
+
205
+ # 3) apply residual_func to get residual, and split the result
206
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
207
+
208
+ outputs = []
209
+ for x, brange, residual, residual_scale_factor in zip(
210
+ x_list, branges, residual_list, residual_scale_factors
211
+ ):
212
+ outputs.append(
213
+ add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)
214
+ )
215
+ return outputs
216
+
217
+
218
+ class NestedTensorBlock(Block):
219
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
220
+ """
221
+ x_list contains a list of tensors to nest together and run
222
+ """
223
+ assert isinstance(self.attn, MemEffAttention)
224
+
225
+ if self.training and self.sample_drop_ratio > 0.0:
226
+
227
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
228
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
229
+
230
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
231
+ return self.mlp(self.norm2(x))
232
+
233
+ x_list = drop_add_residual_stochastic_depth_list(
234
+ x_list,
235
+ residual_func=attn_residual_func,
236
+ sample_drop_ratio=self.sample_drop_ratio,
237
+ scaling_vector=(self.ls1.gamma if isinstance(self.ls1, LayerScale) else None),
238
+ )
239
+ x_list = drop_add_residual_stochastic_depth_list(
240
+ x_list,
241
+ residual_func=ffn_residual_func,
242
+ sample_drop_ratio=self.sample_drop_ratio,
243
+ scaling_vector=(self.ls2.gamma if isinstance(self.ls1, LayerScale) else None),
244
+ )
245
+ return x_list
246
+ else:
247
+
248
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
249
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
250
+
251
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
252
+ return self.ls2(self.mlp(self.norm2(x)))
253
+
254
+ attn_bias, x = get_attn_bias_and_cat(x_list)
255
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
256
+ x = x + ffn_residual_func(x)
257
+ return attn_bias.split(x)
258
+
259
+ def forward(self, x_or_x_list):
260
+ if isinstance(x_or_x_list, Tensor):
261
+ return super().forward(x_or_x_list)
262
+ elif isinstance(x_or_x_list, list):
263
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
264
+ return self.forward_nested(x_or_x_list)
265
+ else:
266
+ raise AssertionError
rgbddepth/dinov2_layers/drop_path.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10
+
11
+
12
+ from torch import nn
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ if drop_prob == 0.0 or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21
+ if keep_prob > 0.0:
22
+ random_tensor.div_(keep_prob)
23
+ output = x * random_tensor
24
+ return output
25
+
26
+
27
+ class DropPath(nn.Module):
28
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29
+
30
+ def __init__(self, drop_prob=None):
31
+ super(DropPath, self).__init__()
32
+ self.drop_prob = drop_prob
33
+
34
+ def forward(self, x):
35
+ return drop_path(x, self.drop_prob, self.training)
rgbddepth/dinov2_layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8
+
9
+ from typing import Union
10
+
11
+ import torch
12
+ from torch import Tensor, nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
rgbddepth/dinov2_layers/mlp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10
+
11
+
12
+ from typing import Callable, Optional
13
+
14
+ from torch import Tensor, nn
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features: int,
21
+ hidden_features: Optional[int] = None,
22
+ out_features: Optional[int] = None,
23
+ act_layer: Callable[..., nn.Module] = nn.GELU,
24
+ drop: float = 0.0,
25
+ bias: bool = True,
26
+ ) -> None:
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
rgbddepth/dinov2_layers/patch_embed.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ from typing import Callable, Optional, Tuple, Union
12
+
13
+ import torch.nn as nn
14
+ from torch import Tensor
15
+
16
+
17
+ def make_2tuple(x):
18
+ if isinstance(x, tuple):
19
+ assert len(x) == 2
20
+ return x
21
+
22
+ assert isinstance(x, int)
23
+ return (x, x)
24
+
25
+
26
+ class PatchEmbed(nn.Module):
27
+ """
28
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29
+
30
+ Args:
31
+ img_size: Image size.
32
+ patch_size: Patch token size.
33
+ in_chans: Number of input image channels.
34
+ embed_dim: Number of linear projection output channels.
35
+ norm_layer: Normalization layer.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ img_size: Union[int, Tuple[int, int]] = 224,
41
+ patch_size: Union[int, Tuple[int, int]] = 16,
42
+ in_chans: int = 3,
43
+ embed_dim: int = 768,
44
+ norm_layer: Optional[Callable] = None,
45
+ flatten_embedding: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ image_HW = make_2tuple(img_size)
50
+ patch_HW = make_2tuple(patch_size)
51
+ patch_grid_size = (
52
+ image_HW[0] // patch_HW[0],
53
+ image_HW[1] // patch_HW[1],
54
+ )
55
+
56
+ self.img_size = image_HW
57
+ self.patch_size = patch_HW
58
+ self.patches_resolution = patch_grid_size
59
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60
+
61
+ self.in_chans = in_chans
62
+ self.embed_dim = embed_dim
63
+
64
+ self.flatten_embedding = flatten_embedding
65
+
66
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68
+
69
+ def forward(self, x: Tensor) -> Tensor:
70
+ _, _, H, W = x.shape
71
+ patch_H, patch_W = self.patch_size
72
+
73
+ assert (
74
+ H % patch_H == 0
75
+ ), f"Input image height {H} is not a multiple of patch height {patch_H}"
76
+ assert (
77
+ W % patch_W == 0
78
+ ), f"Input image width {W} is not a multiple of patch width: {patch_W}"
79
+
80
+ x = self.proj(x) # B C H W
81
+ H, W = x.size(2), x.size(3)
82
+ x = x.flatten(2).transpose(1, 2) # B HW C
83
+ x = self.norm(x)
84
+ if not self.flatten_embedding:
85
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
86
+ return x
87
+
88
+ def flops(self) -> float:
89
+ Ho, Wo = self.patches_resolution
90
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
91
+ if self.norm is not None:
92
+ flops += Ho * Wo * self.embed_dim
93
+ return flops
rgbddepth/dinov2_layers/swiglu_ffn.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, Optional
8
+
9
+ import torch.nn.functional as F
10
+ from torch import Tensor, nn
11
+
12
+
13
+ class SwiGLUFFN(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_features: int,
17
+ hidden_features: Optional[int] = None,
18
+ out_features: Optional[int] = None,
19
+ act_layer: Callable[..., nn.Module] = None,
20
+ drop: float = 0.0,
21
+ bias: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28
+
29
+ def forward(self, x: Tensor) -> Tensor:
30
+ x12 = self.w12(x)
31
+ x1, x2 = x12.chunk(2, dim=-1)
32
+ hidden = F.silu(x1) * x2
33
+ return self.w3(hidden)
34
+
35
+
36
+ try:
37
+ from xformers.ops import SwiGLU
38
+
39
+ XFORMERS_AVAILABLE = True
40
+ except ImportError:
41
+ SwiGLU = SwiGLUFFN
42
+ XFORMERS_AVAILABLE = False
43
+
44
+
45
+ class SwiGLUFFNFused(SwiGLU):
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ hidden_features: Optional[int] = None,
50
+ out_features: Optional[int] = None,
51
+ act_layer: Callable[..., nn.Module] = None,
52
+ drop: float = 0.0,
53
+ bias: bool = True,
54
+ ) -> None:
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58
+ super().__init__(
59
+ in_features=in_features,
60
+ hidden_features=hidden_features,
61
+ out_features=out_features,
62
+ bias=bias,
63
+ )
rgbddepth/dpt.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import cv2
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torchvision.transforms import Compose
10
+
11
+ from .dinov2 import DINOv2
12
+ from .flexible_attention import FlexibleCrossAttention
13
+ from .util.blocks import FeatureFusionBlock, _make_scratch
14
+ from .util.transform import NormalizeImage, PrepareForNet, Resize
15
+
16
+
17
+ def _make_fusion_block(features, use_bn, size=None):
18
+ return FeatureFusionBlock(
19
+ features,
20
+ nn.ReLU(False),
21
+ deconv=False,
22
+ bn=use_bn,
23
+ expand=False,
24
+ align_corners=True,
25
+ size=size,
26
+ )
27
+
28
+
29
+ class ConvBlock(nn.Module):
30
+ def __init__(self, in_feature, out_feature):
31
+ super().__init__()
32
+
33
+ self.conv_block = nn.Sequential(
34
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
35
+ nn.BatchNorm2d(out_feature),
36
+ nn.ReLU(True),
37
+ )
38
+
39
+ def forward(self, x):
40
+ return self.conv_block(x)
41
+
42
+
43
+ class DPTHead(nn.Module):
44
+ def __init__(
45
+ self,
46
+ in_channels,
47
+ features=256,
48
+ use_bn=False,
49
+ out_channels=[256, 512, 1024, 1024],
50
+ use_clstoken=False,
51
+ sigact_out=False,
52
+ ):
53
+ super(DPTHead, self).__init__()
54
+
55
+ self.use_clstoken = use_clstoken
56
+
57
+ self.projects = nn.ModuleList(
58
+ [
59
+ nn.Conv2d(
60
+ in_channels=in_channels,
61
+ out_channels=out_channel,
62
+ kernel_size=1,
63
+ stride=1,
64
+ padding=0,
65
+ )
66
+ for out_channel in out_channels
67
+ ]
68
+ )
69
+
70
+ self.resize_layers = nn.ModuleList(
71
+ [
72
+ nn.ConvTranspose2d(
73
+ in_channels=out_channels[0],
74
+ out_channels=out_channels[0],
75
+ kernel_size=4,
76
+ stride=4,
77
+ padding=0,
78
+ ),
79
+ nn.ConvTranspose2d(
80
+ in_channels=out_channels[1],
81
+ out_channels=out_channels[1],
82
+ kernel_size=2,
83
+ stride=2,
84
+ padding=0,
85
+ ),
86
+ nn.Identity(),
87
+ nn.Conv2d(
88
+ in_channels=out_channels[3],
89
+ out_channels=out_channels[3],
90
+ kernel_size=3,
91
+ stride=2,
92
+ padding=1,
93
+ ),
94
+ ]
95
+ )
96
+
97
+ if use_clstoken:
98
+ self.readout_projects = nn.ModuleList()
99
+ for _ in range(len(self.projects)):
100
+ self.readout_projects.append(
101
+ nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU())
102
+ )
103
+
104
+ self.scratch = _make_scratch(
105
+ out_channels,
106
+ features,
107
+ groups=1,
108
+ expand=False,
109
+ )
110
+
111
+ self.scratch.stem_transpose = None
112
+
113
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
114
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
115
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
116
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
117
+
118
+ head_features_1 = features
119
+ head_features_2 = 32
120
+
121
+ self.scratch.output_conv1 = nn.Conv2d(
122
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
123
+ )
124
+
125
+ if not sigact_out:
126
+ self.scratch.output_conv2 = nn.Sequential(
127
+ nn.Conv2d(
128
+ head_features_1 // 2,
129
+ head_features_2,
130
+ kernel_size=3,
131
+ stride=1,
132
+ padding=1,
133
+ ),
134
+ nn.ReLU(True),
135
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
136
+ nn.ReLU(True),
137
+ nn.Identity(),
138
+ )
139
+ else:
140
+ self.scratch.output_conv2 = nn.Sequential(
141
+ nn.Conv2d(
142
+ head_features_1 // 2,
143
+ head_features_2,
144
+ kernel_size=3,
145
+ stride=1,
146
+ padding=1,
147
+ ),
148
+ nn.ReLU(True),
149
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
150
+ nn.Sigmoid(),
151
+ )
152
+
153
+ def forward(self, out_features, patch_h, patch_w):
154
+ out = []
155
+ for i, x in enumerate(out_features):
156
+ if self.use_clstoken:
157
+ x, cls_token = x[0], x[1]
158
+ readout = cls_token.unsqueeze(1).expand_as(x)
159
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
160
+ else:
161
+ x = x[0]
162
+
163
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
164
+
165
+ x = self.projects[i](x)
166
+ x = self.resize_layers[i](x)
167
+
168
+ out.append(x)
169
+
170
+ layer_1, layer_2, layer_3, layer_4 = out
171
+
172
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
173
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
174
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
175
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
176
+
177
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
178
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
179
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
180
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
181
+
182
+ out = self.scratch.output_conv1(path_1)
183
+ out = F.interpolate(
184
+ out,
185
+ (int(patch_h * 14), int(patch_w * 14)),
186
+ mode="bilinear",
187
+ align_corners=True,
188
+ )
189
+ out = self.scratch.output_conv2(out)
190
+
191
+ return out
192
+
193
+
194
+ class RGBDDepth(nn.Module):
195
+ def __init__(
196
+ self,
197
+ encoder="vitl",
198
+ features=256,
199
+ out_channels=[256, 512, 1024, 1024],
200
+ use_bn=False,
201
+ use_clstoken=False,
202
+ max_depth=20.0,
203
+ use_xformers=False,
204
+ ):
205
+ super(RGBDDepth, self).__init__()
206
+
207
+ self.intermediate_layer_idx = {
208
+ "vits": [2, 5, 8, 11],
209
+ "vitb": [2, 5, 8, 11],
210
+ "vitl": [4, 11, 17, 23],
211
+ "vitg": [9, 19, 29, 39],
212
+ }
213
+
214
+ self.max_depth = max_depth
215
+
216
+ self.encoder = encoder
217
+ self.pretrained = DINOv2(model_name=encoder)
218
+ self.depth_pretrained = DINOv2(model_name=encoder)
219
+
220
+ # self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, sigact_out=False)
221
+ self.depth_head_rgbd = DPTHead(
222
+ self.pretrained.embed_dim * 2,
223
+ features,
224
+ use_bn,
225
+ out_channels=out_channels,
226
+ use_clstoken=use_clstoken,
227
+ sigact_out=False,
228
+ )
229
+
230
+ # cross attention with xFormers support
231
+ num_heads = 4
232
+ self.crossAtts = nn.ModuleList(
233
+ [
234
+ FlexibleCrossAttention(
235
+ self.pretrained.embed_dim, num_heads, use_xformers=use_xformers
236
+ )
237
+ for _ in range(4)
238
+ ]
239
+ )
240
+
241
+ def forward(self, x):
242
+ rgb, depth = x[:, :3], x[:, 3:]
243
+ patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
244
+
245
+ with torch.no_grad():
246
+ features_rgb = self.pretrained.get_intermediate_layers(
247
+ rgb, self.intermediate_layer_idx[self.encoder], return_class_token=True
248
+ )
249
+
250
+ features_depth = self.depth_pretrained.get_intermediate_layers(
251
+ depth.repeat(1, 3, 1, 1),
252
+ self.intermediate_layer_idx[self.encoder],
253
+ return_class_token=True,
254
+ )
255
+ features = []
256
+ for f_rgb, f_depth, crossAtt in zip(features_rgb, features_depth, self.crossAtts):
257
+ B, N, C = f_rgb[0].shape
258
+ tf_rgb = f_rgb[0].reshape(B * N, 1, C)
259
+ tf_depth = f_depth[0].reshape(B * N, 1, C)
260
+ token_feat = torch.concat((tf_rgb, tf_depth), axis=1)
261
+ att_feat, _ = crossAtt(token_feat, token_feat, token_feat)
262
+ att_feat = att_feat.reshape(B * N, 2, C).sum(axis=1).reshape(B, N, C)
263
+
264
+ feat = torch.concat((f_rgb[0], att_feat), axis=2)
265
+ cls_t = torch.concat((f_rgb[1], f_depth[1]), axis=1)
266
+ tuples = (feat, cls_t)
267
+ features.append(tuples)
268
+ depth = self.depth_head_rgbd(features, patch_h, patch_w)
269
+ depth = F.relu(depth)
270
+ return depth.squeeze(1)
271
+
272
+ @torch.no_grad()
273
+ def infer_image(self, raw_image, depth_low_res, input_size=518):
274
+ inputs, (h, w) = self.image2tensor(raw_image, depth_low_res, input_size)
275
+ pred_depth = self.forward(inputs)
276
+ pred_depth = F.interpolate(pred_depth[:, None], (h, w), mode="nearest")[0, 0]
277
+ return pred_depth.cpu().numpy()
278
+
279
+ def image2tensor(self, raw_image, depth, input_size=518):
280
+ transform = Compose(
281
+ [
282
+ Resize(
283
+ width=input_size,
284
+ height=input_size,
285
+ resize_target=True,
286
+ keep_aspect_ratio=True,
287
+ ensure_multiple_of=14,
288
+ resize_method="lower_bound",
289
+ image_interpolation_method=cv2.INTER_CUBIC,
290
+ ),
291
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
292
+ PrepareForNet(),
293
+ ]
294
+ )
295
+
296
+ h, w = raw_image.shape[:2]
297
+
298
+ image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
299
+ prepared = transform({"image": image, "depth": depth})
300
+ image = prepared["image"]
301
+ image = torch.from_numpy(image).unsqueeze(0)
302
+
303
+ depth = prepared["depth"]
304
+ depth = torch.from_numpy(depth).unsqueeze(0).unsqueeze(0)
305
+
306
+ inputs = torch.cat((image, depth), dim=1)
307
+
308
+ # Use the same device as model parameters
309
+ device = next(self.parameters()).device
310
+ inputs = inputs.to(device)
311
+
312
+ return inputs, (h, w)
rgbddepth/flexible_attention.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """Flexible cross-attention module with xFormers support and automatic fallback."""
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class FlexibleCrossAttention(nn.MultiheadAttention):
12
+ """Cross-attention with optional xFormers support and automatic fallback to SDPA.
13
+
14
+ This module inherits from nn.MultiheadAttention to ensure weight compatibility.
15
+ It overrides forward() to use xFormers when available and requested.
16
+
17
+ Uses:
18
+ 1. xFormers memory-efficient attention (CUDA only, if installed and use_xformers=True)
19
+ 2. PyTorch native SDPA (Scaled Dot Product Attention, PyTorch 2.0+, default)
20
+ 3. Standard MultiheadAttention (fallback for older PyTorch versions)
21
+
22
+ Args:
23
+ embed_dim: Total dimension of the model
24
+ num_heads: Number of parallel attention heads
25
+ use_xformers: Whether to attempt using xFormers (only works on CUDA)
26
+ """
27
+
28
+ def __init__(self, embed_dim, num_heads, use_xformers=False, **kwargs):
29
+ # Initialize parent with batch_first=True to match original usage
30
+ super().__init__(embed_dim, num_heads, batch_first=True, **kwargs)
31
+
32
+ self.embed_dim = embed_dim
33
+ self.num_heads = num_heads
34
+ self.head_dim = embed_dim // num_heads
35
+
36
+ # Check if xFormers is available and requested
37
+ self.use_xformers = use_xformers and self._check_xformers()
38
+
39
+ def _check_xformers(self):
40
+ """Check if xFormers is available for import.
41
+
42
+ Returns:
43
+ bool: True if xFormers can be imported, False otherwise
44
+ """
45
+ try:
46
+ import importlib.util
47
+
48
+ return importlib.util.find_spec("xformers.ops") is not None
49
+ except (ImportError, ValueError):
50
+ return False
51
+
52
+ def forward(self, query, key, value, **kwargs):
53
+ """Forward pass with automatic backend selection.
54
+
55
+ Args:
56
+ query: Query tensor of shape [B, N, C]
57
+ key: Key tensor of shape [B, N, C]
58
+ value: Value tensor of shape [B, N, C]
59
+
60
+ Returns:
61
+ tuple: (output, attention_weights)
62
+ - output: Attention output of shape [B, N, C]
63
+ - attention_weights: None (not computed for efficiency)
64
+ """
65
+ if not self.use_xformers:
66
+ # Standard path using parent nn.MultiheadAttention (with SDPA in PyTorch 2.0+)
67
+ # This uses the original weights (in_proj_weight, out_proj) from checkpoint
68
+ return super().forward(query, key, value, need_weights=False, **kwargs)
69
+ else:
70
+ # xFormers memory-efficient attention path
71
+ import xformers.ops as xops
72
+
73
+ # Use parent's projection weights for Q, K, V
74
+ # in_proj_weight contains concatenated [W_q; W_k; W_v]
75
+ # This ensures we use the exact same weights as standard MultiheadAttention
76
+ if self.in_proj_weight is not None:
77
+ # Split the combined in_proj_weight into Q, K, V weights
78
+ w_q, w_k, w_v = self.in_proj_weight.chunk(3, dim=0)
79
+ b_q, b_k, b_v = None, None, None
80
+ if self.in_proj_bias is not None:
81
+ b_q, b_k, b_v = self.in_proj_bias.chunk(3, dim=0)
82
+
83
+ # Apply projections using the same weights as standard attention
84
+ q = torch.nn.functional.linear(query, w_q, b_q)
85
+ k = torch.nn.functional.linear(key, w_k, b_k)
86
+ v = torch.nn.functional.linear(value, w_v, b_v)
87
+ else:
88
+ # Separate projection weights (shouldn't happen with default config)
89
+ q = torch.nn.functional.linear(query, self.q_proj_weight, self.in_proj_bias)
90
+ k = torch.nn.functional.linear(key, self.k_proj_weight)
91
+ v = torch.nn.functional.linear(value, self.v_proj_weight)
92
+
93
+ # Reshape for multi-head attention: [B, N, C] -> [B, N, H, C//H]
94
+ B, N, C = q.shape
95
+ q = q.reshape(B, N, self.num_heads, self.head_dim)
96
+ k = k.reshape(B, N, self.num_heads, self.head_dim)
97
+ v = v.reshape(B, N, self.num_heads, self.head_dim)
98
+
99
+ # Apply xFormers memory-efficient attention
100
+ # This is significantly faster and uses less memory than standard attention
101
+ out = xops.memory_efficient_attention(q, k, v)
102
+
103
+ # Reshape back: [B, N, H, C//H] -> [B, N, C]
104
+ out = out.reshape(B, N, C)
105
+
106
+ # Use parent's output projection (same weights as standard attention)
107
+ out = torch.nn.functional.linear(out, self.out_proj.weight, self.out_proj.bias)
108
+
109
+ return out, None
rgbddepth/util/__init__.py ADDED
File without changes
rgbddepth/util/blocks.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import torch.nn as nn
6
+
7
+
8
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
9
+ scratch = nn.Module()
10
+
11
+ out_shape1 = out_shape
12
+ out_shape2 = out_shape
13
+ out_shape3 = out_shape
14
+ if len(in_shape) >= 4:
15
+ out_shape4 = out_shape
16
+
17
+ if expand:
18
+ out_shape1 = out_shape
19
+ out_shape2 = out_shape * 2
20
+ out_shape3 = out_shape * 4
21
+ if len(in_shape) >= 4:
22
+ out_shape4 = out_shape * 8
23
+
24
+ scratch.layer1_rn = nn.Conv2d(
25
+ in_shape[0],
26
+ out_shape1,
27
+ kernel_size=3,
28
+ stride=1,
29
+ padding=1,
30
+ bias=False,
31
+ groups=groups,
32
+ )
33
+ scratch.layer2_rn = nn.Conv2d(
34
+ in_shape[1],
35
+ out_shape2,
36
+ kernel_size=3,
37
+ stride=1,
38
+ padding=1,
39
+ bias=False,
40
+ groups=groups,
41
+ )
42
+ scratch.layer3_rn = nn.Conv2d(
43
+ in_shape[2],
44
+ out_shape3,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1,
48
+ bias=False,
49
+ groups=groups,
50
+ )
51
+ if len(in_shape) >= 4:
52
+ scratch.layer4_rn = nn.Conv2d(
53
+ in_shape[3],
54
+ out_shape4,
55
+ kernel_size=3,
56
+ stride=1,
57
+ padding=1,
58
+ bias=False,
59
+ groups=groups,
60
+ )
61
+
62
+ return scratch
63
+
64
+
65
+ class ResidualConvUnit(nn.Module):
66
+ """Residual convolution module."""
67
+
68
+ def __init__(self, features, activation, bn):
69
+ """Init.
70
+
71
+ Args:
72
+ features (int): number of features
73
+ """
74
+ super().__init__()
75
+
76
+ self.bn = bn
77
+
78
+ self.groups = 1
79
+
80
+ self.conv1 = nn.Conv2d(
81
+ features,
82
+ features,
83
+ kernel_size=3,
84
+ stride=1,
85
+ padding=1,
86
+ bias=True,
87
+ groups=self.groups,
88
+ )
89
+
90
+ self.conv2 = nn.Conv2d(
91
+ features,
92
+ features,
93
+ kernel_size=3,
94
+ stride=1,
95
+ padding=1,
96
+ bias=True,
97
+ groups=self.groups,
98
+ )
99
+
100
+ if self.bn:
101
+ self.bn1 = nn.BatchNorm2d(features)
102
+ self.bn2 = nn.BatchNorm2d(features)
103
+
104
+ self.activation = activation
105
+
106
+ self.skip_add = nn.quantized.FloatFunctional()
107
+
108
+ def forward(self, x):
109
+ """Forward pass.
110
+
111
+ Args:
112
+ x (tensor): input
113
+
114
+ Returns:
115
+ tensor: output
116
+ """
117
+
118
+ out = self.activation(x)
119
+ out = self.conv1(out)
120
+ if self.bn:
121
+ out = self.bn1(out)
122
+
123
+ out = self.activation(out)
124
+ out = self.conv2(out)
125
+ if self.bn:
126
+ out = self.bn2(out)
127
+
128
+ if self.groups > 1:
129
+ out = self.conv_merge(out)
130
+
131
+ return self.skip_add.add(out, x)
132
+
133
+
134
+ class FeatureFusionBlock(nn.Module):
135
+ """Feature fusion block."""
136
+
137
+ def __init__(
138
+ self,
139
+ features,
140
+ activation,
141
+ deconv=False,
142
+ bn=False,
143
+ expand=False,
144
+ align_corners=True,
145
+ size=None,
146
+ ):
147
+ """Init.
148
+
149
+ Args:
150
+ features (int): number of features
151
+ """
152
+ super(FeatureFusionBlock, self).__init__()
153
+
154
+ self.deconv = deconv
155
+ self.align_corners = align_corners
156
+
157
+ self.groups = 1
158
+
159
+ self.expand = expand
160
+ out_features = features
161
+ if self.expand:
162
+ out_features = features // 2
163
+
164
+ self.out_conv = nn.Conv2d(
165
+ features,
166
+ out_features,
167
+ kernel_size=1,
168
+ stride=1,
169
+ padding=0,
170
+ bias=True,
171
+ groups=1,
172
+ )
173
+
174
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
175
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
176
+
177
+ self.skip_add = nn.quantized.FloatFunctional()
178
+
179
+ self.size = size
180
+
181
+ def forward(self, *xs, size=None):
182
+ """Forward pass.
183
+
184
+ Returns:
185
+ tensor: output
186
+ """
187
+ output = xs[0]
188
+
189
+ if len(xs) == 2:
190
+ res = self.resConfUnit1(xs[1])
191
+ output = self.skip_add.add(output, res)
192
+
193
+ output = self.resConfUnit2(output)
194
+
195
+ if (size is None) and (self.size is None):
196
+ modifier = {"scale_factor": 2}
197
+ elif size is None:
198
+ modifier = {"size": self.size}
199
+ else:
200
+ modifier = {"size": size}
201
+
202
+ output = nn.functional.interpolate(
203
+ output, **modifier, mode="bilinear", align_corners=self.align_corners
204
+ )
205
+
206
+ output = self.out_conv(output)
207
+
208
+ return output
rgbddepth/util/transform.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+
9
+ class Resize(object):
10
+ """Resize sample to given size (width, height)."""
11
+
12
+ def __init__(
13
+ self,
14
+ width,
15
+ height,
16
+ resize_target=True,
17
+ keep_aspect_ratio=False,
18
+ ensure_multiple_of=1,
19
+ resize_method="lower_bound",
20
+ image_interpolation_method=cv2.INTER_AREA,
21
+ ):
22
+ """Init.
23
+
24
+ Args:
25
+ width (int): desired output width
26
+ height (int): desired output height
27
+ resize_target (bool, optional):
28
+ True: Resize the full sample (image, mask, target).
29
+ False: Resize image only.
30
+ Defaults to True.
31
+ keep_aspect_ratio (bool, optional):
32
+ True: Keep the aspect ratio of the input sample.
33
+ Output sample might not have the given width and height, and
34
+ resize behaviour depends on the parameter 'resize_method'.
35
+ Defaults to False.
36
+ ensure_multiple_of (int, optional):
37
+ Output width and height is constrained to be multiple of this parameter.
38
+ Defaults to 1.
39
+ resize_method (str, optional):
40
+ "lower_bound": Output will be at least as large as the given size.
41
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
42
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
43
+ Defaults to "lower_bound".
44
+ """
45
+ self.__width = width
46
+ self.__height = height
47
+
48
+ self.__resize_target = resize_target
49
+ self.__keep_aspect_ratio = keep_aspect_ratio
50
+ self.__multiple_of = ensure_multiple_of
51
+ self.__resize_method = resize_method
52
+ self.__image_interpolation_method = image_interpolation_method
53
+
54
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
55
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
56
+
57
+ if max_val is not None and y > max_val:
58
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
59
+
60
+ if y < min_val:
61
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
62
+
63
+ return y
64
+
65
+ def get_size(self, width, height):
66
+ # determine new height and width
67
+ scale_height = self.__height / height
68
+ scale_width = self.__width / width
69
+
70
+ if self.__keep_aspect_ratio:
71
+ if self.__resize_method == "lower_bound":
72
+ # scale such that output size is lower bound
73
+ if scale_width > scale_height:
74
+ # fit width
75
+ scale_height = scale_width
76
+ else:
77
+ # fit height
78
+ scale_width = scale_height
79
+ elif self.__resize_method == "upper_bound":
80
+ # scale such that output size is upper bound
81
+ if scale_width < scale_height:
82
+ # fit width
83
+ scale_height = scale_width
84
+ else:
85
+ # fit height
86
+ scale_width = scale_height
87
+ elif self.__resize_method == "minimal":
88
+ # scale as least as possbile
89
+ if abs(1 - scale_width) < abs(1 - scale_height):
90
+ # fit width
91
+ scale_height = scale_width
92
+ else:
93
+ # fit height
94
+ scale_width = scale_height
95
+ else:
96
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
97
+
98
+ if self.__resize_method == "lower_bound":
99
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
100
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
101
+ elif self.__resize_method == "upper_bound":
102
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
103
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
104
+ elif self.__resize_method == "minimal":
105
+ new_height = self.constrain_to_multiple_of(scale_height * height)
106
+ new_width = self.constrain_to_multiple_of(scale_width * width)
107
+ else:
108
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
109
+
110
+ return (new_width, new_height)
111
+
112
+ def __call__(self, sample):
113
+ width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
114
+
115
+ # resize sample
116
+ sample["image"] = cv2.resize(
117
+ sample["image"],
118
+ (width, height),
119
+ interpolation=self.__image_interpolation_method,
120
+ )
121
+
122
+ if self.__resize_target:
123
+ if "depth" in sample:
124
+ sample["depth"] = cv2.resize(
125
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
126
+ )
127
+
128
+ if "mask" in sample:
129
+ sample["mask"] = cv2.resize(
130
+ sample["mask"].astype(np.float32),
131
+ (width, height),
132
+ interpolation=cv2.INTER_NEAREST,
133
+ )
134
+
135
+ return sample
136
+
137
+
138
+ class NormalizeImage(object):
139
+ """Normlize image by given mean and std."""
140
+
141
+ def __init__(self, mean, std):
142
+ self.__mean = mean
143
+ self.__std = std
144
+
145
+ def __call__(self, sample):
146
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
147
+
148
+ return sample
149
+
150
+
151
+ class PrepareForNet(object):
152
+ """Prepare sample for usage as network input."""
153
+
154
+ def __init__(self):
155
+ pass
156
+
157
+ def __call__(self, sample):
158
+ image = np.transpose(sample["image"], (2, 0, 1))
159
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
160
+
161
+ if "depth" in sample:
162
+ depth = sample["depth"].astype(np.float32)
163
+ sample["depth"] = np.ascontiguousarray(depth)
164
+
165
+ if "mask" in sample:
166
+ sample["mask"] = sample["mask"].astype(np.float32)
167
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
168
+
169
+ return sample