Source code for dmme.ddpm.ddpm

from typing import Tuple

from tqdm import tqdm

import torch
from torch import nn

import einops

import dmme


[docs]def linear_schedule(timesteps: int, start=0.0001, end=0.02) -> torch.Tensor: r"""constants increasing linearly from :math:`10^{-4}` to :math:`0.02` Args: timesteps (int): total timesteps start (float): starting value, defaults to 0.0001 end (float): end value, defaults to 0.02 """ beta = torch.linspace(start, end, timesteps) return dmme.pad(beta)
[docs]def sample_gaussian(mean, variance, noise): r"""Samples from a gaussian distribution using the reparameterization trick Args: mean (torch.Tensor): mean of the distribution variance (torch.Tensor): variance of the distribution noise (torch.Tensor): noise sampled from :math:`\mathcal{N}(0, I)` """ return mean + torch.sqrt(variance) * noise
[docs]def forward_process(image, alpha_bar_t, noise): r"""Forward Process, :math:`q(x_t|x_{t-1})` Args: image (torch.Tensor): image of shape :math:`(N, C, H, W)` alpha_bar_t (torch.Tensor): :math:`\bar\alpha_t` of shape :math:`(N, 1, 1, *)` noise (torch.Tensor): noise sampled from standard normal distribution with the same shape as the image """ mean = torch.sqrt(alpha_bar_t) * image variance = 1 - alpha_bar_t return sample_gaussian(mean, variance, noise)
[docs]def reverse_process(x_t, beta_t, alpha_t, alpha_bar_t, noise_in_x_t, variance, noise): r"""Reverse Denoising Process, :math:`p_\theta(x_{t-1}|x_t)` Args: beta_t (torch.Tensor): :math:`\beta_t` of shape :math:`(N, 1, 1, *)` alpha_t (torch.Tensor): :math:`\alpha_t` of shape :math:`(N, 1, 1, *)` alpha_bar_t (torch.Tensor): :math:`\bar\alpha_t` of shape :math:`(N, 1, 1, *)` noise_in_x_t (torch.Tensor): estimated noise in :math:`x_t` predicted by a neural network variance (torch.Tensor): variance of the reverse process, either learned or fixed noise (torch.Tensor): noise sampled from :math:`\mathcal{N}(0, I)` """ mean = ( 1 / torch.sqrt(alpha_t) * (x_t - beta_t / torch.sqrt(1 - alpha_bar_t) * noise_in_x_t) ) return sample_gaussian(mean, variance, noise)
[docs]def simple_loss(noise, estimated_noise): r"""Simple Loss objective :math:`L_\text{simple}`, MSE loss between noise and predicted noise Args: noise (torch.Tensor): noise used in the forward process estimated_noise (torch.Tensor): estimated noise with the same shape as :code:`noise` """ return nn.functional.mse_loss(noise, estimated_noise)
[docs]class DDPM(nn.Module): """Training and Sampling for DDPM Args: timesteps (int): total timesteps :math:`T` """ beta: torch.Tensor alpha: torch.Tensor alpha_bar: torch.Tensor def __init__(self, model, timesteps) -> None: super().__init__() self.model = model self.timesteps = timesteps beta = linear_schedule(timesteps) beta = einops.rearrange(beta, "t -> t 1 1 1") alpha = 1 - beta # alpha[0] = 1 so no problems here alpha_bar = torch.cumprod(alpha, dim=0) self.register_buffer("beta", beta, persistent=False) self.register_buffer("alpha", alpha, persistent=False) self.register_buffer("alpha_bar", alpha_bar, persistent=False)
[docs] def training_step(self, x_0): r"""Computes loss for DDPM Args: x_0 (torch.Tensor): sample image to add noise and denoise for training Returns: (torch.Tensor): loss, :math:`L_\text{simple}` """ batch_size = x_0.size(0) time = dmme.uniform_int( 0, self.timesteps, batch_size, device=x_0.device, ) noise = dmme.gaussian_like(x_0) alpha_bar_t = self.alpha_bar[time] x_t = forward_process(x_0, alpha_bar_t, noise) noise_in_x_t = self.model(x_t, time) loss = simple_loss(noise, noise_in_x_t) return loss
[docs] def sampling_step(self, x_t, t): r"""Denoise image by sampling from :math:`p_\theta(x_{t-1}|x_t)` Args: model (nn.Module): model for estimating noise x_t (torch.Tensor): image of shape :math:`(N, C, H, W)` t (torch.Tensor): starting :math:`t` to sample from, a tensor of shape :math:`(N,)` Returns: (torch.Tensor): denoised image of shape :math:`(N, C, H, W)` """ noise = dmme.gaussian_like(x_t) (idx,) = torch.where(t == 1) noise[idx] = 0 beta_t = self.beta[t] alpha_t = self.alpha[t] alpha_bar_t = self.alpha_bar[t] noise_in_x_t = self.model(x_t, t) x_t = reverse_process( x_t, beta_t, alpha_t, alpha_bar_t, noise_in_x_t, variance=beta_t, noise=noise, ) return x_t
[docs] def generate(self, img_size: Tuple[int, int, int, int]): """Generate image of shape :math:`(N, C, H, W)` by running the full denoising steps Args: img_size (Tuple[int, int, int, int]): image size to generate as a tuple :math:`(N, C, H, W)` Returns: (torch.Tensor): generated image of shape :math:`(N, C, H, W)` """ x_t = dmme.gaussian(img_size, device=self.beta.device) all_t = torch.arange( 0, self.timesteps + 1, device=self.beta.device, ).unsqueeze(dim=1) for t in tqdm(range(self.timesteps, 0, -1), leave=False): x_t = self.sampling_step(x_t, all_t[t]) return x_t
[docs] def forward(self, x, t): """Predicts noise given image and timestep""" noise_in_x = self.model(x, t) return noise_in_x