DDPM#
In physics and chemistry, the microscopic reversibility states that
“the microscopic detailed dynamics of particles and fields is time-reversible because the microscopic equations of motion are symmetric with respect to inversion in time”
This mean that if a data distribution is diffused to noise, the reverse process exists in a microscopic level.
This is because the equations that describe the dynamics are “symmetric with respect to inversion in time”.
Assuming this reverse process exists, the Denoising Diffusion Probabilistic Model generates data by gradually denoising data starting from Gaussian noise.
Since this principle holds for “microscopic detailed dynamics”, we design a Forward Diffusion process that gradually diffuses data to Gaussian noise.
In each step, we sample from a Gaussian distribution that perturbs the data. Formally, we define it as a Markov chain of Gaussians:
Note that we can sample \(\bx_t\) for an arbitrary timestep $t# in closed form:
If \(\beta_t\) is small enough, the reverse process should also exist. And since the process is symmetric it should also be a Markov chain of Gaussians starting from \(p(\bx_T)=\mathcal{N}(\bx_T; \bzero, \bI)\):
In order to generate data, we sample from the Standard Normal distribution then iteratively sample \(p_\theta(x_{t-1}|x_t)\)
For training, we optimize the variance lower bound objective from variational autoencoders.
We can reparameterize the variance lower bound into
Rewriting loss as \(L = L_T + \sum_{t\lt1}L_{t-1} + L_0\)
We parameterize the neural network to closely match the forward process in \(L_{t-1}\)
Recall that \(p_\theta(\bx_{t-1}|\bx_t) = \mathcal{N}(\bx_{t-1}; \bmu_\theta(\bx_t, t), \bSigma_\theta(\bx_t, t))\) for \({1 \lt t \leq T}\).
With \(p_\theta(\bx_{t-1} | \bx_t) = \mathcal{N}(\bx_{t-1}; \bmu_\theta(\bx_t, t), \sigma_t^2\bI)\), we can write:
Experimentally, both \(\sigma_t^2 = \beta_t\) and \(\sigma_t^2 = \tilde\beta_t = \frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t\) had similar results.
Input image data is assumed to be integers in \({0, 1, \, ... \, ,255}\) scaled linearly to \([-1, 1]\). The last step of the reverse process is set to an independent discrete decoder. At the final step of sampling, noise is not used.
Then we can simplify the loss to
For small \(t\), \(\lambda_t\) is too large, In the paper setting \(\lambda_t = 1\) improves sample quality
Forward, Reverse, Sampling for DDPM |
|
constants increasing linearly from \(10^{-4}\) to \(0.02\) |
|
UNet with GroupNorm and Attention, Predicts noise from \(x_t\) and \(t\) |
|
LightningModule for training DDPM |
Sampler#
- class dmme.ddpm.DDPM(timesteps)[source]#
Forward, Reverse, Sampling for DDPM
- Parameters:
timesteps (int) – total timesteps \(T\)
- forward_process(x_0: Tensor, t: Tensor, noise: Tensor)[source]#
Forward Diffusion Process
Samples \(x_t\) from \(q(x_t|x_0) = \mathcal{N}(x_t;\sqrt{\bar\alpha_t}\bold{x}_0,(1-\bar\alpha_t)\bold{I})\)
Computes \(\bold{x}_t = \sqrt{\bar\alpha_t}\bold{x}_0 + \sqrt{1-\bar\alpha_t}\bold{I}\)
- Parameters:
x_0 (torch.Tensor) – data to add noise to
t (int) – \(t\) in \(x_t\)
noise (torch.Tensor, optional) – \(\epsilon\), noise used in the forward process
- Returns:
\(\bold{x}_t \sim q(\bold{x}_t|\bold{x}_0)\)
- Return type:
(torch.Tensor)
- reverse_process(model, x_t, t, noise)[source]#
Reverse Denoising Process
Samples \(x_{t-1}\) from \(p_\theta(\bold{x}_{t-1}|\bold{x}_t) = \mathcal{N}(\bold{x}_{t-1};\mu_\theta(\bold{x}_t, t), \sigma_t\bold{I})\)
\[\begin{aligned} \bold\mu_\theta(\bold{x}_t, t) &= \frac{1}{\sqrt{\alpha_t}}\bigg(\bold{x}_t -\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(\bold{x}_t,t)\bigg) \\ \sigma_t &= \beta_t \end{aligned} \]Computes \(\bold{x}_{t-1} = \frac{1}{\sqrt{\alpha_t}}\bigg(\bold{x}_t -\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(\bold{x}_t,t)\bigg) +\sigma_t\epsilon\)
- Parameters:
model (nn.Module) – model for estimating noise
x_t (torch.Tensor) – x_t
t (int) – current timestep
noise (torch.Tensor) – noise
- sample(model, x_t, t, noise)[source]#
Sample from \(p_\theta(x_{t-1}|x_t)\)
- Parameters:
model (nn.Module) – model for estimating noise
x_t (torch.Tensor) – image of shape \((N, C, H, W)\)
t (int) – starting \(t\) to sample from
noise (torch.Tensor) – noise to use for sampling, if None samples new noise
- Returns:
generated sample of shape \((N, C, H, W)\)
- Return type:
(torch.Tensor)
Model#
- class dmme.ddpm.UNet(in_channels=3, pos_dim=128, emb_dim=512, num_blocks=2, channels=(128, 256, 256, 256), attn_depth=(2,), groups=32, drop_rate=0.1)[source]#
UNet with GroupNorm and Attention, Predicts noise from \(x_t\) and \(t\)
- Parameters:
in_channels (int) – input image channels
pos_dim (int) – sinusoidal position encoding dim
emb_dim (int) – time embedding mlp dim
num_blocks (int) – number of resblocks to use
channels (Tuple[int...]) – list of channel dimensions
attn_depth (Tuple[int...]) – depth where attention is applied
groups (int) – number of groups in nn.GroupNorm
drop_rate (float) – drop_rate in ResBlock
- forward(x, t)[source]#
Using timestep embeddings, predict noise to denoise \(x_t\) from \(x_t\) and \(t\) using a UNet
- Parameters:
x (torch.Tensor) – \(x_t\), tensor of shape \((N, C, H, W)\)
t (torch.Tensor) – \(t\), tensor of shape \((N,)\)
- Returns:
\(\epsilon_\theta(x_t,t)\) predicted noise from image, a tensor of shape \((N, C, H, W)\)
- Return type:
(torch.Tensor)
Transformer Sinusoidal Position Encoding |
|
|
Self Attention layer |
|
Pre Normalization with residual connections |
|
BasicWideResBlock for UNet GroupNorm and optional self-attention |
|
Build 3x3 convolution with normalization and dropout in norm act drop conv order |
- class dmme.ddpm.SinusoidalPositionEmbeddings(dim)[source]#
Transformer Sinusoidal Position Encoding
- Parameters:
dim (int) – embedding dimension
- forward(t)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class dmme.ddpm.Attention(dim)[source]#
Self Attention layer
- Parameters:
dim (int) – \(d_\text{model}\)
- class dmme.ddpm.PreNorm(norm_layer, attention_layer)[source]#
Pre Normalization with residual connections
- Parameters:
norm_layer (nn.Module) – normalization layer
attention_layer (nn.Module) – attention layer
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class dmme.ddpm.ResBlock(in_channels, out_channels, emb_dim, groups, drop_rate, attention=False)[source]#
BasicWideResBlock for UNet GroupNorm and optional self-attention
- Parameters:
in_channels (int) – number of input channels
out_channels (int) – number of output channels
emb_dim (int) – timestep embedding dim
groups (int) – num groups in nn.GroupNorm
drop_rate (float) – dropout applied in each conv
attention (bool) – flag for adding self-attention layer
- forward(x, t)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- dmme.ddpm.conv3x3(in_channels, out_channels, groups, drop_rate)[source]#
Build 3x3 convolution with normalization and dropout in norm act drop conv order
- Parameters:
in_channels (int) – passed to nn.Conv2d
out_channels (int) – passed to nn.Conv2d
groups (int) – passed to nn.GroupNorm
drop_rate (float) – passed to nn.Dropout2d
Training#
- class dmme.ddpm.LitDDPM(model: Module, lr: float = 0.0002, warmup: int = 5000, imgsize: Tuple[int, int, int] = (3, 32, 32), timesteps: int = 1000, decay: float = 0.9999)[source]#
LightningModule for training DDPM
- Parameters:
model (nn.Module) – neural network predicting noise \(\epsilon_\theta\)
lr (float) – learning rate, defaults to \(2e-4\)
warmup (int) – linearly increases learning rate for warmup steps until lr is reached, defaults to 5000
imgsize (Tuple[int, int, int]) – image size in (C, H, W)
timestpes (int) – total timesteps for the forward and reverse process, \(T\)
decay (float) – EMA decay value
- forward(x_t: Tensor, t: int, noise: Optional[Tensor] = None)[source]#
Denoise image once using DDPM
- Parameters:
x_t (torch.Tensor) – image of shape \((N, C, H, W)\)
t (int) – starting \(t\) to sample from
noise (torch.Tensor) – noise to use for sampling, if None samples new noise
- Returns:
generated sample of shape \((N, C, H, W)\)
- Return type:
(torch.Tensor)