pytorch备忘录

1.测试时不需要梯度:

with torch.no_grad():
    out = net(input)

2.可视化网络生成的图片(整个batch)

from torchvision.utils import save_image

out = net(input)
save_image(out,'./a.jpg')

3.按条件修改tensor值

out = torch.where(out>0.5, torch.full_like(out, 1), torch.full_like(out, 0))

4.加载预训练模型:

net.load_state_dict(torch.load('/home/dell/checkpoint.pt'))

5.预训练模型键值不对应:

def load_GPU(model, model_path, mapLoc='cpu'):
    state_dict = torch.load(model_path, map_location=mapLoc)['net']
    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    return model

6.tensor赋值:

out[out > 0.5] = 1

  

 

原文地址:https://www.cnblogs.com/jiangnanyanyuchen/p/13696796.html