1、pytorch写的第一个Linear模型(原始版,不调用nn.Modules模块)

参考: https://github.com/Iallen520/lhy_DL_Hw/blob/master/PyTorch_Introduction.ipynb

模拟一个回归模型,y = X * w + 随机数

如,y : n*1矩阵, X : n*2矩阵 , w : 1,2矩阵

设置true_w = [[-1.0], [2.0]] , 随机初始化w 比如[[1.0], [0.0]],目标是拟合出正确的w

代码如下:

#自己写一个test case
import torch

d = 2
n = 50
X = torch.randn(n,d)
true_w = torch.tensor([[-1.0], [2.0]])
y = X @ true_w + torch.randn(n,1) * 0.1
print('X shape', X.shape)
print('y shape', y.shape)
print('w shape', true_w.shape)

print(X.shape)
print(true_w.shape)

#w = torch.rand(2,1, requires_grad = True)
w = torch.tensor([[1.],[0]], requires_grad= True)
print("w:", w)

print('iter,	loss,	w')
for i in range(20):
    loss = torch.norm(y - torch.matmul(X,w))**2 / n
    
    loss.backward()

    w.data = w.data - 0.1 * w.grad
    
    print('{},	{:.2f},	{}'.format(i, loss.item(), w.view(2).detach().numpy()))
    
    w.grad.zero_()

print('
true w		', true_w.view(2).numpy())
print('estimated w	', w.view(2).detach().numpy())
X shape torch.Size([50, 2])
y shape torch.Size([50, 1])
w shape torch.Size([2, 1])
torch.Size([50, 2])
torch.Size([2, 1])
w: tensor([[1.],
        [0.]], requires_grad=True)
iter,	loss,	w
0,	6.20,	[0.7062541  0.32114884]
1,	4.45,	[0.45446268 0.59016734]
2,	3.20,	[0.2387764  0.81564814]
3,	2.30,	[0.05413117 1.0047431 ]
4,	1.65,	[-0.10385066  1.1634098 ]
5,	1.19,	[-0.23894812  1.2966142 ]
6,	0.86,	[-0.3544196  1.4084985]
7,	0.62,	[-0.4530713  1.5025208]
8,	0.45,	[-0.5373176  1.5815694]
9,	0.32,	[-0.60923356  1.6480589 ]
10,	0.24,	[-0.6706013  1.7040086]
11,	0.17,	[-0.7229501  1.7511086]
12,	0.13,	[-0.76759106  1.7907746 ]
13,	0.09,	[-0.8056477  1.8241924]
14,	0.07,	[-0.8380821  1.8523566]
15,	0.05,	[-0.86571765  1.8761011 ]
16,	0.04,	[-0.8892586  1.8961263]
17,	0.03,	[-0.90930706  1.91302   ]
18,	0.02,	[-0.9263775  1.9272763]
19,	0.02,	[-0.9409093  1.9393103]

true w		 [-1.  2.]
estimated w	 [-0.9409093  1.9393103]

出错点:

1.初始化tensor时,要为float,否则容易报错

2. 初始化时设置 requires_grad = True

3. 定义的loss要在for循坏之内

4. 要用w.data不用w,否则报错,可能和pytorch初始化有关

5. w.grad.zero_()  注意这个写法

原文地址:https://www.cnblogs.com/qiezi-online/p/13945836.html