| | import os |
| | import numpy as np |
| | from abc import abstractmethod |
| | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset |
| |
|
| |
|
| | class Txt2ImgIterableBaseDataset(IterableDataset): |
| | ''' |
| | Define an interface to make the IterableDatasets for text2img data chainable |
| | ''' |
| | def __init__(self, num_records=0, valid_ids=None, size=256): |
| | super().__init__() |
| | self.num_records = num_records |
| | self.valid_ids = valid_ids |
| | self.sample_ids = valid_ids |
| | self.size = size |
| |
|
| | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') |
| |
|
| | def __len__(self): |
| | return self.num_records |
| |
|
| | @abstractmethod |
| | def __iter__(self): |
| | pass |
| |
|
| |
|
| | class PRNGMixin(object): |
| | """ |
| | Adds a prng property which is a numpy RandomState which gets |
| | reinitialized whenever the pid changes to avoid synchronized sampling |
| | behavior when used in conjunction with multiprocessing. |
| | """ |
| | @property |
| | def prng(self): |
| | currentpid = os.getpid() |
| | if getattr(self, "_initpid", None) != currentpid: |
| | self._initpid = currentpid |
| | self._prng = np.random.RandomState() |
| | return self._prng |
| |
|