Transformer入门-轨迹预测实例解析_transformer轨迹预测-程序员宅基地

技术标签: 深度学习  transformer  人工智能  

最近动手玩了一下Transformer,找到了一个很适合练手的小例子,基于https://github.com/cxl-ustb/AISTransforemr的代码做了一些修改(感谢原作者),改进后的代码地址:GitHub - BITcsy/AISTransformer: 利用transformer进行船舶轨迹预测。


1. 任务简介:

该代码功能是处理船只的轨迹、状态预测(经度,维度,速度,朝向)。每条数据涵盖11个点,输入是完整的11个点(Encoder输入前10个点,Decoder输入后10个点,模型整体输出后10个点),如下图,训练数据140条,测试数据160条。整个任务本身并没有什么意义(已知轨迹再输出部分轨迹),并没有做什么预测任务。不过整体例子简单明了,调试难度小,比较适合做入门练手的例子。

 2. 模型结构:

原作者实现了比较完整的Transformer结构,当然对于轨迹预测类任务很多结构是冗余的,因为Decoder那部分输入在语音识别上可能有用,但是在轨迹预测上应该没有必要,不过可以学习一下。

(1)Transformer参考资料

这网上很多,这里引用几个个人觉得比较好的:

Transformer: NLP里的变形金刚 --- 详述 - 知乎

多图详解attention和mask。从循环神经网络、transformer到GPT2,我悟了 - 知乎

(2)本文模型结构及主要变量的对应图

费劲把这个代码里的整体数据流和主要的模型结构整理了下,可以看出来和最典型的Transformer的结构是一样的。其中标颜色的几个模块单独再打开来看吧,左下角的几个变量和word embedding及positional encoding相关,也单独来看。

(3)word embedding & positional encoding

word embedding参考资料:词嵌入向量(Word Embedding)的原理和生成方法 - 程序员大本营

nn.embedding: PyTorch中的nn.Embedding - 知乎

positional encoding: positional encoding位置编码详解:绝对位置与相对位置编码对比_夕小瑶的博客-程序员宅基地

Positional Encoding的原理和计算 - 知乎   有比较形象的图形解释。

总之,word embedding想要对序列化的输入(比如一句话,或者这里的一条轨迹)进行降维表达,把input映射成embedding。疑点1:nn.embedding这个模块貌似比较适合NLP任务这种比较稀疏的embedding表达,做轨迹类任务不一定合适。那代码中的n_src_vocab, n_trg_vocab应该就是对应于字典的长度,d_word_src, d_model应该就是隐层要用多少维去表达这些词。疑点2:按道理d_word_src应该小于n_src_vocab才对,但是代码里分别设置了512和500

positional encoding的作用是为了解决attention这种并行计算的结构丢失了RNN、LSTM这种具有先后顺序的网络特性,因为NLP或者轨迹预测类任务还是需要看特征的时序上的改变的,因此用positional encoding来区分不同的位置,结合其公式应该比较容易理解。代码中position设置为200,按道理这个数设置为大于最大序列长度的数就可以了(本代码最大序列长度就是10)。

word embedding和positional encoding这块的整体计算原理大概如下图,在这个代码里,d_word和d_model其实是一个意思,但是如果是其他场景,d_model的含义应该更广,毕竟是dimension of the model。

(4)注意力部分代码详解

class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)


    def forward(self, q, k, v, mask=None):

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        residual = q
        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: b x lq x n x dv
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        # Transpose for attention dot product: b x n x lq x dv
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)   # For head axis broadcasting.

        q, attn = self.attention(q, k, v, mask=mask)

        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        q = self.dropout(self.fc(q))
        q += residual

        q = self.layer_norm(q)

        return q, attn

根据上面的代码,画了一下attention详细的计算过程,如下图:

 几个需要注意的地方

(1)attention的mask是可选的,在时序任务里是需要的,因为时序是有因果性的。

(2)Q*K才是attention,它的维度是Batch*heads*lens*lens,也就意味着是有明确物理意义的,只和最初的序列长度有关系,可以可视化出来做分析用。

(3)ScaledDotProduct算完后,最后一维是由dv决定的,想回到q的维度,还需要做linear的操作。

(4)contiguous & view:代码里的操作实际就是把dv数据按照heads concate到了一起,具体为何需要使用contiguous的原因PyTorch中的contiguous - 知乎,其实可以用reshape来代替 PyTorch:view() 与 reshape() 区别详解_reshape和view_地球被支点撬走啦的博客-程序员宅基地

(5)residual:类似于Resnet的设计。

3. 增加的功能:

(1)DataLoader:原代码直接用的pickle来的raw data,为了更好的进行训练控制,我改写成了Dataloader的形式,实测同参数训练要比原来慢1/10到1/5左右。

(2)tqdm:给训练加了个进度条,方便监测比较细节的训练进度。

(3)tensorboard:方便调试(tensorboardX和tensorboard都要安装,tensorboard --logdir XXX)。

(4)模型断点训练:因为自己电脑是10年前的机器,性能很差,所以增加了一个断点续训的功能,避免每次都要从头训练。这里是用checkpoint形式写的保存细节信息。

(5)pt2onxx.py里面用netron可以查看网络结构,略改了一下,可以实现,不过感觉netron显示的太细节,比较难看出最主要的网络设计。

4. 调试笔记

(1)learning rate & 模型大小:原代码直接拿过来可以训,但是收敛不了,loss一直是好几千。分析出来2个主要问题:

i. 原learning rate做了特殊设计,但是用tessorboard看,原learning rate是随训练次数上升的,很令人困惑,所以改写逐步下降的形式。

ii. 改了learning rate还是收敛不了的,分析到训练数据只有140条,但是又搞了个这么大的模型,可能根本喂不饱网络,所以直接把attention部分的大小改小了,multiheadattention的head数8->2,attention layer数6->1,这样改完网络收敛了,至少在训练集上可以过拟合,loss到1左右。

 (2)模型保存

关于模型保存,有一类需求,就是本地需要load一个已有的模型pth文件,然后接着训练调试。有一个很奇怪的地方,当我load一个已存的pth文件后,模型做forward的时候,凭什么知道找我源码里同名model的forward函数?我尝试改了一下模型的类名(把Transformer改成了Transformers),果然test的时候就报错了,所以在load模型后,会自动找同名模型的实现:

return super().find_class(mod_name, name)
AttributeError: Can't get attribute 'Transformer' on <module 'transformer.Models' from '/home/XXX/AISTransformer/AISTransformer/transformer/Models.py'>

(3)训练集表现 & 关于model.eval的问题

在测试的过程中发现了一个很奇怪的现象,就是如果我把训好的模型save后,再重新load上来,用训练集做测试,loss差距很大(训练的loss是1.x,但是重新load的模型在同数据集下loss是1000+),做了几个尝试

(4)attention & 测试集表现

在训练集上loss约为1,在测试集上测试loss是2k-3k,泛化能力很差,当然数据太少是一个比较显著的问题。为了在小数据集上提升模型泛化能力,可以通过attention来加入一些先验信息引导模型训练。可能的尝试:

  • 分析attention:这个轨迹预测的任务,因为直接拷贝trg_seq就可以了,所以对于decoder部分的self attention应该有非常明显的对应位的attention的结果,(encoder的self att和deocder的cross att不一定有这种结果,因为可能的关系已经比较间接了)
  • attention mask:用mask的方式约束训练
  • attention loss(guide attention):用loss的方式引导训练

(5)Shuffle的问题:因为发现测试集上表现不好,考虑到原代码训练时并没有shuffle,可能会让网络学到一些关于顺序的东西(https://zhuanlan.zhihu.com/p/57108650),所以在引入Dataloader,加了下shuffle,但是训出来感觉测试集上差别不大。

(6)结构对比:做这个简单的任务用这么复杂的模型可能往往是难以达到较好的效果的,还有用了例如word embedding这样的结构,看上去比较奇怪。用MLP简单实现了下这个网络。

  • 直接实现了一个4层的MLP:各层维度4,25,50,25,4,输入只用transformer的encoder部分的10*4的输入,因为网络比较小,训练速度会提高很多,30000epochs 2min训练完,但是训练是比较难收敛的,30000epochs的train loss仍然在550+,但是test的error也是550左右。在此感谢@追梦5号的建议,预测轨迹的时序信息才是核心信息,所以各层维度调整为10,25,50,25,10,效果要比之前好
  • 如果和transformer一样,把11个坐标点分成两段输入,那train loss直接0.01(当然这个任务也没什么意义),test error 0.03。
  • 所以几个个模型结构对比下,在本例子给的任务下,如果MLP模型和transformer同样的输入,小MLP网络的收敛性和泛化性要比transformer好很多,合理。
  • 输入 epochs 训练时长 train loss test loss
    transformer

    1. traj[:,:-1,:]

    2. traj[:,1:,:]

    8000 >1h@cpu(11min@GTX1080) 1.0+ 2k+
    MLP traj[:,-1:,:] 30000 2min 550+ 550+
    MLP traj[:,:,-1:] 30000 2min 20+ 20+
    MLP

    1. traj[:,:-1,:]

    2. traj[:,1:,:]

    30000 2min 0.01 0.03

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/BIT_csy/article/details/129971757

智能推荐

JWT(Json Web Token)实现无状态登录_无状态token登录-程序员宅基地

文章浏览阅读685次。1.1.什么是有状态?有状态服务,即服务端需要记录每次会话的客户端信息,从而识别客户端身份,根据用户身份进行请求的处理,典型的设计如tomcat中的session。例如登录:用户登录后,我们把登录者的信息保存在服务端session中,并且给用户一个cookie值,记录对应的session。然后下次请求,用户携带cookie值来,我们就能识别到对应session,从而找到用户的信息。缺点是什么?服务端保存大量数据,增加服务端压力 服务端保存用户状态,无法进行水平扩展 客户端请求依赖服务.._无状态token登录

SDUT OJ逆置正整数-程序员宅基地

文章浏览阅读293次。SDUT OnlineJudge#include<iostream>using namespace std;int main(){int a,b,c,d;cin>>a;b=a%10;c=a/10%10;d=a/100%10;int key[3];key[0]=b;key[1]=c;key[2]=d;for(int i = 0;i<3;i++){ if(key[i]!=0) { cout<<key[i.

年终奖盲区_年终奖盲区表-程序员宅基地

文章浏览阅读2.2k次。年终奖采用的平均每月的收入来评定缴税级数的,速算扣除数也按照月份计算出来,但是最终减去的也是一个月的速算扣除数。为什么这么做呢,这样的收的税更多啊,年终也是一个月的收入,凭什么减去12*速算扣除数了?这个霸道(不要脸)的说法,我们只能合理避免的这些跨级的区域了,那具体是那些区域呢?可以参考下面的表格:年终奖一列标红的一对便是盲区的上下线,发放年终奖的数额一定一定要避免这个区域,不然公司多花了钱..._年终奖盲区表

matlab 提取struct结构体中某个字段所有变量的值_matlab读取struct类型数据中的值-程序员宅基地

文章浏览阅读7.5k次,点赞5次,收藏19次。matlab结构体struct字段变量值提取_matlab读取struct类型数据中的值

Android fragment的用法_android reader fragment-程序员宅基地

文章浏览阅读4.8k次。1,什么情况下使用fragment通常用来作为一个activity的用户界面的一部分例如, 一个新闻应用可以在屏幕左侧使用一个fragment来展示一个文章的列表,然后在屏幕右侧使用另一个fragment来展示一篇文章 – 2个fragment并排显示在相同的一个activity中,并且每一个fragment拥有它自己的一套生命周期回调方法,并且处理它们自己的用户输_android reader fragment

FFT of waveIn audio signals-程序员宅基地

文章浏览阅读2.8k次。FFT of waveIn audio signalsBy Aqiruse An article on using the Fast Fourier Transform on audio signals. IntroductionThe Fast Fourier Transform (FFT) allows users to view the spectrum content of _fft of wavein audio signals

随便推点

Awesome Mac:收集的非常全面好用的Mac应用程序、软件以及工具_awesomemac-程序员宅基地

文章浏览阅读5.9k次。https://jaywcjlove.github.io/awesome-mac/ 这个仓库主要是收集非常好用的Mac应用程序、软件以及工具,主要面向开发者和设计师。有这个想法是因为我最近发了一篇较为火爆的涨粉儿微信公众号文章《工具武装的前端开发工程师》,于是建了这么一个仓库,持续更新作为补充,搜集更多好用的软件工具。请Star、Pull Request或者使劲搓它 issu_awesomemac

java前端技术---jquery基础详解_简介java中jquery技术-程序员宅基地

文章浏览阅读616次。一.jquery简介 jQuery是一个快速的,简洁的javaScript库,使用户能更方便地处理HTML documents、events、实现动画效果,并且方便地为网站提供AJAX交互 jQuery 的功能概括1、html 的元素选取2、html的元素操作3、html dom遍历和修改4、js特效和动画效果5、css操作6、html事件操作7、ajax_简介java中jquery技术

Ant Design Table换滚动条的样式_ant design ::-webkit-scrollbar-corner-程序员宅基地

文章浏览阅读1.6w次,点赞5次,收藏19次。我修改的是表格的固定列滚动而产生的滚动条引用Table的组件的css文件中加入下面的样式:.ant-table-body{ &amp;amp;::-webkit-scrollbar { height: 5px; } &amp;amp;::-webkit-scrollbar-thumb { border-radius: 5px; -webkit-box..._ant design ::-webkit-scrollbar-corner

javaWeb毕设分享 健身俱乐部会员管理系统【源码+论文】-程序员宅基地

文章浏览阅读269次。基于JSP的健身俱乐部会员管理系统项目分享:见文末!

论文开题报告怎么写?_开题报告研究难点-程序员宅基地

文章浏览阅读1.8k次,点赞2次,收藏15次。同学们,是不是又到了一年一度写开题报告的时候呀?是不是还在为不知道论文的开题报告怎么写而苦恼?Take it easy!我带着倾尽我所有开题报告写作经验总结出来的最强保姆级开题报告解说来啦,一定让你脱胎换骨,顺利拿下开题报告这个高塔,你确定还不赶快点赞收藏学起来吗?_开题报告研究难点

原生JS 与 VUE获取父级、子级、兄弟节点的方法 及一些DOM对象的获取_获取子节点的路径 vue-程序员宅基地

文章浏览阅读6k次,点赞4次,收藏17次。原生先获取对象var a = document.getElementById("dom");vue先添加ref <div class="" ref="divBox">获取对象let a = this.$refs.divBox获取父、子、兄弟节点方法var b = a.childNodes; 获取a的全部子节点 var c = a.parentNode; 获取a的父节点var d = a.nextSbiling; 获取a的下一个兄弟节点 var e = a.previ_获取子节点的路径 vue