|
|
|
|
|
import json |
|
|
import os |
|
|
import subprocess |
|
|
from pathlib import Path |
|
|
|
|
|
import cv2 |
|
|
import matplotlib.patches as patches |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import pycocotools.mask as mask_utils |
|
|
import torch |
|
|
from matplotlib.colors import to_rgb |
|
|
from PIL import Image |
|
|
from skimage.color import lab2rgb, rgb2lab |
|
|
from sklearn.cluster import KMeans |
|
|
from torchvision.ops import masks_to_boxes |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
def generate_colors(n_colors=256, n_samples=5000): |
|
|
|
|
|
np.random.seed(42) |
|
|
rgb = np.random.rand(n_samples, 3) |
|
|
|
|
|
|
|
|
lab = rgb2lab(rgb.reshape(1, -1, 3)).reshape(-1, 3) |
|
|
|
|
|
|
|
|
kmeans = KMeans(n_clusters=n_colors, n_init=10) |
|
|
|
|
|
kmeans.fit(lab) |
|
|
|
|
|
centers_lab = kmeans.cluster_centers_ |
|
|
|
|
|
colors_rgb = lab2rgb(centers_lab.reshape(1, -1, 3)).reshape(-1, 3) |
|
|
colors_rgb = np.clip(colors_rgb, 0, 1) |
|
|
return colors_rgb |
|
|
|
|
|
|
|
|
COLORS = generate_colors(n_colors=128, n_samples=5000) |
|
|
|
|
|
|
|
|
def show_img_tensor(img_batch, vis_img_idx=0): |
|
|
MEAN_IMG = np.array([0.5, 0.5, 0.5]) |
|
|
STD_IMG = np.array([0.5, 0.5, 0.5]) |
|
|
im_tensor = img_batch[vis_img_idx].detach().cpu() |
|
|
assert im_tensor.dim() == 3 |
|
|
im_tensor = im_tensor.numpy().transpose((1, 2, 0)) |
|
|
im_tensor = (im_tensor * STD_IMG) + MEAN_IMG |
|
|
im_tensor = np.clip(im_tensor, 0, 1) |
|
|
plt.imshow(im_tensor) |
|
|
|
|
|
|
|
|
def draw_box_on_image(image, box, color=(0, 255, 0)): |
|
|
""" |
|
|
Draws a rectangle on a given PIL image using the provided box coordinates in xywh format. |
|
|
:param image: PIL.Image - The image on which to draw the rectangle. |
|
|
:param box: tuple - A tuple (x, y, w, h) representing the top-left corner, width, and height of the rectangle. |
|
|
:param color: tuple - A tuple (R, G, B) representing the color of the rectangle. Default is red. |
|
|
:return: PIL.Image - The image with the rectangle drawn on it. |
|
|
""" |
|
|
|
|
|
image = image.convert("RGB") |
|
|
|
|
|
x, y, w, h = box |
|
|
x, y, w, h = int(x), int(y), int(w), int(h) |
|
|
|
|
|
pixels = image.load() |
|
|
|
|
|
for i in range(x, x + w): |
|
|
pixels[i, y] = color |
|
|
pixels[i, y + h - 1] = color |
|
|
pixels[i, y + 1] = color |
|
|
pixels[i, y + h] = color |
|
|
pixels[i, y - 1] = color |
|
|
pixels[i, y + h - 2] = color |
|
|
|
|
|
for j in range(y, y + h): |
|
|
pixels[x, j] = color |
|
|
pixels[x + 1, j] = color |
|
|
pixels[x - 1, j] = color |
|
|
pixels[x + w - 1, j] = color |
|
|
pixels[x + w, j] = color |
|
|
pixels[x + w - 2, j] = color |
|
|
return image |
|
|
|
|
|
|
|
|
def plot_bbox( |
|
|
img_height, |
|
|
img_width, |
|
|
box, |
|
|
box_format="XYXY", |
|
|
relative_coords=True, |
|
|
color="r", |
|
|
linestyle="solid", |
|
|
text=None, |
|
|
ax=None, |
|
|
): |
|
|
if box_format == "XYXY": |
|
|
x, y, x2, y2 = box |
|
|
w = x2 - x |
|
|
h = y2 - y |
|
|
elif box_format == "XYWH": |
|
|
x, y, w, h = box |
|
|
elif box_format == "CxCyWH": |
|
|
cx, cy, w, h = box |
|
|
x = cx - w / 2 |
|
|
y = cy - h / 2 |
|
|
else: |
|
|
raise RuntimeError(f"Invalid box_format {box_format}") |
|
|
|
|
|
if relative_coords: |
|
|
x *= img_width |
|
|
w *= img_width |
|
|
y *= img_height |
|
|
h *= img_height |
|
|
|
|
|
if ax is None: |
|
|
ax = plt.gca() |
|
|
rect = patches.Rectangle( |
|
|
(x, y), |
|
|
w, |
|
|
h, |
|
|
linewidth=1.5, |
|
|
edgecolor=color, |
|
|
facecolor="none", |
|
|
linestyle=linestyle, |
|
|
) |
|
|
ax.add_patch(rect) |
|
|
if text is not None: |
|
|
facecolor = "w" |
|
|
ax.text( |
|
|
x, |
|
|
y - 5, |
|
|
text, |
|
|
color=color, |
|
|
weight="bold", |
|
|
fontsize=8, |
|
|
bbox={"facecolor": facecolor, "alpha": 0.75, "pad": 2}, |
|
|
) |
|
|
|
|
|
|
|
|
def plot_mask(mask, color="r", ax=None): |
|
|
im_h, im_w = mask.shape |
|
|
mask_img = np.zeros((im_h, im_w, 4), dtype=np.float32) |
|
|
mask_img[..., :3] = to_rgb(color) |
|
|
mask_img[..., 3] = mask * 0.5 |
|
|
|
|
|
if ax is None: |
|
|
ax = plt.gca() |
|
|
ax.imshow(mask_img) |
|
|
|
|
|
|
|
|
def normalize_bbox(bbox_xywh, img_w, img_h): |
|
|
|
|
|
if isinstance(bbox_xywh, list): |
|
|
assert ( |
|
|
len(bbox_xywh) == 4 |
|
|
), "bbox_xywh list must have 4 elements. Batching not support except for torch tensors." |
|
|
normalized_bbox = bbox_xywh.copy() |
|
|
normalized_bbox[0] /= img_w |
|
|
normalized_bbox[1] /= img_h |
|
|
normalized_bbox[2] /= img_w |
|
|
normalized_bbox[3] /= img_h |
|
|
else: |
|
|
assert isinstance( |
|
|
bbox_xywh, torch.Tensor |
|
|
), "Only torch tensors are supported for batching." |
|
|
normalized_bbox = bbox_xywh.clone() |
|
|
assert ( |
|
|
normalized_bbox.size(-1) == 4 |
|
|
), "bbox_xywh tensor must have last dimension of size 4." |
|
|
normalized_bbox[..., 0] /= img_w |
|
|
normalized_bbox[..., 1] /= img_h |
|
|
normalized_bbox[..., 2] /= img_w |
|
|
normalized_bbox[..., 3] /= img_h |
|
|
return normalized_bbox |
|
|
|
|
|
|
|
|
def visualize_frame_output(frame_idx, video_frames, outputs, figsize=(12, 8)): |
|
|
plt.figure(figsize=figsize) |
|
|
plt.title(f"frame {frame_idx}") |
|
|
img = load_frame(video_frames[frame_idx]) |
|
|
img_H, img_W, _ = img.shape |
|
|
plt.imshow(img) |
|
|
for i in range(len(outputs["out_probs"])): |
|
|
box_xywh = outputs["out_boxes_xywh"][i] |
|
|
prob = outputs["out_probs"][i] |
|
|
obj_id = outputs["out_obj_ids"][i] |
|
|
binary_mask = outputs["out_binary_masks"][i] |
|
|
color = COLORS[obj_id % len(COLORS)] |
|
|
plot_bbox( |
|
|
img_H, |
|
|
img_W, |
|
|
box_xywh, |
|
|
text=f"(id={obj_id}, {prob=:.2f})", |
|
|
box_format="XYWH", |
|
|
color=color, |
|
|
) |
|
|
plot_mask(binary_mask, color=color) |
|
|
|
|
|
|
|
|
def visualize_formatted_frame_output( |
|
|
frame_idx, |
|
|
video_frames, |
|
|
outputs_list, |
|
|
titles=None, |
|
|
points_list=None, |
|
|
points_labels_list=None, |
|
|
figsize=(12, 8), |
|
|
title_suffix="", |
|
|
prompt_info=None, |
|
|
): |
|
|
"""Visualize up to three sets of segmentation masks on a video frame. |
|
|
|
|
|
Args: |
|
|
frame_idx: Frame index to visualize |
|
|
image_files: List of image file paths |
|
|
outputs_list: List of {frame_idx: {obj_id: mask_tensor}} or single dict {obj_id: mask_tensor} |
|
|
titles: List of titles for each set of outputs_list |
|
|
points_list: Optional list of point coordinates |
|
|
points_labels_list: Optional list of point labels |
|
|
figsize: Figure size tuple |
|
|
save: Whether to save the visualization to file |
|
|
output_dir: Base output directory when saving |
|
|
scenario_name: Scenario name for organizing saved files |
|
|
title_suffix: Additional title suffix |
|
|
prompt_info: Dictionary with prompt information (boxes, points, etc.) |
|
|
""" |
|
|
|
|
|
if isinstance(outputs_list, dict) and frame_idx in outputs_list: |
|
|
|
|
|
outputs_list = [outputs_list] |
|
|
elif isinstance(outputs_list, dict) and not any( |
|
|
isinstance(k, int) for k in outputs_list.keys() |
|
|
): |
|
|
|
|
|
single_frame_outputs = {frame_idx: outputs_list} |
|
|
outputs_list = [single_frame_outputs] |
|
|
|
|
|
num_outputs = len(outputs_list) |
|
|
if titles is None: |
|
|
titles = [f"Set {i+1}" for i in range(num_outputs)] |
|
|
assert ( |
|
|
len(titles) == num_outputs |
|
|
), "length of `titles` should match that of `outputs_list` if not None." |
|
|
|
|
|
_, axes = plt.subplots(1, num_outputs, figsize=figsize) |
|
|
if num_outputs == 1: |
|
|
axes = [axes] |
|
|
|
|
|
img = load_frame(video_frames[frame_idx]) |
|
|
img_H, img_W, _ = img.shape |
|
|
|
|
|
for idx in range(num_outputs): |
|
|
ax, outputs_set, ax_title = axes[idx], outputs_list[idx], titles[idx] |
|
|
ax.set_title(f"Frame {frame_idx} - {ax_title}{title_suffix}") |
|
|
ax.imshow(img) |
|
|
|
|
|
if frame_idx in outputs_set: |
|
|
_outputs = outputs_set[frame_idx] |
|
|
else: |
|
|
print(f"Warning: Frame {frame_idx} not found in outputs_set") |
|
|
continue |
|
|
|
|
|
if prompt_info and frame_idx == 0: |
|
|
if "boxes" in prompt_info: |
|
|
for box in prompt_info["boxes"]: |
|
|
|
|
|
x, y, w, h = box |
|
|
plot_bbox( |
|
|
img_H, |
|
|
img_W, |
|
|
[x, y, x + w, y + h], |
|
|
box_format="XYXY", |
|
|
relative_coords=True, |
|
|
color="yellow", |
|
|
linestyle="dashed", |
|
|
text="PROMPT BOX", |
|
|
ax=ax, |
|
|
) |
|
|
|
|
|
if "points" in prompt_info and "point_labels" in prompt_info: |
|
|
points = np.array(prompt_info["points"]) |
|
|
labels = np.array(prompt_info["point_labels"]) |
|
|
|
|
|
points_pixel = points * np.array([img_W, img_H]) |
|
|
|
|
|
|
|
|
pos_points = points_pixel[labels == 1] |
|
|
if len(pos_points) > 0: |
|
|
ax.scatter( |
|
|
pos_points[:, 0], |
|
|
pos_points[:, 1], |
|
|
color="lime", |
|
|
marker="*", |
|
|
s=200, |
|
|
edgecolor="white", |
|
|
linewidth=2, |
|
|
label="Positive Points", |
|
|
zorder=10, |
|
|
) |
|
|
|
|
|
|
|
|
neg_points = points_pixel[labels == 0] |
|
|
if len(neg_points) > 0: |
|
|
ax.scatter( |
|
|
neg_points[:, 0], |
|
|
neg_points[:, 1], |
|
|
color="red", |
|
|
marker="*", |
|
|
s=200, |
|
|
edgecolor="white", |
|
|
linewidth=2, |
|
|
label="Negative Points", |
|
|
zorder=10, |
|
|
) |
|
|
|
|
|
objects_drawn = 0 |
|
|
for obj_id, binary_mask in _outputs.items(): |
|
|
mask_sum = ( |
|
|
binary_mask.sum() |
|
|
if hasattr(binary_mask, "sum") |
|
|
else np.sum(binary_mask) |
|
|
) |
|
|
|
|
|
if mask_sum > 0: |
|
|
|
|
|
if not isinstance(binary_mask, torch.Tensor): |
|
|
binary_mask = torch.tensor(binary_mask) |
|
|
|
|
|
|
|
|
if binary_mask.any(): |
|
|
box_xyxy = masks_to_boxes(binary_mask.unsqueeze(0)).squeeze() |
|
|
box_xyxy = normalize_bbox(box_xyxy, img_W, img_H) |
|
|
else: |
|
|
|
|
|
box_xyxy = [0.45, 0.45, 0.55, 0.55] |
|
|
|
|
|
color = COLORS[obj_id % len(COLORS)] |
|
|
|
|
|
plot_bbox( |
|
|
img_H, |
|
|
img_W, |
|
|
box_xyxy, |
|
|
text=f"(id={obj_id})", |
|
|
box_format="XYXY", |
|
|
color=color, |
|
|
ax=ax, |
|
|
) |
|
|
|
|
|
|
|
|
mask_np = ( |
|
|
binary_mask.numpy() |
|
|
if isinstance(binary_mask, torch.Tensor) |
|
|
else binary_mask |
|
|
) |
|
|
plot_mask(mask_np, color=color, ax=ax) |
|
|
objects_drawn += 1 |
|
|
|
|
|
if objects_drawn == 0: |
|
|
ax.text( |
|
|
0.5, |
|
|
0.5, |
|
|
"No objects detected", |
|
|
transform=ax.transAxes, |
|
|
fontsize=16, |
|
|
ha="center", |
|
|
va="center", |
|
|
color="red", |
|
|
weight="bold", |
|
|
) |
|
|
|
|
|
|
|
|
if points_list is not None and points_list[idx] is not None: |
|
|
show_points( |
|
|
points_list[idx], points_labels_list[idx], ax=ax, marker_size=200 |
|
|
) |
|
|
|
|
|
ax.axis("off") |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
|
|
|
def render_masklet_frame(img, outputs, frame_idx=None, alpha=0.5): |
|
|
""" |
|
|
Overlays masklets and bounding boxes on a single image frame. |
|
|
Args: |
|
|
img: np.ndarray, shape (H, W, 3), uint8 or float32 in [0,255] or [0,1] |
|
|
outputs: dict with keys: out_boxes_xywh, out_probs, out_obj_ids, out_binary_masks |
|
|
frame_idx: int or None, for overlaying frame index text |
|
|
alpha: float, mask overlay alpha |
|
|
Returns: |
|
|
overlay: np.ndarray, shape (H, W, 3), uint8 |
|
|
""" |
|
|
if img.dtype == np.float32 or img.max() <= 1.0: |
|
|
img = (img * 255).astype(np.uint8) |
|
|
img = img[..., :3] |
|
|
height, width = img.shape[:2] |
|
|
overlay = img.copy() |
|
|
|
|
|
for i in range(len(outputs["out_probs"])): |
|
|
obj_id = outputs["out_obj_ids"][i] |
|
|
color = COLORS[obj_id % len(COLORS)] |
|
|
color255 = (color * 255).astype(np.uint8) |
|
|
mask = outputs["out_binary_masks"][i] |
|
|
if mask.shape != img.shape[:2]: |
|
|
mask = cv2.resize( |
|
|
mask.astype(np.float32), |
|
|
(img.shape[1], img.shape[0]), |
|
|
interpolation=cv2.INTER_NEAREST, |
|
|
) |
|
|
mask_bool = mask > 0.5 |
|
|
for c in range(3): |
|
|
overlay[..., c][mask_bool] = ( |
|
|
alpha * color255[c] + (1 - alpha) * overlay[..., c][mask_bool] |
|
|
).astype(np.uint8) |
|
|
|
|
|
|
|
|
for i in range(len(outputs["out_probs"])): |
|
|
box_xywh = outputs["out_boxes_xywh"][i] |
|
|
obj_id = outputs["out_obj_ids"][i] |
|
|
prob = outputs["out_probs"][i] |
|
|
color = COLORS[obj_id % len(COLORS)] |
|
|
color255 = tuple(int(x * 255) for x in color) |
|
|
x, y, w, h = box_xywh |
|
|
x1 = int(x * width) |
|
|
y1 = int(y * height) |
|
|
x2 = int((x + w) * width) |
|
|
y2 = int((y + h) * height) |
|
|
cv2.rectangle(overlay, (x1, y1), (x2, y2), color255, 2) |
|
|
if prob is not None: |
|
|
label = f"id={obj_id}, p={prob:.2f}" |
|
|
else: |
|
|
label = f"id={obj_id}" |
|
|
cv2.putText( |
|
|
overlay, |
|
|
label, |
|
|
(x1, max(y1 - 10, 0)), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
0.5, |
|
|
color255, |
|
|
1, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
|
|
|
|
|
|
if frame_idx is not None: |
|
|
cv2.putText( |
|
|
overlay, |
|
|
f"Frame {frame_idx}", |
|
|
(10, 30), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
1.0, |
|
|
(255, 255, 255), |
|
|
2, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
|
|
|
return overlay |
|
|
|
|
|
|
|
|
def save_masklet_video(video_frames, outputs, out_path, alpha=0.5, fps=10): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
first_img = load_frame(video_frames[0]) |
|
|
height, width = first_img.shape[:2] |
|
|
if first_img.dtype == np.float32 or first_img.max() <= 1.0: |
|
|
first_img = (first_img * 255).astype(np.uint8) |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
|
writer = cv2.VideoWriter("temp.mp4", fourcc, fps, (width, height)) |
|
|
|
|
|
outputs_list = [ |
|
|
(video_frames[frame_idx], frame_idx, outputs[frame_idx]) |
|
|
for frame_idx in sorted(outputs.keys()) |
|
|
] |
|
|
|
|
|
for frame, frame_idx, frame_outputs in tqdm(outputs_list): |
|
|
img = load_frame(frame) |
|
|
overlay = render_masklet_frame( |
|
|
img, frame_outputs, frame_idx=frame_idx, alpha=alpha |
|
|
) |
|
|
writer.write(cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)) |
|
|
|
|
|
writer.release() |
|
|
|
|
|
|
|
|
subprocess.run(["ffmpeg", "-y", "-i", "temp.mp4", out_path]) |
|
|
print(f"Re-encoded video saved to {out_path}") |
|
|
|
|
|
os.remove("temp.mp4") |
|
|
|
|
|
|
|
|
def save_masklet_image(frame, outputs, out_path, alpha=0.5, frame_idx=None): |
|
|
""" |
|
|
Save a single image with masklet overlays. |
|
|
""" |
|
|
img = load_frame(frame) |
|
|
overlay = render_masklet_frame(img, outputs, frame_idx=frame_idx, alpha=alpha) |
|
|
Image.fromarray(overlay).save(out_path) |
|
|
print(f"Overlay image saved to {out_path}") |
|
|
|
|
|
|
|
|
def prepare_masks_for_visualization(frame_to_output): |
|
|
|
|
|
for frame_idx, out in frame_to_output.items(): |
|
|
_processed_out = {} |
|
|
for idx, obj_id in enumerate(out["out_obj_ids"].tolist()): |
|
|
if out["out_binary_masks"][idx].any(): |
|
|
_processed_out[obj_id] = out["out_binary_masks"][idx] |
|
|
frame_to_output[frame_idx] = _processed_out |
|
|
return frame_to_output |
|
|
|
|
|
|
|
|
def convert_coco_to_masklet_format( |
|
|
annotations, img_info, is_prediction=False, score_threshold=0.5 |
|
|
): |
|
|
""" |
|
|
Convert COCO format annotations to format expected by render_masklet_frame |
|
|
""" |
|
|
outputs = { |
|
|
"out_boxes_xywh": [], |
|
|
"out_probs": [], |
|
|
"out_obj_ids": [], |
|
|
"out_binary_masks": [], |
|
|
} |
|
|
|
|
|
img_h, img_w = img_info["height"], img_info["width"] |
|
|
|
|
|
for idx, ann in enumerate(annotations): |
|
|
|
|
|
if "bbox" in ann: |
|
|
bbox = ann["bbox"] |
|
|
if max(bbox) > 1.0: |
|
|
bbox = [ |
|
|
bbox[0] / img_w, |
|
|
bbox[1] / img_h, |
|
|
bbox[2] / img_w, |
|
|
bbox[3] / img_h, |
|
|
] |
|
|
else: |
|
|
mask = mask_utils.decode(ann["segmentation"]) |
|
|
rows = np.any(mask, axis=1) |
|
|
cols = np.any(mask, axis=0) |
|
|
if np.any(rows) and np.any(cols): |
|
|
rmin, rmax = np.where(rows)[0][[0, -1]] |
|
|
cmin, cmax = np.where(cols)[0][[0, -1]] |
|
|
|
|
|
bbox = [ |
|
|
cmin / img_w, |
|
|
rmin / img_h, |
|
|
(cmax - cmin + 1) / img_w, |
|
|
(rmax - rmin + 1) / img_h, |
|
|
] |
|
|
else: |
|
|
bbox = [0, 0, 0, 0] |
|
|
|
|
|
outputs["out_boxes_xywh"].append(bbox) |
|
|
|
|
|
|
|
|
if is_prediction: |
|
|
prob = ann["score"] |
|
|
else: |
|
|
prob = 1.0 |
|
|
outputs["out_probs"].append(prob) |
|
|
|
|
|
outputs["out_obj_ids"].append(idx) |
|
|
mask = mask_utils.decode(ann["segmentation"]) |
|
|
mask = (mask > score_threshold).astype(np.uint8) |
|
|
|
|
|
outputs["out_binary_masks"].append(mask) |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
def save_side_by_side_visualization(img, gt_anns, pred_anns, noun_phrase): |
|
|
""" |
|
|
Create side-by-side visualization of GT and predictions |
|
|
""" |
|
|
|
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 7)) |
|
|
|
|
|
main_title = f"Noun phrase: '{noun_phrase}'" |
|
|
fig.suptitle(main_title, fontsize=16, fontweight="bold") |
|
|
|
|
|
gt_overlay = render_masklet_frame(img, gt_anns, alpha=0.5) |
|
|
ax1.imshow(gt_overlay) |
|
|
ax1.set_title("Ground Truth", fontsize=14, fontweight="bold") |
|
|
ax1.axis("off") |
|
|
|
|
|
pred_overlay = render_masklet_frame(img, pred_anns, alpha=0.5) |
|
|
ax2.imshow(pred_overlay) |
|
|
ax2.set_title("Predictions", fontsize=14, fontweight="bold") |
|
|
ax2.axis("off") |
|
|
|
|
|
plt.subplots_adjust(top=0.88) |
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
def bitget(val, idx): |
|
|
return (val >> idx) & 1 |
|
|
|
|
|
|
|
|
def pascal_color_map(): |
|
|
colormap = np.zeros((512, 3), dtype=int) |
|
|
ind = np.arange(512, dtype=int) |
|
|
for shift in reversed(list(range(8))): |
|
|
for channel in range(3): |
|
|
colormap[:, channel] |= bitget(ind, channel) << shift |
|
|
ind >>= 3 |
|
|
|
|
|
return colormap.astype(np.uint8) |
|
|
|
|
|
|
|
|
def draw_masks_to_frame( |
|
|
frame: np.ndarray, masks: np.ndarray, colors: np.ndarray |
|
|
) -> np.ndarray: |
|
|
masked_frame = frame |
|
|
for mask, color in zip(masks, colors): |
|
|
curr_masked_frame = np.where(mask[..., None], color, masked_frame) |
|
|
masked_frame = cv2.addWeighted(masked_frame, 0.75, curr_masked_frame, 0.25, 0) |
|
|
|
|
|
if int(cv2.__version__[0]) > 3: |
|
|
contours, _ = cv2.findContours( |
|
|
np.array(mask, dtype=np.uint8).copy(), |
|
|
cv2.RETR_TREE, |
|
|
cv2.CHAIN_APPROX_NONE, |
|
|
) |
|
|
else: |
|
|
_, contours, _ = cv2.findContours( |
|
|
np.array(mask, dtype=np.uint8).copy(), |
|
|
cv2.RETR_TREE, |
|
|
cv2.CHAIN_APPROX_NONE, |
|
|
) |
|
|
|
|
|
cv2.drawContours( |
|
|
masked_frame, contours, -1, (255, 255, 255), 7 |
|
|
) |
|
|
cv2.drawContours( |
|
|
masked_frame, contours, -1, (0, 0, 0), 5 |
|
|
) |
|
|
cv2.drawContours( |
|
|
masked_frame, contours, -1, color.tolist(), 3 |
|
|
) |
|
|
return masked_frame |
|
|
|
|
|
|
|
|
def get_annot_df(file_path: str): |
|
|
with open(file_path, "r") as f: |
|
|
data = json.load(f) |
|
|
|
|
|
dfs = {} |
|
|
|
|
|
for k, v in data.items(): |
|
|
if k in ("info", "licenses"): |
|
|
dfs[k] = v |
|
|
continue |
|
|
df = pd.DataFrame(v) |
|
|
dfs[k] = df |
|
|
|
|
|
return dfs |
|
|
|
|
|
|
|
|
def get_annot_dfs(file_list: list[str]): |
|
|
dfs = {} |
|
|
for annot_file in tqdm(file_list): |
|
|
dataset_name = Path(annot_file).stem |
|
|
dfs[dataset_name] = get_annot_df(annot_file) |
|
|
|
|
|
return dfs |
|
|
|
|
|
|
|
|
def get_media_dir(media_dir: str, dataset: str): |
|
|
if dataset in ["saco_veval_sav_test", "saco_veval_sav_val"]: |
|
|
return os.path.join(media_dir, "saco_sav", "JPEGImages_24fps") |
|
|
elif dataset in ["saco_veval_yt1b_test", "saco_veval_yt1b_val"]: |
|
|
return os.path.join(media_dir, "saco_yt1b", "JPEGImages_6fps") |
|
|
elif dataset in ["saco_veval_smartglasses_test", "saco_veval_smartglasses_val"]: |
|
|
return os.path.join(media_dir, "saco_sg", "JPEGImages_6fps") |
|
|
elif dataset == "sa_fari_test": |
|
|
return os.path.join(media_dir, "sa_fari", "JPEGImages_6fps") |
|
|
else: |
|
|
raise ValueError(f"Dataset {dataset} not found") |
|
|
|
|
|
|
|
|
def get_all_annotations_for_frame( |
|
|
dataset_df: pd.DataFrame, video_id: int, frame_idx: int, data_dir: str, dataset: str |
|
|
): |
|
|
media_dir = os.path.join(data_dir, "media") |
|
|
|
|
|
|
|
|
annot_df = dataset_df["annotations"] |
|
|
video_df = dataset_df["videos"] |
|
|
|
|
|
|
|
|
video_df_current = video_df[video_df.id == video_id] |
|
|
assert ( |
|
|
len(video_df_current) == 1 |
|
|
), f"Expected 1 video row, got {len(video_df_current)}" |
|
|
video_row = video_df_current.iloc[0] |
|
|
file_name = video_row.file_names[frame_idx] |
|
|
file_path = os.path.join( |
|
|
get_media_dir(media_dir=media_dir, dataset=dataset), file_name |
|
|
) |
|
|
frame = cv2.imread(file_path) |
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
annot_df_current_video = annot_df[annot_df.video_id == video_id] |
|
|
if len(annot_df_current_video) == 0: |
|
|
print(f"No annotations found for video_id {video_id}") |
|
|
return frame, None, None |
|
|
else: |
|
|
empty_mask = np.zeros(frame.shape[:2], dtype=np.uint8) |
|
|
mask_np_pairs = annot_df_current_video.apply( |
|
|
lambda row: ( |
|
|
( |
|
|
mask_utils.decode(row.segmentations[frame_idx]) |
|
|
if row.segmentations[frame_idx] |
|
|
else empty_mask |
|
|
), |
|
|
row.noun_phrase, |
|
|
), |
|
|
axis=1, |
|
|
) |
|
|
|
|
|
mask_np_pairs = sorted(mask_np_pairs, key=lambda x: x[1]) |
|
|
masks, noun_phrases = zip(*mask_np_pairs) |
|
|
|
|
|
return frame, masks, noun_phrases |
|
|
|
|
|
|
|
|
def visualize_prompt_overlay( |
|
|
frame_idx, |
|
|
video_frames, |
|
|
title="Prompt Visualization", |
|
|
text_prompt=None, |
|
|
point_prompts=None, |
|
|
point_labels=None, |
|
|
bounding_boxes=None, |
|
|
box_labels=None, |
|
|
obj_id=None, |
|
|
): |
|
|
"""Simple prompt visualization function""" |
|
|
img = Image.fromarray(load_frame(video_frames[frame_idx])) |
|
|
fig, ax = plt.subplots(1, figsize=(6, 4)) |
|
|
ax.imshow(img) |
|
|
|
|
|
img_w, img_h = img.size |
|
|
|
|
|
if text_prompt: |
|
|
ax.text( |
|
|
0.02, |
|
|
0.98, |
|
|
f'Text: "{text_prompt}"', |
|
|
transform=ax.transAxes, |
|
|
fontsize=12, |
|
|
color="white", |
|
|
weight="bold", |
|
|
bbox=dict(boxstyle="round,pad=0.3", facecolor="red", alpha=0.7), |
|
|
verticalalignment="top", |
|
|
) |
|
|
|
|
|
if point_prompts: |
|
|
for i, point in enumerate(point_prompts): |
|
|
x, y = point |
|
|
|
|
|
x_img, y_img = x * img_w, y * img_h |
|
|
|
|
|
|
|
|
if point_labels and len(point_labels) > i: |
|
|
color = "green" if point_labels[i] == 1 else "red" |
|
|
marker = "o" if point_labels[i] == 1 else "x" |
|
|
else: |
|
|
color = "green" |
|
|
marker = "o" |
|
|
|
|
|
ax.plot( |
|
|
x_img, |
|
|
y_img, |
|
|
marker=marker, |
|
|
color=color, |
|
|
markersize=10, |
|
|
markeredgewidth=2, |
|
|
markeredgecolor="white", |
|
|
) |
|
|
ax.text( |
|
|
x_img + 5, |
|
|
y_img - 5, |
|
|
f"P{i+1}", |
|
|
color=color, |
|
|
fontsize=10, |
|
|
weight="bold", |
|
|
bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.8), |
|
|
) |
|
|
|
|
|
if bounding_boxes: |
|
|
for i, box in enumerate(bounding_boxes): |
|
|
x, y, w, h = box |
|
|
|
|
|
x_img, y_img = x * img_w, y * img_h |
|
|
w_img, h_img = w * img_w, h * img_h |
|
|
|
|
|
|
|
|
if box_labels and len(box_labels) > i: |
|
|
color = "green" if box_labels[i] == 1 else "red" |
|
|
else: |
|
|
color = "green" |
|
|
|
|
|
rect = patches.Rectangle( |
|
|
(x_img, y_img), |
|
|
w_img, |
|
|
h_img, |
|
|
linewidth=2, |
|
|
edgecolor=color, |
|
|
facecolor="none", |
|
|
) |
|
|
ax.add_patch(rect) |
|
|
ax.text( |
|
|
x_img, |
|
|
y_img - 5, |
|
|
f"B{i+1}", |
|
|
color=color, |
|
|
fontsize=10, |
|
|
weight="bold", |
|
|
bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.8), |
|
|
) |
|
|
|
|
|
|
|
|
if obj_id is not None: |
|
|
ax.text( |
|
|
0.02, |
|
|
0.02, |
|
|
f"Object ID: {obj_id}", |
|
|
transform=ax.transAxes, |
|
|
fontsize=10, |
|
|
color="white", |
|
|
weight="bold", |
|
|
bbox=dict(boxstyle="round,pad=0.3", facecolor="blue", alpha=0.7), |
|
|
verticalalignment="bottom", |
|
|
) |
|
|
|
|
|
ax.set_title(title) |
|
|
ax.axis("off") |
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
|
|
|
def plot_results(img, results): |
|
|
plt.figure(figsize=(12, 8)) |
|
|
plt.imshow(img) |
|
|
nb_objects = len(results["scores"]) |
|
|
print(f"found {nb_objects} object(s)") |
|
|
for i in range(nb_objects): |
|
|
color = COLORS[i % len(COLORS)] |
|
|
plot_mask(results["masks"][i].squeeze(0).cpu(), color=color) |
|
|
w, h = img.size |
|
|
prob = results["scores"][i].item() |
|
|
plot_bbox( |
|
|
h, |
|
|
w, |
|
|
results["boxes"][i].cpu(), |
|
|
text=f"(id={i}, {prob=:.2f})", |
|
|
box_format="XYXY", |
|
|
color=color, |
|
|
relative_coords=False, |
|
|
) |
|
|
|
|
|
|
|
|
def single_visualization(img, anns, title): |
|
|
""" |
|
|
Create a single image visualization with overlays. |
|
|
""" |
|
|
fig, ax = plt.subplots(figsize=(7, 7)) |
|
|
fig.suptitle(title, fontsize=16, fontweight="bold") |
|
|
overlay = render_masklet_frame(img, anns, alpha=0.5) |
|
|
ax.imshow(overlay) |
|
|
ax.axis("off") |
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
def show_mask(mask, ax, obj_id=None, random_color=False): |
|
|
if random_color: |
|
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
|
else: |
|
|
cmap = plt.get_cmap("tab10") |
|
|
cmap_idx = 0 if obj_id is None else obj_id |
|
|
color = np.array([*cmap(cmap_idx)[:3], 0.6]) |
|
|
h, w = mask.shape[-2:] |
|
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
|
ax.imshow(mask_image) |
|
|
|
|
|
|
|
|
def show_box(box, ax): |
|
|
x0, y0 = box[0], box[1] |
|
|
w, h = box[2] - box[0], box[3] - box[1] |
|
|
ax.add_patch( |
|
|
plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2) |
|
|
) |
|
|
|
|
|
|
|
|
def show_points(coords, labels, ax, marker_size=375): |
|
|
pos_points = coords[labels == 1] |
|
|
neg_points = coords[labels == 0] |
|
|
ax.scatter( |
|
|
pos_points[:, 0], |
|
|
pos_points[:, 1], |
|
|
color="green", |
|
|
marker="*", |
|
|
s=marker_size, |
|
|
edgecolor="white", |
|
|
linewidth=1.25, |
|
|
) |
|
|
ax.scatter( |
|
|
neg_points[:, 0], |
|
|
neg_points[:, 1], |
|
|
color="red", |
|
|
marker="*", |
|
|
s=marker_size, |
|
|
edgecolor="white", |
|
|
linewidth=1.25, |
|
|
) |
|
|
|
|
|
|
|
|
def load_frame(frame): |
|
|
if isinstance(frame, np.ndarray): |
|
|
img = frame |
|
|
elif isinstance(frame, Image.Image): |
|
|
img = np.array(frame) |
|
|
elif isinstance(frame, str) and os.path.isfile(frame): |
|
|
img = plt.imread(frame) |
|
|
else: |
|
|
raise ValueError(f"Invalid video frame type: {type(frame)=}") |
|
|
return img |
|
|
|