Pytorch-基础入门之ANN

在这部分中来介绍下ANN的Pytorch,这里的ANN具有三个隐含层。

这一块的话与上一篇逻辑斯蒂回归使用的是相同的数据集MNIST。

第一部分:构造模型

# Import Libraries
import torch
import torch.nn as nn
from torch.autograd import Variable

# Create ANN Model
class ANNModel(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ANNModel, self).__init__()
        
        # Linear function 1: 784 --> 150
        self.fc1 = nn.Linear(input_dim, hidden_dim) 
        # Non-linearity 1
        self.relu1 = nn.ReLU()
        
        # Linear function 2: 150 --> 150
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        # Non-linearity 2
        self.tanh2 = nn.Tanh()
        
        # Linear function 3: 150 --> 150
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        # Non-linearity 3
        self.elu3 = nn.ELU()
        
        # Linear function 4 (readout): 150 --> 10
        self.fc4 = nn.Linear(hidden_dim, output_dim)  
    
    def forward(self, x):
        # Linear function 1
        out = self.fc1(x)
        # Non-linearity 1
        out = self.relu1(out)
        
        # Linear function 2
        out = self.fc2(out)
        # Non-linearity 2
        out = self.tanh2(out)
        
        # Linear function 2
        out = self.fc3(out)
        # Non-linearity 2
        out = self.elu3(out)
        
        # Linear function 4 (readout)
        out = self.fc4(out)
        return out

# instantiate ANN
input_dim = 28*28
hidden_dim = 150 #hidden layer dim is one of the hyper parameter and it should be chosen and tuned. For now I only say 150 there is no reason.
output_dim = 10

# Create ANN
model = ANNModel(input_dim, hidden_dim, output_dim)

# Cross Entropy Loss 
error = nn.CrossEntropyLoss()

# SGD Optimizer
learning_rate = 0.02
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

 第二部分:训练模型

# ANN model training
count = 0
loss_list = []
iteration_list = []
accuracy_list = []
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):

        train = Variable(images.view(-1, 28*28))
        labels = Variable(labels)
        
        # Clear gradients
        optimizer.zero_grad()
        
        # Forward propagation
        outputs = model(train)
        
        # Calculate softmax and ross entropy loss
        loss = error(outputs, labels)
        
        # Calculating gradients
        loss.backward()
        
        # Update parameters
        optimizer.step()
        
        count += 1
        
        if count % 50 == 0:
            # Calculate Accuracy         
            correct = 0
            total = 0
            # Predict test dataset
            for images, labels in test_loader:

                test = Variable(images.view(-1, 28*28))
                
                # Forward propagation
                outputs = model(test)
                
                # Get predictions from the maximum value
                predicted = torch.max(outputs.data, 1)[1]
                
                # Total number of labels
                total += len(labels)

                # Total correct predictions
                correct += (predicted == labels).sum()
            
            accuracy = 100 * correct / float(total)
            
            # store loss and iteration
            loss_list.append(loss.data)
            iteration_list.append(count)
            accuracy_list.append(accuracy)
        if count % 500 == 0:
            # Print Loss
            print('Iteration: {}  Loss: {}  Accuracy: {} %'.format(count, loss.data, accuracy))

 结果:

Iteration: 500  Loss: 0.8311067223548889  Accuracy: 77 %
Iteration: 1000  Loss: 0.4767582416534424  Accuracy: 87 %
Iteration: 1500  Loss: 0.21807175874710083  Accuracy: 89 %
Iteration: 2000  Loss: 0.2915269732475281  Accuracy: 90 %
Iteration: 2500  Loss: 0.3073478937149048  Accuracy: 91 %
Iteration: 3000  Loss: 0.12328791618347168  Accuracy: 92 %
Iteration: 3500  Loss: 0.24098418653011322  Accuracy: 93 %
Iteration: 4000  Loss: 0.06471655517816544  Accuracy: 93 %
Iteration: 4500  Loss: 0.3368555009365082  Accuracy: 94 %
Iteration: 5000  Loss: 0.12026549130678177  Accuracy: 94 %
Iteration: 5500  Loss: 0.217212975025177  Accuracy: 94 %
Iteration: 6000  Loss: 0.20914879441261292  Accuracy: 94 %
Iteration: 6500  Loss: 0.10008767992258072  Accuracy: 95 %
Iteration: 7000  Loss: 0.13490895926952362  Accuracy: 95 %
Iteration: 7500  Loss: 0.11741413176059723  Accuracy: 95 %
Iteration: 8000  Loss: 0.17519493401050568  Accuracy: 95 %
Iteration: 8500  Loss: 0.06657659262418747  Accuracy: 95 %
Iteration: 9000  Loss: 0.05512683466076851  Accuracy: 95 %
Iteration: 9500  Loss: 0.02535334974527359  Accuracy: 96 %

 第三部分:可视化展示

# visualization loss 
plt.plot(iteration_list,loss_list)
plt.xlabel("Number of iteration")
plt.ylabel("Loss")
plt.title("ANN: Loss vs Number of iteration")
plt.show()

# visualization accuracy 
plt.plot(iteration_list,accuracy_list,color = "red")
plt.xlabel("Number of iteration")
plt.ylabel("Accuracy")
plt.title("ANN: Accuracy vs Number of iteration")
plt.show()

 结果:

 
 
原文地址:https://www.cnblogs.com/zhuozige/p/14696695.html