Source code for dmme.ddim

from typing import Tuple

from tqdm import tqdm

import torch
from torch import nn
from torch import Tensor

from dmme.ddpm import LitDDPM, DDPM, pad


[docs]class LitDDIM(LitDDPM): r"""LightningModule for sampling with DDIM with :code:`LitDDPM`'s checkpoints Args: model (nn.Module): neural network predicting noise :math:`\epsilon_\theta` lr (float): learning rate, defaults to :math:`2e-4` warmup (int): linearly increases learning rate for `warmup` steps until `lr` is reached, defaults to 5000 imgsize (Tuple[int, int, int]): image size in `(C, H, W)` timestpes (int): total timesteps for the forward and reverse process, :math:`T` decay (float): EMA decay value sample_steps (int): sample steps for generation process tau_schedule (str): tau schedule to use for generation, `"linear"` or `"quadratic"` """ def __init__( self, model: nn.Module, lr: float = 2e-4, warmup: int = 5000, imgsize: Tuple[int, int, int] = (3, 32, 32), timesteps: int = 1000, decay: float = 0.9999, sample_steps: int = 50, tau_schedule: str = "quadratic", ): super().__init__(model, lr, warmup, imgsize, timesteps, decay) self.sample_steps = sample_steps self.process = DDIM(timesteps, tau_schedule=tau_schedule)
[docs] def forward(self, x_t: Tensor, t: int): r"""Denoise image once using :code:`DDIM` Args: x_t (torch.Tensor): image of shape :math:`(N, C, H, W)` t (int): starting :math:`t` to sample from Returns: (torch.Tensor): generated sample of shape :math:`(N, C, H, W)` """ timestep = torch.tensor([t], device=x_t.device) x_t = self.process.sample(self.model, x_t, timestep) return x_t
[docs] def generate(self, x_t): r"""Iteratively sample from :math:`p_\theta(x_{t-1}|x_t)` to generate images Args: x_t (torch.Tensor): :math:`x_T` to start from """ for t in tqdm(range(self.sample_steps, 0, -1), leave=False): x_t = self(x_t, t) return x_t
[docs]class DDIM(DDPM): r"""Reverse process and Sampling for DDIM Args: timesteps (int): total timesteps :math:`T` tau_schedule (str): tau schedule, `"linear"`or `"quadratic"` """ tau: Tensor def __init__(self, timesteps, tau_schedule="quadratic") -> None: super().__init__(timesteps) full_timesteps = self.beta.size(0) - 1 tau_schedule = tau_schedule.lower() if tau_schedule == "linear": c = full_timesteps / timesteps tau = [round(c * i) for i in range(timesteps + 1)] elif tau_schedule == "quadratic": c = full_timesteps / (timesteps**2) tau = [round(c * i**2) for i in range(timesteps + 1)] else: raise NotImplementedError tau = torch.tensor(tau) tau = pad(tau) self.register_buffer("tau", tau, persistent=False)
[docs] def reverse_process(self, model, x_t, t): r"""Reverse Denoising Process Samples :math:`x_{t-1}` from :math:`p_\theta(\bold{x}_{t-1}|\bold{x}_t) = \mathcal{N}(\bold{x}_{t-1};\mu_\theta(\bold{x}_t, t), \sigma_t\bold{I})` .. math:: \begin{aligned} \bold\mu_\theta(\bold{x}_t, t) &= \frac{1}{\sqrt{\alpha_t}}\bigg(\bold{x}_t -\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(\bold{x}_t,t)\bigg) \\ \sigma_t &= \beta_t \end{aligned} Computes :math:`\bold{x}_{t-1} = \frac{1}{\sqrt{\alpha_t}}\bigg(\bold{x}_t -\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(\bold{x}_t,t)\bigg) +\sigma_t\epsilon` Args: model (nn.Module): model for estimating noise x_t (torch.Tensor): x_t t (int): current timestep noise (torch.Tensor): noise """ tau_t = self.tau[t] tau_t_minus_one = self.tau[t - 1] alpha_bar_t_minus_one = self.alpha_bar[tau_t_minus_one] alpha_bar_t = self.alpha_bar[tau_t] noise_estimate = model(x_t, tau_t) predicted_x_0 = ( x_t - torch.sqrt(1 - alpha_bar_t) * noise_estimate ) / torch.sqrt(alpha_bar_t) direction_pointing_to_x_t = ( torch.sqrt(1 - alpha_bar_t_minus_one) * noise_estimate ) x_t_minus_one = ( torch.sqrt(alpha_bar_t_minus_one) * predicted_x_0 + direction_pointing_to_x_t ) return x_t_minus_one
[docs] def sample(self, model, x_t, t): r"""Sample from :math:`p_\theta(x_{t-1}|x_t)` Args: model (nn.Module): model for estimating noise x_t (torch.Tensor): image of shape :math:`(N, C, H, W)` t (int): starting :math:`t` to sample from Returns: (torch.Tensor): generated sample of shape :math:`(N, C, H, W)` """ return self.reverse_process(model, x_t, self.tau[t])