pytorch.utils.data

概览

torch.utils.data主要是负责容纳数据集、数据打散、分批等操作。

这里面有三个概念:数据集dataset,抽样器sampler,数据加载器dataloader。其中第三个就是最终对外的接口,也是最重要的。

它们之间的关系是:首先需要根据源数据创建数据集dataset,然后根据dataset创建抽样器sampler,最后同时通过dataset和sampler来创建dataloader,这就是我们最终需要的。这个在训练、测试的时候,会得到batch数据。

dataset

第一个是dataset,就是常规理解的数据集。

数据集主要分为两种:map-style和iterable-style

map-style数据集,一般都是继承Dataset类 ,必须要实现__getitem__()__len__()方法,表示从索引或者key到数据样本的映射

iterable-style数据集,一般都是继承IterableDataset类,必须实现__iter__()方法,表示在数据样本上迭代。一般从一些流中实时获取数据(比如数据库、远程服务器或者日志),是无法进行随机读取的,这时就主要使用迭代式数据集。

一般如果数据量小,使用map-style就可以了,如果数据量很大,需要从数据流中获取,那就使用iterable-style

对应到具体的类,有以下六个:

  • torch.utils.data.Dataset
  • torch.utils.data.IterableDataset
  • torch.utils.data.TensorDataset
  • torch.utils.data.ConcatDataset
  • torch.utils.data.ChainDataset
  • torch.utils.data.Subset

除此之外,torch.utils.data还包含了两个函数

  • torch.utils.data.get_worker_info()
  • torch.utils.data.random_split()

sampler

sampler是抽样器,作用在dataset上面

抽样的方式也有几个方式:

按顺序抽样,随机抽样,在子集合中随机抽样,带权重的抽样等等

包括以下类:

  • class Sampler
  • class SequentialSampler
  • class RandomSampler
  • class SubsetRandomSampler
  • class WeightedRandomSampler
  • class BatchSampler
  • class DistributedSampler

生成sampler的最终目的就是为了创建dataloader。

dataLoader

DataLoader是核心。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

构建DataLoader有几个重要的参数:

  • dataset是数据集,
  • batch_size
  • shuffle 是否每一轮都将数据进行打散,最好通过sampler来打散,否则使用SequentialSampler的时候也会被打散。
  • sampler 生成indices
  • collate_fn
  • pin_memory 含义参考pytorch pinned memory

实例1:通过TensorDataset快速生成dataloader

数据中有字符串类型的时候慎用。

import torch
from torch.utils.data import DataLoader, TensorDataset, Dataset, RandomSampler
import numpy as np


# 创建TensorDataset
feature = torch.tensor(np.arange(100))
dataset = TensorDataset([feature, feature])
sampler = RandomSampler(dataset)
dataloader = DataLoader(dataset, batch_size=5, sampler=sampler)

for epoch in range(2):
    print('epoch=', epoch)
    for index, batch in enumerate(dataloader):
        print(batch)
        if index > 10:
            break
epoch= 0
[tensor([79,  6, 81, 35, 21], dtype=torch.int32), tensor([79,  6, 81, 35, 21], dtype=torch.int32)]
[tensor([43, 98, 86, 23, 68], dtype=torch.int32), tensor([43, 98, 86, 23, 68], dtype=torch.int32)]
[tensor([ 0, 36, 60,  1, 91], dtype=torch.int32), tensor([ 0, 36, 60,  1, 91], dtype=torch.int32)]
[tensor([71, 59, 72, 75, 52], dtype=torch.int32), tensor([71, 59, 72, 75, 52], dtype=torch.int32)]
[tensor([45,  2, 73, 46, 95], dtype=torch.int32), tensor([45,  2, 73, 46, 95], dtype=torch.int32)]
[tensor([82, 37, 24, 12, 16], dtype=torch.int32), tensor([82, 37, 24, 12, 16], dtype=torch.int32)]
[tensor([90, 11, 70, 31, 53], dtype=torch.int32), tensor([90, 11, 70, 31, 53], dtype=torch.int32)]
[tensor([15,  7, 64, 22, 65], dtype=torch.int32), tensor([15,  7, 64, 22, 65], dtype=torch.int32)]
[tensor([ 3, 87,  4, 17, 99], dtype=torch.int32), tensor([ 3, 87,  4, 17, 99], dtype=torch.int32)]
[tensor([83, 20, 19, 89, 42], dtype=torch.int32), tensor([83, 20, 19, 89, 42], dtype=torch.int32)]
[tensor([97, 58,  8, 38, 30], dtype=torch.int32), tensor([97, 58,  8, 38, 30], dtype=torch.int32)]
[tensor([54, 56, 48, 27, 57], dtype=torch.int32), tensor([54, 56, 48, 27, 57], dtype=torch.int32)]
epoch= 1
[tensor([66, 15, 37, 82, 47], dtype=torch.int32), tensor([66, 15, 37, 82, 47], dtype=torch.int32)]
[tensor([75, 70,  5, 99, 33], dtype=torch.int32), tensor([75, 70,  5, 99, 33], dtype=torch.int32)]
[tensor([80, 76, 55, 29, 41], dtype=torch.int32), tensor([80, 76, 55, 29, 41], dtype=torch.int32)]
[tensor([79, 17, 63, 92, 74], dtype=torch.int32), tensor([79, 17, 63, 92, 74], dtype=torch.int32)]
[tensor([52, 53, 58, 38, 87], dtype=torch.int32), tensor([52, 53, 58, 38, 87], dtype=torch.int32)]
[tensor([84, 59, 77, 48, 71], dtype=torch.int32), tensor([84, 59, 77, 48, 71], dtype=torch.int32)]
[tensor([56, 16, 27, 81, 60], dtype=torch.int32), tensor([56, 16, 27, 81, 60], dtype=torch.int32)]
[tensor([50, 73, 46, 28, 32], dtype=torch.int32), tensor([50, 73, 46, 28, 32], dtype=torch.int32)]
[tensor([45, 40, 10, 25,  9], dtype=torch.int32), tensor([45, 40, 10, 25,  9], dtype=torch.int32)]
[tensor([12, 49, 22, 51, 20], dtype=torch.int32), tensor([12, 49, 22, 51, 20], dtype=torch.int32)]
[tensor([ 6, 68, 72, 24, 67], dtype=torch.int32), tensor([ 6, 68, 72, 24, 67], dtype=torch.int32)]
[tensor([57, 96, 23, 97, 98], dtype=torch.int32), tensor([57, 96, 23, 97, 98], dtype=torch.int32)]

自定义Dataset

import torch
import torch.nn as nn
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader

class ToyDataset(Dataset):
    def __init__(self):
        self.Data = np.arange(32).reshape(16, 2).tolist()
        self.Target = np.random.randint(0, 2, (16,1)).tolist()

    def __getitem__(self, index):
        txt = torch.LongTensor(self.Data[index])
        label = torch.LongTensor(self.Target[index])
        return txt, label
    
    def __len__(self):
        return len(self.Data)
原文地址:https://www.cnblogs.com/YoungF/p/13941346.html