torch_13_自定义数据集实战

1.将图片的路径和标签写入csv文件并实现读取

 1  # 创建一个文件,包含image,存放方式:label pokemeon\mew\0001.jpg,0
 2     def load_csv(self,filename):
 3         if not os.path.exists(os.path.join(self.root,filename)):
 4             images = [] # 将所有的信息组成一个列表,类别信息通过中间的一个路径判断
 5             for name in self.name2label.keys():
 6                 # pokemeon\mew\0001.jpg mew可以通过字典查看其类别
 7                 images += glob.glob(os.path.join(self.root,name,'*.png'))#img的完整路径
 8                 images += glob.glob(os.path.join(self.root,name,'*.jpg'))
 9             random.shuffle(images)
10             with open(os.path.join(self.root,filename),'w') as f:
11                 writer = csv.writer(f)
12                 for img in images:
13                     name = img.split(os.sep)
14                     label = self.name2label[name[-2]]
15                     writer.writerow([img,label])
16 
17          # 从csv中读取文件
18         images, labels = [], []
19         with open(os.path.join(self.root,filename),'r') as f:
20             reader = csv.reader(f)
21             for row in reader:
22                 img,label = row
23                 label = int(label)
24                 images.append(img)
25                 labels.append(label)
26         assert len(images) == len(labels) # 保证数据长度一致
       return images,labels

 2.加载自定义数据集

  1 """
  2 自定义数据集
  3 image_resize
  4 data argumentation(数据增强):Rotate,crop
  5 normalize:mean,std
  6 ToTensor
  7 
  8 """
  9 import torch
 10 import os,glob
 11 import random,csv
 12 from torch.utils.data import Dataset,DataLoader
 13 from torchvision import transforms
 14 from PIL import Image
 15 import visdom
 16 
 17 
 18 class Pokemon(Dataset):
 19     def __init__(self,root,resize,mode):
 20         super(Pokemon,self).__init__()
 21         self.root = root
 22         self.resize = resize
 23         self.name2label = {}
 24         for name in os.listdir(os.path.join(root)): #把文件和dir都会加载近来
 25             if not sorted(os.path.isdir(os.path.join(root,name))):#排序后,文件夹顺序固定了
 26                 continue
 27             self.name2label[name] = len(self.name2label.keys())
 28         # name2label:{文件夹名,类别编号}
 29         # 创建一个文件,包含image,存放方式:label pokemeon\mew\0001.jpg,0
 30         self.images, self.labels = self.load_csv('images.csv')
 31         # 对数据进行裁剪,mode:train-0.6,validation-0.2,test-0.2数据量是不同的
 32         if mode == 'train':
 33             self.images = self.images[:,int(len(self.images)*0.6)]
 34             self.labels = self.labels[:,int(len(self.images)*0.6)]
 35         elif mode == 'val':
 36             self.images = self.images[int(len(self.images)*0.6):int(len(self.images)*0.8)]
 37             self.labels = self.labels[int(len(self.labels)*0.6):int(len(self.labels)*0.8)]
 38         else:
 39             self.images = self.images[int(len(self.images) * 0.8):]
 40             self.labels = self.labels[int(len(self.labels) * 0.8):]
 41 
 42     def load_csv(self,filename):
 43         if not os.path.exists(os.path.join(self.root,filename)):
 44             images = [] # 将所有的信息组成一个列表,类别信息通过中间的一个路径判断
 45             for name in self.name2label.keys():
 46                 # pokemeon\mew\0001.jpg mew可以通过字典查看其类别
 47                 images += glob.glob(os.path.join(self.root,name,'*.png'))#img的完整路径
 48                 images += glob.glob(os.path.join(self.root,name,'*.jpg'))
 49             random.shuffle(images)
 50             with open(os.path.join(self.root,filename),'w') as f:
 51                 writer = csv.writer(f)
 52                 for img in images:
 53                     name = img.split(os.sep)
 54                     label = self.name2label[name[-2]]
 55                     writer.writerow([img,label])
 56          # 从csv中读取文件
 57         images, labels = [], []
 58         with open(os.path.join(self.root,filename),'r') as f:
 59             reader = csv.reader(f)
 60             for row in reader:
 61                 img,label = row
 62                 label = int(label)
 63                 images.append(img)
 64                 labels.append(label)
 65         assert len(images) == len(labels) # 保证数据长度一致
 66         return images,labels
 67 
 68     def __len__(self):
 69         return len(self.images)
 70 
 71     def __getitem__(self, idx):
 72         # idx是[0-len(self.images]
 73         # self.images,self.label
 74         # img:pokemeon\mew\0001.jpg(这是一个路径)要转变成img数据
 75         # label:是数字
 76         img, label = self.images[idx], self.labels[idx]
 77         tf = transforms.Compose([
 78             lambda x:Image.open(x).convert('RGB'),# string path -> img data
 79             transforms.Resize(int(self.resize*1.25), int(self.resize*1.25)),
 80             transforms.Randomrotation(15), # 旋转度数
 81             transforms.CenterCrop(self.resize),#中心裁剪,保留resize大小
 82             transforms.ToTensor(),
 83             transforms.Normalize(mean=[0.485,0.456,0.406],
 84                                  std=[0.229,0.224,0.225])  # 归一化之后,范围为-1~1,之前的图片范围为0~1
 85             ])
 86         img = tf(img)  # 将path转换成数据
 87         label = torch.tensor(label)  # 将变量label转换成tensor
 88         return img,label
 89 
 90     def denormalize(self,x_hat):
 91         mean=[0.485,0.456,0.406]
 92         std=[0.229,0.224,0.225]
 93         # x:[c,h,w]
 94         # x_hat = (x-mean)/std
 95         # maen[3]->[3,1,1]
 96         mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
 97         std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
 98         x = x_hat * std+mean
 99         return x
100 
101 def main():
102     import torchvision
103     vis = visdom.Visdom()
104     """
105     如果存储比较规范的话,可以使用下面简单的代码加载数据集,文件夹的标签从0开始编码
106     tf = transforms.Compose([
107         transforms.Resize((64,64)),
108         transforms.ToTensor()
109     ])
110     db = torchvision.datasets.ImageFolder('./pokemon',transform=tf)
111     loader = DataLoader(db,batch_size=32,shuffle=True)
112     print(db.class_to_idx) #查看类标签
113     
114     """
115     db = Pokemon('./pokemon', 224, 'train') # 根据idx,返回一个
116     x,y = next(iter(db))
117     print('sample:',x.shape,y.shape)
118     #可视化
119     vis.image(db.denormalize(x),win='sample_x',opts=dict(title = 'sample_x'))
120     # 加载一批
121     loader = DataLoader(db,batch_size = 32,shuffle=True,num_workers=8 )
122     for x,y in loader:
123         vis.images(db.denormalize(x), nrow=8, win='batch',opts=dict(title='batch'))
124         vis.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
125 
126 
127 if __name__ == '__main__':
128     main()

 小结:

在加载自定义数据集时,一般步骤

1.定义一个类继承Dataset

2.在类中读取数据集(图片的路径),重写len函数,和getitem函数

在len函数中返回数据集的长度

在getitem函数中,处理一张图片,单个图片路径转换成图片数据(包括transform转换),返回该图片数据和标签

3,将处理好的数据集(均为张量)放入DataLoader中,进行分批

loader = DataLoader(db,batch_size = 32,shuffle=True,num_workers=8 )

4.训练时通过enumerate遍历每个batchsize

原文地址:https://www.cnblogs.com/shuangcao/p/11905505.html