为了账号安全,请及时绑定邮箱和手机立即绑定

1 Dataset-庖丁解牛之pytorch

标签:
机器学习

1 数据库基类

用来实现数据的大小和索引。
pytorch的Dataset类是一个抽象类,只先实现了三个魔法方法

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError    def __len__(self):
        raise NotImplementedError    def __add__(self, other):
        return ConcatDataset([self, other])

如描述所说,这是一个抽象类,其他数据库类应该是它的子类,所有子类应该重载如下两个函数

* __len__函数,用来提供数据库的大小
* __getitem__函数,支持一个整形索引,重来获取单个数据,范围是__len__定义的,范围是[0, len(self)]

2 数据库的合并

其中Dataset.add函数返回一个ConcatDataset类,这个类实现了数据库的合并,针对从基类DataSet派生类,ConcatDataset实现了不同源的数据库整合,数据存储在链表datasets中,通过累计长度,可以查询不同的datasets,这个类的详细描述如下:

class ConcatDataset(Dataset):
    """
    Dataset to concatenate multiple datasets.
    Purpose: useful to assemble different existing datasets, possibly
    large-scale datasets as the concatenation operation is done in an
    on-the-fly manner.

    Arguments:
        datasets (sequence): List of datasets to be concatenated
    """

    @staticmethod    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l        return r    def __init__(self, datasets):        super(ConcatDataset, self).__init__()
        assert len(datasets) > 0, 'datasets should not be an empty iterable'
        self.datasets = list(datasets)        self.cumulative_sizes = self.cumsum(self.datasets)    def __len__(self):        return self.cumulative_sizes[-1]    def __getitem__(self, idx):
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)        if dataset_idx == 0:
            sample_idx = idx        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]        return self.datasets[dataset_idx][sample_idx]

    @property    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)        return self.cumulative_sizes

注意的是给定索引的时候,需要先判定是哪个数据集,然后判定数据集的索引,getitem函数使用了bitsect.bitsec_right查找数据库索引,然后计算该数据库的内部索引。

3 子数据库Subset

ConcatDataset将不同数据集组成链表,在这个大数据集的基础上,通过索引可以建立一个虚拟数据集,实现不同数据集的一个子集,如果通过随机函数实现索引,可以混合所有数据集,Subset数据集的源码如下:

import torch
from torch.utils.data import Dataset, ConcatDataset, Subset, random_splitclass MyDataset(Dataset):
    def __init__(self, t=0, name="myDataset"):        super(MyDataset, self).__init__()        self.nums = []        if t == 0:            self.nums = [torch.randn(1).item() for _ in range(100)]
        elif t == 1:            self.nums = list(range(230))
        elif t == 2:            self.nums = torch.linspace(-1, 1, 250).data.numpy()        self.name = name        self.t = t    def __getitem__(self, i):        return self.nums[i]    def __len__(self):        return len(self.nums)if __name__ == "__main__":
    ds0 = MyDataset(0, "type_0")
    ds1 = MyDataset(1, "type_1")
    ds2 = MyDataset(2, "type_2")
    ds = ds0 + ds1
    ds = ds + ds2
    print(ds.datasets[0].datasets[0].name,ds.datasets[0].datasets[1].name,ds.datasets[1].name)
    print(len(ds))
    dss = random_split(ds, [310, 270]) # 第二个参数是长度,累积和是数据集长度

此处要注意的是 ds0和ds1首先进行合并,形成一个ConcatDataset,然后和ds2合并,再形成一个ConcatDataset,因此ds的datasets长度为2,第一个数据是ConcatDataset,第二个数据是MyDataset(2, "type_2")

4 Tensor向量化数据库

内存数据需要转为Tensor才能使用,pytorch提供了TensorDataset类可以直接对Tensor数据进行数据库封装

class TensorDataset(Dataset):
    """Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """

    def __init__(self, *tensors):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)    def __len__(self):
        return self.tensors[0].size(0)

最后介绍一个对数据集进行子集切分的函数

def random_split(dataset, lengths):
    """
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths))    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]

数据集源码解读完毕了,虽然这是一个基类,但是提供了一个可迭代的思想,类似于道教的一分为二,二生四,......,提供了数据索引,合并,tensor,子集的等基本功能。

torchvision.dataset可以使用的数据集

LSUN,  大规模场景理解
LSUNClass
ImageFolder, 图片目录的数据集
DatasetFolder 文件目录的数据集
CocoCaptions,  微软 MS COCO 相关的 Image Captioning 
CocoDetection MS COCO数据集目标检测CIFAR10,  该数据集共有60000张彩色图像分类数据集CIFAR100 数据集包含100小类,每小类包含600个图像,其中有500个训练图像和100个测试图像。100类被分组为20个大类。每个图像带有1个小类的“fine”标签和1个大类“coarse”标签。
STL10 *   10个类:飞机,鸟,汽车,猫,鹿,狗,马,猴子,船,卡车。*   图像为96x96像素,颜色。*   500个训练图像(10个预定义的折叠),每个类800个测试图像。
MNIST, MNIST数据集是一个手写体数据集
EMNIST, 扩展手写体数据集
FashionMNIST FashionMNIST 是一个替代 MNIST 手写数字集[1] 的图像数据集。 它是由 Zalando(一家德国的时尚科技公司)旗下的研究部门提供。其涵盖了来自 10 种类别的共 7 万个不同商品的正面图片。
SVHN
PhotoTour
FakeData
SEMEION 图像处理_Semeion Handwritten Digit Data Set(Semeion手写体数字数据集)
Omniglot Omniglot是一个在线的语言文字百科,其内涵盖了已知的全部书写系统



作者:readilen
链接:https://www.jianshu.com/p/5b65c43d45c0


点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消