pytorch基本使用

自定义一个数据集

from torch.utils.data import Dataset
import os
import cv2

# 定义一个类,继承Dataset
class MyData(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(root_dir, label_dir)
        self.img_path = os.listdir(self.path)



    def __getitem__(self, index):
        img_name = self.img_path[index]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = cv2.imread(img_item_path)
        return img, self.label_dir

    def __len__(self):
        return len(self.img_path)

root_dir = 'dataset/hymenoptera_data/train'

ants_dataset = MyData(root_dir, 'ants')
img, label = ants_dataset[0]
cv2.imshow('img', img)
cv2.waitKey(0)
cv2.destroyAllWindows()

Tensorboard的使用

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('logs')
for i in range(100):
    writer.add_scalar("y = x", i, i)

writer.close()

Transforms的使用

from PIL import Image
from torchvision import transforms

img_path = 'dataset/hymenoptera_data/train/ants/6240329_72c01e663e.jpg'
img = Image.open(img_path)

# 得到一个ToTensor的对象
tensor_trans = transforms.ToTensor()
# 将img转换为tensorImg
tensor_img = tensor_trans(img)
print(tensor_img)

结合pytorch的数据集,使用transforms

import torchvision
import ssl
# 去掉ssl证书
from torch.utils.tensorboard import SummaryWriter

ssl._create_default_https_context = ssl._create_unverified_context

dataset_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root='./torch_dataset', train=True, transform=dataset_transforms, download=True)
test_set = torchvision.datasets.CIFAR10(root='./torch_dataset', train=False, transform=dataset_transforms, download=True)

print(train_set[0])

img, target = train_set[0]

writer = SummaryWriter("pytorch_dataset_logs")
for i in range(100):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)

DataLoader

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 测试集
test_data = torchvision.datasets.CIFAR10('./torch_dataset', transform=torchvision.transforms.ToTensor(), train=False)

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

# print(img.shape)
# print(target)

writer = SummaryWriter('dataLoader')

step = 0
for data in test_loader:
    img, target = data
    writer.add_images("test_data_loader", img, step)

    step = step + 1

writer.close()

原文地址:https://www.cnblogs.com/Gazikel/p/15749910.html