| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| import math |
| from typing import Iterator, Optional, Sequence, TypeVar |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.utils.data import Dataset, Sampler |
|
|
| __all__ = ["DistributedWeightedSampler"] |
|
|
| T_co = TypeVar("T_co", covariant=True) |
|
|
|
|
| class DistributedWeightedSampler(Sampler[T_co]): |
| def __init__( |
| self, |
| dataset: Dataset, |
| weights: Sequence[float], |
| num_samples: int, |
| num_replicas: Optional[int] = None, |
| rank: Optional[int] = None, |
| shuffle: bool = True, |
| seed: int = 0, |
| drop_last: bool = False, |
| ) -> None: |
| if not isinstance(num_samples, int) or isinstance(num_samples, bool) or num_samples <= 0: |
| raise ValueError(f"num_samples should be a positive integer value, but got num_samples={num_samples}") |
|
|
| weights_tensor = torch.as_tensor(weights, dtype=torch.float) |
| if len(weights_tensor.shape) != 1: |
| raise ValueError( |
| "weights should be a 1d sequence but given " f"weights have shape {tuple(weights_tensor.shape)}" |
| ) |
|
|
| self.weights = weights_tensor |
| self.num_samples = num_samples |
|
|
| if num_replicas is None: |
| if not dist.is_available(): |
| raise RuntimeError("Requires distributed package to be available") |
| num_replicas = dist.get_world_size() |
| if rank is None: |
| if not dist.is_available(): |
| raise RuntimeError("Requires distributed package to be available") |
| rank = dist.get_rank() |
| if rank >= num_replicas or rank < 0: |
| raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") |
| self.dataset = dataset |
| self.num_replicas = num_replicas |
| self.rank = rank |
| self.epoch = 0 |
| self.drop_last = drop_last |
| self.shuffle = shuffle |
|
|
| if self.shuffle: |
| self.num_samples = int(math.ceil(self.num_samples / self.num_replicas)) |
| else: |
| |
|
|
| |
| |
| if self.drop_last and len(self.dataset) % self.num_replicas != 0: |
| |
| |
| |
| self.num_samples = math.ceil( |
| (len(self.dataset) - self.num_replicas) / self.num_replicas |
| ) |
| else: |
| self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) |
|
|
| self.total_size = self.num_samples * self.num_replicas |
| self.shuffle = shuffle |
| self.seed = seed |
|
|
| def __iter__(self) -> Iterator[T_co]: |
| if self.shuffle: |
| |
| g = torch.Generator() |
| g.manual_seed(self.seed + self.epoch) |
| indices = torch.multinomial(input=self.weights, num_samples=self.total_size, replacement=True, generator=g).tolist() |
| else: |
| |
| indices = list(range(len(self.dataset))) |
| if not self.drop_last: |
| |
| padding_size = self.total_size - len(indices) |
| if padding_size <= len(indices): |
| indices += indices[:padding_size] |
| else: |
| indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] |
| else: |
| |
| indices = indices[: self.total_size] |
| assert len(indices) == self.total_size |
|
|
| |
| indices = indices[self.rank : self.total_size : self.num_replicas] |
| assert len(indices) == self.num_samples |
|
|
| return iter(indices) |
|
|
| def __len__(self) -> int: |
| return self.num_samples |
|
|
| def set_epoch(self, epoch: int) -> None: |
| self.epoch = epoch |
|
|