from typing import Tuple, Optional
import torch
from torch import nn
from torch.optim import Adam
import pytorch_lightning as pl
from dmme.ddpm.ddpm_sampler import DDPMSampler
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from dmme.common import denorm, set_default
from dmme.lr_scheduler import WarmupLR
from dmme.callbacks import EMA
from .ddpm_sampler import DDPMSampler
from .unet import UNet
[docs]class LitDDPM(pl.LightningModule):
"""LightningModule for training DDPM
Args:
sampler (nn.Module): an instance of `DDPMSampler`
lr (float): learning rate, defaults to :math:`2e-4`
warmup (int): linearly increases learning rate for
`warmup` steps until `lr` is reached, defaults to 5000
imgsize (Tuple[int, int, int]): image size in `(C, H, W)`
timestpes (int): total timesteps for the
forward and reverse process, :math:`T`
decay (float): EMA decay value
"""
def __init__(
self,
sampler: Optional[nn.Module] = None,
lr: float = 2e-4,
warmup: int = 5000,
imgsize: Tuple[int, int, int] = (3, 32, 32),
timesteps: int = 1000,
decay: float = 0.9999,
):
super().__init__()
self.save_hyperparameters(ignore="sampler")
self.sampler = set_default(
sampler, DDPMSampler(UNet(in_channels=3), timesteps=timesteps)
)
[docs] def forward(self, x_t, start_t, stop_t=0, step_t=-1, noise=None):
r"""Iteratively sample from :math:`p_\theta(x_{t-1}|x_t)` starting with :math:`x_t` with start, stop step specified from arguments
Args:
x_t (torch.Tensor): image of shape :math:`(N, C, H, W)`
start_t (int): starting :math:`t` to sample from
stop_t (int): stops sampling when reached
steps_t (int): step sizes for sequence
noise (torch.Tensor): noise to use for sampling, if `None` samples new noise
Returns:
(torch.Tensor): generated samples
"""
if start_t is None:
start_t = self.sampler.timesteps
if noise is None:
num_steps = abs(stop_t - start_t) // abs(step_t) + 1
noise = [None] * self.sampler.timesteps
for t in range(start_t, stop_t, step_t):
x_t = self.sampler.sample(x_t, t, noise[t - 1])
return x_t
[docs] def training_step(self, batch, batch_idx):
"""Compute loss using sampler"""
x_0, _ = batch
loss = self.sampler.compute_loss(x_0)
self.log("train/loss", loss)
return loss
[docs] def test_step(self, batch, batch_idx):
"""Generate samples for evaluation"""
x, _ = batch
self.fid.update(denorm(x), real=True)
x_T = gaussian_like(x)
x_0 = self(x_T)
fake_x = denorm(x_0)
self.fid.update(fake_x, real=False)
self.inception.update(fake_x)
[docs] def test_epoch_end(self, outputs):
"""Compute metrics and log at the end of the epoch"""
fid_score = self.fid.compute()
kl_mean, _ = self.inception.compute()
inception_score = torch.exp(kl_mean)
self.log("fid", fid_score)
self.log("inception_score", inception_score)
[docs] def setup(self, stage: str):
"""Prepare metrics for test stage"""
if stage == "test":
self.fid = FrechetInceptionDistance(
normalize=True,
reset_real_features=False,
)
self.inception = InceptionScore(normalize=True)