2、pytorch——Linear模型(最基础版,理解框架,背诵记忆)(调用nn.Modules模块)

#define y = X @ w
import torch from torch import nn
#第一模块,数据初始化 n
= 100 X = torch.rand(n,2) true_w = torch.tensor([[-1.],[2]]) y = X @ true_w + torch.rand(n,1) w = torch.tensor([[1.],[1.]], requires_grad = True) """model = nn.Sequential(nn.Linear(2,3), nn.tanh(), nn.Linear(3,1), nn.tanh() )"""
#第二模块,定义model,定义loss_func,定义优化器optim model = nn.Linear(2,1) loss_func = nn.MSELoss() optim = torch.optim.SGD(model.parameters(), 0.1)
#第三模块,for循环,定义y_hat,定义loss,三步走:优化器参数梯度清零,从loss出发计算梯度,优化器更新各参数
print("epoch loss w ") epochs = 100 for i in range(epochs): y_hat = model(X) loss = loss_func(y_hat, y) optim.zero_grad() loss.backward() optim.step() print(f"{i} {loss} {model.weight.reshape(2).detach()}")

epoch	 loss	 w	
0	 3.0193979740142822	 tensor([-0.4203, -0.2541])
1	 1.724941372871399	 tensor([-0.3228, -0.1113])
2	 1.0764166116714478	 tensor([-0.2580, -0.0022])
3	 0.7487826347351074	 tensor([-0.2163,  0.0831])
4	 0.5806469917297363	 tensor([-0.1907,  0.1515])
5	 0.49187105894088745	 tensor([-0.1765,  0.2078])
6	 0.44265982508659363	 tensor([-0.1703,  0.2556])
7	 0.41324958205223083	 tensor([-0.1695,  0.2971])
8	 0.39382266998291016	 tensor([-0.1726,  0.3341])
9	 0.3794998526573181	 tensor([-0.1784,  0.3679])
10	 0.3678540289402008	 tensor([-0.1860,  0.3993])
11	 0.3576757311820984	 tensor([-0.1948,  0.4288])
12	 0.34836021065711975	 tensor([-0.2044,  0.4569])
13	 0.3396031856536865	 tensor([-0.2145,  0.4839])
14	 0.331249475479126	 tensor([-0.2250,  0.5100])
15	 0.32321831583976746	 tensor([-0.2356,  0.5354])
16	 0.3154657781124115	 tensor([-0.2463,  0.5601])
17	 0.3079665005207062	 tensor([-0.2570,  0.5843])
18	 0.30070433020591736	 tensor([-0.2677,  0.6080])
19	 0.29366785287857056	 tensor([-0.2783,  0.6313])
20	 0.28684815764427185	 tensor([-0.2888,  0.6542])
21	 0.2802375555038452	 tensor([-0.2992,  0.6766])
22	 0.2738291919231415	 tensor([-0.3095,  0.6987])
23	 0.2676165997982025	 tensor([-0.3196,  0.7204])
24	 0.2615937292575836	 tensor([-0.3296,  0.7418])
25	 0.255754679441452	 tensor([-0.3394,  0.7628])
26	 0.25009381771087646	 tensor([-0.3492,  0.7835])
27	 0.24460570514202118	 tensor([-0.3587,  0.8039])
28	 0.23928505182266235	 tensor([-0.3682,  0.8239])
29	 0.2341267466545105	 tensor([-0.3775,  0.8437])
30	 0.22912582755088806	 tensor([-0.3867,  0.8631])
31	 0.22427748143672943	 tensor([-0.3957,  0.8822])
32	 0.21957707405090332	 tensor([-0.4046,  0.9011])
33	 0.2150200754404068	 tensor([-0.4134,  0.9196])
34	 0.21060210466384888	 tensor([-0.4220,  0.9379])
35	 0.2063189297914505	 tensor([-0.4305,  0.9558])
36	 0.20216642320156097	 tensor([-0.4389,  0.9735])
37	 0.19814060628414154	 tensor([-0.4472,  0.9910])
38	 0.194237619638443	 tensor([-0.4553,  1.0081])
39	 0.1904536783695221	 tensor([-0.4633,  1.0250])
40	 0.18678519129753113	 tensor([-0.4712,  1.0416])
41	 0.18322861194610596	 tensor([-0.4790,  1.0580])
42	 0.17978054285049438	 tensor([-0.4867,  1.0741])
43	 0.176437646150589	 tensor([-0.4943,  1.0900])
44	 0.17319674789905548	 tensor([-0.5017,  1.1056])
45	 0.17005468904972076	 tensor([-0.5091,  1.1210])
46	 0.16700850427150726	 tensor([-0.5163,  1.1361])
47	 0.1640552133321762	 tensor([-0.5234,  1.1510])
48	 0.16119202971458435	 tensor([-0.5304,  1.1657])
49	 0.15841616690158844	 tensor([-0.5374,  1.1802])
50	 0.15572498738765717	 tensor([-0.5442,  1.1944])
51	 0.1531158685684204	 tensor([-0.5509,  1.2084])
52	 0.15058636665344238	 tensor([-0.5575,  1.2222])
53	 0.1481340080499649	 tensor([-0.5640,  1.2358])
54	 0.14575643837451935	 tensor([-0.5704,  1.2491])
55	 0.14345139265060425	 tensor([-0.5768,  1.2623])
56	 0.14121665060520172	 tensor([-0.5830,  1.2753])
57	 0.13905006647109985	 tensor([-0.5892,  1.2880])
58	 0.1369495391845703	 tensor([-0.5952,  1.3006])
59	 0.1349131017923355	 tensor([-0.6012,  1.3129])
60	 0.13293875753879547	 tensor([-0.6071,  1.3251])
61	 0.13102462887763977	 tensor([-0.6129,  1.3371])
62	 0.12916886806488037	 tensor([-0.6186,  1.3489])
63	 0.1273697167634964	 tensor([-0.6242,  1.3605])
64	 0.1256254017353058	 tensor([-0.6297,  1.3719])
65	 0.123934306204319	 tensor([-0.6352,  1.3832])
66	 0.12229477614164352	 tensor([-0.6406,  1.3943])
67	 0.12070523947477341	 tensor([-0.6459,  1.4052])
68	 0.11916416138410568	 tensor([-0.6511,  1.4159])
69	 0.11767008155584335	 tensor([-0.6562,  1.4265])
70	 0.11622155457735062	 tensor([-0.6613,  1.4369])
71	 0.11481721699237823	 tensor([-0.6663,  1.4472])
72	 0.11345569044351578	 tensor([-0.6712,  1.4573])
73	 0.11213566362857819	 tensor([-0.6761,  1.4672])
74	 0.11085589230060577	 tensor([-0.6809,  1.4770])
75	 0.10961514711380005	 tensor([-0.6856,  1.4867])
76	 0.10841222107410431	 tensor([-0.6902,  1.4961])
77	 0.10724597424268723	 tensor([-0.6948,  1.5055])
78	 0.10611527413129807	 tensor([-0.6993,  1.5147])
79	 0.10501907020807266	 tensor([-0.7038,  1.5237])
80	 0.10395626723766327	 tensor([-0.7081,  1.5326])
81	 0.1029258519411087	 tensor([-0.7124,  1.5414])
82	 0.10192685574293137	 tensor([-0.7167,  1.5500])
83	 0.10095832496881485	 tensor([-0.7209,  1.5585])
84	 0.10001931339502335	 tensor([-0.7250,  1.5669])
85	 0.09910892695188522	 tensor([-0.7291,  1.5752])
86	 0.09822628647089005	 tensor([-0.7331,  1.5833])
87	 0.09737054258584976	 tensor([-0.7370,  1.5913])
88	 0.0965408906340599	 tensor([-0.7409,  1.5991])
89	 0.09573652595281601	 tensor([-0.7448,  1.6069])
90	 0.09495667368173599	 tensor([-0.7485,  1.6145])
91	 0.09420059621334076	 tensor([-0.7523,  1.6220])
92	 0.09346755594015121	 tensor([-0.7559,  1.6294])
93	 0.09275685995817184	 tensor([-0.7596,  1.6367])
94	 0.09206782281398773	 tensor([-0.7631,  1.6438])
95	 0.09139978885650635	 tensor([-0.7666,  1.6509])
96	 0.09075210243463516	 tensor([-0.7701,  1.6578])
97	 0.09012416005134583	 tensor([-0.7735,  1.6647])
98	 0.08951534330844879	 tensor([-0.7769,  1.6714])
99	 0.08892509341239929	 tensor([-0.7802,  1.6780])
原文地址:https://www.cnblogs.com/qiezi-online/p/13947702.html