torchvision

  torchvision中的datasets模块种包含了多种常用的分类数据集相关的下载、导入函数,如表格:

数据集对应的类 描述
datasets.MNIST() 手写字体数据集
datasets.FashionMNIST() 衣服、鞋子、包等10类
datasets.KMNIST() 一些文字的灰度数据
datasets.CocoCaptions() 用于图像检测标注的MS COCO数据
datasets.CocoDetection() 用于检测的MS COCO数据
datasets.LSUN() 10个场景和20个目标的分类数据集
datasets.CIFAR10() CIFAR10类数据集
datasets.CIFAR100() CIFAR100类数据集
datasets.STL10() 包含10类的分类数据集和大量的未标记数据
datasets.ImageFolder() 定义一个数据加载器从文件种读取数据

torchvision.transforms模块

对应的类 描述
transforms.Compose() 将多个transform组合起来使用
transforms.Scale() 按照指定的图像尺寸对图像进行调整
transforms.CenterCrop() 将图像进行中心切割,得到指定大小的图像
transforms.RandomCrop() 切割中心点的位置随机选取
transforms.RandomHorizontalFlip() 将图像进行随机水平翻转
transforms.RandomSizedCrop() 将给定的图像随机切割,然后再变换给定大小
transforms.Pad() 把图像所有的边用给定的pad value填充
transforms.ToTensor() 把一个取值范围为[0,255]的PIL图像或形状为[H,W,C]的数组,转换成形状为[C,H,W],取值范围为[0,1.0],的张量(torch.FloatTensor)
transforms.Normalize() 将给定的图像进行规范化操作
transforms.Lambda(lambd) 使用lambd作为转化器,可自定义图像操作方式

例如代码所示

def ImgSplit(img_root=img_root0,batch_size=BTACH_SIZE,trainrate=0.8):
    # 数据加载及处理,对数据进行翻转,亮度,对比度等数据增广
    #print("图像预处理中。。。。。")
    transform = transforms.Compose([
        transforms.Resize(224),             #将图片按照比例缩放至224*224
        transforms.RandomResizedCrop(224, scale=(0.6, 1.0), ratio=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),      #随机旋转
        torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
        transforms.ToTensor(),              #转为tensor
        transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
    ])
    all_data = torchvision.datasets.ImageFolder(
        root=img_root,
        transform=transform
    )
    
    train_data, vaild_data = torch.utils.data.random_split(all_data, [int(trainrate * len(all_data)),
                                                                      len(all_data) - int(trainrate * len(all_data))])

    train_set = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True
    )
    test_set = torch.utils.data.DataLoader(
        vaild_data,
        batch_size=batch_size,
        shuffle=False
    )
    #print("图像预处完成。。。。。")
View Code

这里将图像路径为img_root0的数据集划分为80%的训练集和20%的测试集,每次放入训练的数据是BATCH_SIZE

原文地址:https://www.cnblogs.com/2020zxc/p/14629702.html