PyTorch源码(1)DatatSamplerDataLoader

更新时间:2023-05-12 01:37:21 阅读: 评论:0

PyTorch源码(1)DatatSamplerDataLoader
⽂章⽬录
0. 前⾔
参考资料
1. Datat 相关源码
源码位于 torch/utils/data/datat.py
Datat:定义了数据集的基本形式(通过下标获取元素)。
IterableDatat:定义了iterable数据集的基本形式(通过迭代器获取元素)。
TensorDatat:输⼊若⼲个tensor,将每个tensor中对应元素组成为元组,作为数据集元素。
ConcatDatat:合并基本形式的数据集(通过下标获取元素)。
ChainDatat:合并ierable数据集。
Subt:定义基本形式数据集(通过下标获取元素)的⼦集。
random_split:将基本形式数据集(通过下标获取元素)分为若⼲⼦集。
1.1. Datat
定义了PyTorch中数据集基本形式,key-value形式,即通过key获取对应的样本数据。key可以是数字,也可以是字符串。
定义了两个基本函数,__getitem__实现key-value结构,__add__定义两个数据集叠加的操作。
class Datat(object):
def__getitem__(lf, index):
rai NotImplementedError
def__add__(lf, other):
return ConcatDatat([lf, other])
1.2. IterableDatat
Iterable的数据集,其实就是增加了⼀个 __iter__ 函数
注意,数据集合并后⽅法__add__的实现有所变化。
注释中给出了分布式训练时的样例,
class IterableDatat(Datat):
def__iter__(lf):
rai NotImplementedError
def__add__(lf, other):
return ChainDatat([lf, other])
1.3. TensorDatat
应⽤场景:有若⼲个tensor,每个样本是从每个tensor中获取⼀个元素构成的。
例如,有⼀个image name list和⼀个label list,那每个样本就是image list中的⼀个元素和label list中的⼀个元素组成。
实现的功能是:
输⼊⼀组tensor,要求每个tensor的第⼀维shape的数值是⼀样的。
数据集⼤⼩就时tensor的第⼀维shape数值。
每个样本就是tensor列表中分别获取⼀个元素,由这些元素组成的元组。
class TensorDatat(Datat):
def__init__(lf,*tensors):
asrt all(tensors[0].size(0)== tensor.size(0)for tensor in tensors)
def__getitem__(lf, index):
return tuple(tensor[index]for tensor sors)
def__len__(lf):
sors[0].size(0)
1.4. ConcatDatat
作⽤:多个⾮IterableDatat数据集的合并。
实现原理
底层各个数据库保存在⼀个list中。
记录⼀个cumulative_sizes列表,⽤于保存每个datat有多少元素。
在通过idx获取元素的时候,通过cumulative_sizes判断是第⼏个datat的第⼏个元素。
class ConcatDatat(Datat):
@staticmethod
def cumsum(quence):
r, s =[],0
for e in quence:
l =len(e)
r.append(l + s)
s += l
return r
def__init__(lf, datats):
super(ConcatDatat, lf).__init__()
asrt len(datats)>0,'datats should not be an empty iterable'
lf.datats =list(datats)
for d in lf.datats:
asrt not isinstance(d, IterableDatat),"ConcatDatat does not support IterableDatat"
lf.cumulative_sizes = lf.cumsum(lf.datats)
def__len__(lf):
return lf.cumulative_sizes[-1]
def__getitem__(lf, idx):
if idx <0:
if-idx >len(lf):
rai ValueError("absolute value of index should not exceed datat length")
idx =len(lf)+ idx
datat_idx = bict.bict_right(lf.cumulative_sizes, idx)
if datat_idx ==0:
sample_idx = idx
el:
sample_idx = idx - lf.cumulative_sizes[datat_idx -1]
# 通过数据集id和数据集中元素id来获取
return lf.datats[datat_idx][sample_idx]
@property
def cummulative_sizes(lf):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return lf.cumulative_sizes
1.5. ChainDatat
作⽤:多个IterableDatat的合并。
实现原理
在定义对象的时候,其实就是保存了⼀下输⼊的datats,其他啥都没做。所以⽂档中说,定义该对象是on-the-fly,⾮常⾼效。
实现过程也⾮常容易,其实就是通过迭代器、返回迭代器(即每个IterableDatat对象都是⼀个迭代器)。
class ChainDatat(IterableDatat):
def__init__(lf, datats):
super(ChainDatat, lf).__init__()
lf.datats = datats
def__iter__(lf):
for d in lf.datats:
asrt isinstance(d, IterableDatat),"ChainDatat only supports IterableDatat"
for x in d:
yield x
def__len__(lf):
total =0
for d in lf.datats:
asrt isinstance(d, IterableDatat),"ChainDatat only supports IterableDatat"
total +=len(d)
return total
1.6. Subt
作⽤:构建数据集的⼦集。
实现原理
输⼊⼀个数据集(成为raw datat)和⼀个下标集合(称为raw index),选择数据集中这些指定下标的元素,将这些结果作为⼀个⼦集(称为sub datat)。
在代码实现中,其实就是保存了下标集合(raw index)和原始数据集对象(raw datat),获取⼦集对象的时候就是通过⼦集下标(subt index)获取raw index中对应位置的id,然后通过获取的raw index获取raw datat中的元素。
class Subt(Datat):
def__init__(lf, datat, indices):
lf.datat = datat
lf.indices = indices
def__getitem__(lf, idx):
return lf.datat[lf.indices[idx]]
def__len__(lf):
return len(lf.indices)
1.7. ⽅法random_split
作⽤:将输⼊的数据集,分为指定长度的若⼲个⼦集。
def random_split(datat, lengths):
if sum(lengths)!=len(datat):
rai ValueError("Sum of input lengths does not equal the length of the input datat!")
indices = randperm(sum(lengths)).tolist()
return[Subt(datat, indices[offt - length:offt])for offt, length in zip(_accumulate(lengths), lengths)]
2. Sampler 源码
源码位于 torch/utils/data/sampler.py
Sampler:所有Sampler的⽗类。
SequentialSampler:顺序依次获取下标。
RandomSampler:乱序获取下标。
SubtRandomSampler:某个⼦集内乱序获取下标。
WeightedRandomSampler:为每个样本设置权重,权重⼤表⽰获取概率⾼。
BatchSampler:即将若⼲个样本形成⼀个batch。
2.1. Sampler
作⽤:定义了 Sampler 的基本形式,作⽤是定义从数据集中获取元素的⽅法。
实现原理
本质就是定义了构造器和集成了魔法⽅法__iter__。
构造器中包含⼀个数据来源。
集成魔法⽅法 __iter__ 是为了使得 Sampler 成为⼀个迭代器。
注释很清楚的写明了 Sampler 的作⽤:providing a way to iterate over indices of datat elements。
注解中说明了,Sampler 其实是需要⼀个 __len__ ⽅法,但如果要定义⼀个⽅法会存在⼀些BUG,最好的⽅法就是不定义。
其实我没看懂。

本文发布于:2023-05-12 01:37:21,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/82/593043.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:数据   获取   元素   实现   下标
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图