这是一个官网的例子:torch.nn入门。
一般而言,我们会根据自己的数据需求继承Dataset(from torch.utils.data import Dataset, DataLoader)重写数据读取函数。或者利用TensorDataset更加简洁实现读取数据。
抑或利用 torchvision里面的ImageFolder
也可管理数据。这几种方法已经可以实现数据读取了,而DataLoader的作用是更加全面管理批量数据:
下面进入正题,MNIST数据利用CNN时需要转换为二维数据,所以需要对初始的线性数据进行转换。一般,可以读取先行数据后在模型中进行view来实现:
class Lambda(nn.Module): def __init__(self, func): super().__init__() self.func = func def forward(self, x): return self.func(x) def preprocess(x): return x.view(-1, 1, 28, 28) model = nn.Sequential( Lambda(preprocess), nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.AvgPool2d(4), Lambda(lambda x: x.view(x.size(0), -1)), )
文中给出另一种解决方案:重写DateLoader:将数据处理移到生成器里面
def get_data(train_ds, valid_ds, bs): return ( DataLoader(train_ds, batch_size=bs, shuffle=True), DataLoader(valid_ds, batch_size=bs * 2), ) def preprocess(x, y): return x.view(-1, 1, 28, 28), y class WrappedDataLoader: def __init__(self, dl, func): self.dl = dl self.func = func def __len__(self): return len(self.dl) def __iter__(self): batches = iter(self.dl) for b in batches: yield (self.func(*b)) train_dl, valid_dl = get_data(train_ds, valid_ds, bs) train_dl = WrappedDataLoader(train_dl, preprocess) valid_dl = WrappedDataLoader(valid_dl, preprocess)
模型就可以写成这样:
model = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1), Lambda(lambda x: x.view(x.size(0), -1)), )