AI艺术鉴赏挑战赛

AI研习社 AI艺术鉴赏挑战赛 - 看画猜作者

亚军方案:

  • 主干网络resnest200,输入448尺寸,在不同loss下取得5组最好效果,最后进行投票,得到最后分数。单模最高93.75。

'''
import os
import math
import copy
import shutil
import time
import random
import pickle
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm
from collections import OrderedDict, namedtuple
from sklearn.metrics import roc_auc_score, average_precision_score
import se_resnext101_32x4d
from efficientnet_pytorch import EfficientNet
from data_augmentation import FixedRotation
from inceptionv4 import inceptionv4
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.models as models
from torchvision.models import resnet101,resnet50,resnet152,resnet34
import torchvision.transforms as transforms
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
from resnest.torch import resnest200,resnest269,resnest101
from torch.utils.data import Dataset, DataLoader
import warnings
warnings.filterwarnings("ignore")

def main(index):
np.random.seed(359)
torch.manual_seed(359)
torch.cuda.manual_seed_all(359)
random.seed(359)

os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3,4,5,6,7'
batch_size = 48
workers = 16

# stage_epochs = [8, 8, 8, 6, 5, 4, 3, 2]
# stage_epochs = [12, 6, 5, 3, 4]
lr = 5e-4
lr_decay = 10
weight_decay = 1e-4

stage = 0
start_epoch = 0
# total_epochs = sum(stage_epochs)
total_epochs = 200
patience = 4
no_improved_times = 0
total_stages = 3
best_score = 0
samples_num = 54

print_freq = 20
train_ratio = 0.9  # others for validation
momentum = 0.9
pre_model = 'senet'
pre_trained = True
evaluate = False
use_pre_model = False
# file_name = os.path.basename(__file__).split('.')[0]

file_name = "resnest200_448_all_{}".format(index)
img_size = 448

resumeflg = False
resume = ''

# 创建保存模型和结果的文件夹
if not os.path.exists('./model/%s' % file_name):
    os.makedirs('./model/%s' % file_name)
if not os.path.exists('./result/%s' % file_name):
    os.makedirs('./result/%s' % file_name)

if not os.path.exists('./result/%s.txt' % file_name):
    txt_mode = 'w'
else:
    txt_mode = 'a'
with open('./result/%s.txt' % file_name, txt_mode) as acc_file:
    acc_file.write('
%s %s
' % (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), file_name))

# build a model
model =resnest200(pretrained=True)
model.avgpool = torch.nn.AdaptiveAvgPool2d(output_size=1)
model.fc = torch.nn.Linear(model.fc.in_features,49)
# model = se_resnext101_32x4d.se_resnext101(num_classes=3)

model = EfficientNet.from_pretrained('efficientnet-b4',num_classes=2)

# model = inceptionv4(pretrained='imagenet')
# model.last_linear  = torch.nn.Linear(model.last_linear.in_features,2)
model = torch.nn.DataParallel(model).cuda()

def load_pre_cloth_model_dict(self, state_dict):
    own_state = self.state_dict()
    for name, param in state_dict.items():
        if name not in own_state:
            continue
        if 'fc' in name:
            continue
        if isinstance(param, nn.Parameter):
            # backwards compatibility for serialized parameters
            param = param.data
        own_state[name].copy_(param)

if use_pre_model:
    print('using pre model')
    pre_model_path = ''
    load_pre_cloth_model_dict(model, torch.load(pre_model_path)['state_dict'])

# optionally resume from a checkpoint
if resume:
    if os.path.isfile(resume):
        print("=> loading checkpoint '{}'".format(resume))
        checkpoint = torch.load(resume)
        start_epoch = checkpoint['epoch']
        best_score = checkpoint['best_score']
        stage = checkpoint['stage']
        lr = checkpoint['lr']
        model.load_state_dict(checkpoint['state_dict'])
        no_improved_times = checkpoint['no_improved_times']
        if no_improved_times == 0:
            model.load_state_dict(torch.load('./model/%s/model_best.pth.tar' % file_name)['state_dict'])
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(resume))

def default_loader(root_dir,path):
    final_path = os.path.join(root_dir,str(path))
    return Image.open(final_path+".jpg").convert('RGB')
    # return Image.open(path)

class TrainDataset(Dataset):
    def __init__(self, label_list, transform=None, target_transform=None, loader=default_loader):
        imgs = []
        for index, row in label_list.iterrows():
            imgs.append((row['filename'], row['label']))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        filename, label= self.imgs[index]
        label = label
        img = self.loader('../train/',filename)


        if self.transform is not None:
                img = self.transform(img)

        return img, label

    def __len__(self):
        return len(self.imgs)

class ValDataset(Dataset):
    def __init__(self, label_list, transform=None, target_transform=None, loader=default_loader):
        imgs = []
        for index, row in label_list.iterrows():
            imgs.append((row['filename'], row['label']))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        filename, label= self.imgs[index]
        label = label
        img = self.loader('../train/',filename)
        if self.transform is not None:
            img = self.transform(img)
        return img, label, filename

    def __len__(self):
        return len(self.imgs)

class TestDataset(Dataset):
    def __init__(self, label_list, transform=None, target_transform=None, loader=default_loader):
        imgs = []
        for index, row in label_list.iterrows():
            imgs.append((row['filename'], row['label']))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        filename,label = self.imgs[index]
        img = self.loader('../test/',filename)
        if self.transform is not None:
            img = self.transform(img)
        return img, filename

    def __len__(self):
        return len(self.imgs)

train_data_list = pd.read_csv("data/train_{}.csv".format(index), sep=",")
val_data_list = pd.read_csv("data/test_{}.csv".format(index), sep=",")
test_data_list = pd.read_csv("../test.csv",sep=",")

train_data_list = train_data_list.fillna(0)

# 训练集正常样本尺寸
random_crop = [transforms.RandomCrop(640), transforms.RandomCrop(768), transforms.RandomCrop(896)]



smax = nn.Softmax()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

train_data = TrainDataset(train_data_list,
                          transform=transforms.Compose([
                              transforms.Resize((img_size, img_size)),
                              transforms.ColorJitter(0.3, 0.3, 0.3, 0.15),
                              # transforms.RandomRotation(30),
                              transforms.RandomHorizontalFlip(),

transforms.RandomVerticalFlip(),

transforms.RandomGrayscale(),

                              FixedRotation([-16,-14,-12,-10,-8,-6,-4,-2,0,2,4,6,8,10,12,14,16]),
                              transforms.ToTensor(),
                              normalize,
                          ]))

val_data = ValDataset(val_data_list,
                      transform=transforms.Compose([
                          transforms.Resize((img_size, img_size)),
                          # transforms.CenterCrop((500, 500)),
                          transforms.ToTensor(),
                          normalize,
                      ]))

test_data = TestDataset(test_data_list,
                        transform=transforms.Compose([
                            transforms.Resize((img_size, img_size)),
                            # transforms.CenterCrop((500, 500)),
                            transforms.ToTensor(),
                            normalize,
                            # transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                            # transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])),
                        ]))

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=workers,drop_last=True)
val_loader = DataLoader(val_data, batch_size=batch_size * 2, shuffle=False, pin_memory=False, num_workers=workers,drop_last=True)
test_loader = DataLoader(test_data, batch_size=batch_size * 2, shuffle=False, pin_memory=False, num_workers=workers)

test_data_hflip = TestDataset(test_data_list,
                        transform=transforms.Compose([
                            transforms.Resize((img_size, img_size)),
                            transforms.RandomHorizontalFlip(p=2),
                            # transforms.CenterCrop((500, 500)),
                            transforms.ToTensor(),
                            normalize,
                        ]))


test_loader_hflip = DataLoader(test_data_hflip, batch_size=batch_size * 2, shuffle=False, pin_memory=False, num_workers=workers)

test_data_vflip = TestDataset(test_data_list,
                              transform=transforms.Compose([
                                  transforms.Resize((336, 336)),
                                  transforms.RandomVerticalFlip(p=2),
                                  # transforms.CenterCrop((500, 500)),
                                  transforms.ToTensor(),
                                  normalize,
                              ]))

test_loader_vflip = DataLoader(test_data_vflip, batch_size=batch_size * 2, shuffle=False, pin_memory=False,
                               num_workers=workers)

test_data_vhflip = TestDataset(test_data_list,
                              transform=transforms.Compose([
                                  transforms.Resize((336, 336)),
                                  transforms.RandomHorizontalFlip(p=2),
                                  transforms.RandomVerticalFlip(p=2),
                                  # transforms.CenterCrop((500, 500)),
                                  transforms.ToTensor(),
                                  normalize,
                              ]))

test_loader_vhflip = DataLoader(test_data_vhflip, batch_size=batch_size * 2, shuffle=False, pin_memory=False,
                               num_workers=workers)


def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading
        # if len(target) % workers == 1:
        #     images = images[:-1]
        #     target = target[:-1]

        data_time.update(time.time() - end)
        image_var = torch.tensor(images, requires_grad=False).cuda(non_blocking=True)
        # print(image_var)
        label = torch.tensor(target).cuda(non_blocking=True)
        # compute y_pred
        y_pred = model(image_var)
        loss = criterion(y_pred, label)

        # measure accuracy and record loss
        prec, PRED_COUNT = accuracy(y_pred.data, target, topk=(1, 1))
        losses.update(loss.item(), images.size(0))
        acc.update(prec, PRED_COUNT)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]	'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})	'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})	'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})	'
                  'Accuray {acc.val:.3f} ({acc.avg:.3f})'.format(
                epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, acc=acc))

def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    # losses = AverageMeter()
    # acc = AverageMeter()

    # switch to evaluate mode
    model.eval()

    # 保存概率,用于评测
    val_imgs, val_preds, val_labels, = [], [], []

    end = time.time()
    for i, (images, labels, img_path) in enumerate(val_loader):
        # if len(labels) % workers == 1:
        #     images = images[:-1]
        #     labels = labels[:-1]
        image_var = torch.tensor(images, requires_grad=False).cuda(non_blocking=True)  # for pytorch 0.4
        # label_var = torch.tensor(labels, requires_grad=False).cuda(async=True)  # for pytorch 0.4
        target = torch.tensor(labels).cuda(non_blocking=True)

        # compute y_pred
        with torch.no_grad():
            y_pred = model(image_var)
            loss = criterion(y_pred, target)

        # measure accuracy and record loss
        # prec, PRED_COUNT = accuracy(y_pred.data, labels, topk=(1, 1))
        # losses.update(loss.item(), images.size(0))
        # acc.update(prec, PRED_COUNT)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % (print_freq * 5) == 0:
            print('TrainVal: [{0}/{1}]	'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})	'.format(i, len(val_loader),
                                                                              batch_time=batch_time))

        # 保存概率,用于评测
        smax_out = smax(y_pred)
        val_imgs.extend(img_path)
        val_preds.extend([i.tolist() for i in smax_out])
        val_labels.extend([i.item() for i in labels])
    val_preds = [';'.join([str(j) for j in i]) for i in val_preds]
    val_score = pd.DataFrame({'img_path': val_imgs, 'preds': val_preds, 'label': val_labels,})
    val_score.to_csv('./result/%s/val_score.csv' % file_name, index=False)
    acc, f1  = score(val_score)
    print('acc: %.4f, f1: %.4f' % (acc, f1))
    print(' * Score {final_score:.4f}'.format(final_score=f1), '(Previous Best Score: %.4f)' % best_score)
    return acc, f1

def test(test_loader, model):
    csv_map = OrderedDict({'FileName': [], 'type': [], 'probability': []})
    # switch to evaluate mode
    model.eval()
    for i, (images, filepath) in enumerate(tqdm(test_loader)):
        # bs, ncrops, c, h, w = images.size()

        filepath = [str(i) for i in filepath]
        image_var = torch.tensor(images, requires_grad=False)  # for pytorch 0.4

        with torch.no_grad():
            y_pred = model(image_var)  # fuse batch size and ncrops
            # y_pred = y_pred.view(bs, ncrops, -1).mean(1) # avg over crops

            # get the index of the max log-probability
            smax = nn.Softmax()
            smax_out = smax(y_pred)
        csv_map['FileName'].extend(filepath)
        for output in smax_out:
            prob = ';'.join([str(i) for i in output.data.tolist()])
            csv_map['probability'].append(prob)
            csv_map['type'].append(np.argmax(output.data.tolist()))
        # print(len(csv_map['filename']), len(csv_map['probability']))

    result = pd.DataFrame(csv_map)
    result.to_csv('./result/%s/submission.csv' % file_name, index=False)
    result[['FileName','type']].to_csv('./result/%s/final_submission.csv' % file_name, index=False)
    return

def save_checkpoint(state, is_best, filename='./model/%s/checkpoint.pth.tar' % file_name):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, './model/%s/model_best.pth.tar' % file_name)

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def adjust_learning_rate():
    nonlocal lr
    lr = lr / lr_decay
    return optim.Adam(model.parameters(), lr, weight_decay=weight_decay, amsgrad=True)

def accuracy(y_pred, y_actual, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    final_acc = 0
    maxk = max(topk)
    # for prob_threshold in np.arange(0, 1, 0.01):
    PRED_COUNT = y_actual.size(0)
    PRED_CORRECT_COUNT = 0

    prob, pred = y_pred.topk(maxk, 1, True, True)
    # prob = np.where(prob > prob_threshold, prob, 0)


    for j in range(pred.size(0)):
        if int(y_actual[j]) == int(pred[j]):
            PRED_CORRECT_COUNT += 1
    if PRED_COUNT == 0:
        final_acc = 0
    else:
        final_acc = PRED_CORRECT_COUNT / PRED_COUNT
    return final_acc * 100, PRED_COUNT

def softmax(x):
    return np.exp(x) / np.sum(np.exp(x), axis=0)

def doitf(tp, fp, fn):
    if (tp + fp == 0):
        return 0
    if (tp + fn == 0):
        return 0
    pre = float(1.0 * float(tp) / float(tp + fp))
    rec = float(1.0 * float(tp) / float(tp + fn))
    if (pre + rec == 0):
        return 0
    return (2 * pre * rec) / (pre + rec)

# 参数 samples_num 表示选取多少个样本来取平均
def score(val_score):
    val_score['preds'] = val_score['preds'].map(lambda x: [float(i) for i in x.split(';')])
    acc = 0
    tp = np.zeros(49)
    fp = np.zeros(49)
    fn = np.zeros(49)
    f1 = np.zeros(49)
    f1_tot = 0

    print(val_score.head(10))

    val_score['preds_label'] = val_score['preds'].apply(lambda x: np.argmax(x))
    for i in range(val_score.shape[0]):
        preds = val_score['preds_label'].iloc[i]
        label = val_score['label'].iloc[i]
        if (preds == label):
            acc = acc + 1
            tp[label] = tp[label] + 1
        else:
            fp[preds] = fp[preds] + 1
            fn[label] = fn[label] + 1
    
    for classes in range(49):
        f1[classes] = doitf(tp[classes], fp[classes], fn[classes])
        f1_tot = f1_tot + f1[classes]
    acc = acc / val_score.shape[0]
    f1_tot = f1_tot / 49

    return acc, f1_tot

# define loss function (criterion) and pptimizer
criterion = nn.CrossEntropyLoss().cuda()

# optimizer = optim.Adam(model.module.last_linear.parameters(), lr, weight_decay=weight_decay, amsgrad=True)
optimizer = optim.Adam(model.parameters(), lr, weight_decay=weight_decay, amsgrad=True)

if evaluate:
    validate(val_loader, model, criterion)
else:
    for epoch in range(start_epoch, total_epochs):
        if stage >= total_stages - 1:
            break
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        # evaluate on validation set
        if epoch >= 0:
            acc , f1 = validate(val_loader, model, criterion)

            with open('./result/%s.txt' % file_name, 'a') as acc_file:
                acc_file.write('Epoch: %2d, acc: %.8f, f1: %.8f
' % (epoch, acc, f1))

            # remember best Accuracy and save checkpoint
            is_best = acc > best_score
            best_score = max(acc, best_score)

            # if (epoch + 1) in np.cumsum(stage_epochs)[:-1]:
            #     stage += 1
            #     optimizer = adjust_learning_rate()

            if is_best:
                no_improved_times = 0
            else:
                no_improved_times += 1

            print('stage: %d, no_improved_times: %d' % (stage, no_improved_times))

            if no_improved_times >= patience:
                stage += 1
                optimizer = adjust_learning_rate()

            state = {
                'epoch': epoch + 1,
                'arch': pre_model,
                'state_dict': model.state_dict(),
                'best_score': best_score,
                'no_improved_times': no_improved_times,
                'stage': stage,
                'lr': lr,
            }
            save_checkpoint(state, is_best)

            # if (epoch + 1) in np.cumsum(stage_epochs)[:-1]:
            if no_improved_times >= patience:
                no_improved_times = 0
                model.load_state_dict(torch.load('./model/%s/model_best.pth.tar' % file_name)['state_dict'])
                print('Step into next stage')
                with open('./result/%s.txt' % file_name, 'a') as acc_file:
                    acc_file.write('---------------------Step into next stage---------------------
')

with open('./result/%s.txt' % file_name, 'a') as acc_file:
    acc_file.write('* best acc: %.8f  %s
' % (best_score, os.path.basename(__file__)))
with open('./result/best_acc.txt', 'a') as acc_file:
    acc_file.write('%s  * best acc: %.8f  %s
' % (
    time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), best_score, os.path.basename(__file__)))

# test
best_model = torch.load('model/{}/model_best.pth.tar'.format(file_name))
model.load_state_dict(best_model['state_dict'])
test(test_loader=test_loader, model=model)

torch.cuda.empty_cache()
# resume = False

if name == 'main':
for index in range(1,6):
main(index)

'''

季军方案:

  • 基于Resnext50,eff-b3训练图像尺寸448,512,600的模型,取得分最高的4组结果进行投票。

'''
from torch.utils.data import DataLoader
from ArtModel import BaseModel
import time
import numpy as np
import random
from torch.optim import lr_scheduler
from torch.backends import cudnn
import argparse
import os
import torch
import torch.nn as nn
from dataload import Dataset

parser = argparse.ArgumentParser()
parser.add_argument('--model_name', default='resnext50', type=str)
parser.add_argument('--savepath', default='./Art/', type=str)
parser.add_argument('--loss', default='ce', type=str)
parser.add_argument('--num_classes', default=49, type=int)
parser.add_argument('--pool_type', default='avg', type=str)
parser.add_argument('--metric', default='linear', type=str)
parser.add_argument('--down', default=0, type=int)
parser.add_argument('--lr', default=0.01, type=float)
parser.add_argument('--weight_decay', default=5e-4, type=float)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--scheduler', default='cos', type=str)
parser.add_argument('--resume', default=None, type=str)
parser.add_argument('--lr_step', default=25, type=int)
parser.add_argument('--lr_gamma', default=0.1, type=float)
parser.add_argument('--total_epoch', default=60, type=int)
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--num_workers', default=8, type=int)
parser.add_argument('--multi-gpus', default=0, type=int)
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--seed', default=2020, type=int)
parser.add_argument('--pretrained', default=1, type=int)
parser.add_argument('--gray', default=0, type=int)

args = parser.parse_args()

def train():
model.train()

epoch_loss = 0
correct = 0.
total = 0.
t1 = time.time()
for idx, (data, labels) in enumerate(trainloader):
    data, labels = data.to(device), labels.long().to(device)
    
    out, se, feat_flat = model(data)
   
    loss = criterion(out, labels)
    optimizer.zero_grad()
    
    loss.backward()
    optimizer.step()

    epoch_loss += loss.item() * data.size(0)
    total += data.size(0)
    _, pred = torch.max(out, 1)
    correct += pred.eq(labels).sum().item()

acc = correct / total
loss = epoch_loss / total

print(f'loss:{loss:.4f} acc@1:{acc:.4f} time:{time.time() - t1:.2f}s', end=' --> ')

with open(os.path.join(savepath, 'log.txt'), 'a+')as f:
    f.write('loss:{:.4f}, acc:{:.4f} ->'.format(loss, acc))

return {'loss': loss, 'acc': acc}

def test(epoch):
model.eval()

epoch_loss = 0
correct = 0.
total = 0.
with torch.no_grad():
    for idx, (data, labels) in enumerate(valloader):
        data, labels = data.to(device), labels.long().to(device)
        
        out = model(data)

        loss = criterion(out, labels)

        epoch_loss += loss.item() * data.size(0)
        total += data.size(0)
        _, pred = torch.max(out, 1)
        correct += pred.eq(labels).sum().item()

    acc = correct / total
    loss = epoch_loss / total

    print(f'test loss:{loss:.4f} acc@1:{acc:.4f}', end=' ')

global best_acc, best_epoch

state = {
    'net': model.state_dict(),
    'acc': acc,
    'epoch': epoch
}

if acc > best_acc:
    best_acc = acc
    best_epoch = epoch

    torch.save(state, os.path.join(savepath, 'best.pth'))
    print('*')
else:
    print()

torch.save(state, os.path.join(savepath, 'last.pth'))


with open(os.path.join(savepath, 'log.txt'), 'a+')as f:
    f.write('epoch:{}, loss:{:.4f}, acc:{:.4f}
'.format(epoch, loss, acc))

return {'loss': loss, 'acc': acc}

def plot(d, mode='train', best_acc_=None):
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 4))
plt.suptitle('%s_curve' % mode)
plt.subplots_adjust(wspace=0.2, hspace=0.2)
epochs = len(d['acc'])

plt.subplot(1, 2, 1)
plt.plot(np.arange(epochs), d['loss'], label='loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(loc='upper left')

plt.subplot(1, 2, 2)
plt.plot(np.arange(epochs), d['acc'], label='acc')
if best_acc_ is not None:
    plt.scatter(best_acc_[0], best_acc_[1], c='r')
plt.xlabel('epoch')
plt.ylabel('acc')
plt.legend(loc='upper left')

plt.savefig(os.path.join(savepath, '%s.jpg' % mode), bbox_inches='tight')
plt.close()

if name == 'main':
best_epoch = 0
best_acc = 0.
use_gpu = False

if args.seed is not None:
    print('use random seed:', args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    cudnn.deterministic = False

if torch.cuda.is_available():
    use_gpu = True
    cudnn.benchmark = True

# loss
criterion = nn.CrossEntropyLoss()
# dataloader
trainset = Dataset(mode='train')
valset = Dataset(mode='val')

trainloader = DataLoader(dataset=trainset, batch_size=args.batch_size, shuffle=True, 
                         num_workers=args.num_workers, pin_memory=True, drop_last=True)

valloader = DataLoader(dataset=valset, batch_size=128, shuffle=False, num_workers=args.num_workers, 
                       pin_memory=True)

# model
model = BaseModel(model_name=args.model_name, num_classes=args.num_classes, pretrained=args.pretrained, pool_type=args.pool_type, down=args.down, metric=args.metric)
if args.resume:
    state = torch.load(args.resume)
    print('best_epoch:{}, best_acc:{}'.format(state['epoch'], state['acc']))
    model.load_state_dict(state['net'])

if torch.cuda.device_count() > 1 and args.multi_gpus:
    print('use multi-gpus...')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.distributed.init_process_group(backend="nccl", init_method='tcp://localhost:23456', rank=0, world_size=1)
    model = model.to(device)
    model = nn.parallel.DistributedDataParallel(model)
else:
    device = ('cuda:%d'%args.gpu if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
print('device:', device)

# optim
optimizer = torch.optim.SGD(
        [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': args.lr}],
        weight_decay=args.weight_decay, momentum=args.momentum)

print('init_lr={}, weight_decay={}, momentum={}'.format(args.lr, args.weight_decay, args.momentum))

if args.scheduler == 'step':
    scheduler = lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma, last_epoch=-1)
elif args.scheduler == 'multi':
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[150, 225], gamma=args.lr_gamma, last_epoch=-1)
elif args.scheduler == 'cos':
    warm_up_step = 10
    lambda_ = lambda epoch: (epoch + 1) / warm_up_step if epoch < warm_up_step else 0.5 * (
                np.cos((epoch - warm_up_step) / (args.total_epoch - warm_up_step) * np.pi) + 1)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda_)

# savepath
savepath = os.path.join(args.savepath, args.model_name+args.pool_type+args.metric+'_'+str(args.down))

print('savepath:', savepath)

if not os.path.exists(savepath):
    os.makedirs(savepath)

with open(os.path.join(savepath, 'setting.txt'), 'w')as f:
    for k, v in vars(args).items():
        f.write('{}:{}
'.format(k, v))

f = open(os.path.join(savepath, 'log.txt'), 'w')
f.close()

total = args.total_epoch
start = time.time()

train_info = {'loss': [], 'acc': []}
test_info = {'loss': [], 'acc': []}

for epoch in range(total):
    print('epoch[{:>3}/{:>3}]'.format(epoch, total), end=' ')
    d_train = train()
    scheduler.step()
    d_test = test(epoch)

    for k in train_info.keys():
        train_info[k].append(d_train[k])
        test_info[k].append(d_test[k])

    plot(train_info, mode='train')
    plot(test_info, mode='test', best_acc_=[best_epoch, best_acc])

end = time.time()
print('total time:{}m{:.2f}s'.format((end - start) // 60, (end - start) % 60))
print('best_epoch:', best_epoch)
print('best_acc:', best_acc)
with open(os.path.join(savepath, 'log.txt'), 'a+')as f:
    f.write('# best_acc:{:.4f}, best_epoch:{}'.format(best_acc, best_epoch))

'''

原文地址:https://www.cnblogs.com/yzm10/p/13934164.html