Source code for dmme.ddpm.ddpm_sampler

import torch
from torch import nn
import torch.nn.functional as F

import einops

from dmme.common import gaussian, gaussian_like, uniform_int, denorm


[docs]class DDPMSampler(nn.Module): """Wrapper for computing forward and reverse processes, sampling data, and computing loss for DDPM Paper: https://arxiv.org/abs/2006.11239 Code: https://github.com/hojonathanho/diffusion Args: model (nn.Module): model timesteps (int): diffusion timesteps """ def __init__( self, model: nn.Module, timesteps: int, ): super().__init__() self.model = model self.timesteps = timesteps beta = self.noise_schedule() if beta is not None: self.register_alphas(beta)
[docs] def forward_process(self, x_0, t, noise=None): r"""Forward Diffusion Process Samples :math:`x_t` from :math:`q(x_t|x_0) = \mathcal{N}(x_t;\sqrt{\bar\alpha_t}\bold{x}_0,(1-\bar\alpha_t)\bold{I})` Computes :math:`\bold{x}_t = \sqrt{\bar\alpha_t}\bold{x}_0 + \sqrt{1-\bar\alpha_t}\bold{I}` Args: x_0 (torch.Tensor): data to add noise to t (int): :math:`t` in :math:`x_t` noise (torch.Tensor, optional): :math:`\epsilon`, noise used in the forward process Returns: (torch.Tensor): :math:`\bold{x}_t \sim q(\bold{x}_t|\bold{x}_0)` """ t_index = t - 1 if noise is None: noise = gaussian_like(x_0) x_t = ( self.sqrt_alpha_bar[t_index] * x_0 + self.sqrt_one_minus_alpha_bar[t_index] * noise ) return x_t
[docs] def noise_schedule(self): r"""Noise Schedule for DDPM DDPM sets :math:`T = 1000` and linearly increases :math:`\beta_t` from :math:`10^{-4}` to :math:`0.02` Returns: (torch.Tensor): :math:`\beta_1, \, ... \, ,\beta_T` as a tensor of shape :math:`(T,)` """ return linear_schedule(timesteps=self.timesteps)
[docs] def register_alphas(self, beta): r"""Caches :math:`\alpha_t` used in the forward and reverse process :math:`\alpha_t` is constant so we register them in `nn.Module`'s buffers Args: beta (torch.Tensor): beta values to use to compute alphas, a tensor of shape :math:`(T,)` """ alpha = 1.0 - beta alpha_bar = torch.cumprod(alpha, dim=0) sqrt_alpha_bar = einops.rearrange(torch.sqrt(alpha_bar), "t -> t 1 1 1") sqrt_one_minus_alpha_bar = einops.rearrange( torch.sqrt(1 - alpha_bar), "t -> t 1 1 1" ) one_over_sqrt_alpha = 1 / torch.sqrt(alpha) beta_over_sqrt_one_minus_alpha_bar = beta / torch.sqrt(1 - alpha_bar) sigma = torch.sqrt(beta) self.register_buffer("beta", beta) self.register_buffer("sqrt_alpha_bar", sqrt_alpha_bar) self.register_buffer("sqrt_one_minus_alpha_bar", sqrt_one_minus_alpha_bar) self.register_buffer("one_over_sqrt_alpha", one_over_sqrt_alpha) self.register_buffer( "beta_over_sqrt_one_minus_alpha_bar", beta_over_sqrt_one_minus_alpha_bar ) self.register_buffer("sigma", sigma) self.register_buffer("t", torch.arange(1, self.timesteps)[:, None])
[docs] def reverse_process(self, x_t, t, noise=None): r"""Reverse Denoising Process Samples :math:`x_{t-1}` from :math:`p_\theta(\bold{x}_{t-1}|\bold{x}_t) = \mathcal{N}(\bold{x}_{t-1};\mu_\theta(\bold{x}_t, t), \sigma_t\bold{I})` .. math:: \begin{aligned} \bold\mu_\theta(\bold{x}_t, t) &= \frac{1}{\sqrt{\alpha_t}}\bigg(\bold{x}_t -\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(\bold{x}_t,t)\bigg) \\ \sigma_t &= \beta_t \end{aligned} Computes :math:`\bold{x}_{t-1} = \frac{1}{\sqrt{\alpha_t}}\bigg(\bold{x}_t -\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(\bold{x}_t,t)\bigg) +\sigma_t\epsilon` Args: x_t (torch.Tensor): x_t t (int): current timestep noise (torch.Tensor): noise """ t_index = t - 1 t = torch.tensor([t], device=x_t.device).float() if noise is None: noise = gaussian_like(x_t) x_t_minus_one = ( self.one_over_sqrt_alpha[t_index] * ( x_t - ( self.beta_over_sqrt_one_minus_alpha_bar[t_index] * self.model(x_t, t) ) ) + self.sigma[t_index] * noise ) return x_t_minus_one
[docs] def compute_loss(self, x_0, t=None, noise=None): r"""Computes the loss :math:`L_\text{simple} = \mathbb{E}_{\bold{x}_0\sim q(\bold{x}_0), \epsilon\sim\mathcal{N}(\bold{0},\bold{I}), t\sim\mathcal{U}(1,T)} \left[\|\epsilon-\epsilon_\theta(\bold{x}_t, t) \|^2\right]` Args: x_0 (torch.Tensor): :math:`x_0` t (int, optional): sampled :math:`t` noise (torch.Tensor, optional): sampled :math:`\epsilon` """ if t is None: batch_size = x_0.size(0) t = uniform_int(0, self.timesteps, batch_size, device=x_0.device) if noise is None: noise = gaussian_like(x_0) noisy_x = self.forward_process(x_0, t, noise) noise_estimate = self.model(noisy_x, t) loss = F.mse_loss(noise, noise_estimate) return loss
[docs] @torch.inference_mode() def sample(self, x_shape, start=0, end=None, step=1, save_last=True, device=None): r"""Generate Samples Iteratively sample from :math:`p_\theta(x_t,t)` .. code-block:: python x = gaussian() for t in range(T, 0, -1): z = gaussian() if t > 1 else 0 x = reverse_process(x, t, z) return x start, end, step parameters specify which steps to return Args: x_shape (Tuple[int, int, int]): image shape start (int): start t end (int): end t step (int): step save_last (bool): whether to save last sample device (torch.device): device Returns: (List[torch.Tensor]): denoised samples """ history = [] def save(x): history.append(denorm(x)) def format(n): if n < 0: n %= self.timesteps n += 1 elif n is None: n = self.timesteps + 1 return n start = format(start) end = format(end) x_t = gaussian(x_shape, device=device) if self.timesteps == start: save(x_t) for t in range(self.timesteps, 1, -1): x_t = self.reverse_process(x_t, t) if end > t - 1 >= start: if (t - 1 - start) % step == 0: save(x_t) x_0 = self.reverse_process(x_t, 1, noise=0) if save_last or end > 0 >= start: save(x_0) return history
[docs] def forward(self, x, t): r"""Predicts the noise given :math:`x_t` and :math:`t` Applies forward to the internal model Expects :math:`x_t` to have shape :math:`(N, C, H, W)` Args: x (torch.Tensor): image t (int): :math:`t` in :math:`\bold{x}_t` """ return self.model(x, t)
[docs]def linear_schedule(timesteps, start=0.0001, end=0.02): r"""constants increasing linearly from :math:`10^{-4}` to :math:`0.02` Args: timesteps (int): total timesteps start (float): starting value, defaults to 0.0001 end (float): end value, defaults to 0.02 """ return torch.linspace(start, end, timesteps)