【深度学习】翻译:60分钟入门PyTorch(二)——Autograd自动求导-程序员宅基地

技术标签: 机器学习  深度学习  人工智能  神经网络  大数据  

前言

原文翻译自:Deep Learning with PyTorch: A 60 Minute Blitz

翻译:林不清(https://www.zhihu.com/people/lu-guo-92-42-88)

目录

60分钟入门PyTorch(一)——Tensors

60分钟入门PyTorch(二)——Autograd自动求导

60分钟入门Pytorch(三)——神经网络

60分钟入门PyTorch(四)——训练一个分类器

Autograd:自动求导

torch.autograd是pytorch自动求导的工具,也是所有神经网络的核心。我们首先先简单了解一下这个包如何训练神经网络。

背景介绍

神经网络(NNs)是作用在输入数据上的一系列嵌套函数的集合,这些函数由权重和误差来定义,被存储在PyTorch中的tensors中。
神经网络训练的两个步骤:
前向传播:在前向传播中,神经网络通过将接收到的数据与每一层对应的权重和误差进行运算来对正确的输出做出最好的预测。
反向传播:在反向传播中,神经网络调整其参数使得其与输出误差成比例。反向传播基于梯度下降策略,是链式求导法则的一个应用,以目标的负梯度方向对参数进行调整。
更加详细的介绍可以参照下述地址:

[3Blue1Brown]:

https://www.youtube.com/watch?v=tIeHLnjs5U8

Pytorch应用

来看一个简单的示例,我们从torchvision加载一个预先训练好的resnet18模型,接着创建一个随机数据tensor来表示一有3个通道、高度和宽度为64的图像,其对应的标签初始化为一些随机值。

%matplotlib inline
import torch, torchvision
model = torchvision.models.resnet18(pretrained=True)
data = torch.rand(1, 3, 64, 64)
labels = torch.rand(1, 1000)

接下来,我们将输入数据向输出方向传播到模型的每一层中来预测输出,这就是前向传播。

prediction = model(data) # 前向传播

我们利用模型的预测输出和对应的权重来计算误差,然后反向传播误差。完成计算后,您可以调用.backward()并自动计算所有梯度。此张量的梯度将累积到.grad属性中。

loss = (prediction - labels).sum()
loss.backward() # 反向传播

接着,我们加载一个优化器,在本例中,SGD的学习率为0.01,momentum 为0.9。我们在优化器中注册模型的所有参数。

optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)

最后,我们调用.step()来执行梯度下降,优化器通过存储在.grad中的梯度来调整每个参数。

optim.step() #梯度下降

现在,你已经具备了训练神经网络所需所有条件。下面几节详细介绍了Autograd包的工作原理——可以跳过它们。


Autograd中的求导

先来看一下autograd是如何收集梯度的。我们创建两个张量a和b并设置requires_grad = True以跟踪它的计算。

import torch

a = torch.tensor([2., 3.], requires_grad=True)
b = torch.tensor([6., 4.], requires_grad=True)

接着在ab的基础上创建张量Q

Q = 3*a**3 - b**2

假设ab是一个神经网络的权重,Q是它的误差,在神经网络训练中,我们需要w.r.t参数的误差梯度,即

当我们调用Q.backward()时,autograd计算这些梯度并把它们存储在张量的 .grad属性中。我们需要在Q.backward()中显式传递gradientgradient是一个与Q相同形状的张量,它表示Q w.r.t本身的梯度,即

同样,我们也可以将 Q聚合为一个标量并隐式向后调用,如 Q.sum().backward()
external_grad = torch.tensor([1., 1.])
Q.backward(gradient=external_grad)

现在梯度都被存放在a.gradb.grad

# 检查一下存储的梯度是否正确
print(9*a**2 == a.grad)
print(-2*b == b.grad)

可选阅读----用autograd进行向量计算

在数学上,如果你有一个向量值函数????⃗ =????(????⃗ ) ,则????⃗ 相对于????⃗ 的梯度是雅可比矩阵:

一般来说,torch.autograd是一个计算雅可比向量积的引擎。也就是说,给定任何向量????=(????1????2...????????)????,计算乘积????⋅????。如果????恰好是标量函数的梯度????=????(????⃗ ),即 然后根据链式法则,雅可比向量乘积将是????相对于????⃗ 的梯度

雅可比向量积的这种特性使得将外部梯度馈送到具有非标量输出的模型中非常方便。external_grad 代表 .

图计算

从概念上讲,autograd在由函数对象组成的有向无环图(DAG)中保存数据(tensor)和所有执行的操作(以及产生的新tensor)的记录,在这个DAG中,叶节点是输入数据,根节点是输出数据,通过从根节点到叶节点跟踪这个图,您可以使用链式法则自动计算梯度。

在前向传播中,autograd同时完成两件事情:

  • 运行所请求的操作来计算结果tensor

  • 保持DAG中操作的梯度

在反向传播中,当在DAG根节点上调用.backward()时,反向传播启动,autograd接下来完成:

  • 计算每一个.grad_fn的梯度

  • 将它们累加到各自张量的.grad属性中

  • 利用链式法则,一直传播到叶节点

下面是DAG的可视化表示的示例。图中,箭头表示前向传播的方向,节点表示向前传递中每个操作的向后函数。蓝色标记的叶节点代表叶张量 ab


注意

DAG在PyTorch中是动态的。值得注意的是图是重新开始创建的; 在调用每一个``.backward()``后,autograd开始填充一个新图,这就是能够在模型中使用控制流语句的原因。你可以根据需求在每次迭代时更改形状、大小和操作。

torch.autograd追踪所有requires_gradTrue的张量的相关操作。对于不需要梯度的张量,将此属性设置为False将其从梯度计算DAG中排除。操作的输出张量将需要梯度,即使只有一个输入张量requires_grad=True

x = torch.rand(5, 5)
y = torch.rand(5, 5)
z = torch.rand((5, 5), requires_grad=True)

a = x + y
print(f"Does `a` require gradients? : {a.requires_grad}")
b = x + z
print(f"Does `b` require gradients?: {b.requires_grad}")

在神经网络中,不计算梯度的参数通常称为冻结参数。如果您事先知道您不需要这些参数的梯度,那么“冻结”部分模型是很有用的(这通过减少autograd计算带来一些性能好处)。另外一个常见的用法是微调一个预训练好的网络,在微调的过程中,我们冻结大部分模型——通常,只修改分类器来对新的<标签>做出预测,让我们通过一个小示例来演示这一点。与前面一样,我们加载一个预先训练好的resnet18模型,并冻结所有参数。

from torch import nn, optim

model = torchvision.models.resnet18(pretrained=True)

# 冻结网络中所有的参数
for param in model.parameters():
    param.requires_grad = False

假设我们想在一个有10个标签的新数据集上微调模型。在resnet中,分类器是最后一个线性层模型model.fc。我们可以简单地用一个新的线性层(默认未冻结)代替它作为我们的分类器。

model.fc = nn.Linear(512, 10)

现在除了model.fc的参数外,模型的其他参数均被冻结,参与计算的参数是model.fc的权值和偏置。

# 只优化分类器
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

注意,尽管我们注册了优化器中所有参数,但唯一参与梯度计算(并因此在梯度下降中更新)的参数是分类器的权值和偏差。torch.no_grad()中也具有相同的功能。

拓展阅读

  • [就地修改操作以及多线程Autograd]:(https://pytorch.org/docs/stable/notes/autograd.html)

  • [反向模式autodiff的示例]:(https://colab.research.google.com/drive/1VpeE6UvEPRz9HmsHh1KS0XxXjYu533EC)

往期精彩回顾



适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑
本站知识星球“黄博的机器学习圈子”(92416895)
本站qq群704220115。
加入微信群请扫码:

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/fengdu78/article/details/113488470

智能推荐

c# 调用c++ lib静态库_c#调用lib-程序员宅基地

文章浏览阅读2w次,点赞7次,收藏51次。四个步骤1.创建C++ Win32项目动态库dll 2.在Win32项目动态库中添加 外部依赖项 lib头文件和lib库3.导出C接口4.c#调用c++动态库开始你的表演...①创建一个空白的解决方案,在解决方案中添加 Visual C++ , Win32 项目空白解决方案的创建:添加Visual C++ , Win32 项目这......_c#调用lib

deepin/ubuntu安装苹方字体-程序员宅基地

文章浏览阅读4.6k次。苹方字体是苹果系统上的黑体,挺好看的。注重颜值的网站都会使用,例如知乎:font-family: -apple-system, BlinkMacSystemFont, Helvetica Neue, PingFang SC, Microsoft YaHei, Source Han Sans SC, Noto Sans CJK SC, W..._ubuntu pingfang

html表单常见操作汇总_html表单的处理程序有那些-程序员宅基地

文章浏览阅读159次。表单表单概述表单标签表单域按钮控件demo表单标签表单标签基本语法结构<form action="处理数据程序的url地址“ method=”get|post“ name="表单名称”></form><!--action,当提交表单时,向何处发送表单中的数据,地址可以是相对地址也可以是绝对地址--><!--method将表单中的数据传送给服务器处理,get方式直接显示在url地址中,数据可以被缓存,且长度有限制;而post方式数据隐藏传输,_html表单的处理程序有那些

PHP设置谷歌验证器(Google Authenticator)实现操作二步验证_php otp 验证器-程序员宅基地

文章浏览阅读1.2k次。使用说明:开启Google的登陆二步验证(即Google Authenticator服务)后用户登陆时需要输入额外由手机客户端生成的一次性密码。实现Google Authenticator功能需要服务器端和客户端的支持。服务器端负责密钥的生成、验证一次性密码是否正确。客户端记录密钥后生成一次性密码。下载谷歌验证类库文件放到项目合适位置(我这边放在项目Vender下面)https://github.com/PHPGangsta/GoogleAuthenticatorPHP代码示例://引入谷_php otp 验证器

【Python】matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距-程序员宅基地

文章浏览阅读4.3k次,点赞5次,收藏11次。matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距

docker — 容器存储_docker 保存容器-程序员宅基地

文章浏览阅读2.2k次。①Storage driver 处理各镜像层及容器层的处理细节,实现了多层数据的堆叠,为用户 提供了多层数据合并后的统一视图②所有 Storage driver 都使用可堆叠图像层和写时复制(CoW)策略③docker info 命令可查看当系统上的 storage driver主要用于测试目的,不建议用于生成环境。_docker 保存容器

随便推点

网络拓扑结构_网络拓扑csdn-程序员宅基地

文章浏览阅读834次,点赞27次,收藏13次。网络拓扑结构是指计算机网络中各组件(如计算机、服务器、打印机、路由器、交换机等设备)及其连接线路在物理布局或逻辑构型上的排列形式。这种布局不仅描述了设备间的实际物理连接方式,也决定了数据在网络中流动的路径和方式。不同的网络拓扑结构影响着网络的性能、可靠性、可扩展性及管理维护的难易程度。_网络拓扑csdn

JS重写Date函数,兼容IOS系统_date.prototype 将所有 ios-程序员宅基地

文章浏览阅读1.8k次,点赞5次,收藏8次。IOS系统Date的坑要创建一个指定时间的new Date对象时,通常的做法是:new Date("2020-09-21 11:11:00")这行代码在 PC 端和安卓端都是正常的,而在 iOS 端则会提示 Invalid Date 无效日期。在IOS年月日中间的横岗许换成斜杠,也就是new Date("2020/09/21 11:11:00")通常为了兼容IOS的这个坑,需要做一些额外的特殊处理,笔者在开发的时候经常会忘了兼容IOS系统。所以就想试着重写Date函数,一劳永逸,避免每次ne_date.prototype 将所有 ios

如何将EXCEL表导入plsql数据库中-程序员宅基地

文章浏览阅读5.3k次。方法一:用PLSQL Developer工具。 1 在PLSQL Developer的sql window里输入select * from test for update; 2 按F8执行 3 打开锁, 再按一下加号. 鼠标点到第一列的列头,使全列成选中状态,然后粘贴,最后commit提交即可。(前提..._excel导入pl/sql

Git常用命令速查手册-程序员宅基地

文章浏览阅读83次。Git常用命令速查手册1、初始化仓库git init2、将文件添加到仓库git add 文件名 # 将工作区的某个文件添加到暂存区 git add -u # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,不处理untracked的文件git add -A # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,包括untracked的文件...

分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120-程序员宅基地

文章浏览阅读202次。分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120

【C++缺省函数】 空类默认产生的6个类成员函数_空类默认产生哪些类成员函数-程序员宅基地

文章浏览阅读1.8k次。版权声明:转载请注明出处 http://blog.csdn.net/irean_lau。目录(?)[+]1、缺省构造函数。2、缺省拷贝构造函数。3、 缺省析构函数。4、缺省赋值运算符。5、缺省取址运算符。6、 缺省取址运算符 const。[cpp] view plain copy_空类默认产生哪些类成员函数

推荐文章

热门文章

相关标签