DDIM#

DDIM

Reverse process and Sampling for DDIM

LitDDIM

LightningModule for sampling with DDIM with LitDDPM's checkpoints

Sampler#

class dmme.ddim.DDIM(timesteps, tau_schedule='quadratic')[source]#

Reverse process and Sampling for DDIM

Parameters:
  • timesteps (int) – total timesteps \(T\)

  • tau_schedule (str) – tau schedule, “linear”`or `”quadratic”

reverse_process(model, x_t, t)[source]#

Reverse Denoising Process

Samples \(x_{t-1}\) from \(p_\theta(\bold{x}_{t-1}|\bold{x}_t) = \mathcal{N}(\bold{x}_{t-1};\mu_\theta(\bold{x}_t, t), \sigma_t\bold{I})\)

\[\begin{aligned} \bold\mu_\theta(\bold{x}_t, t) &= \frac{1}{\sqrt{\alpha_t}}\bigg(\bold{x}_t -\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(\bold{x}_t,t)\bigg) \\ \sigma_t &= \beta_t \end{aligned} \]

Computes \(\bold{x}_{t-1} = \frac{1}{\sqrt{\alpha_t}}\bigg(\bold{x}_t -\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(\bold{x}_t,t)\bigg) +\sigma_t\epsilon\)

Parameters:
  • model (nn.Module) – model for estimating noise

  • x_t (torch.Tensor) – x_t

  • t (int) – current timestep

  • noise (torch.Tensor) – noise

sample(model, x_t, t)[source]#

Sample from \(p_\theta(x_{t-1}|x_t)\)

Parameters:
  • model (nn.Module) – model for estimating noise

  • x_t (torch.Tensor) – image of shape \((N, C, H, W)\)

  • t (int) – starting \(t\) to sample from

Returns:

generated sample of shape \((N, C, H, W)\)

Return type:

(torch.Tensor)

Training#

class dmme.ddim.LitDDIM(model: Module, lr: float = 0.0002, warmup: int = 5000, imgsize: Tuple[int, int, int] = (3, 32, 32), timesteps: int = 1000, decay: float = 0.9999, sample_steps: int = 50, tau_schedule: str = 'quadratic')[source]#

LightningModule for sampling with DDIM with LitDDPM’s checkpoints

Parameters:
  • model (nn.Module) – neural network predicting noise \(\epsilon_\theta\)

  • lr (float) – learning rate, defaults to \(2e-4\)

  • warmup (int) – linearly increases learning rate for warmup steps until lr is reached, defaults to 5000

  • imgsize (Tuple[int, int, int]) – image size in (C, H, W)

  • timestpes (int) – total timesteps for the forward and reverse process, \(T\)

  • decay (float) – EMA decay value

  • sample_steps (int) – sample steps for generation process

  • tau_schedule (str) – tau schedule to use for generation, “linear” or “quadratic”

forward(x_t: Tensor, t: int)[source]#

Denoise image once using DDIM

Parameters:
  • x_t (torch.Tensor) – image of shape \((N, C, H, W)\)

  • t (int) – starting \(t\) to sample from

Returns:

generated sample of shape \((N, C, H, W)\)

Return type:

(torch.Tensor)

generate(x_t)[source]#

Iteratively sample from \(p_\theta(x_{t-1}|x_t)\) to generate images

Parameters:

x_t (torch.Tensor) – \(x_T\) to start from