VisionTransformer(ViT)

更新时间:2023-07-07 08:41:36 阅读: 评论:0

VisionTransformer(ViT)
⽂章⽬录
论⽂链接:
葱油面的家常做法⼀、ViT整体结构
结构简单说明
⾸先关注ViT的输⼊
⼀张图⽚会被分成⼀个个⼩的patch,如ViT-L/16表⽰每个patch⼤⼩为16×16,然后将每个patch输⼊到Embedding层(Linear Projection of Flattened Patches),通过Embedding层后可以得到其对应的向量,称为token,图中⼀张图⽚划分为9个patches,在经过Embedding层后得到了9个token embedding。
紧接着,我们会在这⼀系列token的最前⾯加上⼀个新的token *(class token),它的维度与前⾯得到的是⼀样的。
原始的token加⼊class token与位置信息后,将其输⼊到transformer encoder。
transformer encoder的结构如下:
ViT的transformer encoder的操作是将Encoder Block重复堆叠了L次,然后提取class token对应的输出输⼊到如下的MLP Head中进⾏分类,最后得到分类结果。
⼆、ViT分解说明
根据上⾯的阐述,可以看出整个ViT可以分为三⼤部分
Embedding层
Encoder
MLP Head ⽤于分类
处理流程为
1. 将图⽚切分为patch
2. patch转换为embedding
3. 位置embedding和token embedding相加
4. 输⼊到ViT模型
5. CLS输出做多分类
Embedding层
对于标准的transformer模块,它接收的是token embedding向量,变化过程如下图1、2、3标注
抗战胜利的意义
对于编码部分,共有三个操作
1. ⽣成class符号的token(图中*标记)
2. ⽣成所有序列的位置编码(图中淡紫⾊)
3. token embedding + 位置编码
图中⾸先将原始图⽚变换为多个patch,每个patch⼤⼩为3×16×16。再将其展平为token embedding,维度为768,patch转换为embedding需要两个操作:
将patch拉平
将patch拉平后的维度映射到 encoder需要的维度
在这⼀系列embedding的⾸部加⼊cls token,然后⽣成位置编码,并将位置编码与token embedding相加得到最终的输⼊embedding。
关于位置编码:
在transformer中,编码器是并⾏输⼊的,不会等待之前信息的输出情况,所以需要位置编码提供信息的位置信息,在ViT中,表⽰图像patch的前后信息。
Encoder
ViT的Encoder模块与原始transformer中的类似。
根据论⽂中Encoder,结合具体实现可得出Encoder Block
与原始transformer的Encoder输⼊⽐较
将LN操作提前了,同时,因为将图⽚切分为patches,保证patch的⼤⼩⼀致,所以没有了padding操作。
MLP Head
MLP Head结合代码实现来理解会清晰很多
class Mlp(nn.Module):
"""
MLP as ud in Vision Transformer, MLP-Mixer and related networks
"""
def__init__(lf, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
lf.fc1 = nn.Linear(in_features, hidden_features)
lf.act = act_layer()
lf.fc2 = nn.Linear(hidden_features, out_features)
lf.drop = nn.Dropout(drop)
def forward(lf, x):
x = lf.fc1(x)
希波战争
x = lf.act(x)
x = lf.drop(x)
x = lf.fc2(x)
x = lf.drop(x)
return x
可以看到,MLP仅由GELU激活函数、全连接层和DropOut层组成,作⽤就是对Encoder的输出进⾏多分类处理。
三、ViT简洁实现
对⼏个关键模块结构进⾏解释。
Attention
attention模块与transformer类似,实现多头注意⼒multi head机制,在forward函数中,通过to_qkv和chunk函数⼀次⽣成总体的Q、K、V矩阵,再划分为多头注意⼒的q、k、v,这⼀点与原始transformer梦见熟人
不同,原始transformer是通过Linear层各⾃⽣成Q、K、V,这个差别的原因在于ViT⽆需解码。
class Attention(nn.Module):
def__init__(lf, dim, heads =8, dim_head =64, dropout =0.):
super().__init__()
inner_dim = dim_head *  heads
project_out =not(heads ==1and dim_head == dim)
lf.heads = heads
lf.scale = dim_head **-0.5
lf.attend = nn.Softmax(dim =-1)
<_qkv = nn.Linear(dim, inner_dim *3, bias =Fal)
<_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)if project_out el nn.Identity()
def forward(lf, x):
# 对tensor张量分块 eg. x :1 197 1024
# 通过to_qkv操作将维度提升⾄原维度的3倍
淳的读音
# qkv 最后是⼀个元组,tuple,长度是3,每个元素形状:1 197 1024
qkv = lf.to_qkv(x).chunk(3, dim =-1)
print('qkv is ', qkv)
# 将q,k,v矩阵分头(multi head)
q, k, v =map(lambda t: rearrange(t,'b n (h d) -> b h n d', h = lf.heads), qkv)
# attention计算
dots = torch.matmul(q, k.transpo(-1,-2))* lf.scale
attn = lf.attend(dots)
# 与V矩阵相乘
out = torch.matmul(attn, v)
out = rearrange(out,'b h n d -> b n (h d)')
_out(out)
对⽐原始transformer的多头注意⼒机制:
class MultiHeadAttention(nn.Module):
def__init__(lf):
super(MultiHeadAttention, lf).__init__()
# 输⼊进来的QKV是相等的,使⽤映射Linear做⼀个映射分别得到参数矩阵Wq, Wk,Wv
lf.W_Q = nn.Linear(d_model, d_k * n_heads)
lf.W_K = nn.Linear(d_model, d_k * n_heads)
lf.W_V = nn.Linear(d_model, d_v * n_heads)
lf.linear = nn.Linear(n_heads * d_v, d_model)
lf.layer_norm = nn.LayerNorm(d_model)
transformer
遵循论⽂中架构,堆叠L个Encoder。
class Transformer(nn.Module):
def__init__(lf, dim, depth, heads, dim_head, mlp_dim, dropout =0.):
super().__init__()
lf.layers = nn.ModuleList([])
# 堆叠depth个encoder
for _ in range(depth):
# 根据论⽂结构中搭建
lf.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(lf, x):
for attn, ff in lf.layers:
x = attn(x)+ x
x = ff(x)+ x
return x
ViT
ViT实现图⽚分割为patch、patch展平并添加位置编码,同时映射为输⼊embedding,并对各个模块进⾏组装,代码见下⽅完整代码部分。
完整代码
import torch
from torch import nn
from einops import rearrange, repeat
from h import Rearrange
# helpers
def pair(t):
return t if isinstance(t,tuple)el(t, t)
# class
class PreNorm(nn.Module):
def__init__(lf, dim, fn):
super().__init__()
< = nn.LayerNorm(dim)
lf.fn = fn
def forward(lf, x,**kwargs):
return lf.(x),**kwargs)
class FeedForward(nn.Module):
def__init__(lf, dim, hidden_dim, dropout =0.):
super().__init__()
lf = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(lf, x):
return lf(x)
class Attention(nn.Module):
def__init__(lf, dim, heads =8, dim_head =64, dropout =0.):
super().__init__()
inner_dim = dim_head *  heads
project_out =not(heads ==1and dim_head == dim)
lf.heads = heads
lf.scale = dim_head **-0.5
lf.attend = nn.Softmax(dim =-1)
<_qkv = nn.Linear(dim, inner_dim *3, bias =Fal)
<_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)if project_out el nn.Identity()
def forward(lf, x):
利希慎家族# 对tensor张量分块 x :1 197 1024
# 通过to_qkv操作将维度提升⾄原维度的3倍
# qkv 最后是⼀个元组,tuple,长度是3,每个元素形状:1 197 1024
qkv = lf.to_qkv(x).chunk(3, dim =-1)
print('qkv is ', qkv)
# 将q,k,v矩阵分头(multi head)
q, k, v =map(lambda t: rearrange(t,'b n (h d) -> b h n d', h = lf.heads), qkv)
# attention计算
dots = torch.matmul(q, k.transpo(-1,-2))* lf.scale
attn = lf.attend(dots)
# 与V矩阵相乘
艺术摄影out = torch.matmul(attn, v)
out = rearrange(out,'b h n d -> b n (h d)')
_out(out)
class Transformer(nn.Module):
def__init__(lf, dim, depth, heads, dim_head, mlp_dim, dropout =0.):
super().__init__()
lf.layers = nn.ModuleList([])
# 堆叠depth个encoder
for _ in range(depth):
# 根据论⽂结构中搭建
lf.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(lf, x):
for attn, ff in lf.layers:
x = attn(x)+ x
x = ff(x)+ x
return x
class ViT(nn.Module):
def__init__(lf,*, image_size, patch_size, num_class, dim, depth, heads, mlp_dim, pool ='cls', channels =3, dim_head =64, dropout =0., emb_dro pout =0.):
super().__init__()
# 原图像的⼤⼩
image_height, image_width = pair(image_size)## 224*224
# patch的⼤⼩
patch_height, patch_width = pair(patch_size)## 16 * 16
asrt image_height % patch_height ==0and image_width % patch_width ==0,'Image dimensions must be divisible by the patch size.'
# 对应论⽂中提到的patch数⽬:num_patches=H*W/P^2
num_patches =(image_height // patch_height)*(image_width // patch_width)
# 对应论⽂中,将patch展平
patch_dim = channels * patch_height * patch_width
西服搭配
asrt pool in{'cls','mean'},'pool type must be either cls (cls token) or mean (mean pooling)'
# 将拉平后的patch映射为Encoder需要的维度dim

本文发布于:2023-07-07 08:41:36,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/82/1083315.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:位置   模块   编码   结构   信息   得到   维度
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图