from typing import OrderedDict
import copy
import math
import torch
from torch import nn
import torch.nn.functional as F
import einops
__all__ = [
"TimeStepEmbedding",
"SinusoidalPositionEmbeddings",
"Block",
"DownSample",
"UpSample",
"Attention",
"ResBlock",
"conv2d",
]
[docs]class UNet(nn.Module):
"""UNet with GroupNorm and Attention, Predicts noise from :math:`x_t` and :math:`t`
Args:
in_channels (int): input image channels
dim (int): initial dim
pos_dim (int): sinusoidal position encoding dim
emb_dim (int): time embedding mlp dim
multipliers (Tuple[int...]): list of channel multipliers
attn_depth (Tuple[int...]): depth where attention is applied
groups (int): number of groups in `nn.GroupNorm`
dropout (float): dropout in `ResBlock`
"""
def __init__(
self,
in_channels=3,
dim=128,
pos_dim=128,
emb_dim=512,
multipliers=(1, 2, 2, 2),
attn_depth=(2,),
groups=32,
dropout=0.1,
):
super().__init__()
self.depth = len(multipliers)
channels = [dim]
for mult in multipliers:
channels.append(dim * mult)
output_dims = channels[1:]
input_dims = channels[:-1]
middle_dim = output_dims[-1]
self.time_emb_mlp = TimeStepEmbedding(pos_dim=pos_dim, emb_dim=emb_dim)
self.first_conv = conv2d(
in_channels, dim, 3, 1, nn.GroupNorm(groups, dim), nn.SiLU()
)
self.final_conv = nn.Conv2d(dim, in_channels, 3, 1, 1)
contract_layers = []
expand_layers = []
for i, (c_in, c_out) in enumerate(zip(input_dims, output_dims)):
attention = i + 1 == attn_depth
contract_layers.append(
Block(
c_in,
c_out,
emb_dim,
groups,
dropout,
num_blocks=3,
add_attention=attention,
),
)
expand_layers.append(
Block(
2 * c_out,
c_in,
emb_dim,
groups,
dropout,
num_blocks=3,
add_attention=attention,
),
)
expand_layers.reverse()
self.contracting_path = nn.ModuleList(contract_layers)
self.expansive_path = nn.ModuleList(expand_layers)
self.downsamples = nn.ModuleList([DownSample(c) for c in output_dims])
self.upsamples = nn.ModuleList([UpSample(c, 2) for c in output_dims[::-1]])
self.middle = Block(
middle_dim,
middle_dim,
emb_dim,
groups,
dropout,
num_blocks=2,
add_attention=True,
)
[docs] def forward(self, x, t):
r"""Using timestep embeddings, predict noise to denoise :math:`x_t` from :math:`x_t` and :math:`t` using a UNet
Args:
x (torch.Tensor): :math:`x_t`, tensor of shape :math:`(N, C, H, W)`
t (int): :math:`t`
Returns:
(torch.Tensor): :math:`\epsilon_\theta(x_t,t)` predicted noise from image, a tensor of shape :math:`(N, C, H, W)`
"""
t = self.time_emb_mlp(t)
x = self.first_conv(x)
x_copies = []
for i in range(self.depth):
x = self.contracting_path[i](x, t)
x_copies.append(x)
x = self.downsamples[i](x)
x = self.middle(x, t)
for i in range(self.depth):
x = self.upsamples[i](x)
copied_x = x_copies.pop()
x = torch.cat([x, copied_x], dim=1)
x = self.expansive_path[i](x, t)
x = self.final_conv(x)
return x
[docs]class TimeStepEmbedding(nn.Module):
"""Timestep embedding network
Args:
pos_dim (int): sinusoidal position encoding dim
emb_dim (int): time embedding mlp dim
"""
def __init__(self, pos_dim=64, emb_dim=256):
super().__init__()
self.position_embedding = SinusoidalPositionEmbeddings(pos_dim)
self.mlp = nn.Sequential(
OrderedDict(
[
("linear0", nn.Linear(pos_dim, emb_dim)),
("act0", nn.SiLU()),
("linear1", nn.Linear(emb_dim, emb_dim)),
]
)
)
[docs] def forward(self, t):
"""Encode :math:`t` into Sinusoidal Position Embedding then use mlps to create timestep embeddings
Args:
t (torch.Tensor): timestep as a tensor of shape :math:`(N, 1)`
Returns:
(torch.Tensor): embedding of shape :math:`(N, 1)`
"""
h = self.position_embedding(t)
h = self.mlp(h)
return h
[docs]class SinusoidalPositionEmbeddings(nn.Module):
"""Transformer position embedding
Args:
dim (int): dim
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
[docs] def forward(self, time):
"""Encode time :math:`t` as a Sinusoidal Position Embedding
Args:
time (torch.Tensor): :math:`t`, a tensor of shape :math:`(N, 1)`
Returns:
(torch.Tensor): position embedding of shape :math:`(N, T)`
"""
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
[docs]class Block(nn.Module):
"""Convolutional Block with multiple resblocks
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
emb_dim (int): time embedding dim
groups (int): num groups in `nn.GroupNorm`
dropout (float): dropout used in `ResBlock`
num_blocks (int): number of resblokcs used
add_attention (bool): whether to add attention to the final layer
"""
def __init__(
self,
in_channels,
out_channels,
emb_dim,
groups,
dropout,
num_blocks=2,
add_attention=False,
):
super().__init__()
self.num_blocks = num_blocks
self.use_attention = add_attention
resblocks = []
attentions = []
for i in range(num_blocks):
if i == 0:
c_in = in_channels
c_out = out_channels
else:
c_in = c_out = out_channels
resblocks.append(ResBlock(c_in, c_out, emb_dim, groups, dropout))
if add_attention:
attentions.append(Attention(c_out, groups))
self.resblocks = nn.ModuleList(resblocks)
self.attentions = nn.ModuleList(attentions)
[docs] def forward(self, x, t):
"""Apply multiple `ResBlocks` with optional `Attention` at the end
Args:
x (torch.Tensor): :math:`x_t`, tensor of shape :math:`(N, C, H, W)`
t (int): :math:`t`
Returns:
x (torch.Tensor): tensor of shape :math:`(N, C, H, W)` where :math:`C` is `out_channels`
"""
for i in range(self.num_blocks - 1):
x = self.resblocks[i](x, t)
if self.use_attention:
x = self.attentions[i](x)
x = self.resblocks[-1](x, t)
return x
[docs]class UpSample(nn.Module):
"""Upsampling layer
Args:
dim (int): number of input and output channels
scale_factor (float): upsample scale
"""
def __init__(self, dim, scale_factor):
super().__init__()
self.upsample = nn.Upsample(scale_factor=scale_factor)
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
[docs] def forward(self, x):
"""Upsample by an arbitrary factor by upsampling with interpolation
followed by 3x3 convolutions with same input and output channels
Returns:
x
"""
h = self.upsample(x)
h = self.conv(h)
return h
[docs]class Attention(nn.Module):
r"""Multi Head Self Attention layer
Args:
dim (int): :math:`d_\text{model}`
groups (int): num groups in `nn.GroupNorm`
"""
def __init__(self, dim, groups=8):
super().__init__()
self.norm = nn.GroupNorm(groups, dim)
self.scale = dim**-0.5
self.to_qkv = nn.Conv2d(dim, dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(dim, dim, 1)
[docs] def forward(self, x):
"""Multi Head Self Attention on images with prenorm and residual connections
Returns:
x
"""
h, w = x.size()[2:]
x = self.norm(x)
qkv = self.to_qkv(x)
qkv = einops.rearrange(qkv, "b c h w -> b c (h w)")
query, key, value = qkv.chunk(3, dim=1)
score = einops.einsum(query * self.scale, key, "b c qhw, b c khw -> b qhw khw")
attention = F.softmax(score, dim=-1)
out = einops.einsum(attention, value, "b qhw khw, b c khw -> b c qhw")
out = einops.rearrange(out, "b c (h w) -> b c h w", h=h, w=w)
return self.to_out(out) + x
[docs]class ResBlock(nn.Module):
"""ResBlock for UNet
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
emb_dim (int): timestep embedding dim
groups (int): num groups in `nn.GroupNorm`
dropout (float): dropout applied in each conv
"""
def __init__(
self,
in_channels,
out_channels,
emb_dim,
groups=8,
dropout=0.0,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.groups = groups
self._first = conv2d(
in_channels,
out_channels,
kernel_size=3,
padding=1,
)
self._first_norm_then_act = nn.Sequential(
OrderedDict([("norm", self.norm), ("act", self.act)])
)
self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(emb_dim, out_channels))
self._dropout = nn.Dropout2d(dropout)
self._second = conv2d(
out_channels,
out_channels,
kernel_size=3,
padding=1,
)
self._second_norm_then_act = nn.Sequential(
OrderedDict([("norm", self.norm), ("act", self.act)])
)
if in_channels != out_channels:
self._res_conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
else:
self._res_conv = None
self._act = self.act
[docs] def forward(self, x, t):
"""ResBlock with time embeddings
ResBlock with two convolution layers with residual connections.
Adds time embedding to the first layer's output using an mlp to match dimensions.
Then normalization, activation , dropout is applied in that order.
The second convolutional layer is identical to the basic resblock.
Args:
x (torch.Tensor): :math:`x_t`, tensor of shape :math:`(N, C, H, W)`
t (int): :math:`t`
Returns:
x (torch.Tensor): tensor of shape :math:`(N, C, H, W)` where :math:`C` is `out_channels`
"""
h = self._first(x)
h = h + einops.rearrange(self.mlp(t), "b c -> b c 1 1")
h = self._first_norm_then_act(h)
h = self._dropout(h)
h = self._second(h)
if self._res_conv is not None:
x = self._res_conv(x)
return self._second_norm_then_act(h + x)
@property
def norm(self):
"""Returns copies of normalizaiton layers"""
return nn.GroupNorm(self.groups, self.out_channels)
@property
def act(self):
"""Returns copies of activation layers"""
return nn.SiLU()
[docs]class DownSample(nn.Module):
"""Downsampling layer
Args:
dim (int): number of input and output channels
"""
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv2d(dim, dim, 3, 2, 1)
[docs] def forward(self, x):
"""Downsample by a factor of 2 using convolutions
Returns:
x
"""
return self.conv(x)
[docs]def conv2d(
in_channels,
out_channels,
kernel_size,
padding,
norm=None,
act=None,
):
"""convolution layer builder with normalization and activation
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
kernel_size (int): kernel size
padding (int): padding
norm (nn.Module): normalization layer instance
act (nn.Module): activation function instance
"""
layers = OrderedDict()
layers["conv"] = nn.Conv2d(
in_channels, out_channels, kernel_size=kernel_size, padding=padding
)
if norm is not None:
layers["norm"] = copy.deepcopy(norm)
if act is not None:
layers["act"] = copy.deepcopy(act)
return nn.Sequential(layers)