pytorch使用自定义数据集

pytorch使用自定义数据集

DataLoader是pytorch提供的,一般我们要写的是Dataset,也就是DataLoader中的一个参数,其基本框架是:

class CustomDataset(data.Dataset):#需要继承data.Dataset
    def __init__(self):
        # TODO
        # 1. Initialize file path or list of file names.
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        #这里需要注意的是,第一步:read one data,是一个data
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0

由此可见,需要暴露的API只有__getitem____len__,还有一个构造函数

原文地址:https://www.cnblogs.com/jiading/p/12162234.html