EMA#

class dmme.callbacks.EMA(decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False)[source]#

Implements Exponential Moving Averaging (EMA). When training a model, this callback will maintain moving averages of the trained parameters. Code from https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py

When evaluating, we use the moving averages copy of the trained parameters. When saving, we save an additional set of parameters with the prefix ema. :param decay: The exponential decay used when calculating the moving average. Has to be between 0-1. :param validate_original_weights: Validate the original weights, as apposed to the EMA weights. :param every_n_steps: Apply EMA every N steps. :param cpu_offload: Offload weights to CPU.

on_fit_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when fit begins.

on_validation_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the validation loop begins.

on_validation_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the validation loop ends.

on_test_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the test begins.

on_test_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the test ends.

save_ema_model(trainer: Trainer)[source]#

Saves an EMA copy of the model + EMA optimizer states for resume.

on_load_checkpoint(trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) None[source]#

Called when loading a model checkpoint, use to reload state.

Parameters:
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint – the full checkpoint dictionary that got loaded by the Trainer.