pytorch 断点续训练

checkpoint  = torch.load('.pth')
    net.load_state_dict(checkpoint['net'])
    criterion_mse = torch.nn.MSELoss().to(cfg.device)
    criterion_L1 = L1Loss()
    optimizer = torch.optim.Adam([paras for paras in net.parameters() if paras.requires_grad == True], lr=cfg.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.n_steps, gamma=cfg.gamma)
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict= checkpoint['lr_schedule']
    start_epoch = checkpoint['epoch']

 for idx_epoch in range(start_epoch+1,80):
        scheduler.step()
        for idx_iter, () in enumerate(train_loader):
           

            _ = net()

         
            loss = criterion_mse(,)

            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

           if idx_epoch % 1 == 0:
           
     
            checkpoint = {
                "net": net.state_dict(),#网络参数
                'optimizer': optimizer.state_dict(),#优化器
                "epoch": idx_epoch,#训练轮数
                'lr_schedule': scheduler.state_dict()#lr如何变化
            }
            torch.save(checkpoint,os.path.join(save_path, filename))
           
直接训练
a mean psnr:  28.160327919812364
a mean ssim:  0.8067064184409644
b mean psnr:  25.01364162100755
b mean ssim:  0.7600019779915981
c mean psnr:  25.83471135230011
c mean ssim:  0.7774989383731079

断点续训
a mean psnr:  28.15391601255439
a mean ssim:  0.8062857339309237
b mean psnr:  25.01115760689137
b mean ssim:  0.7596963993692107
c mean psnr:  25.842269038618145
c mean ssim:  0.7772710729947427

断点续训的效果基本和直接训练一致,但仍有些差,后面会继续分析

原文地址:https://www.cnblogs.com/tingtin/p/14091452.html