early stopping早停_pytorch学习

转自:https://blog.csdn.net/weixin_40446557/article/details/103387629

1.介绍

结合交叉验证法,可以防止模型过早拟合。在训练中计算模型在验证集上的表现,当模型在验证集上的表现开始下降的时候,停止训练,这样就能避免继续训练导致过拟合的问题。

注:需要将数据集分为训练集和验证集。

早停法主要是训练时间泛化错误之间的权衡。 

        //...........
        early_stopping(valid_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        //...........

链接中提供的代码早停是针对验证集上的损失。

2.解释

知乎 https://www.zhihu.com/question/59201590/answer/167392763

经过每个神经元,在给定激活函数的情况下,它的激活能力是和参数有关系的。

网络一开始训练的时候会赋值权重为较小值,它的拟合能力弱,基本为线性,随着网络的训练,权重会增大,那么就早停,使得网络的参数不那么复杂,降低它的过拟合程度,降低拟合训练集的能力。

3.skorch中的earlystopping

skorch.callbacks.EarlyStopping(patience=args.earlystop)

class EarlyStopping(Callback):
        def __init__(
            self,
            monitor='valid_loss', #默认监控的是valid集上的损失
            patience=5,
            threshold=1e-4,
            threshold_mode='rel',
            lower_is_better=True,
            sink=print,
    ):

4.valid_loss计算

https://debuggercafe.com/using-learning-rate-scheduler-and-early-stopping-with-pytorch/ 的例子:

    with torch.no_grad():
        for i, data in prog_bar:
            counter += 1
            data, target = data[0].to(device), data[1].to(device)
            total += target.size(0)
            outputs = model(data)
            loss = criterion(outputs, target)
            
            val_running_loss += loss.item()
            _, preds = torch.max(outputs.data, 1)
            val_running_correct += (preds == target).sum().item()
        
        val_loss = val_running_loss / counter
        val_accuracy = 100. * val_running_correct / total
        return val_loss, val_accuracy

其中val_loss 是先计算所有batch的总和,然后在batch数目上取均值。

https://github.com/Cai-Yichao/torch_backbones/blob/e3d4850603a795cee0710bba8f83db74f5a70d68/train.py#L96 例子中:

    with torch.no_grad():
        for index, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss_curr = criterion(outputs, targets)

            loss += loss_curr.item()
     
    eval_loss = loss/(index+1)

 其中val_loss 是计算在每个batch上的损失。

https://blog.csdn.net/weixin_40446557/article/details/103387629 给的例子:

        for data, target in valid_loader:
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the loss
            loss = criterion(output, target)
            # record validation loss
            valid_losses.append(loss.item())

       valid_loss = np.average(valid_losses)
       early_stopping(valid_loss, model)

 其中val_loss 是计算在每个batch上的损失。

以上的三个例子都是valid_loss在batch水平的损失。

原文地址:https://www.cnblogs.com/BlueBlueSea/p/14563880.html