Source code for dmme.common.vis

import math

import torch
from torchvision.utils import make_grid


[docs]def make_history(history): r"""Visualize diffusion process given an array of histories Args: history (List[torch.Tensor]): list of diffusion history with each item as a tensor of shape :math:`(N, C, H, W)` """ if len(history) == 1: img = history[-1] nrow = 1 batch_size = img.size(0) for i in range(int(math.sqrt(batch_size)), 2, -1): if batch_size % i == 0: nrow = batch_size // i break grid = make_grid(img, nrow=nrow) else: history = torch.stack(history, dim=1) grid = make_grid(history.flatten(0, 1), nrow=history.size(1)) return grid