1.使用RNN做MNIST分类

第一次用LSTM,从简单做起吧~~

注意事项:

  • batch_first=True 意味着输入的格式为(batch_size,time_step,input_size),False 意味着输入的格式为(time_step,batch_size,input_size)
  • 取r_out[:,-1,:],即取时间步最后一步的结果,相当于LSTM把一张图片全部扫描完后的返回的状态向量(此时的维度变为(64,64),前面的64是batch_size,后面的64是隐藏层的神经元个数)
 1 import torch
 2 from torch.autograd import Variable
 3 from torchvision import datasets,transforms
 4 #超参数
 5 EPOCH=1
 6 BATCH_SIZE=64
 7 TIME_STEP=28#run time step/image height
 8 INPUT_SIZE=28#run input size/image width
 9 LR=0.01
10 DOWNLOAD_MNIST=True
11 
12 
13 train_data=datasets.MNIST(root='./mnist',train=True,transform=transforms.ToTensor(),download=DOWNLOAD_MNIST)
14 train_loader=torch.utils.data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)
15 
16 test_data=datasets.MNIST(root='./mnist',train=False,transform=transforms.ToTensor(),download=DOWNLOAD_MNIST)
17 test_loader=torch.utils.data.DataLoader(dataset=test_data,batch_size=BATCH_SIZE,shuffle=True)
18 
19 class RNN(torch.nn.Module):
20     def __init__(self):
21         super(RNN,self).__init__()
22 
23         self.rnn=torch.nn.LSTM(
24             input_size=INPUT_SIZE,
25             hidden_size=64,
26             num_layers=1,
27 
28             batch_first=True,
29         )
30         self.out=torch.nn.Linear(64,10)
31     def forward(self, x):
32         r_out,(h_n,h_c)=self.rnn(x,None)#[64,28,64]
33         out=self.out(r_out[:,-1,:])#[64,10]
34         return out
35 
36 #time_step,batch,input  batch_first=False,
37 rnn=RNN()
38 print(rnn)
39 
40 optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
41 loss_func=torch.nn.CrossEntropyLoss()
42 
43 for epoch in range(EPOCH):
44     for step,(x,y) in enumerate(train_loader):
45         b_x=Variable(x.view(-1,28,28))#reshape x to (batch,time_step.input_size)
46 
47         b_y=Variable(y).squeeze()
48         output=rnn(b_x)
49         loss=loss_func(output,b_y)
50         optimizer.zero_grad()
51         loss.backward()
52         optimizer.step()
53 
54 
55         if step %50==0:
56             for test_x,test_y in test_loader:
57                 test_output=rnn(test_x.view(-1,28,28))
58                 pred_y=torch.max(test_output,1)[1].data.numpy().squeeze()
59                 test_y=test_y.numpy()
60                 acc=sum(pred_y==test_y)/test_y.size
61                 print(acc)
原文地址:https://www.cnblogs.com/tangweijqxx/p/10601394.html