Improved U-Net from Improved Denoising Diffusion Models#

UNet

U-Net for predicting noise in images and learning variance

ResBlock

3x3 basic resblocks with group norm, dropout and timestep embeddings

MultiHeadAttention

Self Attention with groupnorm

class dmme.models.iddpm.UNet(in_channels=3, pos_dim=128, emb_dim=512, num_groups=32, dropout=0.3, channels_per_depth=(128, 256, 256, 256), num_blocks=2, attention_depths=(2, 3))[source]#

U-Net for predicting noise in images and learning variance

Parameters:
  • in_channels (int) – input channels of image

  • pos_dim (int) – dimension of position embedding

  • emb_dim (int) – dimension of timestep embedding

  • num_groups (int) – number of groups in nn.GroupNorm

  • dropout (float) – dropout rate in nn.Dropout2d

  • channels_per_depth (Tuple[int, ...]) – channels per depth

  • num_blocks (int) – number of resblocks to use in each depth

  • attention_depths (Tuple[int, ...]) – depths to use attention blocks

forward(x, c)[source]#

Predicts noise from x

Parameters:
  • x (torch.Tensor) – image of shape \((N, C, H, W)\)

  • c (torch.Tensor) – timestep of shape \((N,)\)

Returns:

estimated noise in input image x

Return type:

(torch.Tensor)

class dmme.models.iddpm.ResBlock(c_in, c_out, with_attention=False, num_heads=4, emb_dim=512, num_groups=32, p=0.1)[source]#

3x3 basic resblocks with group norm, dropout and timestep embeddings

Parameters:
  • c_in (int) – number of input channels

  • c_out (int) – number of output channels

  • with_attention (bool) – whether to add attention block

  • emb_dim (int) – input timestep embedding dimension

  • num_groups (int) – number of groups in nn.GroupNorm

  • p (float) – dropout rate in nn.Dropout2d

forward(x, c)[source]#
Parameters:
  • x (torch.Tensor) – image of shape \((N, C_\text{in}, H, W)\)

  • c (torch.Tensor) – timestep embedding of shape \((N, d_\text{emb})\)

Returns:

feature map of shape \((N, C_\text{out}, H, W)\)

Return type:

(torch.Tensor)

class dmme.models.iddpm.MultiHeadAttention(dim, num_groups, num_heads)[source]#

Self Attention with groupnorm

Parameters:
  • dim (int) – equivalent to \(d_\text{model}\)

  • num_groups (int) – number of groups in nn.GroupNorm

forward(x)[source]#
Parameters:

x (torch.Tensor) – image of shape \((N, C_\text{in}, H, W)\)

Returns:

feature maps of shape \((N, C_\text{in}, H, W)\)

Return type:

(torch.Tensor)