from typing import Tuple
import torch
from torch import nn
from torch import Tensor
import pytorch_lightning as pl
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torch.optim import Adam
from dmme.lr_scheduler import WarmupLR
from dmme.callbacks import EMA
from .ddpm import DDPM
from .unet import UNet
[docs]class LitDDPM(pl.LightningModule):
r"""LightningModule for training DDPM
Args:
model (nn.Module): neural network predicting noise :math:`\epsilon_\theta`
lr (float): learning rate, defaults to :math:`2e-4`
warmup (int): linearly increases learning rate for
`warmup` steps until `lr` is reached, defaults to 5000
imgsize (Tuple[int, int, int]): image size in `(C, H, W)`
timestpes (int): total timesteps for the
forward and reverse process, :math:`T`
decay (float): EMA decay value
"""
def __init__(
self,
model: nn.Module,
lr: float = 2e-4,
warmup: int = 5000,
imgsize: Tuple[int, int, int] = (3, 32, 32),
timesteps: int = 1000,
decay: float = 0.9999,
):
super().__init__()
self.save_hyperparameters(ignore=["model"])
if model is None:
model = UNet(in_channels=3)
self.diffusion = DDPM(model, timesteps=timesteps)
self.fid = FrechetInceptionDistance(
normalize=True,
reset_real_features=False,
)
self.inception = InceptionScore(normalize=True)
[docs] def forward(self, x_t: Tensor, t: int):
r"""Denoise image once using `DDPM`
Args:
x_t (torch.Tensor): image of shape :math:`(N, C, H, W)`
t (int): starting :math:`t` to sample from
noise (torch.Tensor): noise to use for sampling, if `None` samples new noise
Returns:
(torch.Tensor): generated sample of shape :math:`(N, C, H, W)`
"""
timestep = torch.tensor([t], device=x_t.device)
x_t = self.diffusion.sampling_step(x_t, timestep)
return x_t
[docs] def training_step(self, batch, batch_idx):
r"""Train model using :math:`L_\text{simple}`"""
x_0: Tensor = batch[0]
loss: Tensor = self.diffusion.training_step(x_0)
self.log("train/loss", loss)
return loss
[docs] def test_step(self, batch, batch_idx):
"""Generate samples for evaluation"""
x: Tensor = batch[0]
self.fid.update(dmme.denorm(x), real=True)
x_t = self.generate(x.size())
fake_x: Tensor = dmme.denorm(x_t)
self.fid.update(fake_x, real=False)
self.inception.update(fake_x)
[docs] def generate(self, img_size):
r"""Iteratively sample from :math:`p_\theta(x_{t-1}|x_t)` to generate images
Args:
x_t (torch.Tensor): :math:`x_T` to start from
"""
x_t = self.diffusion.generate(img_size=img_size)
return x_t
[docs] def test_epoch_end(self, outputs):
"""Compute metrics and log at the end of the epoch"""
fid_score: Tensor = self.fid.compute()
kl_mean, kl_std = self.inception.compute()
inception_score = torch.exp(kl_mean)
self.log("fid", fid_score)
self.log("inception_score", inception_score)