Source code for dmme.lit_modules.ddim

from typing import Optional

from torch import nn

from dmme.diffusion_models import DDIM
from dmme.models.ddpm import UNet

from .ddpm import LitDDPM


[docs]class LitDDIM(LitDDPM): r"""LightningModule for accelerated sampling with DDIM using :code:`LitDDPM`'s checkpoints Args: lr: learning rate, defaults to :code:`2e-4` warmup: linearly increases learning rate for `warmup` steps until `lr` is reached, defaults to 5000 decay: EMA decay value diffusion_model: overrides diffusion_model :code:`DDIM` model: overrides model passed to :code:`DDIM` timesteps: default timesteps passed to :code:`DDIM` sample_steps: default sample steps passed to :code:`DDIM` tau_schedule: default tau schedule passed to :code:`DDIM` """ def __init__( self, lr: float = 2e-4, warmup: int = 5000, decay: float = 0.9999, diffusion_model: Optional[DDIM] = None, model: Optional[nn.Module] = None, timesteps: int = 1000, sample_steps: int = 50, tau_schedule: str = "quadratic", ): if diffusion_model is None: if model is None: model = UNet() diffusion_model = DDIM(model, timesteps, sample_steps, tau_schedule) super().__init__(lr, warmup, decay, diffusion_model)