在pytorch中自定义datat读取数据2021-1-8学习笔记

更新时间:2023-07-06 19:03:06 阅读: 评论:0

在pytorch中⾃定义datat读取数据2021-1-8学习笔记
utils
import os
import os
import json
import pickle
import random
import matplotlib.pyplot as plt
def read_split_data(root:str, val_rate:float=0.2):# val_rate划分验证集的⽐例
random.ed(0)# 保证随机结果可复现 #随机种⼦设置为0,⼤家划分的是⼀样的
asrt ists(root),"datat root: {} does not exist.".format(root)#不存在路径报错
# 遍历⽂件夹,⼀个⽂件夹对应⼀个类别
flower_class =[cla for cla in os.listdir(root)if os.path.isdir(os.path.join(root, cla))]#不是⽂件夹丢弃
# 排序,保证顺序⼀致
flower_class.sort()
# ⽣成类别名称以及对应的数字索引
和陌生人聊天class_indices =dict((k, v)for v, k in enumerate(flower_class))
json_str = json.dumps(dict((val, key)for key, val in class_indices.items()), indent=4)
with open('class_indices.json','w')as json_file:
json_file.write(json_str)
train_images_path =[]# 存储训练集的所有图⽚路径
train_images_label =[]# 存储训练集图⽚对应索引信息
val_images_path =[]# 存储验证集的所有图⽚路径
val_images_label =[]# 存储验证集图⽚对应索引信息
every_class_num =[]# 存储每个类别的样本总数
supported =[".jpg",".JPG",".png",".PNG"]# ⽀持的⽂件后缀类型
# 遍历每个⽂件夹下的⽂件
for cla in flower_class:
cla_path = os.path.join(root, cla)#获得该类别的路径
# 遍历获取supported⽀持的所有⽂件路径
images =[os.path.join(root, cla, i)for i in os.listdir(cla_path)
if os.path.splitext(i)[-1]in supported]#splitext(i)[-1]分割出⽂件名称和后缀名然后⽤in判断是否在supported⾥# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
every_class_num.append(len(images))
# 按⽐例随机采样验证样本
val_path = random.sample(images, k=int(len(images)* val_rate))
冬暖夏凉
for img_path in images:机器人总动员插曲
if img_path in val_path:# 如果该路径在采样的验证集样本中则存⼊验证集
val_images_path.append(img_path)
val_images_label.append(image_class)
el:# 否则存⼊训练集
train_images_path.append(img_path)
train_images_label.append(image_class)
print("{} images were found in the datat.".format(sum(every_class_num)))
plot_image =Fal
if plot_image:
# 绘制每种类别个数柱状图
plt.bar(range(len(flower_class)), every_class_num, align='center')
# 将横坐标0,1,2,3,4替换为相应的类别名称
# 在柱状图上添加数值标签
for i, v in enumerate(every_class_num):
<(x=i, y=v +5, s=str(v), ha='center')
# 设置x坐标
plt.xlabel('image class')
# 设置y坐标
plt.ylabel('number of images')
# 设置柱状图的标题
plt.title('flower class distribution')
plt.title('flower class distribution')
plt.show()
return train_images_path, train_images_label, val_images_path, val_images_label溯怎么读
def plot_data_loader_image(data_loader):
batch_size = data_loader.batch_size
plot_num =min(batch_size,4)submition
json_path ='./class_indices.json'
asrt ists(json_path), json_path +" does not exist."
json_file =open(json_path,'r')
class_indices = json.load(json_file)
for data in data_loader:
images, labels = data
for i in range(plot_num):
# [C, H, W] -> [H, W, C] transpo调整顺序
img = images[i].numpy().transpo(1,2,0)
# 反Normalize操作
img =(img *[0.229,0.224,0.225]+[0.485,0.456,0.406])*255
label = labels[i].item()
plt.subplot(1, plot_num, i+1)
plt.xlabel(class_indices[str(label)])
plt.imshow(img.astype('uint8'))
plt.show()
def write_pickle(list_info:list, file_name:str):
with open(file_name,'wb')as f:
pickle.dump(list_info, f)
def read_pickle(file_name:str)->list:
with open(file_name,'rb')as f:
info_list = pickle.load(f)
return info_list
mydatat
from PIL import Image
import torch
from torch.utils.data import Datat
class MyDataSet(Datat):
"""⾃定义数据集"""
def__init__(lf, images_path:list, images_class:list, transform=None):#初始化函数
lf.images_path = images_path
lf.images_class = images_class
def__len__(lf):#计算该数据集下所有的样本个数
return len(lf.images_path)
def__getitem__(lf, item):#每次传⼊⼀个索引,就返回该索引对应的图⽚以及标签信息
img = Image.open(lf.images_path[item])#获得img的路径,然后得到PIL格式图像,pytorch⽤PIL⽐openCV好# RGB为彩⾊图⽚,L为灰度图⽚
de !='RGB':
rai ValueError("image: {} isn't RGB mode.".format(lf.images_path[item]))#报错,如果是灰度,就把上⼀⾏改成L
label = lf.images_class[item]
ansform is not None:
img = lf.transform(img)#对图像进⾏预处理
return img, label
@staticmethod
#是个静态⽅法
韩文
def collate_fn(batch):#dataloader会使⽤
# 官⽅实现的default_collate可以参考
# /pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py        images, labels =tuple(zip(*batch))
#zip将图⽚和图⽚放⼀起,标签和标签放⼀起
images = torch.stack(images, dim=0)
#拼接,并会在dim=0的维度上进⾏拼接(就是拼成⼀个矩阵)
labels = torch.as_tensor(labels)#标签也转换成tensor
return images, labels
main
import os
import torch
from torchvision import transforms
from my_datat import MyDataSet
from utils import read_split_data, plot_data_loader_image
# sorflow/example_images/
root ="/home/wz/my_github/data_t/flower_data/flower_photos"# 数据集所在根⽬录
楔子是什么意思
def main():
device = torch.device("cuda"if torch.cuda.is_available()el"cpu")
print("using {} device.".format(device))
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root)
data_transform ={
"train": transforms.Compo([transforms.RandomResizedCrop(224),#随机裁剪
transforms.RandomHorizontalFlip(),#⽔平翻转
transforms.ToTensor(),#转化成tensor格式
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
"val": transforms.Compo([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])}
you belong with me歌词
##这个很重要,可以⾃⼰实现
#实例化datat
train_data_t = MyDataSet(images_path=train_images_path,#训练集图像列表
images_class=train_images_label,#训练集所有图像对应的标签信息
mute是什么意思transform=data_transform["train"])#预处理⽅法
batch_size =8
nw =min([os.cpu_count(), batch_size if batch_size >1el0,8])# number of workers
print('Using {} dataloader workers'.format(nw))
train_loader = torch.utils.data.DataLoader(train_data_t,#从实例化的datat当中取得图⽚,然后打包成⼀个⼀个batch,然后输⼊⽹络进⾏训练                                              batch_size=batch_size,
初次见面请多关照日语shuffle=True,#打乱数据集
num_workers=nw,#训练时建议nw,调试时建议0
collate_fn=train_llate_fn)
# plot_data_loader_image(train_loader)
for step, data in enumerate(train_loader):
images, labels = data
if __name__ =='__main__':
main()

本文发布于:2023-07-06 19:03:06,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/90/169172.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:对应   类别   样本   图像   路径   验证   柱状图
相关文章
留言与评论(共有 0 条评论)
   
验证码:
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图