可视化分类网络的feature map

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import os

device = ('cuda:1' if torch.cuda.is_available() else 'cpu')
# device = ('cpu')

# Training settings
batch_size = 64
root = 'pytorch-master/mnist_data'
train_dataset = datasets.MNIST(root=root,
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root=root,
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           drop_last=True
                                           )

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          drop_last=True)

save_path = os.path.join(root, 'savepath')
os.makedirs(save_path, exist_ok=True)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, 5)
        self.conv3 = nn.Conv2d(20, 40, 3)
        self.mp = nn.MaxPool2d(2)
        self.mp1 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(2560, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        in_size = x.size(0)
        x = F.relu(self.mp(self.conv1(x)))
        x = F.relu(self.mp(self.conv2(x)))
        x = F.relu(self.mp1(self.conv3(x)))
        x = x.view(in_size, -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


model = Net().to(device)

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


def data_enhance(data, batch_idx):
    new_data = torch.zeros((data.size(0), data.size(1), 28 * 3, 28 * 3))
    noise = torch.rand(new_data.size())

    index = batch_idx % 9
    if index == 0:
        new_data[:, :, 0:28, 0:28] = data
    elif index == 1:
        new_data[:, :, 28:56, 0:28] = data
    elif index == 2:
        new_data[:, :, 56:, 0:28] = data
    elif index == 3:
        new_data[:, :, 0:28, 28:56] = data
    elif index == 4:
        new_data[:, :, 28:56, 28:56] = data
    elif index == 5:
        new_data[:, :, 56:, 28:56] = data
    elif index == 6:
        new_data[:, :, 0:28, 56:] = data
    elif index == 7:
        new_data[:, :, 28:56, 56:] = data
    elif index == 8:
        new_data[:, :, 56:, 56:] = data

    new_data = noise*0.7 + new_data*0.3
    return new_data


def train(epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data_enhance(data, batch_idx).to(device)
        output = model(data)
        loss = F.nll_loss(output, target.to(device))

        if batch_idx % 200 == 0:
            contest = 'Train Epoch: {} [{}/{} ({:.0f}%)]	Loss: {:.6f}
'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.data.item())
            print(contest)
            with open(os.path.join(root, 'log.txt'), 'a') as f:
                f.write(contest)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    torch.save(model.state_dict(), os.path.join(save_path, str(epoch) + '.pth'))


def test():
    test_loss = 0
    correct = 0
    for index, (data, target) in enumerate(test_loader):
        data = data_enhance(data, index).to(device)
        output = model(data)
        test_loss += F.nll_loss(output, target.to(device), size_average=False).data.item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.to(device).data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    contest = 'Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)

'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset))
    print(contest)
    with open(os.path.join(root, 'log.txt'), 'a') as f:
        f.write(contest)

from torchvision.utils import save_image

feature = []


def get_features_hook(self, input, output):
    feature.append(output)


def show(para_path):
    print('device:{}'.format(device))
    show_path = os.path.join(root, 'show')
    os.makedirs(show_path, exist_ok=True)
    model = Net()
    model.load_state_dict(torch.load(para_path,map_location='cpu'))
    model = model.to(device)
    for index, (data, target) in enumerate(test_loader):
        print(index)
        data = data_enhance(data, index).to(device)
        save_image(data, os.path.join(show_path, str(index) + '_img.jpg'))
        handle = model.mp1.register_forward_hook(get_features_hook)
        model(data)
        handle.remove()
        feat = torch.max(feature[-1], dim=1, keepdim=True)[0]
        save_image(feat, os.path.join(show_path, str(index) + '_feat.jpg'))
        if index > 3:
            break


if __name__ == '__main__':
    act = 2
    if act == 1:
        print('start training...')
        for epoch in range(1, 100):
            train(epoch)
            test()
    else:
        print('start show..')
        show('/pytorch-master/mnist_data/savepath/40.pth')

 输入:(为了增加难度,对mnist数据集的图片进行了平移,加噪音操作)

 可视化效果:(可以看出,网络确实学习到了数字特征(至少是位置信息),最终能达到0.96的准确率)

 

原文地址:https://www.cnblogs.com/jiangnanyanyuchen/p/13325322.html