Spaces:
Build error
Build error
REF: SAM2 AMG and the corresponding test case.
Browse files- SegmentAnything2AssistApp.py +20 -18
- src/SegmentAnything2Assist/SegmentAnything2Assist.py +32 -11
- test/test_module.py +36 -9
SegmentAnything2AssistApp.py
CHANGED
|
@@ -257,25 +257,27 @@ def generate_auto_mask(
|
|
| 257 |
if VERBOSE:
|
| 258 |
print("SegmentAnything2AssistApp::generate_auto_mask::Called.")
|
| 259 |
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
|
|
|
|
|
|
| 276 |
)
|
| 277 |
|
| 278 |
-
if len(
|
| 279 |
gradio.Warning(
|
| 280 |
"No masks generated, please tweak the advanced parameters.", duration=5
|
| 281 |
)
|
|
@@ -294,7 +296,7 @@ def generate_auto_mask(
|
|
| 294 |
),
|
| 295 |
)
|
| 296 |
else:
|
| 297 |
-
choices = [str(i) for i in range(len(
|
| 298 |
|
| 299 |
returning_image = __generate_auto_mask(
|
| 300 |
image, ["0"], output_mode, False, masks, bboxes
|
|
|
|
| 257 |
if VERBOSE:
|
| 258 |
print("SegmentAnything2AssistApp::generate_auto_mask::Called.")
|
| 259 |
|
| 260 |
+
masks, bboxes, predicted_iou, stability_score = (
|
| 261 |
+
segment_anything2assist.generate_automatic_masks(
|
| 262 |
+
image,
|
| 263 |
+
points_per_side,
|
| 264 |
+
points_per_batch,
|
| 265 |
+
pred_iou_thresh,
|
| 266 |
+
stability_score_thresh,
|
| 267 |
+
stability_score_offset,
|
| 268 |
+
mask_threshold,
|
| 269 |
+
box_nms_thresh,
|
| 270 |
+
crop_n_layers,
|
| 271 |
+
crop_nms_thresh,
|
| 272 |
+
crop_overlay_ratio,
|
| 273 |
+
crop_n_points_downscale_factor,
|
| 274 |
+
min_mask_region_area,
|
| 275 |
+
use_m2m,
|
| 276 |
+
multimask_output,
|
| 277 |
+
)
|
| 278 |
)
|
| 279 |
|
| 280 |
+
if len(masks) == 0:
|
| 281 |
gradio.Warning(
|
| 282 |
"No masks generated, please tweak the advanced parameters.", duration=5
|
| 283 |
)
|
|
|
|
| 296 |
),
|
| 297 |
)
|
| 298 |
else:
|
| 299 |
+
choices = [str(i) for i in range(len(masks))]
|
| 300 |
|
| 301 |
returning_image = __generate_auto_mask(
|
| 302 |
image, ["0"], output_mode, False, masks, bboxes
|
src/SegmentAnything2Assist/SegmentAnything2Assist.py
CHANGED
|
@@ -98,7 +98,7 @@ class SegmentAnything2Assist:
|
|
| 98 |
)
|
| 99 |
|
| 100 |
if download:
|
| 101 |
-
self.
|
| 102 |
|
| 103 |
if self.is_model_available():
|
| 104 |
self.sam2 = sam2.build_sam.build_sam2(
|
|
@@ -121,14 +121,14 @@ class SegmentAnything2Assist:
|
|
| 121 |
print(f"SegmentAnything2Assist::is_model_available::{ret}")
|
| 122 |
return ret
|
| 123 |
|
| 124 |
-
def
|
| 125 |
if self.is_model_available():
|
| 126 |
self.sam2 = sam2.build_sam(checkpoint=self.model_path)
|
| 127 |
return True
|
| 128 |
|
| 129 |
return False
|
| 130 |
|
| 131 |
-
def
|
| 132 |
if not force and self.is_model_available():
|
| 133 |
print(f"{self.model_path} already exists. Skipping download.")
|
| 134 |
return False
|
|
@@ -162,7 +162,17 @@ class SegmentAnything2Assist:
|
|
| 162 |
min_mask_region_area=0,
|
| 163 |
use_m2m=False,
|
| 164 |
multimask_output=True,
|
| 165 |
-
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
if self.sam2 is None:
|
| 167 |
print(
|
| 168 |
"SegmentAnything2Assist::generate_automatic_masks::SAM2 is not loaded."
|
|
@@ -196,8 +206,15 @@ class SegmentAnything2Assist:
|
|
| 196 |
cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) for mask in segmentation_masks
|
| 197 |
]
|
| 198 |
bbox_masks = [mask["bbox"] for mask in masks]
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
def generate_masks_from_image(
|
| 203 |
self,
|
|
@@ -208,7 +225,15 @@ class SegmentAnything2Assist:
|
|
| 208 |
mask_threshold=0.0,
|
| 209 |
max_hole_area=0.0,
|
| 210 |
max_sprinkle_area=0.0,
|
| 211 |
-
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
generator = sam2.sam2_image_predictor.SAM2ImagePredictor(
|
| 213 |
self.sam2,
|
| 214 |
mask_threshold=mask_threshold,
|
|
@@ -240,8 +265,6 @@ class SegmentAnything2Assist:
|
|
| 240 |
image_with_bounding_boxes = image.copy()
|
| 241 |
all_masks = None
|
| 242 |
|
| 243 |
-
cv2.imwrite(".tmp/mask_2.png", masks[3])
|
| 244 |
-
|
| 245 |
for _ in auto_list:
|
| 246 |
mask = masks[_]
|
| 247 |
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
|
@@ -252,8 +275,6 @@ class SegmentAnything2Assist:
|
|
| 252 |
else:
|
| 253 |
all_masks = cv2.bitwise_or(all_masks, mask)
|
| 254 |
|
| 255 |
-
cv2.imwrite(".tmp/mask_3.png", masks[3])
|
| 256 |
-
|
| 257 |
random_color = numpy.random.randint(0, 255, size=3)
|
| 258 |
image_with_bounding_boxes = cv2.rectangle(
|
| 259 |
image_with_bounding_boxes,
|
|
|
|
| 98 |
)
|
| 99 |
|
| 100 |
if download:
|
| 101 |
+
self.__download_model()
|
| 102 |
|
| 103 |
if self.is_model_available():
|
| 104 |
self.sam2 = sam2.build_sam.build_sam2(
|
|
|
|
| 121 |
print(f"SegmentAnything2Assist::is_model_available::{ret}")
|
| 122 |
return ret
|
| 123 |
|
| 124 |
+
def __load_model(self) -> bool:
|
| 125 |
if self.is_model_available():
|
| 126 |
self.sam2 = sam2.build_sam(checkpoint=self.model_path)
|
| 127 |
return True
|
| 128 |
|
| 129 |
return False
|
| 130 |
|
| 131 |
+
def __download_model(self, force: bool = False) -> bool:
|
| 132 |
if not force and self.is_model_available():
|
| 133 |
print(f"{self.model_path} already exists. Skipping download.")
|
| 134 |
return False
|
|
|
|
| 162 |
min_mask_region_area=0,
|
| 163 |
use_m2m=False,
|
| 164 |
multimask_output=True,
|
| 165 |
+
) -> typing.Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]:
|
| 166 |
+
"""
|
| 167 |
+
Generates automatic masks from the given image.
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
typing.Tuple: Four numpy arrays where:
|
| 171 |
+
- segmentation_masks: Numpy array shape (N, H, W, C) where N is the number of masks, H is the height of the image, W is the width of the image, and C is the number of channels. Each N is a binary mask of the image of shape (H, W, C).
|
| 172 |
+
- bbox_masks: Numpy array of shape (N, 4) where N is the number of masks and 4 is the bounding box coordinates. Each mask is a bounding box of shape (x, y, w, h).
|
| 173 |
+
- predicted_iou: Numpy array of shape (N,) where N is the number of masks. Each value is the predicted IOU of the mask.
|
| 174 |
+
- stability_score: Numpy array of shape (N,) where N is the number of masks. Each value is the stability score of the mask.
|
| 175 |
+
"""
|
| 176 |
if self.sam2 is None:
|
| 177 |
print(
|
| 178 |
"SegmentAnything2Assist::generate_automatic_masks::SAM2 is not loaded."
|
|
|
|
| 206 |
cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) for mask in segmentation_masks
|
| 207 |
]
|
| 208 |
bbox_masks = [mask["bbox"] for mask in masks]
|
| 209 |
+
predicted_iou = [mask["predicted_iou"] for mask in masks]
|
| 210 |
+
stability_score = [mask["stability_score"] for mask in masks]
|
| 211 |
+
|
| 212 |
+
return (
|
| 213 |
+
numpy.array(segmentation_masks, dtype=numpy.uint8),
|
| 214 |
+
numpy.array(bbox_masks, dtype=numpy.uint32),
|
| 215 |
+
numpy.array(predicted_iou, dtype=numpy.float32),
|
| 216 |
+
numpy.array(stability_score, dtype=numpy.float32),
|
| 217 |
+
)
|
| 218 |
|
| 219 |
def generate_masks_from_image(
|
| 220 |
self,
|
|
|
|
| 225 |
mask_threshold=0.0,
|
| 226 |
max_hole_area=0.0,
|
| 227 |
max_sprinkle_area=0.0,
|
| 228 |
+
) -> typing.Tuple[numpy.ndarray, numpy.ndarray]:
|
| 229 |
+
"""
|
| 230 |
+
Generates masks from the given image.
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
typing.Tuple: Two numpy arrays where:
|
| 234 |
+
- masks_chw: Numpy array shape (1, H, W) for the mask, H is the height of the image, and W is the width of the image.
|
| 235 |
+
- mask_iou: Numpy array of shape (1,) for IOU of the mask.
|
| 236 |
+
"""
|
| 237 |
generator = sam2.sam2_image_predictor.SAM2ImagePredictor(
|
| 238 |
self.sam2,
|
| 239 |
mask_threshold=mask_threshold,
|
|
|
|
| 265 |
image_with_bounding_boxes = image.copy()
|
| 266 |
all_masks = None
|
| 267 |
|
|
|
|
|
|
|
| 268 |
for _ in auto_list:
|
| 269 |
mask = masks[_]
|
| 270 |
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
|
|
|
| 275 |
else:
|
| 276 |
all_masks = cv2.bitwise_or(all_masks, mask)
|
| 277 |
|
|
|
|
|
|
|
| 278 |
random_color = numpy.random.randint(0, 255, size=3)
|
| 279 |
image_with_bounding_boxes = cv2.rectangle(
|
| 280 |
image_with_bounding_boxes,
|
test/test_module.py
CHANGED
|
@@ -2,6 +2,8 @@ import unittest
|
|
| 2 |
import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
|
| 3 |
import cv2
|
| 4 |
|
|
|
|
|
|
|
| 5 |
|
| 6 |
class TestSegmentAnything2Assist(unittest.TestCase):
|
| 7 |
def setUp(self) -> None:
|
|
@@ -39,21 +41,46 @@ class TestSegmentAnything2Assist(unittest.TestCase):
|
|
| 39 |
device="cpu",
|
| 40 |
)
|
| 41 |
|
| 42 |
-
def
|
| 43 |
image = cv2.imread("test/assets/liberty.jpg")
|
| 44 |
|
| 45 |
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
|
| 46 |
sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
|
| 47 |
)
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
|
| 3 |
import cv2
|
| 4 |
|
| 5 |
+
import numpy
|
| 6 |
+
|
| 7 |
|
| 8 |
class TestSegmentAnything2Assist(unittest.TestCase):
|
| 9 |
def setUp(self) -> None:
|
|
|
|
| 41 |
device="cpu",
|
| 42 |
)
|
| 43 |
|
| 44 |
+
def _generate_automatic_mask(self):
|
| 45 |
image = cv2.imread("test/assets/liberty.jpg")
|
| 46 |
|
| 47 |
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
|
| 48 |
sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
|
| 49 |
)
|
| 50 |
|
| 51 |
+
segmentation_masks, bboxes, predicted_iou, stability_score = (
|
| 52 |
+
sam_model.generate_automatic_masks(image)
|
| 53 |
+
)
|
| 54 |
|
| 55 |
+
self.assertEqual(len(segmentation_masks.shape), 4)
|
| 56 |
+
self.assertEqual(segmentation_masks[0].shape, image.shape)
|
| 57 |
+
self.assertEqual(segmentation_masks.shape[3], 3)
|
| 58 |
+
self.assertEqual(type(segmentation_masks[0][0][0][0]), numpy.uint8)
|
| 59 |
+
self.assertEqual(len(bboxes.shape), 2)
|
| 60 |
+
self.assertEqual(bboxes[0].shape, (4,))
|
| 61 |
+
self.assertEqual(type(bboxes[0][0]), numpy.uint32)
|
| 62 |
+
self.assertEqual(len(predicted_iou.shape), 1)
|
| 63 |
+
self.assertEqual(type(predicted_iou[0]), numpy.float32)
|
| 64 |
+
self.assertEqual(len(stability_score.shape), 1)
|
| 65 |
+
self.assertEqual(type(stability_score[0]), numpy.float32)
|
| 66 |
|
| 67 |
+
for segmentation_mask in segmentation_masks:
|
| 68 |
+
self.assertEqual(segmentation_mask.shape, image.shape)
|
| 69 |
|
| 70 |
+
def test_generate_masks_from_image(self):
|
| 71 |
+
image = cv2.imread("test/assets/liberty.jpg")
|
| 72 |
+
|
| 73 |
+
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
|
| 74 |
+
sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
mask_chw, mask_iou = sam_model.generate_masks_from_image(
|
| 78 |
+
image, None, None, None
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
self.assertEqual(len(mask_chw.shape), 3)
|
| 82 |
+
self.assertEqual(mask_chw[0].shape, image.shape)
|
| 83 |
+
self.assertEqual(mask_chw.shape[0], 1)
|
| 84 |
+
|
| 85 |
+
self.assertEqual(len(mask_iou.shape), 1)
|
| 86 |
+
self.assertEqual(mask_iou.shape[0], 1)
|