利用torch.utils.data.Dataset自定义数据加载类

import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np

import torchvision.transforms as T

transforms = T.Compose([

  T.Resize(224),

  T.CenterCrop(224),

  T.ToTensor(),

  T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

])

# 继承Dataset类要重写__getitem__()和__len__()
class CatDog(data.Dataset):
  def __init__(self, root, transforms=None):

    # 临时变量不用加self
    imgs = os.listdir(root)
    self.imgs = [os.path.join(root, img) for img in imgs]

    self.transforms = transforms

  def __getitem__(self, index):
    label = 1 if dog else 0

    data = Image.open(self.imgs[index])
    if self.transform:

      data = self.transform(data)
    return data, label

  def __len__(self):
    return len(self.imgs)

原文地址:https://www.cnblogs.com/liujianing/p/12320539.html