Source code for dmme.data_modules.data_module

import multiprocessing as mp

import pytorch_lightning as pl
from torch.utils.data import DataLoader


[docs]class DataModule(pl.LightningDataModule): """LightningDataModule with defaults for generative modeling > Defaults are set from DDPM. `setup_train` and `setup_test` is used for preparing training and test sets. In practice, they both use training sets but augmentations are only applied on `setup_train` Prepares `DataLoader`s with good defaults with batch size set from `__init__`. Args: batch_size (int): batch size for `DataLoader` """ def __init__(self, batch_size): super().__init__() self.batch_size = batch_size
[docs] def setup_train(self): """Prepare training set""" raise NotImplementedError
[docs] def setup_test(self): """Prepare test set""" raise NotImplementedError
[docs] def setup(self, stage: str): """Prepare dataset for training or testing""" if stage == "fit": self.train_set = self.setup_train() elif stage == "test": self.test_set = self.setup_test()
[docs] def train_dataloader(self): """DataLoader with good defaults automatically sets num_workers based on cpu count. """ return DataLoader( self.train_set, batch_size=self.batch_size, shuffle=True, pin_memory=True, num_workers=cpu_count(), )
[docs] def test_dataloader(self): """DataLoader with good defaults automatically sets num_workers based on cpu count. """ return DataLoader( self.test_set, batch_size=self.batch_size, pin_memory=True, num_workers=cpu_count(), )
[docs]def cpu_count(*args, **kwargs): """returns cpu count from multiprocessing package""" return mp.cpu_count(*args, **kwargs)