EMA 原理和实现(torch)

EMA(指数滑动平均),经常被用在深度学习任务中,来提神模型的鲁棒性或是涨点。在半监督的分类任务中,经常使用EMA的方法来给不同参数做self-ensemble。

算法流程

EMA

原理解析

Pytorch 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class EMA:
r"""
Usage:
model = ResNet(config)
ema = EMA(model, alpha=0.999)
... # train an epoch
ema.update_params(model)
ema.apply_shadow(model)
"""
def __init__(self, _net, _alpha=0.999):
self.shadow = {k: v.clone().detach() for k, v in _net.state_dict().items()}
self.param_keys = [k for k, _ in _net.named_parameters()]
self.alpha = _alpha

def init_params(self, _model):
self.shadow = {k: v.clone().detach() for k, v in _model.state_dict().items()}
self.param_keys = [k for k, _ in _model.named_parameters()]

def update_params(self, _model):
state = _model.state_dict()
for name in self.param_keys:
self.shadow[name].copy_(self.alpha * self.shadow[name] + (1 - self.alpha) * state[name])

def apply_shadow(self, _model):
_model.load_state_dict(self.shadow, strict=True)

quote-JiXianlin

今天仍然拼命看书,因为明天就要考了。学期的成绩就全仗这两天挣,现在更感到考试无用与无聊。 ——季羡林 《清华园日记》

Linux复习

期末linux复习笔记,记录了大部分,实在是复习不完了,考试不难,注意实验中的细节和线程、进程的细节。特别注意一下每个函数的参数形式,返回值。
阅读更多