EMA#
- class dmme.callbacks.EMA(decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False)[source]#
Implements Exponential Moving Averaging (EMA). When training a model, this callback will maintain moving averages of the trained parameters. Code from https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py
When evaluating, we use the moving averages copy of the trained parameters. When saving, we save an additional set of parameters with the prefix ema. :param decay: The exponential decay used when calculating the moving average. Has to be between 0-1. :param validate_original_weights: Validate the original weights, as apposed to the EMA weights. :param every_n_steps: Apply EMA every N steps. :param cpu_offload: Offload weights to CPU.
- on_validation_start(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the validation loop begins.
- on_validation_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the validation loop ends.
- on_test_start(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the test begins.
- save_ema_model(trainer: Trainer)[source]#
Saves an EMA copy of the model + EMA optimizer states for resume.
- on_load_checkpoint(trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) None[source]#
Called when loading a model checkpoint, use to reload state.
- Parameters:
trainer – the current
Trainerinstance.pl_module – the current
LightningModuleinstance.checkpoint – the full checkpoint dictionary that got loaded by the Trainer.