从感知机到支持向量机—学习笔记

step 1

用高斯分布生成两类点

 1 class Point3:
 2     def __init__(self):
 3         self.x = random.gauss(50, 10)
 4         self.y = random.gauss(50, 10)
 5 
 6         self.label = -1
 7         self.color = 'r'
 8 
 9 class Point4:
10     def __init__(self):
11         self.x = random.gauss(90, 10)
12         self.y = random.gauss(90, 10)
13 
14         self.label = 1
15         self.color = 'b'

step 2

画一条初始直线,先定义两个点(x1, 0)和(x2, 100),x1属于(0, 50),x2属于(50, 100),有了两个点之后,画出一条直线

 1 class Line:
 2     def __init__(self):
 3         self.x1 = random.randint(MIN, MAX//2)       # MAX=100 MIN=0 (0, 50)  随机生成一个整数
 4         self.x2 = random.randint(MAX//2, MAX)       # MAX=100 MIN=0 (50, 100)
 5         self.y1 = 0
 6         self.y2 = 100
 7 
 8         self.x = [self.x1, self.x2]
 9         self.y = [self.y1, self.y2]
10 
11         self.w1 = -(self.y2 - self.y1) / (self.x2 - self.x1)
12         self.w2 = 1
13         self.b = -(self.w1 * self.x1) + self.w2 * self.y1

step 3
判断误分类点
正确分类1:w1*x+w2*y+b>0且label=1
正确分类2:w1*x+w2*y+b<0且label=-1

1  def sign(self, point):
2         # print(self.w1 * point.x + self.w2 * point.y + self.b)
3         # print(point.label * (self.w1 * point.x + self.w2 * point.y + self.b))
4         return point.label * (self.w1 * point.x + self.w2 * point.y + self.b)

step 4
有了更新后的w1、w2和b之后,更新一条新的直线。
首先,需要先找到两个点,此时y1=0, y2=100不变,则我们只需找到对应的x1,x2即可。

1     def update(self):
2         self.x1 = -self.b / self.w1
3         self.x2 = (-self.b - self.w2 * self.y2) / self.w1
4         self.x = [self.x1, self.x2]
5         self.y = [self.y1, self.y2]

step 5
w1、w2和b的更新规则,参考博文支持向量机http://www.carefree0910.com/posts/d455305a/

 1 def preceptron_base_dis(all_points):
 2     line = Line()
 3     plt.plot(line.x, line.y, 'g--', linewidth=1)
 4     Flag = True
 5     while True:
 6         Flag = True
 7         for point in all_points:
 8             if line.sign(point) < 1:    # 只有误分类点才更新
 9                 line.w1 = (1 - step) * line.w1 + step * C * point.label * point.x
10                 line.w2 = (1 - step) * line.w2 + step * C * point.label * point.y
11                 line.b = line.b + step * C * point.label
12                 Flag = False
13         if Flag:
14             break
15         line.update()
16         #plt.plot(l.x, l.y, 'y--', linewidth=1)
17     plt.plot(line.x, line.y, '.-', linewidth=1)
18     plt.show()

全部代码汇总

  1 import matplotlib.pyplot as plt
  2 import numpy
  3 import random
  4 import sys
  5 
  6 MAX=100
  7 MIN=0
  8 POINT_NUM=20
  9 step=0.01
 10 C = 0.1
 11 
 12 class Point:
 13     def __init__(self):
 14         self.x = random.uniform(MIN, MAX)
 15         self.y = random.uniform(MIN, MAX)
 16 
 17         if self.x > self.y:
 18             self.label = 1
 19             self.color = 'b'
 20         else:
 21             self.label = -1
 22             self.color = 'r'
 23 class Point2:
 24     def __init__(self):
 25         self.x = random.randint(MIN, MAX)
 26         if self.x > MAX // 2:
 27             self.y = random.randint(0, MAX // 4)
 28         else:
 29             self.y = random.randint(MAX * 2 // 4, MAX)
 30 
 31         if self.x > self.y:
 32             self.label = 1
 33             self.color = 'b'
 34         else:
 35             self.label = -1
 36             self.color = 'r'
 37 
 38 class Point3:
 39     def __init__(self):
 40         self.x = random.gauss(50, 10)
 41         self.y = random.gauss(50, 10)
 42 
 43         self.label = -1
 44         self.color = 'r'
 45 
 46 class Point4:
 47     def __init__(self):
 48         self.x = random.gauss(90, 10)
 49         self.y = random.gauss(90, 10)
 50 
 51         self.label = 1
 52         self.color = 'b'
 53 class Line:
 54     def __init__(self):
 55         self.x1 = random.randint(MIN, MAX//2)       # MAX=100 MIN=0 (0, 50)  随机生成一个整数
 56         self.x2 = random.randint(MAX//2, MAX)       # MAX=100 MIN=0 (50, 100)
 57         self.y1 = 0
 58         self.y2 = 100
 59 
 60         self.x = [self.x1, self.x2]
 61         self.y = [self.y1, self.y2]
 62 
 63         self.w1 = -(self.y2 - self.y1) / (self.x2 - self.x1)
 64         self.w2 = 1
 65         self.b = -(self.w1 * self.x1) + self.w2 * self.y1
 66 
 67     def sign(self, point):
 68         # print(self.w1 * point.x + self.w2 * point.y + self.b)
 69         # print(point.label * (self.w1 * point.x + self.w2 * point.y + self.b))
 70         return point.label * (self.w1 * point.x + self.w2 * point.y + self.b)
 71 
 72     def update(self):
 73         self.x1 = -self.b / self.w1
 74         self.x2 = (-self.b - self.w2 * self.y2) / self.w1
 75         self.x = [self.x1, self.x2]
 76         self.y = [self.y1, self.y2]
 77 
 78 
 79 def initialPoint():
 80     plt.figure()
 81     all_point = []
 82     for idx in range(POINT_NUM):
 83         p = Point3()
 84         plt.plot(p.x, p.y, p.color + 'o', label="point")
 85         all_point.append(p)
 86 
 87     for idx in range(POINT_NUM):
 88         p = Point4()
 89         plt.plot(p.x, p.y, p.color + 'o', label="point")
 90         all_point.append(p)
 91     return all_point
 92 
 93 def preceptron_base_dis(all_points):
 94     line = Line()
 95     plt.plot(line.x, line.y, 'g--', linewidth=1)
 96     Flag = True
 97     while True:
 98         Flag = True
 99         for point in all_points:
100             if line.sign(point) < 1:    # 只有误分类点才更新
101                 line.w1 = (1 - step) * line.w1 + step * C * point.label * point.x
102                 line.w2 = (1 - step) * line.w2 + step * C * point.label * point.y
103                 line.b = line.b + step * C * point.label
104                 Flag = False
105         if Flag:
106             break
107         line.update()
108         #plt.plot(l.x, l.y, 'y--', linewidth=1)
109     plt.plot(line.x, line.y, '.-', linewidth=1)
110     plt.show()
111 
112 def preceptron(all_points):
113     line = Line()
114     plt.plot(line.x, line.y, 'g--', linewidth=1)
115     Flag = True
116     while True:
117         Flag = True
118         for point in all_points:
119             if line.sign(point) <= 0:
120                 line.w1 += step * point.label * point.x
121                 line.w2 += step * point.label * point.y
122                 line.b += step * point.label
123                 Flag = False
124         if Flag:
125             break
126         line.update()
127         #plt.plot(line.x, line.y, 'y--', linewidth=1)
128     plt.plot(line.x, line.y, 'o-', linewidth=1)
129     plt.show()
130 
131 all_points = initialPoint()
132 preceptron_base_dis(all_points) 
原文地址:https://www.cnblogs.com/Joyce-song94/p/7594806.html