| from typing import Dict |
|
|
| import numpy as np |
| from einops import rearrange |
| from monai.transforms.transform import Transform |
|
|
|
|
| class OrientationGuidanceMultipleLabelDeepEditd(Transform): |
| def __init__(self, ref_image="image", label_names=None): |
| """ |
| Convert the guidance to the RAS orientation |
| """ |
| self.ref_image = ref_image |
| self.label_names = label_names |
|
|
| def transform_points(self, point, affine): |
| """transform point to the coordinates of the transformed image |
| point: numpy array [bs, N, 3] |
| """ |
| bs, n = point.shape[:2] |
| point = np.concatenate((point, np.ones((bs, n, 1))), axis=-1) |
| point = rearrange(point, "b n d -> d (b n)") |
| point = affine @ point |
| point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3] |
| return point |
|
|
| def __call__(self, data): |
| d: Dict = dict(data) |
| for key_label in self.label_names.keys(): |
| points = d.get(key_label, []) |
| if len(points) < 1: |
| continue |
| reoriented_points = self.transform_points( |
| np.array(points)[None], |
| np.linalg.inv(d[self.ref_image].meta["affine"].numpy()) @ d[self.ref_image].meta["original_affine"], |
| ) |
| d[key_label] = reoriented_points[0] |
| return d |
|
|