Source code for dmme.lit_modules.iddpm

from typing import Optional

from torch import nn

from dmme.diffusion_models import IDDPM
from dmme.models.iddpm import UNet

from .ddpm import LitDDPM


[docs]class LitIDDPM(LitDDPM): r"""Improved Denoising Diffusion Probablistic Models Args: lr: learning rate, defaults to :code:`2e-4` warmup: linearly increases learning rate for `warmup` steps until `lr` is reached, defaults to 5000 decay: EMA decay value diffusion_model: overrides default diffusion_model :code:`DDPM` model: overrides default model passed to :code:`DDPM` timesteps: default timesteps passed to :code:`DDPM` loss_type: loss type to use either "hybrid" or "simple" gamma: :math:`\gamma` in hybrid loss shcedule: variance schedule to use either "linear" or "cosine" offset: default offset for :code:`IDDPM` if cosine schedule is used start: default start for :code:`IDDPM` if linear schedule is used end: default end for :code:`IDDPM` if linear schedule is used """ def __init__( self, lr: float = 0.0002, warmup: int = 5000, decay: float = 0.9999, diffusion_model: Optional[IDDPM] = None, model: Optional[nn.Module] = None, timesteps: int = 1000, loss_type: str = "hybrid", gamma: float = 0.001, schedule: str = "cosine", offset: float = 0.008, start: float = 0.0001, end: float = 0.02, ): if diffusion_model is None: if model is None: model = UNet() diffusion_model = IDDPM( model, timesteps, loss_type, gamma, schedule, offset, start, end ) super().__init__(lr, warmup, decay, diffusion_model)