MobileNetV1实战:使用MobileNetV1实现植物幼苗分类

更新时间:2023-06-25 12:30:04 阅读: 评论:0

MobileNetV1实战:使⽤MobileNetV1实现植物幼苗分类
⽂章⽬录
摘要
本例提取了植物幼苗数据集中的部分数据做数据集,数据集共有12种类别,演⽰如何使⽤pytorch版本的MobileNetV1图像分类模型实现分类任务。
通过本⽂你和学到:
1、如何⾃定义MobileNetV1模型。
2、如何⾃定义数据集加载⽅式?
3、如何使⽤Cutout数据增强?
四级估分器4、如何使⽤Mixup数据增强。
5、如何实现训练和验证。
6、预测的两种写法。
MobileNetV1的论⽂翻译:
MobileNetV1解析:
Keras版本:
数据增强Cutout和Mixup
初次见面日语
为了提⾼成绩我在代码中加⼊Cutout和Mixup这两种增强⽅式。实现这两种增强需要安装torchtoolbox。安装命令:
pip install torchtoolbox
Cutout实现,在transforms中。
ansform import Cutout
# 数据预处理
transform = transforms.Compo([
transforms.Resize((224,224)),
Cutout(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])
Mixup实现,在train⽅法中。需要导⼊包:ls import mixup_data, mixup_criterion
for batch_idx,(data, target)in enumerate(train_loader):
data, target = (device, non_blocking=True), (device, non_blocking=True)
data, labels_a, labels_b, lam = mixup_data(data, target, alpha)
<_grad()
output = model(data)
loss = mixup_criterion(criterion, output, labels_a, labels_b, lam)
loss.backward()
optimizer.step()
print_loss = loss.data.item()
项⽬结构
MobileNetV1_demo
├─data
│└─train
崇文区小学│├─Black-grass
│├─Charlock
│├─Cleavers
│├─Common Chickweed
│├─Common wheat
│├─Fat Hen
│├─Loo Silky-bent
│├─Maize
│├─Scentless Mayweed
│├─Shepherds Pur
│├─Small-flowered Cranesbill
│└─Sugar beet
├─datat
│└─datat.py
└─models
│└─mobilenetV1.py
├─train.py
├─test1.py
└─test.py
导⼊项⽬使⽤的库
import torch.optim as optim
import torch
房地产销售技巧 as nn
parallel
import torch.utils.data
import torch.utils.data.distributed
ansforms as transforms
from datat.datat import SeedlingData
from torch.autograd import Variable
bilenetv1 import MobileNetV1
ls import mixup_data, mixup_criterion
ansform import Cutout
设置全局参数
设置学习率、BatchSize、epoch等参数,判断环境中是否存在GPU,如果没有则使⽤CPU。建议使⽤GPU,CPU太慢了。# 设置全局参数
modellr =1e-4
BATCH_SIZE =16
EPOCHS =300
DEVICE = torch.device('cuda'if torch.cuda.is_available()el'cpu')
图像预处理与增强
数据处理⽐较简单,加⼊了Cutout、做了Resize和归⼀化。
transform = transforms.Compo([
transforms.Resize((224,224)),
Cutout(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])
transform_test = transforms.Compo([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])
读取数据
将数据集解压后放到data⽂件夹下⾯,如图:
然后我们在datat⽂件夹下⾯新建 init.py和datat.py,在datats.py⽂件夹写⼊下⾯的代码:
import os
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
del_lection import train_test_split
Labels ={'Black-grass':0,'Charlock':1,'Cleavers':2,'Common Chickweed':3,
'Common wheat':4,'Fat Hen':5,'Loo Silky-bent':6,'Maize':7,'Scentless Mayweed':8,
'Shepherds Pur':9,'Small-flowered Cranesbill':10,'Sugar beet':11}
class SeedlingData (data.Datat):
def__init__(lf, root, transforms=None, train=True, test=Fal):
"""
主要⽬标:获取所有图⽚的地址,并根据训练,验证,测试划分数据
"""
st:
imgs =[os.path.join(root, img)for img in os.listdir(root)]
lf.imgs = imgs
el:
imgs_labels =[os.path.join(root, img)for img in os.listdir(root)]
imgs =[]
for imglable in imgs_labels:
for imgname in os.listdir(imglable):
威尼斯商人英语剧本imgpath = os.path.join(imglable, imgname)
imgs.append(imgpath)
trainval_files, val_files = train_test_split(imgs, test_size=0.3, random_state=42)
if train:
lf.imgs = trainval_files
el:
lf.imgs = val_files
def__getitem__(lf, index):
"""
中日文在线翻译⼀次返回⼀张图⽚的数据
"""
img_path = lf.imgs[index]
img_path=place("\\",'/')
st:
label =-1
el:
labelname = img_path.split('/')[-2]
label = Labels[labelname]
data = Image.open(img_path).convert('RGB')
data = lf.transforms(data)
return data, label
def__len__(lf):
return len(lf.imgs)
说⼀下代码的核⼼逻辑:
第⼀步 建⽴字典,定义类别对应的ID,⽤数字代替类别。
第⼆步 在__init__⾥⾯编写获取图⽚路径的⽅法。测试集只有⼀层路径直接读取,训练集在train⽂件夹下⾯是类别⽂件夹,先获取到类别,再获取到具体的图⽚路径。然后使⽤sklearn中切分数据集的⽅法,按照7:3的⽐例切分训练集和验证集。
apathy
第三步 在__getitem__⽅法中定义读取单个图⽚和类别的⽅法,由于图像中有位深度32位的,所以我在读取图像的时候做了转换。
然后我们在train.py调⽤SeedlingData读取数据 ,记着导⼊刚才写的datat.py(from datat.datat import SeedlingData)
datat_train = SeedlingData('data/train', transforms=transform, train=True)韩语口语
datat_test = SeedlingData("data/train", transforms=transform_test, train=Fal)
# 读取数据
美女纸牌print(datat_train.imgs)
# 导⼊数据
train_loader = torch.utils.data.DataLoader(datat_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(datat_test, batch_size=BATCH_SIZE, shuffle=Fal)
设置模型
设置loss函数为nn.CrossEntropyLoss()。
设置模型为MobileNetV1,num_class设置为12。
优化器设置为adam。
学习率调整策略选择为余弦退⽕。
# 实例化模型并且移动到GPU
criterion = nn.CrossEntropyLoss()
model_ft = MobileNetV1(num_class=12)
(DEVICE)
# 选择简单暴⼒的Adam优化器,学习率调低
dare
optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max=20,eta_min=1e-9)定义训练和验证函数

本文发布于:2023-06-25 12:30:04,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/90/157152.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:数据   训练   分类   实现   类别
相关文章
留言与评论(共有 0 条评论)
   
验证码:
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图