DDIM#
Reverse process and Sampling for DDIM |
|
LightningModule for sampling with DDIM with |
Sampler#
- class dmme.ddim.DDIM(timesteps, tau_schedule='quadratic')[source]#
Reverse process and Sampling for DDIM
- Parameters:
timesteps (int) – total timesteps \(T\)
tau_schedule (str) – tau schedule, “linear”`or `”quadratic”
- reverse_process(model, x_t, t)[source]#
Reverse Denoising Process
Samples \(x_{t-1}\) from \(p_\theta(\bold{x}_{t-1}|\bold{x}_t) = \mathcal{N}(\bold{x}_{t-1};\mu_\theta(\bold{x}_t, t), \sigma_t\bold{I})\)
\[\begin{aligned} \bold\mu_\theta(\bold{x}_t, t) &= \frac{1}{\sqrt{\alpha_t}}\bigg(\bold{x}_t -\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(\bold{x}_t,t)\bigg) \\ \sigma_t &= \beta_t \end{aligned} \]Computes \(\bold{x}_{t-1} = \frac{1}{\sqrt{\alpha_t}}\bigg(\bold{x}_t -\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(\bold{x}_t,t)\bigg) +\sigma_t\epsilon\)
- Parameters:
model (nn.Module) – model for estimating noise
x_t (torch.Tensor) – x_t
t (int) – current timestep
noise (torch.Tensor) – noise
- sample(model, x_t, t)[source]#
Sample from \(p_\theta(x_{t-1}|x_t)\)
- Parameters:
model (nn.Module) – model for estimating noise
x_t (torch.Tensor) – image of shape \((N, C, H, W)\)
t (int) – starting \(t\) to sample from
- Returns:
generated sample of shape \((N, C, H, W)\)
- Return type:
(torch.Tensor)
Training#
- class dmme.ddim.LitDDIM(model: Module, lr: float = 0.0002, warmup: int = 5000, imgsize: Tuple[int, int, int] = (3, 32, 32), timesteps: int = 1000, decay: float = 0.9999, sample_steps: int = 50, tau_schedule: str = 'quadratic')[source]#
LightningModule for sampling with DDIM with
LitDDPM’s checkpoints- Parameters:
model (nn.Module) – neural network predicting noise \(\epsilon_\theta\)
lr (float) – learning rate, defaults to \(2e-4\)
warmup (int) – linearly increases learning rate for warmup steps until lr is reached, defaults to 5000
imgsize (Tuple[int, int, int]) – image size in (C, H, W)
timestpes (int) – total timesteps for the forward and reverse process, \(T\)
decay (float) – EMA decay value
sample_steps (int) – sample steps for generation process
tau_schedule (str) – tau schedule to use for generation, “linear” or “quadratic”