[转]统计学习方法—chapter2—感知机算法实现

描述:李航《统计学习方法》第二章感知机算法实现(Python)

原始形式:

 1 # _*_ encoding:utf-8 _*_
 2 
 3 import numpy as np
 4 import matplotlib.pyplot as plt
 5 
 6 
 7 def createdata():
 8     """创建数据集和相应类标记"""
 9     samples = np.array([[3, 3], [4, 3], [1, 1]])
10     labels = np.array([1, 1, -1])
11     return samples, labels
12 
13 
14 
15 class Perceptron:
16     """感知机模型"""
17 
18     def __init__(self, x, y, a=1):
19         self.x = x
20         self.y = y
21         self.w = np.zeros((x.shape[1], 1))
22         self.b = 0
23         self.a = 1  #学习率
24         self.numsamples = self.x.shape[0]
25         self.numfeatures = self.x.shape[1]
26 
27     def sign(self, w, b, x):
28         """计算某样本点的f(x)"""
29         y = np.dot(x, w) + b
30         return int(y)
31 
32     def update(self, label_i, data_i):
33         """更新w和b"""
34         tmp = label_i * self.a * data_i
35         tmp = tmp.reshape(self.w.shape)
36         self.w = tmp + self.w
37         self.b = self.b + label_i * self.a
38 
39     def train(self):
40         """训练感知机模型"""
41         isfind = False
42         while not isfind:
43             count = 0
44             for i in range(self.numsamples):
45                 tmp = self.sign(self.w, self.b, self.x[i, :])
46                 if tmp * self.y[i] <= 0:
47                     print('误分类点为: ', self.x[i, :], '此时的w和b为: ', self.w, self.b)
48                     count += 1
49                     self.update(self.y[i], self.x[i, :])
50             if count == 0:
51                 print('最终训练得到的w和b为: ', self.w, self.b)
52                 isfind = True
53         return self.w, self.b
54 
55 
56 
57 class Picture:
58     """数据可视化"""
59 
60     def __init__(self, data, w, b):
61         """初始化参数"""
62         self.b = b
63         self.w = w
64         plt.figure(1)
65         plt.title('Perceptron Learning Algorithm', size= 14)
66         plt.xlabel('x0-axis', size=14)
67         plt.ylabel('x1-axis', size=14)
68 
69         xData = np.linspace(0, 5, 100)
70         yData = self.expression(xData)
71         plt.plot(xData, yData, color='r', label='sample data')
72 
73         plt.scatter(data[0][0], data[0][1], c='r', s=50)
74         plt.scatter(data[1][0], data[1][1], c='g', s=50)
75         plt.scatter(data[2][0], data[2][1], s=50, c='b', marker='x')
76  
77         plt.savefig('original.png', dpi=75)
78 
79     def expression(self, x):
80         """计算超平面上对应的纵坐标"""
81         y = (-self.b - self.w[0] * x) / self.w[1]
82         return y
83 
84     def show(self):
85         """画图"""
86         plt.show()
87 
88 
89 if __name__ == '__main__':
90     samples, labels = createdata()
91     myperceptron = Perceptron(samples, labels)
92     weights, bias = myperceptron.train()
93     picture = Picture(samples, weights, bias)
94     picture.show()

对偶形式:

  1 # _*_ encoding:utf-8 _*_
  2 
  3 import numpy as np
  4 import matplotlib.pyplot as plt
  5 
  6 def createdata():
  7     """创建数据集和相应的类标记"""
  8     samples = np.array([[3, 3], [4, 3], [1, 1]])
  9     labels = np.array([1, 1, -1])
 10     return samples, labels
 11 
 12 
 13 class Perceptron:
 14     """感知机模型"""
 15 
 16     def __init__(self, x, y, a=1):
 17         """初始化数据集,标记,学习率,参数等"""
 18         self.x = x
 19         self.y = y
 20         self.w = np.zeros((1, x.shape[0]))
 21         self.b = 0
 22         self.a = a
 23         self.numsamples = self.x.shape[0]
 24         self.numfeatures = self.x.shape[1]
 25         self.gmatrix = self.gMatrix()
 26 
 27     def gMatrix(self):
 28         """计算Gram矩阵"""
 29         gmatrix = np.zeros((self.numsamples, self.numsamples))
 30         for i in range(self.numsamples):
 31             for j in range(self.numsamples):
 32                 gmatrix[i][j] = np.dot(self.x[i, :], self.x[j, :])
 33         return gmatrix
 34 
 35     def sign(self, i):
 36         """计算f(x)"""
 37         y = np.dot(self.w*self.y, self.gmatrix[:, i]) + self.b
 38         return int(y)
 39 
 40     def update(self, i):
 41         """更新w和b"""
 42         self.w[:, i] = self.w[:, i] + self.a
 43         self.b = self.b + self.a * self.y[i]
 44 
 45     def cal_w(self):
 46         """计算最终的w"""
 47         w = np.dot(self.w*self.y, self.x)
 48         return w
 49 
 50     def train(self):
 51         """感知机模型训练"""
 52         isfind = False
 53         while not isfind:
 54             count = 0
 55             for i in range(self.numsamples):
 56                 if self.y[i]*self.sign(i) <= 0:
 57                     count += 1
 58                     print('误分类点为: ', self.x[i, :], '此时w和b分别为: ', self.cal_w(), ', ', self.b)
 59                     self.update(i)
 60             if count == 0:
 61                 print('最终的w和b为: ', self. cal_w(), ', ', self.b)
 62                 isfind = True
 63         weights = self.cal_w()
 64         return weights, self.b
 65 
 66 
 67 class Picture:
 68     """数据可视化"""
 69 
 70     def __init__(self, data, w, b):
 71         """"初始化画图参数"""
 72         self.w = w
 73         self.b = b
 74         plt.figure(1)
 75         plt.title('Perceptron Learning Algorithm of Duality', size=20)
 76         plt.xlabel('X0-axis', size=14)
 77         plt.ylabel('X1-axis', size=14)
 78 
 79         xdata = np.linspace(1, 5, 100)
 80         ydata = self.expression(xdata)
 81         plt.plot(xdata, ydata, c='r')
 82 
 83         plt.scatter(data[0][0], data[0][1], s=50)
 84         plt.scatter(data[1][0], data[1][1], s=50)
 85         plt.scatter(data[2][0], data[2][1], s=50, marker='x')
 86         plt.savefig('test.png', dpi=95)
 87 
 88     def expression(self, xdata):
 89         """计算超平面上的纵坐标"""
 90         y = (-self.b - self.w[:, 0]*xdata) / self.w[:, 1]
 91         return y
 92 
 93     def show(self):
 94         """画图"""
 95         plt.show()
 96 
 97 
 98 if __name__ == '__main__':
 99     samples, labels = createdata()
100     perceptron = Perceptron(x=samples, y=labels)
101     weights, b = perceptron.train()
102     picture = Picture(samples, weights, b)
103     picture.show()

参考自:https://blog.csdn.net/u010626937/article/details/72896144

原文地址:https://www.cnblogs.com/OoycyoO/p/9538055.html