【transformer】【ViT】【code】ViT代码

更新时间:2023-07-07 08:12:33 阅读: 评论:0

【transformer】【ViT】【code】ViT代码桃叶⼉尖上尖,柳絮⼉飞满了天…
1 导⼊库
import torch
from torch import nn, einsum
functional as F
from einops import rearrange, repeat
from h import Rearrange
解释:其中einops库⽤于张量操作,增强代码的可读性,使⽤还是⽐较⽅便的。教程链接:
2 调⽤
if __name__=="__main__":
net = ViT(image_size=256,
patch_size=32,#pathces的尺⼨
num_class=1000,
dim=1024,#embddings的长度,也就是每个block的输⼊输出的尺⼨
depth=6,#⽹络深度,多少个block
heads=16,#注意⼒抽头的个数
mlp_dim=2048,#mlp中反瓶颈结构的中间维度,也就是先升维,再降维
dropout=0.1,
emb_dropout=0.1)
x = torch.rand((2,3,256,256))#测试数据
output = net(x)
从主⼲到分⽀解释代码。
3 ViT⽹络
class ViT(nn.Module):
def__init__(lf,*, image_size, patch_size, num_class, dim, depth, heads,
mlp_dim, pool ='cls', channels =3, dim_head =64, dropout =0., emb_dropout =0.):
super().__init__()
asrt image_size % patch_size ==0,'Image dimensions must be divisible by the patch size.'
num_patches =(image_size // patch_size)**2
营养排骨汤patch_dim = channels * patch_size **2
asrt pool in{'cls','mean'},'pool type must be either cls (cls token) or mean (mean pooling)'
<_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_dim, dim),#dim是embedding嵌⼊的空间
)
lf.pos_embedding = nn.Parameter(torch.randn(1, num_patches +1, dim))
#设置位置参数,这个计算的是块之间的位置,多设置⼀个class_tokens
lf.cls_token = nn.Parameter(torch.randn(1,1, dim))
lf.dropout = nn.Dropout(emb_dropout)卖桔者言
lf.pool = pool
<_latent = nn.Identity()
lf.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_class)
)
def forward(lf, img):
import torchsnooper
with torchsnooper.snoop():
#img(2,3,256,256)
#Rearrange(): (2, 8*8, 32*32*3)
#Linear(): (2, 8*8, 1024)  embeddings
水上游乐项目x = lf.to_patch_embedding(img)
b, n, _ = x.shape #b=2, n=64 n表⽰embeddings向量个数
cls_tokens = repeat(lf.cls_token,'() n d -> b n d', b = b)#(2, 1, 1024)#每个样本的都要增加⼀个,⽤
于从其他的注意⼒向量上交互信息            x = torch.cat((cls_tokens, x), dim=1)#(2, 65, 1024) 此处有broadcast
x += lf.pos_embedding[:,:(n +1)]#(2, 65, 1024) 加上位置信息
x = lf.dropout(x)
x = lf.transformer(x)#经过六个变换块
x = x.mean(dim =1)if lf.pool =='mean'el x[:,0]#获取所有向量的平均还是只需要第⼀个向量
x = lf.to_latent(x)
return lf.mlp_head(x)
4 Block
class Transformer(nn.Module):
def__init__(lf, dim, depth, heads, dim_head, mlp_dim, dropout =0.):
'''
dim:嵌⼊vectors, depth:⽹络深度, heads:注意⼒头的个数,dim_head:注意⼒头的维度
'''
super().__init__()
lf.layers = nn.ModuleList([])
for _ in range(depth):
lf.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),#注意⼒块
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))#前向块
]))
def forward(lf, x):
import torchsnooper
健身房教练with torchsnooper.snoop():
岁月悠悠正谱for attn, ff in lf.layers:
x = attn(x)+ x#残差块
x = ff(x)+ x
return x
class PreNorm(nn.Module):#注意⼒块或者前向块前加上LN
def__init__(lf, dim, fn):
super().__init__()
< = nn.LayerNorm(dim)
lf.fn = fn
def forward(lf, x,**kwargs):
return lf.(x),**kwargs)
class Attention(nn.Module):
def__init__(lf, dim, heads =8, dim_head =64, dropout =0.):
super().__init__()
inner_dim = dim_head *  heads #⼋个注意⼒vector变成⼀根vector
project_out =not(heads ==1and dim_head == dim)
lf.heads = heads
lf.scale = dim_head **-0.5#qkTv下的根号dim
lf.attend = nn.Softmax(dim =-1)
<_qkv = nn.Linear(dim, inner_dim *3, bias =Fal)#同时计算qkv
<_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)if project_out el nn.Identity()
def forward(lf, x):
import torchsnooper
with torchsnooper.snoop():
hh = lf.heads
b, n, _, h =*x.shape, lf.heads
#(b, 65, 1024, heads=8),65是64+1
qkv = lf.to_qkv(x).chunk(3, dim =-1)#沿着最后⼀维对此分块,此时是列表,其中有3个元素,均为(2,65,1024)
q, k, v =map(lambda t: rearrange(t,'b n (h d) -> b h n d', h = h), qkv)#将1024分成8个抽头,也就是说⼋个抽头是⼀块计算的,每个就是128 #q,k,v(2,16,65, 64)#16是因为在ViT的调⽤中设置了heads=16
dots = einsum('b h i d, b h j d -> b h i j', q, k)* lf.scale #(2, 16, 65, 65) qkT
attn = lf.attend(dots)#softmax
湘潭大学兴湘学院out = einsum('b h i j, b h j d -> b h i d', attn, v)#qkTv (2,16,65,64)
out = rearrange(out,'b h n d -> b n (h d)')#concat (2,65,16*64)将8个head进⾏合并
_out(out)#linear
class FeedForward(nn.Module):#反瓶颈结构,中间⾼def__init__(lf, dim, hidden_dim, dropout =0.):
super().__init__()
lf = nn.Sequential(
nn.Linear(dim, hidden_dim),勤怎么写
nn.GELU(),
博美狗狗多少钱一只nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(lf, x):
return lf(x)

本文发布于:2023-07-07 08:12:33,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/89/1071339.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:注意   代码   向量   位置   维度
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图