LitDDPM#

class dmme.lit_modules.LitDDPM(lr: float = 0.0002, warmup: int = 5000, decay: float = 0.9999, diffusion_model: Optional[DDPM] = None, model: Optional[Module] = None, timesteps: int = 1000)[source]#

LightningModule for training DDPM

Parameters:
  • lr – learning rate, defaults to 2e-4

  • warmup – linearly increases learning rate for warmup steps until lr is reached, defaults to 5000

  • decay – EMA decay value

  • diffusion_model – overrides default diffusion_model DDPM

  • model – overrides default model passed to DDPM

  • timesteps – default timesteps passed to DDPM

forward(x_t: Tensor, t: int)[source]#

Denoise image once using DDPM

Parameters:
  • x_t – image of shape \((N, C, H, W)\)

  • t (int) – starting \(t\) to sample from

  • noise – noise to use for sampling, if None samples new noise

Returns:

generated sample of shape \((N, C, H, W)\)

training_step(batch, batch_idx)[source]#

Train model using \(L_\text{simple}\)

test_step(batch, batch_idx)[source]#

Generate samples for evaluation

generate(img_size)[source]#

Generate sample using internal diffusion_model

Parameters:

img_size – image size to generate as a tuple \((N, C, H, W)\)

Returns:

generated image of shape \((N, C, H, W)\) as a tensor

test_epoch_end(outputs)[source]#

Compute metrics and log at the end of the epoch

configure_optimizers()[source]#

Configure optimizers for training Uses Adam and warmup lr

configure_callbacks()[source]#

Configure EMA callback, will override any other EMA callback