import argparse import os from typing import Literal import numpy as np import torch from monai.data import PersistentDataset, load_decathlon_datalist from monai.transforms import ( Compose, ConcatItemsd, DeleteItemsd, EnsureTyped, LoadImaged, NormalizeIntensityd, RandCropByPosNegLabeld, RandWeightedCropd, ToTensord, Transform, Transposed, ) from torch.utils.data.dataloader import default_collate from .custom_transforms import ( ClipMaskIntensityPercentilesd, ElementwiseProductd, NormalizeIntensity_customd, ) class DummyMILDataset(torch.utils.data.Dataset): def __init__(self, args, num_samples=8): self.num_samples = num_samples self.args = args def __len__(self): return self.num_samples def __getitem__(self, index): # Simulate the output of your 'data_transform' # A list of dictionaries, one for each 'tile_count' (patch) bag = [] label_value = float(index % 2) for _ in range(self.args.tile_count): item = { # Shape: (Channels=3, Depth, H, W) based on your Transposed(indices=(0, 3, 1, 2)) "image": torch.randn(3, self.args.depth, self.args.tile_size, self.args.tile_size), "label": torch.tensor(label_value, dtype=torch.float32), } if self.args.use_heatmap: item["final_heatmap"] = torch.randn( 1, self.args.depth, self.args.tile_size, self.args.tile_size ) bag.append(item) return bag def list_data_collate(batch: list): """ Combine instances from a list of dicts into a single dict, by stacking them along first dim [{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW} followed by the default collate which will form a batch BxNx3xHxW """ for i, item in enumerate(batch): data = item[0] data["image"] = torch.stack([ix["image"] for ix in item], dim=0) if all("final_heatmap" in ix for ix in item): data["final_heatmap"] = torch.stack([ix["final_heatmap"] for ix in item], dim=0) batch[i] = data return default_collate(batch) def data_transform(args: argparse.Namespace) -> Transform: if args.use_heatmap: transform = Compose( [ LoadImaged( keys=["image", "mask", "dwi", "adc", "heatmap"], reader="ITKReader", ensure_channel_first=True, dtype=np.float32, ), ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"), ConcatItemsd( keys=["image", "dwi", "adc"], name="image", dim=0 ), # stacks to (3, H, W) NormalizeIntensity_customd(keys=["image"], channel_wise=True, mask_key="mask"), ElementwiseProductd(keys=["mask", "heatmap"], output_key="final_heatmap"), RandWeightedCropd( keys=["image", "final_heatmap"], w_key="final_heatmap", spatial_size=(args.tile_size, args.tile_size, args.depth), num_samples=args.tile_count, ), EnsureTyped(keys=["label"], dtype=torch.float32), Transposed(keys=["image"], indices=(0, 3, 1, 2)), DeleteItemsd(keys=["mask", "dwi", "adc", "heatmap"]), ToTensord(keys=["image", "label", "final_heatmap"]), ] ) else: transform = Compose( [ LoadImaged( keys=["image", "mask", "dwi", "adc"], reader="ITKReader", ensure_channel_first=True, dtype=np.float32, ), ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"), ConcatItemsd( keys=["image", "dwi", "adc"], name="image", dim=0 ), # stacks to (3, H, W) NormalizeIntensityd(keys=["image"], channel_wise=True), RandCropByPosNegLabeld( keys=["image"], label_key="mask", spatial_size=(args.tile_size, args.tile_size, args.depth), pos=1, neg=0, num_samples=args.tile_count, ), EnsureTyped(keys=["label"], dtype=torch.float32), Transposed(keys=["image"], indices=(0, 3, 1, 2)), DeleteItemsd(keys=["mask", "dwi", "adc"]), ToTensord(keys=["image", "label"]), ] ) return transform def get_dataloader( args: argparse.Namespace, split: Literal["train", "test"] ) -> torch.utils.data.DataLoader: if args.dry_run: print(f"🛠️ DRY RUN: Creating synthetic {split} dataloader...") dummy_ds = DummyMILDataset(args, num_samples=args.batch_size * 2) return torch.utils.data.DataLoader( dummy_ds, batch_size=args.batch_size, collate_fn=list_data_collate, # Uses your custom stacking logic num_workers=0, # Keep it simple for dry run ) data_list = load_decathlon_datalist( data_list_file_path=args.dataset_json, data_list_key=split, base_dir=args.data_root, ) cache_dir_ = os.path.join(args.logdir, "cache") os.makedirs(os.path.join(cache_dir_, split), exist_ok=True) transform = data_transform(args) dataset = PersistentDataset( data=data_list, transform=transform, cache_dir=os.path.join(cache_dir_, split) ) loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=(split == "train"), num_workers=args.workers, pin_memory=True, multiprocessing_context="fork" if args.workers > 0 else None, sampler=None, collate_fn=list_data_collate, ) return loader