classEMA: r""" Usage: model = ResNet(config) ema = EMA(model, alpha=0.999) ... # train an epoch ema.update_params(model) ema.apply_shadow(model) """ def__init__(self, _net, _alpha=0.999): self.shadow = {k: v.clone().detach() for k, v in _net.state_dict().items()} self.param_keys = [k for k, _ in _net.named_parameters()] self.alpha = _alpha
definit_params(self, _model): self.shadow = {k: v.clone().detach() for k, v in _model.state_dict().items()} self.param_keys = [k for k, _ in _model.named_parameters()]
defupdate_params(self, _model): state = _model.state_dict() for name in self.param_keys: self.shadow[name].copy_(self.alpha * self.shadow[name] + (1 - self.alpha) * state[name])