torch.dataset随机划分为训练集和测试集

1.torch.utils.data.random_split()

pytorch有多种方法划分,但这个是最简单的。

转自:https://www.cnblogs.com/marsggbo/p/10496696.html

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

划分完了之后训练和测试集的类型是:

<class 'torch.utils.data.dataset.Subset'>

由原来的Dataset类型变为Subset类型,两者都可以作为torch.utils.data.DataLoader()的参数构建可迭代的DataLoader。

随机划分时,需要保证和为dataset的长度:

2.torch.utils.data.Subset()

https://stackoverflow.com/questions/47432168/taking-subsets-of-a-pytorch-dataset

import torchvision
import torch

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=None)

evens = list(range(0, len(trainset), 2))#偶数位
odds = list(range(1, len(trainset), 2))#奇数位
trainset_1 = torch.utils.data.Subset(trainset, evens)#Subset类型
trainset_2 = torch.utils.data.Subset(trainset, odds)#Subset类型

#由Subset对象构建DataLoader
trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=4,
                                            shuffle=True, num_workers=2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=4,
                                            shuffle=True, num_workers=2)

传入的第二个参数为所需要选取的样本的下标:

3.SubsetRandomSampler类

 https://www.cnblogs.com/marsggbo/p/10496696.html

# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
#随机选择下标

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

#以sampler取样器作为
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=valid_sampler)
原文地址:https://www.cnblogs.com/BlueBlueSea/p/14617713.html