首页 > 作文

ConvNeXt实战之实现植物幼苗分类

更新时间:2023-04-04 15:33:56 阅读: 评论:0

目录
前言convnext残差模块数据增强cutout和mixup项目结构数据集导入模型文件安装库,并导入需要的库设置全局参数数据预处理设置模型定义训练和验证函数测试第一种写法第二种写法

前言

convnexts 完全由标准 convnet 模块构建,在准确性和可扩展性方面与 transformer 竞争,实现 87.8% imagenet top-1 准确率,在 coco 检测和 ade20k 分割方面优于 swin transformers,同时保持标准 convnet 的简单性和效率。

论文链接:/d/file/titlepic/pp使用7×7的卷积核,在vgg、resnet等经典的cnn模型中,使用的是小卷积核,但是convnexts证明了大卷积和的有效性。作者尝试了几种内核大小,包括.pdf 3、5、7、9 和 11。网络的性能从 79.9% (3×3) 提高到 80.6% (7×7),而网络的 flops 大致保持不变, 内核大小的好处在 7×7 处达到饱和点。

使用gelu(高斯误差线性单元)激活函数。gelus是 dropout、zoneout、relus的综合,gelus对于输入乘以一个0,1组成的mask,而该mask的生成则是依概率随机的依赖于输入。实验效果要比relus与elus都要好。下图是实验数据:

使用layernorm而不是batchnorm。

倒置瓶颈。图 3 (a) 至 (b) 说明了这些配置。尽管深度卷积层的 flops 增加了,但由于下采样残差块的快捷 1×1 卷积层的 flops 显着减少,这种变化将整个网络的 flops 减少到 4.6g。成绩从 80.5% 提高到 80.6%。在 resnet-200/swin-b 方案中,这一步带来了更多的收益(81.9% 到 82.6%),同时也减少了 flop。

convnext残差模块

残差模块是整个模型的核心。如下图:

代码实现:

class block(nn.module):    r""" convnext block. there are two equivalent implementations:    (1) dwconv -&运动员代表发言gt; layernorm (channels_first) -> 1x1 conv -> gelu -> 1x1 conv; all in (n, c, h, w)    (2) dwconv -> permute to (n, h, w, c); layernorm (channels_last) -> linear -> gelu -> linear; permute back    we u (2) as we find it slightly faster in pytorch        args:        dim (int): number of input channels.        drop_path (float): stochastic depth rate. default: 0.0        layer_scale_init_value (float): init value for layer scale. default: 1e-6.    """    def __init__(lf, dim, drop_path=0., layer_scale_init_value=1e-6):        super().__init__()        lf.dwconv = nn.conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwi conv        lf.norm = layernorm(dim, eps=1e-6)        lf.pwconv1 = nn.linear(dim, 4 * dim) # pointwi/1x1 convs, implemented with linear layers        lf.act = nn.gelu()        lf.pwconv2 = nn.linear(4 * dim, dim)        lf.gamma = nn.parameter(layer_scale_init_value * torch.ones((dim)),                                     requires_grad=true) if layer_scale_init_value > 0 el none        lf.drop_path = droppath(drop_path) if drop_path > 0. el nn.identity()    def forward(lf, x):        input = x        x = lf.dwconv(x)        x = x.permute(0, 2, 3, 1) # (n, c, h, w) -> (n, h, w, c)        x = lf.norm(x)        x = lf.pwconv1(x)        x = lf.act(x)        x = lf.pwconv2(x)        if lf.gamma is not none:            x = lf.gamma * x        x = x.permute(0, 3, 1, 2) # (n, h, w, c) -> (n, c, h, w)        x = input + lf.drop_path(x)        return x

数据增强cutout和mixup

convnext使用了cutout和mixup,为了提高成绩我在我的代码中也加入这两种增强方式。官方使用timm,我没有采用官方的,而选择用torchtoolbox。安装命令:

pip install torchtoolbox

cutout实现,在transforms中。

from torchtoolbox.transform import cutout# 数据预处理transform = transforms.compo([  transforms.resize((224, 224)),  cutout(),  transforms.totensor(),  transforms.normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

mixup实现,在train方法中。需要导入包:from torchtoolbox.tools import mixup_data, mixup_criterion

    for batch_idx, (data, target) in enumerate(train_loader):        data, target = data.to(device, non_blocking=true), target.to(device, non_blocking=true)        data, labels_a, labels_b, lam = mixup_data(data, target, alpha)        optimizer.zero_grad()        output = model(data)        loss = mixup_criterion(criterion, output, labels_a, labels_b, lam)        loss.backward()        optimizer.step()        print_loss = loss.data.item()

项目结构

使用tree命令,打印项目结构

数据集

数据集选用植物幼苗分类,总共12类。数据集连接如下:

链接提取码:syng

在工程的根目录新建data文件夹,获取数据集后,将trian和test解压放到data文件夹下面,如下图:

导入模型文件

从官方的链接中找到convnext.py文件,将其放入model文件夹中。如图:

安装库,并导入需要的库

模型用到了timm库,如果没有需要安装,执行命令:

pip install timm

新建train_connext.py文件,导入所需要的包:

import torch.optim as optimimport torchimport torch.nn as nnimport torch.nn.parallelimport torch.utils.dataimport torch.utils.data.distributedimport torchvision.transforms as transformsfrom datat.datat import edlingdatafrom torch.autograd import variablefrom model.convnext import convnext_tinyfrom torchtoolbox.tools import mixup_data, mixup_criterionfrom torchtoolbox.transform import cutout

设置全局参数

设置使用gpu,设置学习率、batchsize、epoch等参数。

# 设置全局参数modellr = 1e-4batch_size = 8epochs = 300device = torch.device('cuda' if torch.cuda.is_available() el 'cpu')

数据预处理

数据处理比较简单,没有做复杂的尝试,有兴趣的可以加入一些处理。

# 数据预处理transform = transforms.compo([  transforms.resize((224, 224)),  cutout(),  transforms.totensor(),  transforms.normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])transform_test = transforms.compo([  transforms.resize((224, 224)),  transforms.totensor(),  transforms.normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

数据读取

然后我们在datat文件夹下面新建 init.py和datat.py,在mydatats.py文件夹写入下面的代码:

说一下代码的核心逻辑。

第一步 建立字典,定义类别对应的id,用数字代替类别。

第二步 在__init__里面编写获取图片路径的方法。测试集只有一层路径直接读取,训练集在train文件夹下面是类别文件夹,先获取到类别,再获取到具体的图片路径。然后使用sklearn中切分数据集的方法,按照7:3的比例切分训练集和验证集。

第三步 在__getitem__方法中定义读取单个图片和类别的方法,由于图像中有位深度32位的,所以我在读取图像的时候做了转换。

代码如下:

# coding:utf8import osfrom pil import imagefrom torch.utils import datafrom torchvision import transforms as tfrom sklearn.model_lection import train_test_splitlabels = {'black-grass': 0, 'charlock': 1, 'cleavers': 2, 'common chickweed': 3,     'common wheat': 4, 'fat hen': 5, 'loo silky-bent': 6, 'maize': 7, 'scentless mayweed': 8,     'shepherds pur': 9, 'small-flowered cranesbill': 10, 'sugar beet': 11}class edlingdata(data.datat):  def __init__(lf, root, transforms=none, train=true, test=fal):    """    主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据    """    lf.test = test    lf.transforms = transforms    if lf.test:      imgs = [os.path.join(root, img) for img in os.listdir(root)]      lf.imgs = imgs    el:      imgs_labels = [os.path.join(root, img) for img in os.listdir(root)]      imgs = []      for imglable in imgs_labels:        for imgname in os.listdir(imglable):          imgpath = os.path.join(imglable, i非主流空间名mgname)          imgs.append托福培训班学费(imgpath)      trainval_files, val_files = train_test_split(imgs, test_size=0.3, random_state=42)      if train:        lf.imgs = trainval_files      el:        lf.imgs = val_files  def __getitem__(lf, index):    """    一次返回一张图片的数据    """    img_path = lf.imgs[index]    img_path = img_path.replace("\\", '/')    if lf.test:      label = -1    el:      labelname = img_path.split('/')[-2]      label = labels[labelname]    data = image.open(img_path).convert('rgb')    data = lf.transforms(data)    return data, label  def __len__(lf):    return len(lf.imgs)

然后我们在train.py调用edlingdata读取数据 ,记着导入刚才写的datat.py(from mydatats import edlingdata)

# 读取数据datat_train = edlingdata('data/train', transforms=transform, train=true)datat_test = edlingdata("data/train", transforms=transform_test, train=fal)# 导入数据train_loader = torch.utils.data.dataloader(datat_train, batch_size=batch_size, shuffle=true)test_loader = torch.utils.data.dataloader(datat_test, batch_size=batch_size, shuffle=fal)

设置模型

设置loss函数为nn.crosntropyloss()。

设置模型为coatnet_0,修改最后一层全连接输出改为12(数据集的类别)。优化器设置为adam。学习率调整策略改为余弦退火
# 实例化模型并且移动到gpucriterion = nn.crosntropyloss()#criterion = softtargetcrosntropy()model_ft = convnext_tiny(pretrained=true)num9月10日_ftrs = model_ft.head.in_featuresmodel_ft.fc = nn.linear(num_ftrs, 12)model_ft.to(device)# 选择简单暴力的adam优化器,学习率调低optimizer = optim.adam(model_ft.parameters(), lr=modellr)cosine_schedule = optim.lr_scheduler.cosineannealinglr(optimizer=optimizer,t_max=20,eta_min=1e-9)

定义训练和验证函数

alpha=0.2 mixup所需的参数。

# 定义训练过程alpha=0.2def train(model, device, train_loader, optimizer, epoch):  model.train()  sum_loss = 0  total_num = len(train_loader.datat)  print(total_num, len(train_loader))  for batch_idx, (data, target) in enumerate(train_loader):    data, target = data.to(device, non_blocking=true), target.to(device, non_blocking=true)    data, labels_a, labels_b, lam = mixup_data(data, target, alpha)    optimizer.zero_grad()    output = model(data)    loss = mixup_criterion(criterion, output, labels_a, labels_b, lam)    loss.backward()    optimizer.step()    print_loss = loss.data.item()    sum_loss += print_loss    if (batch_idx + 1) % 10 == 0:      print('train epoch: {} [{}/{} ({:.0f}%)]\tloss: {:.6f}'.format(        epoch, (batch_idx + 1) * len(data), len(train_loader.datat),           100. * (batch_idx + 1) / len(train_loader), loss.item()))  ave_loss = sum_loss / len(train_loader)  print('epoch:{},loss:{}'.format(epoch, ave_loss))acc=0# 验证过程def val(model, device, test_loader):  global acc  model.eval()  test_loss = 0  correct = 0  total_num = len(test_loader.datat)  print(total_num, len(test_loader))  with torch.no_grad():    for data, target in test_loader:      data, target = variable(data).to(device), variable(target).to(device)      output = model(data)      loss = criterion(output, target)      _, pred = torch.max(output.data, 1)      correct += torch.sum(pred == target)      print_loss = loss.data.item()      test_loss += print_loss    correct = correct.data.item()    acc = correct / total_num    avgloss = test_loss / len(test_loader)    print('\nval t: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)\n'.format(      avgloss, correct, len(test_loader.datat), 100 * acc))    if acc > acc:      torch.save(model_ft, 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')      acc = acc# 训练for epoch in range(1, epochs + 1):  train(model_ft, device, train_loader, optimizer, epoch)  cosine_schedule.step()  val(model_ft, device, test_loader)

然后就可以开始训练了

训练10个epoch就能得到不错的结果:

测试

第一种写法

测试集存放的目录如下图:

第一步 定义类别,这个类别的顺序和训练时的类别顺序对应,一定不要改变顺序!!!!

class = ('black-grass', 'charlock', 'cleavers', 'common chickweed',           'common wheat', 'fat hen', 'loo silky-bent',           'maize', 'scentless mayweed', 'shepherds pur', 'small-flowered cranesbill', 'sugar beet')

第二步 定义transforms,transforms和验证集的transforms一样即可,别做数据增强。

transform_test = transforms.compo([         transforms.resize((224, 224)),        transforms.totensor(),        transforms.normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

第三步 加载model,并将模型放在device里。

母质device = torch.device("cuda:0" if torch.cuda.is_available() el "cpu")model = torch.load("model_8_0.971.pth")model.eval()model.to(device)

第四步 读取图片并预测图片的类别,在这里注意,读取图片用pil库的image。不要用cv2,transforms不支持。

path = 'data/test/'testlist = os.listdir(path)for file in testlist:    img = image.open(path + file)    img = transform_test(img)    img.unsqueeze_(0)    img = variable(img).to(device)    out = model(img)    # predict    _, pred = torch.max(out.data, 1)    print('image name:{},predict:{}'.format(file, class[pred.data.item()]))

测试完整代码:

import torch.utils.data.distributedimport torchvision.transforms as transformsfrom pil import imagefrom torch.autograd import variableimport osclass = ('black-grass', 'charlock', 'cleavers', 'common chickweed',     'common wheat', 'fat hen', 'loo silky-bent',     'maize', 'scentless mayweed', 'shepherds pur', 'small-flowered cranesbill', 'sugar beet')transform_test = transforms.compo([  transforms.resize((224, 224)),  transforms.totensor(),  transforms.normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])device = torch.device("cuda:0" if torch.cuda.is_available() el "cpu")model = torch.load("model_8_0.971.pth")model.eval()model.to(device)path = 'data/test/'testlist = os.listdir(path)for file in testlist:  img = image.open(path + file)  img = transform_test(img)  img.unsqueeze_(0)  img = variable(img).to(device)  out = model(img)  # predict  _, pred = torch.max(out.data, 1)  print('image name:{},predict:{}'.format(file, class[pred.data.item()]))

运行结果:

第二种写法

第二种,使用自定义的datat读取图片。前三步同上,差别主要在第四步。读取数据的时候,使用datat的edlingdata读取。

datat_test =edlingdata('data/test/', transform_test,test=true)print(len(datat_test))# 对应文件夹的label for index in range(len(datat_test)):    item = datat_test[index]    img, label = item    img.unsqueeze_(0)    data = variable(img).to(device)    output = model(data)    _, pred = torch.max(output.data, 1)    print('image name:{},predict:{}'.format(datat_test.imgs[index], class[pred.data.item()]))    index += 1

运行结果:

以上就是convnext实战之实现植物幼苗分类的详细内容,更多关于convnext植物幼苗分类的资料请关注www.887551.com其它相关文章!

本文发布于:2023-04-04 15:33:46,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/zuowen/9ba15c7d9d9ccdeeb4663573903a762b.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

本文word下载地址:ConvNeXt实战之实现植物幼苗分类.doc

本文 PDF 下载地址:ConvNeXt实战之实现植物幼苗分类.pdf

标签:数据   卷积   模型   类别
相关文章
留言与评论(共有 0 条评论)
   
验证码:
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图