| | from statistics import mode |
| | from fvcore.common.config import CfgNode |
| | import numpy as np |
| | import os |
| | import cv2 |
| | import glob |
| | import tqdm |
| | from PIL import Image |
| | from PIL import ImageOps |
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| | from modeling.MaskFormerModel import MaskFormerModel |
| | from utils.misc import load_parallal_model |
| | from utils.misc import ADEVisualize |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | class Segmentation(): |
| | def __init__(self, cfg, model=None): |
| | self.cfg = cfg |
| | self.num_queries = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES |
| | self.size_divisibility = cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY |
| | self.num_classes = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES |
| | self.device = torch.device("cuda", cfg.local_rank) |
| |
|
| | |
| | self.padding_constant = 2**5 |
| | self.test_dir = cfg.TEST.TEST_DIR |
| | self.output_dir = cfg.TEST.SAVE_DIR |
| | self.imgMaxSize = cfg.INPUT.CROP.MAX_SIZE |
| | self.pixel_mean = np.array(cfg.DATASETS.PIXEL_MEAN) |
| | self.pixel_std = np.array(cfg.DATASETS.PIXEL_STD) |
| | self.visualize = ADEVisualize() |
| | self.model = None |
| |
|
| | pretrain_weights = cfg.MODEL.PRETRAINED_WEIGHTS |
| | if model is not None: |
| | self.model = model |
| | elif os.path.exists(pretrain_weights): |
| | self.model = MaskFormerModel(cfg, is_init=False) |
| | self.load_model(pretrain_weights) |
| | else: |
| | print(f'please check weights file: {cfg.MODEL.PRETRAINED_WEIGHTS}') |
| | |
| | def load_model(self, pretrain_weights): |
| | state_dict = torch.load(pretrain_weights, map_location='cuda:0') |
| |
|
| | ckpt_dict = state_dict['model'] |
| | self.last_lr = state_dict['lr'] |
| | self.start_epoch = state_dict['epoch'] |
| | self.model = load_parallal_model(self.model, ckpt_dict) |
| | self.model = self.model.to(self.device) |
| | self.model.eval() |
| | print("loaded pretrain mode:{}".format(pretrain_weights)) |
| |
|
| | def img_transform(self, img): |
| | |
| | img = np.float32(np.array(img)) / 255. |
| | img = (img - self.pixel_mean) / self.pixel_std |
| | img = img.transpose((2, 0, 1)) |
| | return img |
| |
|
| | |
| | def round2nearest_multiple(self, x, p): |
| | return ((x - 1) // p + 1) * p |
| |
|
| | def get_img_ratio(self, img_size, target_size): |
| | img_rate = np.max(img_size) / np.min(img_size) |
| | target_rate = np.max(target_size) / np.min(target_size) |
| | if img_rate > target_rate: |
| | |
| | ratio = max(target_size) / max(img_size) |
| | else: |
| | ratio = min(target_size) / min(img_size) |
| | return ratio |
| |
|
| | def resize_padding(self, img, outsize, Interpolation=Image.BILINEAR): |
| | w, h = img.size |
| | target_w, target_h = outsize[0], outsize[1] |
| | ratio = self.get_img_ratio([w, h], outsize) |
| | ow, oh = round(w * ratio), round(h * ratio) |
| | img = img.resize((ow, oh), Interpolation) |
| | dh, dw = target_h - oh, target_w - ow |
| | top, bottom = dh // 2, dh - (dh // 2) |
| | left, right = dw // 2, dw - (dw // 2) |
| | img = ImageOps.expand(img, border=(left, top, right, bottom), fill=0) |
| | return img, [left, top, right, bottom] |
| |
|
| | def get_img_ratio(self, img_size, target_size): |
| | img_rate = np.max(img_size) / np.min(img_size) |
| | target_rate = np.max(target_size) / np.min(target_size) |
| | if img_rate > target_rate: |
| | |
| | ratio = max(target_size) / max(img_size) |
| | else: |
| | ratio = min(target_size) / min(img_size) |
| | return ratio |
| | |
| | def image_preprocess(self, img): |
| | img_height, img_width = img.shape[0], img.shape[1] |
| | this_scale = self.get_img_ratio((img_width, img_height), self.imgMaxSize) |
| | target_width = img_width * this_scale |
| | target_height = img_height * this_scale |
| | input_width = int(self.round2nearest_multiple(target_width, self.padding_constant)) |
| | input_height = int(self.round2nearest_multiple(target_height, self.padding_constant)) |
| |
|
| | img, padding_info = self.resize_padding(Image.fromarray(img), (input_width, input_height)) |
| | img = self.img_transform(img) |
| |
|
| | transformer_info = {'padding_info': padding_info, 'scale': this_scale, 'input_size':(input_height, input_width)} |
| | input_tensor = torch.from_numpy(img).float().unsqueeze(0).to(self.device) |
| | return input_tensor, transformer_info |
| |
|
| | def semantic_inference(self, mask_cls, mask_pred): |
| | mask_cls = F.softmax(mask_cls, dim=-1)[...,1:] |
| | mask_pred = mask_pred.sigmoid() |
| | semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) |
| | return semseg.cpu().numpy() |
| |
|
| | def postprocess(self, pred_mask, transformer_info, target_size): |
| | oh, ow = pred_mask.shape[0], pred_mask.shape[1] |
| | padding_info = transformer_info['padding_info'] |
| | |
| | left, top, right, bottom = padding_info[0], padding_info[1], padding_info[2], padding_info[3] |
| | mask = pred_mask[top: oh - bottom, left: ow - right] |
| | mask = cv2.resize(mask.astype(np.uint8), dsize=target_size, interpolation=cv2.INTER_NEAREST) |
| | return mask |
| |
|
| | @torch.no_grad() |
| | def forward(self, img_list=None): |
| | if img_list is None or len(img_list) == 0: |
| | img_list = glob.glob(self.test_dir + '/*.[jp][pn]g') |
| | mask_images = [] |
| | for image_path in tqdm.tqdm(img_list): |
| | |
| | |
| | |
| | img = Image.open(image_path).convert('RGB') |
| | img_height, img_width = img.size[1], img.size[0] |
| | inpurt_tensor, transformer_info = self.image_preprocess(np.array(img)) |
| |
|
| | outputs = self.model(inpurt_tensor) |
| | mask_cls_results = outputs["pred_logits"] |
| | mask_pred_results = outputs["pred_masks"] |
| | |
| | mask_pred_results = F.interpolate( |
| | mask_pred_results, |
| | size=(inpurt_tensor.shape[-2], inpurt_tensor.shape[-1]), |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| | pred_masks = self.semantic_inference(mask_cls_results, mask_pred_results) |
| | mask_img = np.argmax(pred_masks, axis=1)[0] |
| | mask_img = self.postprocess(mask_img, transformer_info, (img_width, img_height)) |
| | mask_images.append(mask_img) |
| | return mask_images |
| | |
| |
|
| | def render_image(self, img, mask_img, output_path=None): |
| | self.visualize.show_result(img, mask_img, output_path) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|