Source code for ccsd.src.utils.ema

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""ema.py: code for the exponential moving average class for the parameters.

Adapted from Jo, J. & al (2022), almost left untouched.
"""

from typing import Any, Dict

import torch


[docs] class ExponentialMovingAverage: """ Maintains (exponential) moving average of a set of parameters. """
[docs] def __init__( self, parameters: torch.nn.parameter.Parameter, decay: float, use_num_updates: bool = True, ) -> None: """Initialize the EMA class. Args: parameters (torch.nn.parameter.Parameter): Iterable of `torch.nn.Parameter`, initial parameters to use for EMA. decay (float): Decay rate for exponential moving average. use_num_updates (bool, optional): if True, initialize the number of updates to 0. Defaults to True. Raises: ValueError: raise an error if decay is not between 0 and 1. """ if (decay < 0.0) or (decay > 1.0): raise ValueError("Decay must be between 0 and 1") self.decay = decay self.num_updates = 0 if use_num_updates else None self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] self.collected_params = []
def __repr__(self) -> str: """Return the string representation of the EMA class. Returns: str: the string representation of the EMA class """ return f"ExponentialMovingAverage(decay={self.decay}, num_updates={self.num_updates}, shadow_params={self.shadow_params}, collected_params={self.collected_params})"
[docs] def update(self, parameters: torch.nn.parameter.Parameter) -> None: """Update currently maintained parameters. Call this every time the parameters are updated, such as the result of the `optimizer.step()` call. Args: parameters (torch.nn.parameter.Parameter): Iterable of `torch.nn.Parameter`; usually the same set of parameters used to initialize this object. """ decay = self.decay if self.num_updates is not None: self.num_updates += 1 decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) one_minus_decay = 1.0 - decay with torch.no_grad(): parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): s_param.sub_(one_minus_decay * (s_param - param))
[docs] def copy_to(self, parameters: torch.nn.parameter.Parameter) -> None: """Copy current parameters into given collection of parameters. Args: parameters (torch.nn.parameter.Parameter): Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. """ parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: param.data.copy_(s_param.data)
[docs] def store(self, parameters: torch.nn.parameter.Parameter) -> None: """Save the current parameters for restoring later. Args: parameters (torch.nn.parameter.Parameter): Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ self.collected_params = [param.clone() for param in parameters]
[docs] def restore(self, parameters: torch.nn.parameter.Parameter) -> None: """Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. Args: parameters (torch.nn.parameter.Parameter): Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. """ for c_param, param in zip(self.collected_params, parameters): param.data.copy_(c_param.data)
[docs] def state_dict(self) -> Dict[str, Any]: """Returns a dictionary containing the state of the EMA. Returns: Dict[str, Any]: dictionary containing the state of the EMA. """ return dict( decay=self.decay, num_updates=self.num_updates, shadow_params=self.shadow_params, )
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """Load the dictionary containing the state of the EMA. Args: state_dict (Dict[str, Any]): _description_ """ self.decay = state_dict["decay"] self.num_updates = state_dict["num_updates"] self.shadow_params = state_dict["shadow_params"]