Pytorch的DataLoader和Datat以及TensorDatat的源码分析和使用

更新时间:2023-06-22 22:28:01 阅读: 评论:0

Pytorch的DataLoader和Datat以及TensorDatat的源码分析和使⽤
1.为什么要⽤DataLoader和Datat
要对⼤量数据进⾏加载和处理时因为可能会出现内存不够⽤的情况,这时候就需要⽤到数据集类Datat或TensorDatat和数据集加载类DataLoader了。使⽤这些类后可以将原本的数据分成⼩块,在需要使⽤的时候再⼀部分⼀本分读进内存中,⽽不是⼀开始就将所有数据读进内存中。
2.Datet的使⽤
pytorch中的torch.utils.data.Datat是表⽰数据集的抽象类,但它⼀般不直接使⽤,⽽是通过⾃定义⼀个数据集来使⽤。来⾃定义数据集应该继承Datat并应该有实现返回数据集尺⼨的__len__⽅法和⽤来获取索引数据的__getitem__⽅法。Datat类的源码如下:
class Datat(object):
r"""An abstract class reprenting a :class:`Datat`.
All datats that reprent a map from keys to data samples should subclass
it. All subclass should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclass could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the datat by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices.  To make it work with a map-style
datat with non-integral indices/keys, a custom sampler must be provided.
"""
牛饲料配方
def__getitem__(lf, index):
rai NotImplementedError
def__add__(lf, other):
return ConcatDatat([lf, other])
# No `def __len__(lf)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Ba Class ]
# in pytorch/torch/utils/data/sampler.py
可以看到Datat类中没有__len__⽅法,虽然有__getitem__⽅法,但是并没有实现啥有⽤的功能。所以要写⼀个Datat类的⼦类来实现其应有的功能。
⾃定义类的实现举例:
import torch
from torch.utils.data import Datat, DataLoader, TensorDatat
from torch.autograd import Variable
import numpy as np
import pandas as pd
value_df = pd.read_csv('data1.csv')
value_array = np.array(value_df)
print("value_array.shape =", value_array.shape)# (73700, 300)
value_size = value_array.shape[0]# 73700
train_size =int(0.7*value_size)
train_array = val_array[:train_size]
train_label_array = val_array[60:train_size+60]
class DealDatat(Datat):
"""
下载数据、初始化数据,都可以在这⾥完成
远离危险"""
def__init__(lf,*arrays):
asrt all(arrays[0].shape[0]== array.shape[0]for array in arrays)
lf.arrays = arrays
def__getitem__(lf, index):公共卫生医师
return tuple(array[index]for array in lf.arrays)
def__len__(lf):
return lf.arrays[0].shape[0]
# 实例化这个类,然后我们就得到了Datat类型的数据,记下来就将这个类传给DataLoader,就可以了。train_datat = DealDatat(train_array, train_label_array)
train_loader2 = DataLoader(datat=train_datat,
batch_size=32,
shuffle=True)
for epoch in range(2):
for i, data in enumerate(train_loader2):
# 将数据从 train_loader 中读出来,⼀次读取的样本数是32个
inputs, labels = data
# 将这些数据转换成Variable类型
inputs, labels = Variable(inputs), Variable(labels)
# 接下来就是跑模型的环节了,我们这⾥使⽤print来代替寒冷的近义词是什么
print("epoch:", epoch,"的第", i,"个inputs", inputs.data.size(),"labels", labels.data.size())
结果:
epoch:0的第0个inputs torch.Size([32,300]) labels torch.Size([32,300])
epoch:0的第1个inputs torch.Size([32,300]) labels torch.Size([32,300])
epoch:0的第2个inputs torch.Size([32,300]) labels torch.Size([32,300])
epoch:0的第3个inputs torch.Size([32,300]) labels torch.Size([32,300])
epoch:0的第4个inputs torch.Size([32,300]) labels torch.Size([32,300])
epoch:0的第5个inputs torch.Size([32,300]) labels torch.Size([32,300])
...
3.TensorDatat的使⽤
TensorDatat是可以直接使⽤的数据集类,它的源码如下:
class TensorDatat(Datat):
r"""Datat wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Arguments:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
def__init__(lf,*tensors):
asrt all(tensors[0].size(0)== tensor.size(0)for tensor in tensors)
def__getitem__(lf, index):
return tuple(tensor[index]for tensor sors)
def__len__(lf):
sors[0].size(0)
可以看到TensorDatat类是Datat类的⼦类,且拥有返回数据集尺⼨的__len__⽅法和⽤来获取索引
数据的__getitem__⽅法,所以可以直接使⽤。它的结构跟上⾯⾃定义的⼦类的结构是⼀样的,惟⼀的不同是TensorDatat已经规定了传⼊的数据必须是torch.Tensor类型的,⽽⾃定义⼦类可以⾃由设定。
使⽤举例:
import torch头层牛皮革
from torch.utils.data import Datat, DataLoader, TensorDatat
from torch.autograd import Variable
import numpy as np
import pandas as pd
value_df = pd.read_csv('data1.csv')
value_array = np.array(value_df)
print("value_array.shape =", value_array.shape)# (73700, 300)
value_size = value_array.shape[0]# 73700
train_size =int(0.7*value_size)
train_array = val_array[:train_size]
train_tensor = sor(train_array, dtype=torch.float32).to(device)
train_label_array = val_array[60:train_size+60]
train_labels_tensor = sor(train_label_array,dtype=torch.float32).to(device)
train_datat = TensorDatat(train_tensor, train_labels_tensor)
如何显示桌面train_loader = DataLoader(datat=train_datat,
batch_size=100,
shuffle=Fal,
num_workers=0)
审美趣味
for epoch in range(2):
for i, data in enumerate(train_loader):
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)
print(epoch, i,"inputs", inputs.data.size(),"labels", labels.data.size())
结果:
00 inputs torch.Size([100,300]) labels torch.Size([100,300])
01 inputs torch.Size([100,300]) labels torch.Size([100,300])
02 inputs torch.Size([100,300]) labels torch.Size([100,300])
03 inputs torch.Size([100,300]) labels torch.Size([100,300])
04 inputs torch.Size([100,300]) labels torch.Size([100,300])
05 inputs torch.Size([100,300]) labels torch.Size([100,300])
06 inputs torch.Size([100,300]) labels torch.Size([100,300])
07 inputs torch.Size([100,300]) labels torch.Size([100,300])
台风新动向08 inputs torch.Size([100,300]) labels torch.Size([100,300])
09 inputs torch.Size([100,300]) labels torch.Size([100,300]) 010 inputs torch.Size([100,300]) labels torch.Size([100,300]) ...

本文发布于:2023-06-22 22:28:01,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/82/1016571.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:数据   实现   内存   应该   返回
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图