TensorFlow实现线性回归模型代码

模型构建

1.示例代码linear_regression_model.py

#!/usr/bin/python
# -*- coding: utf-8 -*
import tensorflow as tf
import numpy as np

class linearRegressionModel:

  def __init__(self,x_dimen):
    self.x_dimen = x_dimen
    self._index_in_epoch = 0
    self.constructModel()
    self.sess = tf.Session()
    self.sess.run(tf.global_variables_initializer())

  #权重初始化
  def weight_variable(self,shape):
    initial = tf.truncated_normal(shape,stddev = 0.1)
    return tf.Variable(initial)

  #偏置项初始化
  def bias_variable(self,shape):
    initial = tf.constant(0.1,shape = shape)
    return tf.Variable(initial)

  #每次选取100个样本,如果选完,重新打乱
  def next_batch(self,batch_size):
    start = self._index_in_epoch
    self._index_in_epoch += batch_size
    if self._index_in_epoch > self._num_datas:
        perm = np.arange(self._num_datas)
        np.random.shuffle(perm)
        self._datas = self._datas[perm]
        self._labels = self._labels[perm]
        start = 0
        self._index_in_epoch = batch_size
        assert batch_size <= self._num_datas
    end = self._index_in_epoch
    return self._datas[start:end],self._labels[start:end]

  def constructModel(self):
    self.x = tf.placeholder(tf.float32, [None,self.x_dimen])
    self.y = tf.placeholder(tf.float32,[None,1])
    self.w = self.weight_variable([self.x_dimen,1])
    self.b = self.bias_variable([1])
    self.y_prec = tf.nn.bias_add(tf.matmul(self.x, self.w), self.b)

    mse = tf.reduce_mean(tf.squared_difference(self.y_prec, self.y))
    l2 = tf.reduce_mean(tf.square(self.w))
    self.loss = mse + 0.15*l2
    self.train_step = tf.train.AdamOptimizer(0.1).minimize(self.loss)

  def train(self,x_train,y_train,x_test,y_test):
    self._datas = x_train
    self._labels = y_train
    self._num_datas = x_train.shape[0]
    for i in range(5000):
        batch = self.next_batch(100)
        self.sess.run(self.train_step,feed_dict={self.x:batch[0],self.y:batch[1]})
        if i%10 == 0:
            train_loss = self.sess.run(self.loss,feed_dict={self.x:batch[0],self.y:batch[1]})
            print('step %d,test_loss %f' % (i,train_loss))

  def predict_batch(self,arr,batch_size):
    for i in range(0,len(arr),batch_size):
        yield arr[i:i + batch_size]

  def predict(self, x_predict):
    pred_list = []
    for x_test_batch in self.predict_batch(x_predict,100):
      pred = self.sess.run(self.y_prec, {self.x:x_test_batch})
      pred_list.append(pred)
    return np.vstack(pred_list)

 2.创建run.py

#!/usr/bin/python
# -*- coding: utf-8 -*

from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
from linear_regression_model import linearRegressionModel as lrm

if __name__ == '__main__':
    x, y = make_regression(7000)
    x_train,x_test,y_train, y_test = train_test_split(x, y, test_size=0.5)
    y_lrm_train = y_train.reshape(-1, 1)
    y_lrm_test = y_test.reshape(-1, 1)

    linear = lrm(x.shape[1])
    linear.train(x_train, y_lrm_train,x_test,y_lrm_test)
    y_predict = linear.predict(x_test)
    print("Tensorflow R2: ", r2_score(y_predict.ravel(), y_lrm_test.ravel()))

    lr = LinearRegression()
    y_predict = lr.fit(x_train, y_train).predict(x_test)
    print("Sklearn R2: ", r2_score(y_predict, y_test)) #采用r2_score评分函数

 运行结果:

step 0,test_loss 27078.781250
step 10,test_loss 29246.253906
step 20,test_loss 21168.052734
step 30,test_loss 22109.154297
step 40,test_loss 28030.435547
step 50,test_loss 24265.765625
step 60,test_loss 28433.816406
step 70,test_loss 24395.164062
step 80,test_loss 19135.515625
step 90,test_loss 20932.734375
step 100,test_loss 17176.033203
step 110,test_loss 19729.275391
step 120,test_loss 18076.587891
step 130,test_loss 24546.722656
step 140,test_loss 22370.619141
step 150,test_loss 17227.343750
step 160,test_loss 21498.363281
step 170,test_loss 17482.292969
step 180,test_loss 16188.901367
step 190,test_loss 17961.816406
step 200,test_loss 15168.850586
step 210,test_loss 14205.447266
step 220,test_loss 15992.610352
step 230,test_loss 12878.104492
step 240,test_loss 15663.670898
step 250,test_loss 11105.211914
step 260,test_loss 11135.759766
step 270,test_loss 12083.872070
step 280,test_loss 9544.156250
step 290,test_loss 12040.689453
step 300,test_loss 8685.537109
step 310,test_loss 11533.030273
step 320,test_loss 11031.776367
step 330,test_loss 11258.272461
step 340,test_loss 9219.499023
step 350,test_loss 7839.248047
step 360,test_loss 9757.743164
step 370,test_loss 7579.228027
step 380,test_loss 8326.705078
step 390,test_loss 8823.761719
step 400,test_loss 8431.373047
step 410,test_loss 8025.544922
step 420,test_loss 7954.462891
step 430,test_loss 9809.444336
step 440,test_loss 5645.476074
step 450,test_loss 7813.232422
step 460,test_loss 6410.347656
step 470,test_loss 6623.901367
step 480,test_loss 7697.770508
step 490,test_loss 5924.088867
step 500,test_loss 5174.365234
step 510,test_loss 5223.140625
step 520,test_loss 5655.796387
step 530,test_loss 4949.434570
step 540,test_loss 4330.499023
step 550,test_loss 5321.663086
step 560,test_loss 4629.940918
step 570,test_loss 3220.557373
step 580,test_loss 4162.278320
step 590,test_loss 4546.246582
step 600,test_loss 4487.117188
step 610,test_loss 5037.617676
step 620,test_loss 3526.248047
step 630,test_loss 3432.793457
step 640,test_loss 3385.915527
step 650,test_loss 3272.809814
step 660,test_loss 2710.681396
step 670,test_loss 3326.879883
step 680,test_loss 3275.361084
step 690,test_loss 2347.117432
step 700,test_loss 2957.036621
step 710,test_loss 1699.123535
step 720,test_loss 2293.731445
step 730,test_loss 2275.772705
step 740,test_loss 2176.456055
step 750,test_loss 2457.974121
step 760,test_loss 2203.473877
step 770,test_loss 1920.002686
step 780,test_loss 2047.632446
step 790,test_loss 1736.505615
step 800,test_loss 2039.262451
step 810,test_loss 2055.947510
step 820,test_loss 1908.234375
step 830,test_loss 1280.326904
step 840,test_loss 1412.927856
step 850,test_loss 1737.114258
step 860,test_loss 1251.464111
step 870,test_loss 1589.670532
step 880,test_loss 1396.735474
step 890,test_loss 1706.040527
step 900,test_loss 1558.866333
step 910,test_loss 1334.543213
step 920,test_loss 1306.657471
step 930,test_loss 942.939819
step 940,test_loss 1200.833008
step 950,test_loss 932.249695
step 960,test_loss 1328.827271
step 970,test_loss 1191.408081
step 980,test_loss 832.388062
step 990,test_loss 1052.487427
step 1000,test_loss 896.287964
step 1010,test_loss 707.095093
step 1020,test_loss 622.292297
step 1030,test_loss 798.665649
step 1040,test_loss 789.424316
step 1050,test_loss 606.861450
step 1060,test_loss 573.976074
step 1070,test_loss 465.951965
step 1080,test_loss 631.956543
step 1090,test_loss 679.685913
step 1100,test_loss 440.278046
step 1110,test_loss 476.793945
step 1120,test_loss 450.453278
step 1130,test_loss 541.740479
step 1140,test_loss 502.860077
step 1150,test_loss 363.825653
step 1160,test_loss 378.313232
step 1170,test_loss 364.206024
step 1180,test_loss 359.042999
step 1190,test_loss 304.770569
step 1200,test_loss 354.092407
step 1210,test_loss 296.288147
step 1220,test_loss 313.082031
step 1230,test_loss 321.331512
step 1240,test_loss 327.985718
step 1250,test_loss 257.409210
step 1260,test_loss 250.276291
step 1270,test_loss 191.458878
step 1280,test_loss 216.972244
step 1290,test_loss 229.754684
step 1300,test_loss 219.731140
step 1310,test_loss 197.320190
step 1320,test_loss 185.500366
step 1330,test_loss 180.765671
step 1340,test_loss 223.783081
step 1350,test_loss 166.295975
step 1360,test_loss 146.334641
step 1370,test_loss 191.004700
step 1380,test_loss 137.425964
step 1390,test_loss 155.957443
step 1400,test_loss 137.031784
step 1410,test_loss 144.765793
step 1420,test_loss 123.946625
step 1430,test_loss 133.717957
step 1440,test_loss 136.200287
step 1450,test_loss 109.962036
step 1460,test_loss 107.478485
step 1470,test_loss 111.343063
step 1480,test_loss 113.355667
step 1490,test_loss 110.620399
step 1500,test_loss 116.955994
step 1510,test_loss 102.297958
step 1520,test_loss 107.474968
step 1530,test_loss 88.769562
step 1540,test_loss 88.092247
step 1550,test_loss 93.228027
step 1560,test_loss 78.206909
step 1570,test_loss 99.623810
step 1580,test_loss 67.202003
step 1590,test_loss 77.569229
step 1600,test_loss 78.516144
step 1610,test_loss 76.165176
step 1620,test_loss 64.493408
step 1630,test_loss 70.672768
step 1640,test_loss 68.577499
step 1650,test_loss 72.143890
step 1660,test_loss 63.308643
step 1670,test_loss 64.004288
step 1680,test_loss 64.626549
step 1690,test_loss 59.137959
step 1700,test_loss 63.122589
step 1710,test_loss 56.314068
step 1720,test_loss 51.382557
step 1730,test_loss 58.105713
step 1740,test_loss 57.619289
step 1750,test_loss 54.326633
step 1760,test_loss 51.271332
step 1770,test_loss 56.553986
step 1780,test_loss 51.459373
step 1790,test_loss 49.371822
step 1800,test_loss 52.714359
step 1810,test_loss 50.442295
step 1820,test_loss 49.796776
step 1830,test_loss 48.404625
step 1840,test_loss 47.714275
step 1850,test_loss 49.141331
step 1860,test_loss 46.075230
step 1870,test_loss 47.250427
step 1880,test_loss 47.220695
step 1890,test_loss 47.975838
step 1900,test_loss 47.080906
step 1910,test_loss 45.991798
step 1920,test_loss 45.940758
step 1930,test_loss 45.241516
step 1940,test_loss 45.457054
step 1950,test_loss 44.415176
step 1960,test_loss 44.690414
step 1970,test_loss 44.910900
step 1980,test_loss 43.690544
step 1990,test_loss 42.880653
step 2000,test_loss 42.956898
step 2010,test_loss 43.080429
step 2020,test_loss 43.176693
step 2030,test_loss 43.030117
step 2040,test_loss 43.170925
step 2050,test_loss 42.681801
step 2060,test_loss 42.610954
step 2070,test_loss 42.576504
step 2080,test_loss 42.255066
step 2090,test_loss 42.081310
step 2100,test_loss 42.341095
step 2110,test_loss 42.025223
step 2120,test_loss 42.204201
step 2130,test_loss 42.335026
step 2140,test_loss 41.973049
step 2150,test_loss 42.003143
step 2160,test_loss 41.904259
step 2170,test_loss 41.881233
step 2180,test_loss 41.608265
step 2190,test_loss 41.525867
step 2200,test_loss 41.472271
step 2210,test_loss 41.472610
step 2220,test_loss 41.598587
step 2230,test_loss 41.459789
step 2240,test_loss 41.376347
step 2250,test_loss 41.300011
step 2260,test_loss 41.316811
step 2270,test_loss 41.432549
step 2280,test_loss 41.290428
step 2290,test_loss 41.279583
step 2300,test_loss 41.197216
step 2310,test_loss 41.269833
step 2320,test_loss 41.240284
step 2330,test_loss 41.202190
step 2340,test_loss 41.211605
step 2350,test_loss 41.224072
step 2360,test_loss 41.169403
step 2370,test_loss 41.151337
step 2380,test_loss 41.162971
step 2390,test_loss 41.127731
step 2400,test_loss 41.094795
step 2410,test_loss 41.089066
step 2420,test_loss 41.137642
step 2430,test_loss 41.085999
step 2440,test_loss 41.096901
step 2450,test_loss 41.096237
step 2460,test_loss 41.072151
step 2470,test_loss 41.094440
step 2480,test_loss 41.049301
step 2490,test_loss 41.062485
step 2500,test_loss 41.053036
step 2510,test_loss 41.042328
step 2520,test_loss 41.049831
step 2530,test_loss 41.078171
step 2540,test_loss 41.013088
step 2550,test_loss 41.039490
step 2560,test_loss 41.040127
step 2570,test_loss 41.047153
step 2580,test_loss 41.059521
step 2590,test_loss 41.067646
step 2600,test_loss 41.027416
step 2610,test_loss 41.019939
step 2620,test_loss 41.030586
step 2630,test_loss 41.028877
step 2640,test_loss 41.027557
step 2650,test_loss 41.026352
step 2660,test_loss 41.023903
step 2670,test_loss 41.006763
step 2680,test_loss 41.024330
step 2690,test_loss 41.046272
step 2700,test_loss 41.018227
step 2710,test_loss 41.016628
step 2720,test_loss 41.025139
step 2730,test_loss 41.019703
step 2740,test_loss 41.016834
step 2750,test_loss 41.033138
step 2760,test_loss 41.031982
step 2770,test_loss 41.027203
step 2780,test_loss 41.036865
step 2790,test_loss 41.039066
step 2800,test_loss 41.015831
step 2810,test_loss 41.021862
step 2820,test_loss 41.037052
step 2830,test_loss 41.030590
step 2840,test_loss 41.026188
step 2850,test_loss 41.019707
step 2860,test_loss 41.021141
step 2870,test_loss 41.019894
step 2880,test_loss 41.020607
step 2890,test_loss 41.024086
step 2900,test_loss 41.037041
step 2910,test_loss 41.023495
step 2920,test_loss 41.011646
step 2930,test_loss 41.022732
step 2940,test_loss 41.017460
step 2950,test_loss 41.042557
step 2960,test_loss 41.025982
step 2970,test_loss 41.023857
step 2980,test_loss 41.029766
step 2990,test_loss 41.021320
step 3000,test_loss 41.036278
step 3010,test_loss 41.026100
step 3020,test_loss 41.029068
step 3030,test_loss 41.007935
step 3040,test_loss 41.024139
step 3050,test_loss 41.023842
step 3060,test_loss 41.023033
step 3070,test_loss 41.041313
step 3080,test_loss 41.013794
step 3090,test_loss 41.021595
step 3100,test_loss 41.023506
step 3110,test_loss 41.027863
step 3120,test_loss 41.049881
step 3130,test_loss 41.037209
step 3140,test_loss 41.013416
step 3150,test_loss 41.044666
step 3160,test_loss 41.022858
step 3170,test_loss 41.026386
step 3180,test_loss 41.025173
step 3190,test_loss 41.025276
step 3200,test_loss 41.031715
step 3210,test_loss 41.019821
step 3220,test_loss 41.023750
step 3230,test_loss 41.026768
step 3240,test_loss 41.025543
step 3250,test_loss 41.030800
step 3260,test_loss 41.032837
step 3270,test_loss 41.020596
step 3280,test_loss 41.024185
step 3290,test_loss 41.014019
step 3300,test_loss 41.017628
step 3310,test_loss 41.039688
step 3320,test_loss 41.036552
step 3330,test_loss 41.041679
step 3340,test_loss 41.010323
step 3350,test_loss 41.019321
step 3360,test_loss 41.003582
step 3370,test_loss 41.039524
step 3380,test_loss 41.041386
step 3390,test_loss 41.014439
step 3400,test_loss 41.031914
step 3410,test_loss 41.047981
step 3420,test_loss 41.020836
step 3430,test_loss 41.035324
step 3440,test_loss 41.021690
step 3450,test_loss 41.026123
step 3460,test_loss 41.029877
step 3470,test_loss 41.027092
step 3480,test_loss 41.027649
step 3490,test_loss 41.023071
step 3500,test_loss 41.027126
step 3510,test_loss 41.018978
step 3520,test_loss 41.030590
step 3530,test_loss 41.026154
step 3540,test_loss 41.021610
step 3550,test_loss 41.014198
step 3560,test_loss 41.032345
step 3570,test_loss 41.030876
step 3580,test_loss 41.013630
step 3590,test_loss 41.025135
step 3600,test_loss 41.035576
step 3610,test_loss 41.018707
step 3620,test_loss 41.019424
step 3630,test_loss 41.028542
step 3640,test_loss 41.039867
step 3650,test_loss 41.014717
step 3660,test_loss 41.035339
step 3670,test_loss 41.031448
step 3680,test_loss 41.016773
step 3690,test_loss 41.025093
step 3700,test_loss 41.030968
step 3710,test_loss 41.027367
step 3720,test_loss 41.039196
step 3730,test_loss 41.024532
step 3740,test_loss 41.039036
step 3750,test_loss 41.003342
step 3760,test_loss 41.035763
step 3770,test_loss 41.035271
step 3780,test_loss 41.009220
step 3790,test_loss 41.030884
step 3800,test_loss 41.029705
step 3810,test_loss 41.029217
step 3820,test_loss 41.028343
step 3830,test_loss 41.020901
step 3840,test_loss 41.039314
step 3850,test_loss 41.045189
step 3860,test_loss 41.028725
step 3870,test_loss 41.026402
step 3880,test_loss 41.014465
step 3890,test_loss 41.027691
step 3900,test_loss 41.027061
step 3910,test_loss 41.023037
step 3920,test_loss 41.028137
step 3930,test_loss 41.035686
step 3940,test_loss 41.021793
step 3950,test_loss 41.014446
step 3960,test_loss 41.018074
step 3970,test_loss 41.037655
step 3980,test_loss 41.019314
step 3990,test_loss 41.022900
step 4000,test_loss 41.026077
step 4010,test_loss 41.035042
step 4020,test_loss 41.022713
step 4030,test_loss 41.029526
step 4040,test_loss 41.026649
step 4050,test_loss 41.033508
step 4060,test_loss 41.028713
step 4070,test_loss 41.031872
step 4080,test_loss 41.017612
step 4090,test_loss 41.031342
step 4100,test_loss 41.024128
step 4110,test_loss 41.021511
step 4120,test_loss 41.028091
step 4130,test_loss 41.025402
step 4140,test_loss 41.028831
step 4150,test_loss 41.025154
step 4160,test_loss 41.028797
step 4170,test_loss 41.023502
step 4180,test_loss 41.023289
step 4190,test_loss 41.026257
step 4200,test_loss 41.023941
step 4210,test_loss 41.017677
step 4220,test_loss 41.018219
step 4230,test_loss 41.021465
step 4240,test_loss 41.022671
step 4250,test_loss 41.035088
step 4260,test_loss 41.028889
step 4270,test_loss 41.015503
step 4280,test_loss 41.011471
step 4290,test_loss 41.034992
step 4300,test_loss 41.024700
step 4310,test_loss 41.021152
step 4320,test_loss 41.033760
step 4330,test_loss 41.022285
step 4340,test_loss 41.023975
step 4350,test_loss 41.047928
step 4360,test_loss 41.040417
step 4370,test_loss 41.015713
step 4380,test_loss 41.021191
step 4390,test_loss 41.028423
step 4400,test_loss 41.046730
step 4410,test_loss 41.019470
step 4420,test_loss 41.023933
step 4430,test_loss 41.023426
step 4440,test_loss 41.044052
step 4450,test_loss 41.023289
step 4460,test_loss 41.037994
step 4470,test_loss 41.027950
step 4480,test_loss 41.018356
step 4490,test_loss 41.026508
step 4500,test_loss 41.024136
step 4510,test_loss 41.032318
step 4520,test_loss 41.028934
step 4530,test_loss 41.027802
step 4540,test_loss 41.034740
step 4550,test_loss 41.018875
step 4560,test_loss 41.009151
step 4570,test_loss 41.028728
step 4580,test_loss 41.013172
step 4590,test_loss 41.023643
step 4600,test_loss 41.036564
step 4610,test_loss 41.023758
step 4620,test_loss 41.010895
step 4630,test_loss 41.016830
step 4640,test_loss 41.025158
step 4650,test_loss 41.031147
step 4660,test_loss 41.030773
step 4670,test_loss 41.014057
step 4680,test_loss 41.012878
step 4690,test_loss 41.020706
step 4700,test_loss 41.024204
step 4710,test_loss 41.030964
step 4720,test_loss 41.042183
step 4730,test_loss 41.004620
step 4740,test_loss 41.043163
step 4750,test_loss 41.026157
step 4760,test_loss 41.016129
step 4770,test_loss 41.028667
step 4780,test_loss 41.033478
step 4790,test_loss 41.032280
step 4800,test_loss 41.029270
step 4810,test_loss 41.032330
step 4820,test_loss 41.026970
step 4830,test_loss 41.034531
step 4840,test_loss 41.038826
step 4850,test_loss 41.033676
step 4860,test_loss 41.037766
step 4870,test_loss 41.026272
step 4880,test_loss 41.024136
step 4890,test_loss 41.020840
step 4900,test_loss 41.028576
step 4910,test_loss 41.013222
step 4920,test_loss 41.042625
step 4930,test_loss 41.035049
step 4940,test_loss 41.023026
step 4950,test_loss 41.023335
step 4960,test_loss 41.028851
step 4970,test_loss 41.024628
step 4980,test_loss 41.019810
step 4990,test_loss 41.026733
Tensorflow R2:  0.999997486127
Sklearn R2:  1.0
原文地址:https://www.cnblogs.com/gnool/p/8197029.html