Source code for dmme.equations.ddim.ddim
from torch import Tensor
import torch
from torch.distributions import Normal
import dmme.equations as eq
[docs]def linear_tau(timesteps: int, sub_timesteps: int) -> Tensor:
r"""Linear sub-sequence :math:`\tau`
Args:
timesteps: total timesteps :math:`T`
sub_timesteps: sub-sequence length less than :math:`T`
"""
all_t = torch.arange(0, sub_timesteps + 1)
c = timesteps / sub_timesteps
tau = torch.round(c * all_t)
return tau.long()
[docs]def quadratic_tau(timesteps: int, sub_timesteps: int) -> Tensor:
r"""Quadratic sub-sequence :math:`\tau`
Args:
timesteps: total timesteps :math:`T`
sub_timesteps: sub-sequence length less than :math:`T`
"""
all_t = torch.arange(0, sub_timesteps + 1)
c = timesteps / (sub_timesteps**2)
tau = torch.round(c * all_t**2)
return tau.long()
[docs]def reverse_process(
x_t: Tensor,
alpha_bar_t: Tensor,
alpha_bar_t_minus_one: Tensor,
noise_in_x_t: Tensor,
) -> Normal:
r"""Deterministic Denoising Process where :math:`\sigma_t = 0` for all :math:`t`
Args:
x_t: :math:`x_t`
alpha_bar_t: :math:`\bar\alpha_t`
alpha_bar_t_minus_one: :math:`\bar\alpha_{t-1}` of shape :math:`(N, 1, 1, *)`
noise_in_x_t: estimated noise in :math:`x_t` predicted by a neural network
"""
predicted_x_0 = (x_t - torch.sqrt(1 - alpha_bar_t) * noise_in_x_t) / torch.sqrt(
alpha_bar_t_minus_one
)
p = eq.ddpm.forward_process(predicted_x_0, alpha_bar_t_minus_one)
return p