File size: 18,058 Bytes
887af40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b1eedf
 
 
 
887af40
 
 
 
 
 
8b1eedf
887af40
 
 
 
 
 
 
 
 
8b1eedf
887af40
 
 
 
 
8b1eedf
887af40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b1eedf
 
 
 
 
 
887af40
 
 
8b1eedf
 
887af40
 
8b1eedf
 
 
 
 
887af40
8b1eedf
 
 
 
 
 
887af40
8b1eedf
 
 
 
887af40
8b1eedf
 
 
 
 
 
887af40
8b1eedf
 
 
 
 
887af40
8b1eedf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
887af40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdf7e9f
887af40
 
 
 
fdf7e9f
887af40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
import argparse
import os
import shutil
import subprocess
import threading
from pathlib import Path
from queue import Empty, Queue
from typing import Any, Dict, List, Optional, Tuple

import cv2
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from torch.utils.data import DataLoader, IterableDataset
from tqdm import tqdm

from demo import create_rendering_image
from sheap import load_sheap_model
from sheap.tiny_flame import TinyFlame, pose_components_to_rotmats

try:
    import face_alignment
except ImportError:
    raise ImportError(
        "The 'face_alignment' package is required. Please install it via 'pip install face-alignment'."
    )
from sheap.fa_landmark_utils import detect_face_and_crop


class RenderingThread(threading.Thread):
    """Background thread for rendering frames to images."""

    def __init__(
        self,
        render_queue: Queue,
        temp_dir: Path,
        faces: torch.Tensor,
        c2w: torch.Tensor,
        render_size: int,
    ):
        """
        Initialize rendering thread.

        Args:
            render_queue: Queue containing (frame_idx, cropped_frame, verts) tuples
            temp_dir: Directory to save rendered images
            faces: Face indices tensor from FLAME model
            c2w: Camera-to-world transformation matrix
            render_size: Size of each sub-image in the rendered output
        """
        super().__init__(daemon=True)
        self.render_queue = render_queue
        self.temp_dir = temp_dir
        self.faces = faces
        self.c2w = c2w
        self.render_size = render_size
        self.stop_event = threading.Event()
        self.frames_rendered = 0

    def run(self):
        """Process rendering queue until stop signal is received."""
        # Set PyOpenGL platform for this thread
        os.environ["PYOPENGL_PLATFORM"] = "egl"

        while not self.stop_event.is_set():
            try:
                # Get item from queue with timeout to allow checking stop_event
                try:
                    item = self.render_queue.get(timeout=0.1)
                except Empty:  # Haven't finished, but nothing to render yet
                    continue
                if item is None:  # Sentinel value to stop
                    break

                frame_idx, cropped_frame, verts = item
                frame_idx, cropped_frame, verts = item

                # Render the frame
                cropped_pil = Image.fromarray(cropped_frame)
                combined = create_rendering_image(
                    original_image=cropped_pil,
                    verts=verts,
                    faces=self.faces,
                    c2w=self.c2w,
                    output_size=self.render_size,
                )

                # Save to temp directory with zero-padded frame number
                output_path = self.temp_dir / f"frame_{frame_idx:06d}.png"
                combined.save(output_path)

                self.frames_rendered += 1
                self.render_queue.task_done()

            except Exception as e:
                if not self.stop_event.is_set():
                    print(f"Error rendering frame: {e}")
                    import traceback

                    traceback.print_exc()

    def stop(self):
        """Signal the thread to stop."""
        self.stop_event.set()


class VideoFrameDataset(IterableDataset):
    """Iterable dataset for streaming video frames with face detection and cropping.
    
    Uses a background thread for video frame loading while face detection runs in the main thread.
    """

    def __init__(
        self,
        video_path: str,
        fa_model: face_alignment.FaceAlignment,
        smoothing_alpha: float = 0.3,
        frame_buffer_size: int = 32,
    ):
        """
        Initialize video frame dataset.

        Args:
            video_path: Path to video file
            fa_model: FaceAlignment model instance for face detection
            smoothing_alpha: Smoothing factor for bounding box (0=no smoothing, 1=no change).
                           Lower values = more smoothing
            frame_buffer_size: Size of the frame buffer queue for the background thread
        """
        super().__init__()
        self.video_path = video_path
        self.fa_model = fa_model
        self.smoothing_alpha = smoothing_alpha
        self.frame_buffer_size = frame_buffer_size
        self.prev_bbox: Optional[Tuple[int, int, int, int]] = None

        # Get video metadata (don't keep capture open)
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise ValueError(f"Could not open video file: {video_path}")

        self.fps = cap.get(cv2.CAP_PROP_FPS)
        self.num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        self.width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        cap.release()

        print(
            f"Video info: {self.num_frames} frames, {self.fps:.2f} fps, {self.width}x{self.height}"
        )

    def _video_reader_thread(self, frame_queue: Queue, stop_event: threading.Event):
        """Background thread that reads video frames and puts them in a queue.
        
        Args:
            frame_queue: Queue to put (frame_idx, frame_rgb) tuples
            stop_event: Event to signal thread to stop
        """
        cap = cv2.VideoCapture(self.video_path)
        if not cap.isOpened():
            frame_queue.put(("error", f"Could not open video file: {self.video_path}"))
            return

        frame_idx = 0
        try:
            while not stop_event.is_set():
                ret, frame_bgr = cap.read()
                if not ret:
                    break

                # Convert BGR to RGB
                frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
                
                # Put frame in queue (blocks if queue is full)
                frame_queue.put((frame_idx, frame_rgb))
                frame_idx += 1

        finally:
            cap.release()
            # Signal end of video
            frame_queue.put(None)

    def __iter__(self):
        """
        Iterate through video frames sequentially.
        
        Video frame loading happens in a background thread, while face detection
        and processing happen in the main thread.

        Yields:
            Dictionary containing frame_idx, processed image, and bounding box
        """
        # Reset smoothing state for new iteration
        self.prev_bbox = None

        # Create queue and start background thread for video reading
        frame_queue = Queue(maxsize=self.frame_buffer_size)
        stop_event = threading.Event()
        reader_thread = threading.Thread(
            target=self._video_reader_thread,
            args=(frame_queue, stop_event),
            daemon=True
        )
        reader_thread.start()

        try:
            while True:
                # Get frame from background thread
                item = frame_queue.get()
                
                # Check for end of video
                if item is None:
                    break
                
                # Check for error
                if isinstance(item, tuple) and len(item) == 2 and item[0] == "error":
                    raise RuntimeError(item[1])
                
                frame_idx, frame_rgb = item

                # Convert to torch tensor (C, H, W) with values in [0, 1]
                image = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0

                # Detect face and crop (runs in main thread, can use GPU)
                bbox = detect_face_and_crop(image, self.fa_model, margin=0.9, shift_up=0.5)

                # Apply smoothing using exponential moving average
                bbox = self._smooth_bbox(bbox)
                x0, y0, x1, y1 = bbox

                cropped = image[:, y0:y1, x0:x1]

                # Resize to 224x224 for SHEAP model
                cropped_resized = TF.resize(cropped, [224, 224], antialias=True)
                cropped_for_render = TF.resize(cropped, [512, 512], antialias=True)

                yield {
                    "frame_idx": frame_idx,
                    "image": cropped_resized,
                    "bbox": bbox,
                    "original_frame": frame_rgb,  # Keep original for reference (as numpy array)
                    "cropped_frame": cropped_for_render,  # Cropped region resized to 512x512
                }

        finally:
            # Clean up background thread
            stop_event.set()
            reader_thread.join(timeout=1.0)

    def _smooth_bbox(self, bbox: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]:
        """Apply exponential moving average smoothing to bounding box."""
        if self.prev_bbox is None:
            self.prev_bbox = bbox
            return bbox

        x0, y0, x1, y1 = bbox
        prev_x0, prev_y0, prev_x1, prev_y1 = self.prev_bbox

        # Smooth: new_bbox = alpha * detected_bbox + (1 - alpha) * prev_bbox
        smoothed = (
            int(self.smoothing_alpha * x0 + (1 - self.smoothing_alpha) * prev_x0),
            int(self.smoothing_alpha * y0 + (1 - self.smoothing_alpha) * prev_y0),
            int(self.smoothing_alpha * x1 + (1 - self.smoothing_alpha) * prev_x1),
            int(self.smoothing_alpha * y1 + (1 - self.smoothing_alpha) * prev_y1),
        )

        self.prev_bbox = smoothed
        return smoothed

    def __len__(self) -> int:
        return self.num_frames


def process_video(
    video_path: str,
    model_type: str = "expressive",
    batch_size: int = 1,
    num_workers: int = 0,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    output_video_path: Optional[str] = None,
    render_size: int = 512,
    num_render_workers: int = 1,
    max_queue_size: int = 128,
) -> List[Dict[str, Any]]:
    """
    Process video frames through SHEAP model and optionally render output video.

    Uses an IterableDataset for efficient sequential video processing without seeking overhead.
    Rendering is done in a background thread, and ffmpeg is used to create the final video.

    Args:
        video_path: Path to video file
        model_type: SHEAP model variant ("paper", "expressive", or "lightweight")
        batch_size: Batch size for processing
        num_workers: Number of workers (0 or 1 only). Will be clamped to max 1.
        device: Device to run model on ("cpu" or "cuda")
        output_video_path: If provided, render and save output video to this path
        render_size: Size of each sub-image in the rendered output
        num_render_workers: Number of background threads for rendering
        max_queue_size: Maximum size of the rendering queue

    Returns:
        List of dictionaries containing frame index, bounding box, and FLAME parameters
    """
    # Enforce num_workers constraint for IterableDataset
    num_workers = min(num_workers, 1)
    if num_workers > 1:
        print(f"Warning: num_workers > 1 not supported with IterableDataset. Using num_workers=1.")

    # Load SHEAP model
    print(f"Loading SHEAP model (type: {model_type})...")
    sheap_model = load_sheap_model(model_type=model_type)
    sheap_model.eval()
    sheap_model = sheap_model.to(device)

    # Load face alignment model
    # Force CPU for FA when using num_workers=1 (subprocess issues with GPU)
    fa_device = "cpu" if num_workers >= 1 else device
    print(f"Loading face alignment model on {fa_device}...")
    fa_model = face_alignment.FaceAlignment(
        face_alignment.LandmarksType.THREE_D, flip_input=False, device=fa_device
    )

    # Create dataset and dataloader
    dataset = VideoFrameDataset(video_path, fa_model)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
    )

    print(f"Processing {len(dataset)} frames from {video_path}")

    # Initialize FLAME model and rendering thread if rendering
    flame = None
    rendering_threads = []
    render_queue = None
    temp_dir = None
    c2w = None

    if output_video_path:
        print("Loading FLAME model for rendering...")
        flame_dir = Path("FLAME2020/")
        flame = TinyFlame(flame_dir / "generic_model.pt", eyelids_ckpt=flame_dir / "eyelids.pt")
        flame = flame.to(device)  # Move FLAME to GPU
        c2w = torch.tensor(
            [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=torch.float32
        )

        # Create temporary directory for rendered frames
        temp_dir = Path("./temp_sheap_render/")
        temp_dir.mkdir(parents=True, exist_ok=True)
        print(f"Using temporary directory: {temp_dir}")

        # Start multiple background rendering threads
        render_queue = Queue(maxsize=max_queue_size)
        for _ in range(num_render_workers):
            thread = RenderingThread(render_queue, temp_dir, flame.faces, c2w, render_size)
            thread.start()
            rendering_threads.append(thread)
        print(f"Started {num_render_workers} background rendering threads")

    results = []
    frame_count = 0

    with torch.no_grad():
        progbar = tqdm(total=len(dataset), desc="Processing frames")
        for batch in dataloader:
            frame_indices = batch["frame_idx"]
            images = batch["image"].to(device)
            bboxes = batch["bbox"]

            # Process through SHEAP model
            flame_params_dict = sheap_model(images)

            # Generate vertices for this batch if rendering
            if output_video_path and flame is not None:
                verts = flame(
                    shape=flame_params_dict["shape_from_facenet"],
                    expression=flame_params_dict["expr"],
                    pose=pose_components_to_rotmats(flame_params_dict),
                    eyelids=flame_params_dict["eyelids"],
                    translation=flame_params_dict["cam_trans"],
                )

            # Store results and queue for rendering
            for i in range(len(frame_indices)):
                frame_idx = _extract_scalar(frame_indices[i])
                bbox = tuple(_extract_scalar(b[i]) for b in bboxes)

                result = {
                    "frame_idx": frame_idx,
                    "bbox": bbox,
                    "flame_params": {k: v[i].cpu() for k, v in flame_params_dict.items()},
                }
                results.append(result)

                # Queue frame for rendering
                if output_video_path:
                    cropped_frame = _tensor_to_numpy_image(batch["cropped_frame"][i])
                    render_queue.put((frame_idx, cropped_frame, verts[i].cpu()))
                    frame_count += 1

            progbar.update(len(frame_indices))
        progbar.close()

    # Finalize rendering and create output video
    if output_video_path and render_queue is not None:
        _finalize_rendering(
            rendering_threads,
            render_queue,
            num_render_workers,
            temp_dir,
            dataset.fps,
            output_video_path,
        )

    return results


def _extract_scalar(value: Any) -> int:
    """Extract scalar integer from tensor or return as-is."""
    return value.item() if isinstance(value, torch.Tensor) else value


def _tensor_to_numpy_image(tensor: torch.Tensor) -> np.ndarray:
    """Convert (C, H, W) tensor [0, 1] to numpy (H, W, C) uint8 [0, 255]."""
    if not isinstance(tensor, torch.Tensor):
        return tensor
    return (tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)


def _finalize_rendering(
    rendering_threads: List[RenderingThread],
    render_queue: Queue,
    num_render_workers: int,
    temp_dir: Path,
    fps: float,
    output_video_path: str,
) -> None:
    """Finish rendering threads and create final video with ffmpeg."""
    print("\nWaiting for rendering threads to complete...")

    # Add sentinel values to stop workers
    for _ in range(num_render_workers):
        render_queue.put(None)

    # Wait for all threads to finish
    for thread in rendering_threads:
        thread.join()

    total_rendered = sum(thread.frames_rendered for thread in rendering_threads)
    print(f"Rendered {total_rendered} frames")

    # Create video with ffmpeg
    print("Creating video with ffmpeg...")
    output_path = Path(output_video_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    ffmpeg_cmd = [
        "ffmpeg",
        "-y",  # Overwrite output file if it exists
        "-framerate",
        str(fps),
        "-i",
        str(temp_dir / "frame_%06d.png"),
        "-c:v",
        "libx264",
        "-pix_fmt",
        "yuv420p",
        "-preset",
        "medium",
        "-crf",
        "23",
        str(output_path),
    ]

    subprocess.run(ffmpeg_cmd, check=True, capture_output=True)
    print(f"Video saved to: {output_video_path}")

    # Clean up temporary directory
    if temp_dir.exists():
        print(f"Removing temporary directory: {temp_dir}")
        shutil.rmtree(temp_dir)
        print("Cleanup complete")


if __name__ == "__main__":
    # video_path = "skarsgard.mp4"
    # output_video_path = "skarsgard_rendered.mp4"
    parser = argparse.ArgumentParser(description="Process and render video with SHEAP model.")
    parser.add_argument("in_path", type=str, help="Path to input video file.")
    parser.add_argument(
        "--out_path", type=str, help="Path to save rendered output video.", default=None
    )
    args = parser.parse_args()

    if args.out_path is None:
        args.out_path = str(Path(args.in_path).with_name(f"{Path(args.in_path).stem}_rendered.mp4"))

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    results = process_video(
        video_path=args.in_path,
        model_type="expressive",
        device=device,
        output_video_path=args.out_path,
    )