Pytorch_COCO数据集_dataset

Coco数据集

本文主要内容来源于pytorch加载自己的coco数据集,针对其内容做学习和理解,进一步加深对数据集的理解以及自己的数据到dataset的步骤。仅作学习用
 了解输入和输出

代码示例

#!/usr/bin/env python3
# -*- coding: UTF-8 -*-

import os
import os.path
import json
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import TensorDataset
from torchvision.transforms import functional as F


# step1: 定义 CoCo_DataSet 类, 继承Dataset, 重写抽象方法:__init__, __len()__, __getitem()__
class CoCo_DataSet(Dataset):
    def __init__(self, coco_root_dir,transforms,train_set=True):
        self.transforms = transforms
        self.annotations_root = os.path.join(coco_root_dir,"annotations")
        if train_set:
            self.annotations_json = os.path.join(self.annotations_root,"coco_instance_train.json")
            self.image_root = os.path.join(coco_root_dir,"images","train2021")
        else:
            self.annotations_json = os.path.join(self.annotations_root,"coco_instance_val.json")
            self.image_root = os.path.join(coco_root_dir,"images","val2021")
        #判断文件是否存在
        assert os.path.exists(self.annotations_json), "{} file not exist ".format(self.annotations_json)
        if not os.path.isfile(self.annotations_json):
            print(self.annotations_json + ' ## not a file!')
        #读取Json文件
        json_file = open(file=self.annotations_json,mode='r',encoding="utf8")
        self.coco_dict = json.load(json_file)
        self.bbox_image= {}
        bbox_img = self.coco_dict["annotations"]
        for tmp in bbox_img:
            tmp_append  = list()
            pict_id = tmp["image_id"]
            pict_id = pict_id -1
            class_id = tmp["category_id"]
            bbox = tmp["bbox"]
            tmp_append.append(class_id)
            tmp_append.append(bbox)
            if self.bbox_image.__contains__(pict_id):
                self.bbox_image[pict_id].append(tmp_append)
            else:
                self.bbox_image[pict_id] =[]
                self.bbox_image[pict_id].append(tmp_append)


    def __len__(self):
        return len(self.coco_dict["images"])

    def __getitem__(self,idx):
        image_list = self.coco_dict["images"]
        pict_name = image_list[idx]["file_name"]
        pict_path = os.path.join(self.image_root,pict_name)
        if not os.path.isfile(pict_path):
            print(pict_path +  '@does not exist!')
            return None
        image = cv2.imread(pict_path)
        labels =[]
        bboxes = []
        target = {}
        if self.bbox_image.__contains__(idx):
            for img_annoatations in self.bbox_image[idx]:
                # (class_id) (bbox)
                bboxes.append(img_annoatations[1])
                labels.append(img_annoatations[0])
            bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
            labels = torch.as_tensor(labels,dtype=torch.int64)
            target["bboxes"]= bboxes
            target["labels"]= labels
        else:
            bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
            labels = torch.as_tensor(labels,dtype=torch.int64)
            target["bboxes"]= bboxes
            target["labels"]= labels
        if self.transforms is not None:
            image,target = self.transforms(image,target)
        return image,target

    def collate_fn(self,batch):
        return tuple(zip(*batch))



class Compose():
    def __init__(self,transforms):
        self.transforms = transforms

    def __call__(self,image,target):
        for t in self.transforms:
            image,target = t(image,target)
        return image,target

class ToTensor(object):
    def __call__(self, image,target):
        image =F.to_tensor(image)
        return image,target
# # 变换Resize
class Resize(object):

    def __init__(self, output_size: tuple):
        self.output_size = output_size

    def __call__(self, sample):
        # 图像
        image = sample['image']
        # 对图像进行缩放
        image_new =  cv2.resize(image, self.output_size)
        return {'image': image_new, 'label': sample['label']}

# # 变换ToTensor
class MyToTensor(object):
    def __call__(self, sample):
        image = sample['image']
        image_new = np.transpose(image, (2, 0, 1))
        return {'image': torch.from_numpy(image_new),
                'label': sample['label']}

if __name__ =="__main__":
    data_transform={
        "train": Compose([ToTensor()]),
        "val":Compose([ToTensor()])
    }
    coco_root_path= r"D:\data\dataset\coco"
    mycocoDataset = CoCo_DataSet(coco_root_path,data_transform["train"])
    dataloader = torch.utils.data.DataLoader(mycocoDataset, batch_size=2, shuffle=True,collate_fn=mycocoDataset.collate_fn)
    # dataloader = torch.utils.data.DataLoader(mycocoDataset, batch_size=2, shuffle=True,collate_fn=mycocoDataset.collate_fn)
    for i_batch, sample_batch in enumerate(dataloader):
        # print(type(sample_batch))
        # print(len(sample_batch))
        # print(len(sample_batch[0]))
        # print(len(sample_batch[1]))
        images_batch, labels_batch = sample_batch[0][0], sample_batch[0][1]
        # bboxes  labels
        #images_batch, labels_batch = sample_batch[1][0], sample_batch[1][1]
        print(images_batch)
        print(labels_batch)
        # print(labels_batch.shape,labels_batch.dtype)
        # print(images_batch.shape,images_batch.dtype)
        # print(labels_batch)

语法说明

 1.python3  判断字典中是否存在某个键 -例如arr_dict 是字典,判断"int_key" 是否
    01.函数 arr_dict.__contains__("int_key")

    02.使用 in 方法
     if "int_key" in arr_dict:
         print("存在")
  2. mycocoDataset.__getitem__(1) 返回的数据是
  (image-tensor,{"bboxes":tensor,"labels":tensor }) 

参考:

 深度网络学习-PyTorch_自定义Datsset  https://www.cnblogs.com/ytwang/p/15239433.html
 pytorch加载自己的coco数据集 https://blog.csdn.net/yangyangne/article/details/120384069 
 DATASETS & DATALOADERS  https://pytorch.org/tutorials/beginner/basics/data_tutorial.html  
 目标检测系列一:如何制作数据集?  http://www.spytensor.com/index.php/archives/48/
原文地址:https://www.cnblogs.com/ytwang/p/15753180.html