论文阅读 Vision Transformer - VIT_vision transformer论文-程序员宅基地

技术标签: 论文阅读  深度学习  transformer  papers  

1 摘要

1.1 核心

通过将图像切成patch线形层编码成token特征编码的方法,用transformer的encoder来做图像分类

2 模型架构

2.1 概览

在这里插入图片描述

2.2 对应CV的特定修改和相关理解

解决问题:

  1. transformer输入限制: 由于自注意力+backbone,算法复杂度为o(n²),token长度一般要<512才足够运算
    解决:a) 将图片转为token输入 b) 将特征图转为token输入 c)√ 切patch转为token输入
  2. transformer无先验知识:卷积存在平移不变性(同特征同卷积核同结果)和局部相似性(相邻特征相似结果),
    而transformer无卷积核概念,只有整个编解码器,需要从头学
    解决:大量数据训练
  3. cv的各种自注意力机制需要复杂工程实现:
    解决:直接用整个transformer模块
  4. 分类head:
    解决:直接沿用transformer cls token
  5. position编码:
    解决:1D编码

pipeline:
224x224输入切成16x16patch进行位置编码和线性编码后增加cls token 一起输入的encoder encoder中有L个selfattention模块
输出的cls token为目标类别

3 代码

如果理解了transformer,看完这个结构感觉真的很简单,这篇论文也只是开山之作,没有特别复杂的结构,所以想到代码里看看。

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        # linear(1024 , 3072)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        # [1, 65, 1024]
        x = self.norm(x)
        # [1, 65, 1024]
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        # self.to_qkv(x)                [1, 65, 3072]
        # self.to_qkv(x).chunk(3,-1)    [3, 1, 65, 1024]
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        # q,k,v                         [1, 65, 1024] -> [1, 16, 65, 64]
        # 把 65个1024的特征分为 heads个65个d维的特征 然后每个heads去分别有自己要处理的隐藏层,对不同的特征建立不同学习能力
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        # [1, 16, 65, 64] * [1, 16, 64, 65] -> [1, 16, 65, 65]
        # scale 保证在softmax前所有的值都不太大

        attn = self.attend(dots)
        # softmax [1, 16, 65, 65]
        
        attn = self.dropout(attn)
        # dropout [1, 16, 65, 65]
        
        out = torch.matmul(attn, v)
        # out [1, 16, 65, 64]
        
        out = rearrange(out, 'b h n d -> b n (h d)')
        # out [1, 65, 1024]
        
        return self.to_out(out)
        # out [1, 65, 1024]
        

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        # [1, 65, 1024]
        for attn, ff in self.layers:
            # [1, 65, 1024]
            x = attn(x) + x
            # [1, 65, 1024]
            x = ff(x) + x

        # [1, 65, 1024]
        return self.norm(x)
        # shape不会改变

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {
    'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # num_patches   64
        # patch_dim     3072
        # dim           1024
        self.to_patch_embedding = nn.Sequential(
            #Rearrange是einops中的一个方法
            # einops:灵活和强大的张量操作,可读性强和可靠性好的代码。支持numpy、pytorch、tensorflow等。
            # 代码中Rearrage的意思是将传入的image(3,224,224),按照(3,(h,p1),(w,p2))也就是224=hp1,224 = wp2,接着把shape变成b (h w) (p1 p2 c)格式的,这样把图片分成了每个patch并且将patch拉长,方便下一步的全连接层
            # 还有一种方法是采用窗口为16*16,stride 16的卷积核提取每个patch,然后再flatten送入全连接层。
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        # 1. [1, 3, 256, 256]       输入img
        x = self.to_patch_embedding(img)
        # 2. [1, 64, 1024]          patch embd
        b, n, _ = x.shape
        # 3. [1, 1, 1024]           cls_tokens
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        # 4. [1, 65, 1024]          cat [cls_tokens, x]
        x = torch.cat((cls_tokens, x), dim=1)
        # 5. [1, 65, 1024]          add [x] [pos_embedding]
        x += self.pos_embedding[:, :(n + 1)]
        # 6. [1, 65, 1024]          dropout
        x = self.dropout(x)
        # 7. [1, 65, 1024]          N * transformer
        x = self.transformer(x)
        # 8. [1,1024]               cls_x output
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
        # 9. [1,1024]               cls_x output mean
        x = self.to_latent(x)
        # 10.[1,1024]               nn.Identity()不改变输入和输出 占位层
        return self.mlp_head(x)
        # 11.[1,cls]                mlp_cls_head

4 总结

multihead和我原有的理解偏差修正。
我以为的是QKV会有N块相同的copy(),每一份去做后续的linear等操作。
代码里是直接用linear将QKV分为一整个大块,用permute/rearrange的操作切成了N块,f(Q,K)之后再恢复成一整个大块,很强。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/highoooo/article/details/135487767

智能推荐

JWT(Json Web Token)实现无状态登录_无状态token登录-程序员宅基地

文章浏览阅读685次。1.1.什么是有状态?有状态服务,即服务端需要记录每次会话的客户端信息,从而识别客户端身份,根据用户身份进行请求的处理,典型的设计如tomcat中的session。例如登录:用户登录后,我们把登录者的信息保存在服务端session中,并且给用户一个cookie值,记录对应的session。然后下次请求,用户携带cookie值来,我们就能识别到对应session,从而找到用户的信息。缺点是什么?服务端保存大量数据,增加服务端压力 服务端保存用户状态,无法进行水平扩展 客户端请求依赖服务.._无状态token登录

SDUT OJ逆置正整数-程序员宅基地

文章浏览阅读293次。SDUT OnlineJudge#include<iostream>using namespace std;int main(){int a,b,c,d;cin>>a;b=a%10;c=a/10%10;d=a/100%10;int key[3];key[0]=b;key[1]=c;key[2]=d;for(int i = 0;i<3;i++){ if(key[i]!=0) { cout<<key[i.

年终奖盲区_年终奖盲区表-程序员宅基地

文章浏览阅读2.2k次。年终奖采用的平均每月的收入来评定缴税级数的,速算扣除数也按照月份计算出来,但是最终减去的也是一个月的速算扣除数。为什么这么做呢,这样的收的税更多啊,年终也是一个月的收入,凭什么减去12*速算扣除数了?这个霸道(不要脸)的说法,我们只能合理避免的这些跨级的区域了,那具体是那些区域呢?可以参考下面的表格:年终奖一列标红的一对便是盲区的上下线,发放年终奖的数额一定一定要避免这个区域,不然公司多花了钱..._年终奖盲区表

matlab 提取struct结构体中某个字段所有变量的值_matlab读取struct类型数据中的值-程序员宅基地

文章浏览阅读7.5k次,点赞5次,收藏19次。matlab结构体struct字段变量值提取_matlab读取struct类型数据中的值

Android fragment的用法_android reader fragment-程序员宅基地

文章浏览阅读4.8k次。1,什么情况下使用fragment通常用来作为一个activity的用户界面的一部分例如, 一个新闻应用可以在屏幕左侧使用一个fragment来展示一个文章的列表,然后在屏幕右侧使用另一个fragment来展示一篇文章 – 2个fragment并排显示在相同的一个activity中,并且每一个fragment拥有它自己的一套生命周期回调方法,并且处理它们自己的用户输_android reader fragment

FFT of waveIn audio signals-程序员宅基地

文章浏览阅读2.8k次。FFT of waveIn audio signalsBy Aqiruse An article on using the Fast Fourier Transform on audio signals. IntroductionThe Fast Fourier Transform (FFT) allows users to view the spectrum content of _fft of wavein audio signals

随便推点

Awesome Mac:收集的非常全面好用的Mac应用程序、软件以及工具_awesomemac-程序员宅基地

文章浏览阅读5.9k次。https://jaywcjlove.github.io/awesome-mac/ 这个仓库主要是收集非常好用的Mac应用程序、软件以及工具,主要面向开发者和设计师。有这个想法是因为我最近发了一篇较为火爆的涨粉儿微信公众号文章《工具武装的前端开发工程师》,于是建了这么一个仓库,持续更新作为补充,搜集更多好用的软件工具。请Star、Pull Request或者使劲搓它 issu_awesomemac

java前端技术---jquery基础详解_简介java中jquery技术-程序员宅基地

文章浏览阅读616次。一.jquery简介 jQuery是一个快速的,简洁的javaScript库,使用户能更方便地处理HTML documents、events、实现动画效果,并且方便地为网站提供AJAX交互 jQuery 的功能概括1、html 的元素选取2、html的元素操作3、html dom遍历和修改4、js特效和动画效果5、css操作6、html事件操作7、ajax_简介java中jquery技术

Ant Design Table换滚动条的样式_ant design ::-webkit-scrollbar-corner-程序员宅基地

文章浏览阅读1.6w次,点赞5次,收藏19次。我修改的是表格的固定列滚动而产生的滚动条引用Table的组件的css文件中加入下面的样式:.ant-table-body{ &amp;amp;::-webkit-scrollbar { height: 5px; } &amp;amp;::-webkit-scrollbar-thumb { border-radius: 5px; -webkit-box..._ant design ::-webkit-scrollbar-corner

javaWeb毕设分享 健身俱乐部会员管理系统【源码+论文】-程序员宅基地

文章浏览阅读269次。基于JSP的健身俱乐部会员管理系统项目分享:见文末!

论文开题报告怎么写?_开题报告研究难点-程序员宅基地

文章浏览阅读1.8k次,点赞2次,收藏15次。同学们,是不是又到了一年一度写开题报告的时候呀?是不是还在为不知道论文的开题报告怎么写而苦恼?Take it easy!我带着倾尽我所有开题报告写作经验总结出来的最强保姆级开题报告解说来啦,一定让你脱胎换骨,顺利拿下开题报告这个高塔,你确定还不赶快点赞收藏学起来吗?_开题报告研究难点

原生JS 与 VUE获取父级、子级、兄弟节点的方法 及一些DOM对象的获取_获取子节点的路径 vue-程序员宅基地

文章浏览阅读6k次,点赞4次,收藏17次。原生先获取对象var a = document.getElementById("dom");vue先添加ref <div class="" ref="divBox">获取对象let a = this.$refs.divBox获取父、子、兄弟节点方法var b = a.childNodes; 获取a的全部子节点 var c = a.parentNode; 获取a的父节点var d = a.nextSbiling; 获取a的下一个兄弟节点 var e = a.previ_获取子节点的路径 vue

推荐文章

热门文章

相关标签