EMA 原理和实现(torch)

EMA(指数滑动平均),经常被用在深度学习任务中,来提神模型的鲁棒性或是涨点。在半监督的分类任务中,经常使用EMA的方法来给不同参数做self-ensemble。

算法流程

EMA

原理解析

Pytorch 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class EMA:
r"""
Usage:
model = ResNet(config)
ema = EMA(model, alpha=0.999)
... # train an epoch
ema.update_params(model)
ema.apply_shadow(model)
"""
def __init__(self, _net, _alpha=0.999):
self.shadow = {k: v.clone().detach() for k, v in _net.state_dict().items()}
self.param_keys = [k for k, _ in _net.named_parameters()]
self.alpha = _alpha

def init_params(self, _model):
self.shadow = {k: v.clone().detach() for k, v in _model.state_dict().items()}
self.param_keys = [k for k, _ in _model.named_parameters()]

def update_params(self, _model):
state = _model.state_dict()
for name in self.param_keys:
self.shadow[name].copy_(self.alpha * self.shadow[name] + (1 - self.alpha) * state[name])

def apply_shadow(self, _model):
_model.load_state_dict(self.shadow, strict=True)
作者

WangCH

发布于

2021-10-07

更新于

2021-10-08

许可协议

评论