SWA(随机权重平均)_得克特的博客-程序员信息网_随机权重平均

技术标签: 随机权重平均  深度学习  swa  

SWA(随机权重平均)

[Averaging Weights Leads to Wider Optima and Better Generalization](Averaging Weights Leads to Wider Optima and Better Generalization)
随机权重平均:在优化的末期取k个优化轨迹上的checkpoints,平均他们的权重,得到最终的网络权重,这样就会使得最终的权重位于flat曲面更中心的位置,缓解权重震荡问题,获得一个更加平滑的解,相比于传统训练有更泛化的解。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mH6n69eh-1648029152717)(C:\Users\haonan7\AppData\Roaming\Typora\typora-user-images\image-20220311202349305.png)]

效果如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EtwRUvyw-1648029152719)(C:\Users\haonan7\AppData\Roaming\Typora\typora-user-images\image-20220311202415255.png)]

SWA和EMA

EMA指数滑动平均(Exponential Moving Average)我们讨论了指数滑动平均,可以发现SWA和EMA是有相似之处:

  • 都是在训练之外的操作,不影响训练过程。
  • 与集成学习类似,都是一种权值的平均,EMA是一种指数平均,会赋予近期更多的权重,SWA则是平均赋权重。

所以这里参考了的SWA实现,添加了EMA的实现,两者不同在于影子权值的更新方式。

class WeightAverage(Optimizer):
    def __init__(self, optimizer, wa_start=None, wa_freq=None, wa_lr=None, mode='swa'):
        """实现参考:https://github.com/pytorch/contrib/blob/master/torchcontrib/optim/swa.py
        论文:Averaging Weights Leads to Wider Optima and Better Generalization
        两种权重平均的方式 swa 和 ema
        两种模式:自动模式和手动模式
        参数:
            optimizer (torch.optim.Optimizer): optimizer to use with SWA
            wa_start (int): SWA开始应用的step
            wa_freq (int): 更新SWA的频数
            wa_lr (float): 自动模式:从swa_start开始应用
        """
        if isinstance(mode, float):
            self.mode = 'ema'
            self.beta = mode
        else:
            self.mode = mode
        self._auto_mode, (self.wa_start, self.wa_freq) = self._check_params(wa_start, wa_freq)
        self.wa_lr = wa_lr
        # 参数检查
        if self._auto_mode:
            if wa_start < 0:
                raise ValueError("Invalid wa_start: {}".format(wa_start))
            if wa_freq < 1:
                raise ValueError("Invalid wa_freq: {}".format(wa_freq))
        else:
            if self.wa_lr is not None:
                warnings.warn("Some of wa_start, wa_freq is None, ignoring wa_lr")
            self.wa_lr = None
            self.wa_start = None
            self.wa_freq = None

        if self.wa_lr is not None and self.wa_lr < 0:
            raise ValueError("Invalid WA learning rate: {}".format(wa_lr))

        self.optimizer = optimizer
        self.defaults = self.optimizer.defaults
        self.param_groups = self.optimizer.param_groups
        self.state = defaultdict(dict)
        self.opt_state = self.optimizer.state

        for group in self.param_groups:
            # ema 不需要保存已经平均的个数,为了兼容swa不修改
            group['n_avg'] = 0
            group['step_counter'] = 0

    @staticmethod
    def _check_params(swa_start, swa_freq):
        """检查参数,确认执行模式,并将参数转为int
        """
        params = [swa_start, swa_freq]
        params_none = [param is None for param in params]
        if not all(params_none) and any(params_none):
            warnings.warn("Some of swa_start, swa_freq is None, ignoring other")
        for i, param in enumerate(params):
            if param is not None and not isinstance(param, int):
                params[i] = int(param)
                warnings.warn("Casting swa_start, swa_freq to int")
        return not any(params_none), params

    def _reset_lr_to_swa(self):
        """应用wa学习率
        """
        if self.wa_lr is None:
            return
        for param_group in self.param_groups:
            if param_group['step_counter'] >= self.wa_start:
                param_group['lr'] = self.wa_lr

    def update_swa_group(self, group):
        """更新一组参数的wa: 随机权重平均或者指数滑动平均
        """
        for p in group['params']:
            param_state = self.state[p]
            if 'wa_buffer' not in param_state:
                param_state['wa_buffer'] = torch.zeros_like(p.data)
            buf = param_state['wa_buffer']
            if self.mode == 'swa':
                virtual_decay = 1 / float(group["n_avg"] + 1)
                diff = (p.data - buf) * virtual_decay  # buf + (p-buf) / (n+1) = (p + n*buf) / (n+1)
                buf.add_(diff)
            else:
                buf.mul_(self.beta).add_((1-self.beta) * p.data)
        group["n_avg"] += 1

    def update_swa(self):
        """手动模式:更新所有参数的swa
        """
        for group in self.param_groups:
            self.update_swa_group(group)

    def swap_swa_sgd(self):
        """1.交换swa和模型的参数 2.训练结束时和评估时调用
        """
        for group in self.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                if 'wa_buffer' not in param_state:
                    warnings.warn("WA wasn't applied to param {}; skipping it".format(p))
                    continue
                buf = param_state['wa_buffer']
                tmp = torch.empty_like(p.data)
                tmp.copy_(p.data)
                p.data.copy_(buf)
                buf.copy_(tmp)

    def step(self, closure=None):
        """1.梯度更新 2.如果是自动模式更新swa参数
        """
        self._reset_lr_to_swa()
        loss = self.optimizer.step(closure)
        for group in self.param_groups:
            group["step_counter"] += 1
            steps = group["step_counter"]
            if self._auto_mode:
                if steps > self.wa_start and steps % self.wa_freq == 0:
                    self.update_swa_group(group)
        return loss

    def state_dict(self):
        """打包 opt_state 优化器状态,swa_state SWA状态,param_groups 参数组
        """
        opt_state_dict = self.optimizer.state_dict()
        wa_state = {
    (id(k) if isinstance(k, torch.Tensor) else k): v
                     for k, v in self.state.items()}
        opt_state = opt_state_dict["state"]
        param_groups = opt_state_dict["param_groups"]
        return {
    "opt_state": opt_state, "wa_state": wa_state,
                "param_groups": param_groups}

    def load_state_dict(self, state_dict):
        """加载swa和优化器的状态参数
        """
        wa_state_dict = {
    "state": state_dict["wa_state"],
                         "param_groups": state_dict["param_groups"]}
        opt_state_dict = {
    "state": state_dict["opt_state"],
                          "param_groups": state_dict["param_groups"]}
        super(WeightAverage, self).load_state_dict(wa_state_dict)
        self.optimizer.load_state_dict(opt_state_dict)
        self.opt_state = self.optimizer.state

    def add_param_group(self, param_group):
        """将一组参数添加到优化器的 `param_groups`.
        """
        param_group['n_avg'] = 0
        param_group['step_counter'] = 0
        self.optimizer.add_param_group(param_group)

    @staticmethod
    def bn_update(loader, model, device=None):
        """更新 BatchNorm running_mean, running_var
        """
        if not _check_bn(model):
            return
        was_training = model.training
        model.train()
        momenta = {
    }
        model.apply(_reset_bn)
        model.apply(lambda module: _get_momenta(module, momenta))
        n = 0
        for input in loader:
            if isinstance(input, (list, tuple)):
                input = input[0]
            b = input.size(0)  # batch_size

            momentum = b / float(n + b)
            for module in momenta.keys():
                module.momentum = momentum

            if device is not None:
                input = input.to(device)

            model(input)
            n += b

        model.apply(lambda module: _set_momenta(module, momenta))
        model.train(was_training)


# BatchNorm utils
def _check_bn_apply(module, flag):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        flag[0] = True


def _check_bn(model):
    flag = [False]
    model.apply(lambda module: _check_bn_apply(module, flag))
    return flag[0]


def _reset_bn(module):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        module.running_mean = torch.zeros_like(module.running_mean)
        module.running_var = torch.ones_like(module.running_var)


def _get_momenta(module, momenta):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        momenta[module] = module.momentum


def _set_momenta(module, momenta):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        module.momentum = momenta[module]

Stochastic Weight Averaging in PyTorch

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_40548136/article/details/123692386

智能推荐

编程珠玑第一章习题答案_weixin_30335353的博客-程序员信息网

习题1.1 如果不缺内存,如何使用一个具有库的语言来实现一种排序算法? 因为C++有sort,JAVA也有,这里以C++为例给出,记住如果用set集合来排序时,是不可以有元素重复的 代码: #include &lt;iostream&gt; #include &lt;cstring&gt; #include &lt;cmath&...

java监视和管理控制台怎么看_如何监视和管理IPMI管理控制台_Shepherd Young的博客-程序员信息网

如何监控和管理IPMI管理控制台本指南概述了如何使用Verax NMS监控和管理IPMI管理控制台。 智能P latform管理接口(IPMI)是系统管理员用于管理计算机系统和监视其操作的标准化计算机系统接口。指南分为以下几个部分:将IPMI管理控制台添加到受监视应用程序列表中。配置IPMI控制台的可用性传感器和性能计数器。系统的IPMI管理控制台概述。设置报警和通知策略。本指南中使用的工具:将I...

bzoj 1568: [JSOI2008]Blue Mary开公司(超哥线段树)_clover_hxy的博客-程序员信息网

1568: [JSOI2008]Blue Mary开公司Time Limit: 15 Sec  Memory Limit: 162 MBSubmit: 739  Solved: 250[Submit][Status][Discuss]DescriptionInput第一行 :一个整数N ,表示方案和询问的总数。 接下来N行,每行开头一个单词“Query”或“

详细分析MySQL事务日志(redo log和undo log)_mysql undolog记录时间_intimexy的博客-程序员信息网

innodb事务日志包括redo log和undo log。redo log是重做日志,提供前滚操作,undo log是回滚日志,提供回滚操作。undo log不是redo log的逆向过程,其实它们都算是用来恢复的日志:1.redo log通常是物理日志,记录的是数据页的物理修改,而不是某一行或某几行修改成怎样怎样,它用来恢复提交后的物理数据页(恢复数据页,且只能恢复到最后一次提交的位置)。2.undo用来回滚行记录到某个版本。undo log一般是逻辑日志,根据每行记录进行记录。1.re...

Kafka集群环境搭建_程序员杂谈的博客-程序员信息网

3台虚拟机均进行以下操作:// 解压下载好的kafka压缩包并重命名cd /usr/localwget http://mirror.bit.edu.cn/apache/kafka/1.0.0/kafka_2.11-1.0.0.tgztar -zxvf kafka_2.11-1.0.0.tgzmv kafka_2.12-0.11.0.0 kafka// 修改配置文件vi ./kaf...

node.js前端项目无法启动_RRR-Richard的博客-程序员信息网

node.js前端项目无法启动1.最新版本node导致依赖无法完全下载 会报错fix audit需要完全卸载2.之前设置的npm config set prefix "…node_global"影响现有node方法是删除C:\Users\Administrator.npmrc这个文件。如果.npmrc不在这个目录下,就全局搜一下啦。3.前端代码拉取问题重新拉取前端代码重新拉取前端代码...

随便推点

hdu1085母函数模板_8435123的博客-程序员信息网

#include#include#includeusing namespace std;const int maxn=8000;int c1[maxn+10],c2[maxn+10],last,last2;int main(){ int money[4]={0,1,2,5},num[4],x; while(scanf("%d%d%d",&num[1],&num[2],&

SpringBoot上传图片(文件)到本机文件夹_曾某人啊的博客-程序员信息网

application.properties配置文件# maxFileSize 单个数据大小spring.servlet.multipart.maxFileSize=10MB# maxRequestSize 是总数据大小spring.servlet.multipart.maxRequestSize=100MB新建controller接口import org.springframework.web.bind.annotation.*;import org.springframework.

百度搜索API_江风引雨的博客-程序员信息网

最近需要做一个爬取新闻网站的项目,但考虑到各个网站的搜索api都不同,且准确性较低,所以我想到了百度的高级搜索功能,于是就对其搜索api探索了一番。以下就是我整理的百度api参数列表,可能不够全面,欢迎补充.百度搜索API基本链接1http://www.baidu.com/s?wd=关键字&amp;cl=类型&amp;pn=页码&amp;ie=gb2312&amp;rn=显示条数&am...

[译] 使用 PHPStorm 开发 Laravel 应用_weixin_33971130的博客-程序员信息网

很多PHP程序员使用 [laravel] 创建他们的应用程序。[laravel] 是一个免费开源的PHP web应用程序框架。它基于多个Symfony 组件,提供了一个开发框架,包括authentication, routing, sessions, caching 等模块.去年夏天, 我们介绍了 支持Blade 。blade 是Lar...

推荐文章

热门文章

相关标签