Source code for dmme.callbacks.generate

from tqdm import tqdm

import torch

import pytorch_lightning as pl

import dmme


[docs]class GenerateImage(pl.Callback): r"""Generate samples to check training progress Args: imgsize (Tuple[int, int, int]): A tuple of ints representing image shape :math:`(C, H, W)` batch_size (int): Number of samples to generate vis_length (int): Length of denoising sequence to visualize every_n_epochs (int): Only save those images every N epochs test (bool): generates images on test if set to true """ def __init__( self, imgsize, timesteps, batch_size=8, vis_length=20, every_n_epochs=5, ): super().__init__() self.imgsize = imgsize self.timesteps = timesteps self.batch_size = batch_size self.vis_length = vis_length self.every_n_epochs = every_n_epochs
[docs] def on_train_epoch_end(self, trainer, pl_module): self._shared_hook(trainer, pl_module)
def _shared_hook(self, trainer, pl_module): if trainer.current_epoch % self.every_n_epochs == 0: if trainer.logger is None: return history = self.generate_img(pl_module) grid = dmme.make_history(history) if isinstance(trainer.logger, list): for logger in trainer.logger: self._log(logger, grid) else: self._log(trainer.logger, grid) def _log(self, logger, grid): experiment = logger.experiment if isinstance(logger, pl.loggers.WandbLogger): logger.log_image("generated_images", [grid]) if isinstance(logger, pl.loggers.TensorBoardLogger): experiment.add_image("generated_images", grid, pl_module.global_step) @torch.inference_mode() def generate_img(self, pl_module): pl_module.eval() denoising_sequence = [] x_t = dmme.gaussian((self.batch_size, *self.imgsize), device=pl_module.device) timesteps = self.timesteps save_t = [ int(timesteps / (self.vis_length - 1) * i) for i in range(self.vis_length - 1, 0, -1) ] for t in tqdm(range(timesteps, 0, -1), leave=False): if t in save_t: denoising_sequence.append(dmme.denorm(x_t.clone().detach())) x_t = pl_module(x_t, t) denoising_sequence.append(dmme.denorm(x_t.clone().detach())) pl_module.train() return denoising_sequence