Source code for dmme.equations.ddpm.ddpm

from torch import Tensor

import torch
from torch.distributions import Normal

import dmme


[docs]def linear_schedule(timesteps: int, start: float = 0.0001, end: float = 0.02) -> Tensor: r"""constants increasing linearly from :math:`10^{-4}` to :math:`0.02` Args: timesteps: total timesteps start: starting value, defaults to 0.0001 end: end value, defaults to 0.02 Returns: a 1d tensor representing :math:`\beta_t` indexed by :math:`t` """ beta = torch.linspace(start, end, timesteps) return dmme.pad(beta)
[docs]def forward_process(image: Tensor, alpha_bar_t: Tensor) -> Normal: r"""Forward Process, :math:`q(x_t|x_{t-1})` Args: image: image of shape :math:`(N, C, H, W)` alpha_bar_t: :math:`\bar\alpha_t` of shape :math:`(N, 1, 1, *)` noise: noise sampled from standard normal distribution with the same shape as the image Returns: gaussian transition distirbution :math:`q(x_t|x_{t-1})` """ mean = torch.sqrt(alpha_bar_t) * image variance = 1 - alpha_bar_t std = torch.sqrt(variance) return Normal(mean, std)
[docs]def reverse_process( x_t: Tensor, beta_t: Tensor, alpha_t: Tensor, alpha_bar_t: Tensor, noise_in_x_t: Tensor, variance: Tensor, ) -> Normal: r"""Reverse Denoising Process, :math:`p_\theta(x_{t-1}|x_t)` Args: beta_t: :math:`\beta_t` of shape :math:`(N, 1, 1, *)` alpha_t: :math:`\alpha_t` of shape :math:`(N, 1, 1, *)` alpha_bar_t: :math:`\bar\alpha_t` of shape :math:`(N, 1, 1, *)` noise_in_x_t: estimated noise in :math:`x_t` predicted by a neural network variance: variance of the reverse process, either learned or fixed noise: noise sampled from :math:`\mathcal{N}(0, I)` Returns: denoising distirbution :math:`q(x_t|x_{t-1})` """ mean = ( 1 / torch.sqrt(alpha_t) * (x_t - beta_t / torch.sqrt(1 - alpha_bar_t) * noise_in_x_t) ) std = torch.sqrt(variance) return Normal(mean, std)