from typing import Tuple
from tqdm import tqdm
import torch
from dmme import ddpm
from dmme.common import gaussian
[docs]def linear_tau(timesteps, sub_timesteps):
"""linear tau schedule
Args:
timesteps (int): total timesteps :math:`T`
sub_timesteps (int): sub sequence length less than :math:`T`
"""
all_t = torch.arange(0, sub_timesteps + 1)
c = timesteps / sub_timesteps
tau = torch.round(c * all_t).long()
return tau
[docs]def quadratic_tau(timesteps, sub_timesteps):
"""quadratic tau schedule
Args:
timesteps (int): total timesteps :math:`T`
sub_timesteps (int): sub sequence length less than :math:`T`
"""
all_t = torch.arange(0, sub_timesteps + 1)
c = timesteps / (timesteps**2)
tau = torch.round(c * all_t**2).long()
return tau
[docs]def reverse_process(
x_tau_i, alpha_bar_tau_i, alpha_bar_tau_i_minus_one, noise_in_x_tau_i
):
r"""DDIM Reverse Denoising Process
Args:
model (nn.Module): model for estimating noise
x_t (torch.Tensor): x_t
t (int): current timestep
noise (torch.Tensor): noise
"""
predicted_x_0 = (
x_tau_i - torch.sqrt(1 - alpha_bar_tau_i) * noise_in_x_tau_i
) / torch.sqrt(alpha_bar_tau_i)
x_tau_i_minus_one = ddpm.forward_process(
predicted_x_0, alpha_bar_tau_i_minus_one, noise_in_x_tau_i
)
return x_tau_i_minus_one
[docs]class DDIM(ddpm.DDPM):
r"""Reverse process and Sampling for DDIM
Args:
timesteps (int): total timesteps :math:`T`
tau_schedule (str): tau schedule, `"linear"`or `"quadratic"`
"""
tau: torch.Tensor
def __init__(
self, model, timesteps, sub_timesteps, tau_schedule="quadratic"
) -> None:
super().__init__(model, timesteps)
self.sub_timesteps = sub_timesteps
tau_schedule = tau_schedule.lower()
if tau_schedule == "linear":
tau = linear_tau(timesteps, sub_timesteps)
elif tau_schedule == "quadratic":
tau = quadratic_tau(timesteps, sub_timesteps)
else:
raise NotImplementedError
self.register_buffer("tau", tau, persistent=False)
[docs] def sampling_step(self, x_tau_i, i):
r"""Sample from :math:`p_\theta(x_{t-1}|x_t)`
Args:
model (nn.Module): model for estimating noise
x_t (torch.Tensor): image of shape :math:`(N, C, H, W)`
t (int): starting :math:`t` to sample from
Returns:
(torch.Tensor): generated sample of shape :math:`(N, C, H, W)`
"""
tau_i = self.tau[i]
tau_i_minus_one = self.tau[i - 1]
alpha_bar_tau_i_minus_one = self.alpha_bar[tau_i_minus_one]
alpha_bar_tau_i = self.alpha_bar[tau_i]
noise_in_x_tau_i = self.model(x_tau_i, tau_i)
x_tau_i_minus_one = reverse_process(
x_tau_i, alpha_bar_tau_i, alpha_bar_tau_i_minus_one, noise_in_x_tau_i
)
return x_tau_i_minus_one
[docs] def generate(self, img_size: Tuple[int, int, int, int]):
"""Generate image of shape :math:`(N, C, H, W)` faster by only sampling the sub sequence
Args:
img_size (Tuple[int, int, int, int]): image size to generate as a tuple :math:`(N, C, H, W)`
Returns:
(torch.Tensor): generated image of shape :math:`(N, C, H, W)`
"""
x_tau_i = gaussian(img_size, device=self.beta.device)
all_i = torch.arange(
0,
self.sub_timesteps + 1,
device=self.beta.device,
).unsqueeze(dim=1)
for i in tqdm(range(self.sub_timesteps, 0, -1), leave=False):
x_tau_i = self.sampling_step(x_tau_i, all_i[i])
return x_tau_i