vision_transformer实战总结:非常简单的VIT入门教程,一定不要错过

更新时间:2023-06-25 12:25:34 阅读: 评论:0

vision_transformer实战总结:⾮常简单的VIT⼊门教程,⼀定不要错过
⽂章⽬录
摘要
本例提取了植物幼苗数据集中的部分数据做数据集,数据集共有12种类别,演⽰如何使⽤pytorch版本的VIT图像分类模型实现分类任务。通过本⽂你和学到:
1、如何构建VIT模型?
2、如何⽣成数据集?
3、如何使⽤Cutout数据增强?
4、如何使⽤Mixup数据增强。
5、如何实现训练和验证。
6、如何使⽤余弦退⽕调整学习率?
7、预测的两种写法。
这篇⽂章的代码没有做过多的修饰,⽐较简单,容易理解。
项⽬结构
VIT_demo
├─models
│└─vision_transformer.py
├─data
│├─Black-grass
│├─Charlock
│├─Cleavers
│├─Common Chickweed
│├─Common wheat
│├─Fat Hen
│├─Loo Silky-bent
│├─Maize
│├─Scentless Mayweed
│├─Shepherds Pur
2010考研英语作文│├─Small-flowered Cranesbill
│└─Sugar beet
├─mean_std.py芝麻街
├─makedata.py
├─train.py
├─test1.py
gosh不来梅英文
└─test.py
mean_std.py:计算mean和std的值。
makedata.py:⽣成数据集。
计算mean和std
为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插⼊代码:
from torchvision.datats import ImageFolder
import torch
from torchvision import transforms
def get_mean_and_std(train_data):
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=1, shuffle=Fal, num_workers=0,
pin_memory=True)
mean = s(3)
std = s(3)
for X, _ in train_loader:
for d in range(3):
mean[d]+= X[:, d,:,:].mean()
std[d]+= X[:, d,:,:].std()
mean.div_(len(train_data))
std.div_(len(train_data))
return list(mean.numpy()),list(std.numpy())
if __name__ =='__main__':
train_datat = ImageFolder(root=r'data1', transform=transforms.ToTensor())
print(get_mean_and_std(train_datat))
数据集结构:
运⾏结果:
([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])把这个结果记录下来,后⾯要⽤!
⽣成数据集
我们整理还的图像分类的数据集结构是这样的
husband什么意思data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loo Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Pur
├─Small-flowered Cranesbill
└─Sugar beet
pytorch和keras默认加载⽅式是ImageNet数据集格式,格式是
││├─Black-grass
││├─Charlock
││├─Cleavers
││├─Common Chickweed
youtour││├─Common wheat
││├─Fat Hen
││├─Loo Silky-bent
││├─Maize
││├─Scentless Mayweed
││├─Shepherds Pur
││├─Small-flowered Cranesbill
││└─Sugar beet
│└─train
│├─Black-grass
│├─Charlock
│├─Cleavers
│├─Common Chickweed
│├─Common wheat
│├─Fat Hen
│├─Loo Silky-bent
│├─Maize
│├─Scentless Mayweed
│├─Shepherds Pur
│├─Small-flowered Cranesbill
│└─Sugar beet
新增格式转化脚本makedata.py,插⼊代码:
import shutil
image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if ists(file_dir):
print('true')
#os.rmdir(file_dir)
<(file_dir)#删除再建⽴
os.makedirs(file_dir)
el:
os.makedirs(file_dir)
del_lection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:
file_place("\\","/").split('/')[-2]
file_place("\\","/").split('/')[-1]
file_class=os.path.join(train_root,file_class)
if not os.path.isdir(file_class):
os.makedirs(file_class)新路径英语
for file in val_files:
file_place("\\","/").split('/')[-2]
file_place("\\","/").split('/')[-1]
file_class=os.path.join(val_root,file_class)
if not os.path.isdir(file_class):
os.makedirs(file_class)
清大教育在线
数据增强Cutout和Mixup
为了提⾼成绩我在代码中加⼊Cutout和Mixup这两种增强⽅式。实现这两种增强需要安装torchtoolbox。安装命令:pip install torchtoolbox
Cutout实现,在transforms中。
ansform import Cutout
# 数据预处理
transform = transforms.Compo([
transforms.Resize((224,224)),
Cutout()
服装设计软件下载
])
Mixup实现,在train⽅法中。需要导⼊包:ls import mixup_data, mixup_criterion
for batch_idx,(data, target)in enumerate(train_loader):
data, target = (device, non_blocking=True), (device, non_blocking=True)
data, labels_a, labels_b, lam = mixup_data(data, target, alpha)
<_grad()
output = model(data)
loss = mixup_criterion(criterion, output, labels_a, labels_b, lam)
loss.backward()
optimizer.step()
print_loss = loss.data.item()
导⼊项⽬使⽤的库
import torch.optim as optim
import torch
as nn
parallel
研究生网上报名流程
import torch.utils.data
import torch.utils.data.distributed
ansforms as transforms
import torchvision.datats as datats
from models.vision_transformer import deit_tiny_patch16_224
ls import mixup_data, mixup_criterion
ansform import Cutout
设置全局参数
设置学习率、BatchSize、epoch等参数,判断环境中是否存在GPU,如果没有则使⽤CPU。建议使⽤GPU,CPU太慢了。# 设置全局参数
modellr =1e-4
BATCH_SIZE =16
EPOCHS =300
DEVICE = torch.device('cuda'if torch.cuda.is_available()el'cpu')
图像预处理与增强
数据处理⽐较简单,加⼊了Cutout、做了Resize和归⼀化。在transforms.Normalize中写⼊上⾯求得的mean和std的值。# 数据预处理
transform = transforms.Compo([
transforms.Resize((224,224)),
Cutout(),
transforms.ToTensor(),
transforms.Normalize([0.3281186,0.28937867,0.20702125],[0.09407319,0.09732835,0.106712654])
])
transform_test = transforms.Compo([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.3281186,0.28937867,0.20702125],[0.09407319,0.09732835,0.106712654])
])
读取数据
使⽤pytorch默认读取数据的⽅式,然后将datat_train.class_to_idx打印出来,预测的时候要⽤到。

本文发布于:2023-06-25 12:25:34,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/78/1036187.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:数据   分类   实现   预测   图像   参数
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图