《统计学习方法》第六章,逻辑斯蒂回归2,多分类

▶ 使用逻辑地模型来进行多分类,采用了 one v.s. other 的方式训练了 k 个分类器(k 为类别数),然后选择独类分类概率最高的作为最终结果

● 代码,向下兼容二分类,计算量变大了

  1 import numpy as np
  2 import matplotlib.pyplot as plt
  3 from mpl_toolkits.mplot3d import Axes3D
  4 from mpl_toolkits.mplot3d.art3d import Poly3DCollection
  5 from matplotlib.patches import Rectangle
  6 
  7 dataSize = 10000
  8 trainRatio = 0.3
  9 ita = 0.05
 10 epsilon = 0.01
 11 defaultTurn = 200
 12 trans = 0.5
 13 
 14 def myColor(x):                                                                     # 颜色函数,用于对散点染色
 15     r = np.select([x < 1/2, x < 3/4, x <= 1, True],[0, 4 * x - 2, 1, 0])
 16     g = np.select([x < 1/4, x < 3/4, x <= 1, True],[4 * x, 1, 4 - 4 * x, 0])
 17     b = np.select([x < 1/4, x < 1/2, x <= 1, True],[1, 2 - 4 * x, 0, 0])
 18     return [r**2,g**2,b**2]
 19 
 20 def sigmoid(x):
 21     return 1.0 / (1 + np.exp(-x))
 22 
 23 def function(x, para):                                                              # 回归函数
 24     vector = np.array( [ np.exp( - np.sum(x * para[0][i]) - para[1][i]) for i in range(len(para[0])) ])
 25     return vector                                                                   #return vector / np.sum(vector)
 26 
 27 def judge(x, para):                                                                 # 分类函数
 28     return np.argmin(function(x, para))
 29 
 30 def dataSplit(x, y, part):    
 31     return x[:part], y[:part],x[part:],y[part:]
 32 
 33 def createData(dim, kind, count = dataSize):                                        # 创建数据集
 34     np.random.seed(103)       
 35     X = np.random.rand(count, dim)
 36     Y = ((3 - 2 * dim)*X[:,0] + 2 * np.sum(X[:,1:], 1) > 0.5).astype(int)           # 只考虑 {0,1} 的二分类         
 37     if kind == 2:                           
 38         Y = ((3 - 2 * dim) * X[:,0] + 2 * np.sum(X[:,1:], 1) > 0.5).astype(int)    
 39     else:
 40         randomVector = np.random.rand(dim)
 41         randomVector /= np.sum(randomVector)
 42         Y = (np.sum(X * randomVector,1) * kind).astype(int)
 43     print("dim = %d, kind = %d, dataSize = %d"%(dim, kind, count))
 44     kindCount = np.zeros(kind ,dtype = int)                                         # 各类别的占比
 45     for i in range(count):
 46         kindCount[Y[i]] += 1
 47     for i in range(kind):
 48         print("kind %d -> %4f"%(i, kindCount[i]/count))         
 49     return X, Y
 50 
 51 def gradientDescent(dataX, dataY, turn = defaultTurn):    
 52     count, dim = np.shape(dataX)
 53     kind = len(set(dataY))
 54     xE = np.concatenate((dataX, np.ones(count)[:,np.newaxis]), axis = 1)    
 55     w = np.ones([kind, dim + 1])    
 56     
 57     for t in range(turn):
 58         errorCount = 0
 59         for i in range(count):
 60             for j in range(kind):
 61                 error = int(j == dataY[i]) - sigmoid( np.sum(xE[i] * w[j]) )        # dataYi 类当成 1 号类,其他类当成 0 号类,error = yReal - yPredict
 62                 w[j] += ita * error * xE[i]
 63                 errorCount += int(abs(error) > 0.5)                            
 64         print(w)
 65         if errorCount < count * epsilon:
 66             break
 67     
 68     resultOnTrainData = [ judge(x, (w[:,:-1], w[:,-1])) for x in dataX]
 69     errorRatioOnTrainData = np.sum( ((np.array(resultOnTrainData) != dataY)).astype(int)**2 ) / count
 70     print("errorRatioOnTrainData = %4f
"%(errorRatioOnTrainData))
 71     return (w[:,:-1], w[:,-1])
 72 
 73 def test(dim, kind):                                                
 74     allX, allY = createData(dim, kind)
 75     trainX, trainY, testX, testY = dataSplit(allX, allY, int(dataSize * trainRatio))
 76     
 77     para = gradientDescent(trainX, trainY)                                          # 训练   
 78     
 79     myResult = [ judge(x, para) for x in testX]                                     
 80     errorRatio = np.sum( ((np.array(myResult) != testY)).astype(int)**2 ) / (dataSize * (1 - trainRatio))
 81     print("dim = %d, errorRatio = %4f
"%(dim, errorRatio))
 82     
 83     if dim >= 4:                                                                    # 4维以上不画图,只输出测试错误率
 84         return
 85     errorP = []                                                    
 86     classP = [ [] for i in range(kind) ]                           
 87     for i in range(len(testX)):
 88         if myResult[i] != testY[i]:
 89             if dim == 1:
 90                 errorP.append(np.array([testX[i], testY[i]]))
 91             else:
 92                 errorP.append(np.array(testX[i]))
 93         else:
 94             classP[myResult[i]].append(testX[i])
 95     errorP = np.array(errorP)
 96     classP = [ np.array(classP[i]) for i in range(kind) ]  
 97 
 98     fig = plt.figure(figsize=(10, 8))                  
 99     
100     if dim == 1:
101         plt.xlim(-0.1, 1.1)
102         plt.ylim(-0.1, 1.1)
103         for i in range(kind):
104             plt.scatter(classP[i], np.ones(len(classP[i])) * i / (kind-1), color = myColor(i / kind), s = 2, label = "class" + str(i) + "Data")
105         if len(errorP) != 0:
106             plt.scatter(errorP[:,0], errorP[:,1], color = myColor(1), s = 16, label = "errorData")                       
107         R = [ Rectangle((0,0),0,0, color = myColor(i / kind)) for i in range(kind) ] + [ Rectangle((0,0),0,0, color = myColor(1)) ]
108         plt.legend(R, [ "class" + str(i) for i in range(kind) ] + ["errorData"], loc=[0.84, 0.012], ncol=1, numpoints=1, framealpha = 1)
109 
110     if dim == 2:
111         plt.xlim(-0.1, 1.1)
112         plt.ylim(-0.1, 1.1)
113         for i in range(kind):
114             plt.scatter(classP[i][:,0], classP[i][:,1], color = myColor(i/kind), s = 8, label = "class" + str(i))            
115         if len(errorP) != 0:
116             plt.scatter(errorP[:,0], errorP[:,1], color = myColor(1), s = 16, label = "errorData")
117         R = [ Rectangle((0,0),0,0, color = myColor(i/kind)) for i in range(kind) ] + [ Rectangle((0,0),0,0, color = myColor(1)) ]
118         plt.legend(R, [ "class" + str(i) for i in range(kind) ] + ["errorData"], loc=[0.84, 0.012], ncol=1, numpoints=1, framealpha = 1)
119 
120     if dim == 3:
121         ax = Axes3D(fig)
122         ax.set_xlim3d(-0.1, 1.1)
123         ax.set_ylim3d(-0.1, 1.1)
124         ax.set_zlim3d(-0.1, 1.1)
125         ax.set_xlabel('X', fontdict={'size': 15, 'color': 'k'})
126         ax.set_ylabel('Y', fontdict={'size': 15, 'color': 'k'})
127         ax.set_zlabel('Z', fontdict={'size': 15, 'color': 'k'})
128         #v = [(0, 0, 0.25), (0, 0.25, 0), (0.5, 1, 0), (1, 1, 0.75), (1, 0.75, 1), (0.5, 0, 1)]
129         #f = [[0,1,2,3,4,5]]
130         #poly3d = [[v[i] for i in j] for j in f]
131         #ax.add_collection3d(Poly3DCollection(poly3d, edgecolor = 'k', facecolors = [0.5,0.25,0.0,0.5], linewidths=1))      
132         for i in range(kind):
133             ax.scatter(classP[i][:,0], classP[i][:,1],classP[i][:,2], color = myColor(i/kind), s = 8, label = "class" + str(i))
134         if len(errorP) != 0:
135             ax.scatter(errorP[:,0], errorP[:,1],errorP[:,2], color = myColor(1), s = 16, label = "errorData")
136         R = [ Rectangle((0,0),0,0, color = myColor(i/kind)) for i in range(kind) ] + [ Rectangle((0,0),0,0, color = myColor(1)) ]
137         plt.legend(R, [ "class" + str(i) for i in range(kind) ] + ["errorData"], loc=[0.85, 0.02], ncol=1, numpoints=1, framealpha = 1)
138 
139     fig.savefig("R:\dim" + str(dim) + "kind" + str(kind) + ".png")
140     plt.close()
141 
142 if __name__=='__main__':    
143     test(1, 2)                    
144     test(2, 2)            
145     test(3, 2)    
146     test(4, 2)    
147     test(5, 2)             
148     
149     test(1, 3)        
150     test(2, 3)            
151     test(2, 4)                
152     test(3, 3)        
153     test(3, 4)    
154     test(4, 4)            
155     test(5, 6) 

● 输出结果

dim = 1, kind = 2, dataSize = 10000
kind 0 -> 0.491000
kind 1 -> 0.509000
[[-6.71486872  3.25224943]
 [ 6.90100937 -3.34871095]]
[[-9.4024679   4.63658391]
 [ 9.51290414 -4.69311656]]
[[-11.18847685   5.54813528]
 [ 11.26986341  -5.58954968]]
[[-12.5673646    6.2486029 ]
 [ 12.63308806  -6.2819321 ]]
[[-13.70905382   6.82693293]
 [ 13.76484542  -6.85516245]]
[[-14.69334986   7.32458602]
 [ 14.7422152   -7.34927213]]
[[-15.56456028   7.76445884]
 [ 15.60828366  -7.78652142]]
[[-16.35002989   8.1606272 ]
 [ 16.38976341  -8.18065822]]
[[-17.06791905   8.52241138]
 [ 17.10445313  -8.54081596]]
errorRatioOnTrainData = 0.000000
dim = 1, errorRatio = 0.000857

dim = 2, kind = 2, dataSize = 10000
kind 0 -> 0.504000
kind 1 -> 0.496000
[[ 3.13914102 -6.57280315  1.55975592]
 [-3.0336306   6.74084399 -1.70421297]]
[[ 4.299232   -9.1666085   2.28471823]
 [-4.2828194   9.25105756 -2.33797473]]
...
[[  9.06732636 -18.39799714   4.5422244 ]
 [ -9.07689803  18.41643909  -4.54667908]]
[[  9.34598742 -18.9326538    4.67067573]
 [ -9.35504382  18.95010089  -4.67488586]]
errorRatioOnTrainData = 0.008333
dim = 2, errorRatio = 0.006286

dim = 3, kind = 2, dataSize = 10000
kind 0 -> 0.501800
kind 1 -> 0.498200
[[ 5.68320585 -3.41622995 -3.29229634  0.291568  ]
 [-5.56080847  3.5668759   3.4417622  -0.53161699]]
[[ 7.73185326 -4.86686995 -4.78136275  0.72516127]
 [-7.70673669  4.93636661  4.84909001 -0.8170064 ]]
...
[[ 20.80913419 -13.92078602 -14.0109533    3.30945084]
 [-20.81629159  13.92562185  14.01584334  -3.31075675]]
[[ 21.06744447 -14.09526106 -14.18737273   3.35653806]
 [-21.0744349   14.09998155  14.19214546  -3.35781137]]
errorRatioOnTrainData = 0.011333
dim = 3, errorRatio = 0.011429

dim = 4, kind = 2, dataSize = 10000
kind 0 -> 0.503100
kind 1 -> 0.496900
[[ 6.39482357 -2.22198617 -2.18901346 -2.1453085  -0.11805323]
 [-6.28885654  2.39260481  2.35279566  2.30920127 -0.2083632 ]]
[[ 8.75113897 -3.18484341 -3.20304568 -3.10944871  0.2151458 ]
 [-8.72994808  3.26770859  3.2833056   3.18813441 -0.35640864]]
...
[[ 23.98904034  -9.39850863  -9.62271368  -9.41844857   2.13802609]
 [-23.99935695   9.40265978   9.62690629   9.42265612  -2.13912499]]
[[ 24.27733048  -9.51450089  -9.73983514  -9.53600906   2.16868557]
 [-24.28741351   9.51855807   9.74393065   9.54012085  -2.16975711]]
errorRatioOnTrainData = 0.003000
dim = 4, errorRatio = 0.004000

dim = 5, kind = 2, dataSize = 10000
kind 0 -> 0.500000
kind 1 -> 0.500000
[[ 6.89705474 -1.42540518 -1.41940664 -1.48996056 -1.32395489 -0.50462332]
 [-6.75241758  1.58499742  1.59950655  1.65460537  1.49287702  0.09657823]]
[[ 9.40073388 -2.10821137 -2.14136888 -2.16572326 -2.01953875 -0.37016685]
 [-9.34976596  2.19289023  2.23746207  2.24904918  2.10980714  0.16772995]]
...
[[ 35.23149069  -9.72293551 -10.03130329  -9.42816884  -9.61618479  1.79576885]
 [-35.23648585   9.72438093  10.03278997   9.42956334   9.61760598 -1.7961253 ]]
[[ 35.39225483  -9.76945223 -10.0791467   -9.47305001  -9.66192163  1.80723642]
 [-35.3972061    9.77088478  10.08062007   9.4744323    9.66333016 -1.80758947]]
errorRatioOnTrainData = 0.003667
dim = 5, errorRatio = 0.005714

dim = 1, kind = 3, dataSize = 10000
kind 0 -> 0.321300
kind 1 -> 0.344100
kind 2 -> 0.334600
[[-6.66135149  1.88485463]
 [-0.02324545 -0.47039097]
 [ 6.1365322  -4.22324277]]
[[-9.42188847  2.89492636]
 [-0.04642244 -0.45771294]
 [ 8.62181647 -5.85759954]]
...
[[-5.10898588e+01  1.69452874e+01]
 [-4.93172361e-02 -4.56130326e-01]
 [ 4.72560247e+01 -3.16276995e+01]]
[[-5.11769199e+01  1.69744423e+01]
 [-4.93172361e-02 -4.56130326e-01]
 [ 4.73360203e+01 -3.16809702e+01]]
errorRatioOnTrainData = 0.014333
dim = 1, errorRatio = 0.014714

dim = 2, kind = 3, dataSize = 10000
kind 0 -> 0.227200
kind 1 -> 0.530300
kind 2 -> 0.242500
[[-5.00676085 -2.45044419  1.97399071]
 [ 0.22096798 -0.01350022  0.13430097]
 [ 4.20653754  2.23601111 -4.88443453]]
[[-7.28718161 -3.77134478  3.32839949]
 [ 0.18256302 -0.06175728  0.17951199]
 [ 6.15014501  3.44052674 -6.86263964]]
...
[[-42.98370113 -23.84773265  22.1558293 ]
 [  0.17536418  -0.06894655   0.18704851]
 [ 38.02649615  21.76101314 -39.93451349]]
[[-43.05806049 -23.8893055   22.19431298]
 [  0.17536418  -0.06894655   0.18704851]
 [ 38.09248731  21.79835741 -40.00352016]]
errorRatioOnTrainData = 0.007667
dim = 2, errorRatio = 0.015143

dim = 2, kind = 4, dataSize = 10000
kind 0 -> 0.126800
kind 1 -> 0.364700
kind 2 -> 0.372200
kind 3 -> 0.136300
[[-3.98654929 -2.59558945  0.63871501]
 [-2.98919471 -0.69136057  1.08710997]
 [ 3.07843511  0.64949316 -2.34760386]
 [ 2.97827188  1.97871338 -4.597351  ]]
[[-6.01515149 -3.94558636  1.76243622]
 [-3.5549713  -0.991616    1.49906178]
 [ 3.66384428  0.91725841 -2.82881226]
 [ 4.65164848  3.12899868 -6.4712698 ]]
...
[[-39.81632455 -22.97514101  15.60182524]
 [ -3.75722716  -1.11931026   1.6551282 ]
 [  3.93424393   1.06984973  -3.06739184]
 [ 32.71042376  19.01196407 -38.71261752]]
[[-39.88439704 -23.01307333  15.62864229]
 [ -3.75722716  -1.11931026   1.6551282 ]
 [  3.93424393   1.06984973  -3.06739184]
 [ 32.7681305   19.04391614 -38.77955666]]
errorRatioOnTrainData = 0.102000
dim = 2, errorRatio = 0.104429

dim = 3, kind = 3, dataSize = 10000
kind 0 -> 0.170600
kind 1 -> 0.651200
kind 2 -> 0.178200
[[-2.80037838 -1.85733668 -3.24123098  1.57259022]
 [-0.16142963  0.24714771 -0.19162875  0.6026281 ]
 [ 2.58163287  1.19894974  2.85973347 -4.87914069]]
[[-4.24260115 -2.85160253 -4.82350704  3.1797816 ]
 [-0.27409502  0.16345042 -0.29640289  0.77147737]
 [ 3.84381624  2.04689038  4.19587946 -7.05644654]]
...
[[-28.0745761  -19.73079916 -31.09831116  26.15657118]
 [ -0.30475613   0.13589837  -0.32360268   0.81945196]
 [ 23.68723658  16.08356041  26.22820206 -43.78315189]]
[[-28.1238723  -19.76537271 -31.15279228  26.20290789]
 [ -0.30475613   0.13589837  -0.32360268   0.81945196]
 [ 23.72821623  16.11223028  26.27366977 -43.85969012]]
errorRatioOnTrainData = 0.023333
dim = 3, errorRatio = 0.024286

dim = 3, kind = 4, dataSize = 10000
kind 0 -> 0.067900
kind 1 -> 0.429700
kind 2 -> 0.426400
kind 3 -> 0.076000
[[-2.00252838 -1.72767335 -2.24466305 -0.5296669 ]
 [-2.10120768 -1.22514198 -2.89788163  2.66060746]
 [ 1.95193431  1.34684809  2.75742256 -3.30044301]
 [ 1.38776371  0.68175952  1.46934361 -4.12383637]]
[[-3.07733838 -2.51564061 -3.4173077   0.63571256]
 [-2.85622162 -1.79855655 -3.71498848  3.76282969]
 [ 2.56857273  1.85537708  3.45570769 -4.31289162]
 [ 2.39085453  1.41231651  2.50636325 -5.890961  ]]
...
[[-23.5413656  -16.6567255  -25.3634883   16.160629  ]
 [ -3.59081061  -2.36125037  -4.49314391   4.81119705]
 [  3.19014428   2.37566951   4.13596461  -5.3188518 ]
 [ 18.54073398  13.1811811   19.81375414 -38.83427155]]
[[-23.58185249 -16.68489085 -25.40812839  16.18962668]
 [ -3.59081061  -2.36125037  -4.49314391   4.81119705]
 [  3.19014428   2.37566951   4.13596461  -5.3188518 ]
 [ 18.57306601  13.20422866  19.84937411 -38.90186847]]
errorRatioOnTrainData = 0.086000
dim = 3, errorRatio = 0.097429

dim = 4, kind = 4, dataSize = 10000
kind 0 -> 0.062600
kind 1 -> 0.428800
kind 2 -> 0.437600
kind 3 -> 0.071000
[[-0.71632006 -2.59391172 -1.87113508 -0.80861663 -0.56502796]
 [-0.11663555 -3.59928303 -1.15899786 -0.70171     2.44850899]
 [ 0.37202371  3.49748029  1.09665309  0.59203424 -3.0682648 ]
 [-0.12479086  1.7385499   1.16253629  0.31590619 -3.95319091]]
[[-0.9326788  -4.01402554 -2.80183734 -1.15387173  0.54926063]
 [-0.4657425  -4.55847041 -1.71715683 -1.14886744  3.620152  ]
 [ 0.6652545   4.29734591  1.55022533  0.91784882 -4.0629789 ]
 [ 0.20215012  2.88925454  2.05115192  0.76774619 -5.70430072]]
...
[[ -7.80293379 -32.00422274 -18.32788029  -9.67368173  16.64301744]
 [ -0.87738729  -5.56011488  -2.32996241  -1.63780016   4.88436562]
 [  1.00962029   5.08689554   2.05662977   1.29336137  -5.13507463]
 [  5.79210512  23.64914995  15.05904647   8.41324924 -39.64004545]]
[[ -7.81695093 -32.0618173  -18.35982709  -9.69111765  16.67358517]
 [ -0.87738729  -5.56011488  -2.32996241  -1.63780016   4.88436562]
 [  1.00962029   5.08689554   2.05662977   1.29336137  -5.13507463]
 [  5.80255199  23.69329045  15.08405843   8.42725548 -39.7101195 ]]
errorRatioOnTrainData = 0.118333
dim = 4, errorRatio = 0.111143

dim = 5, kind = 6, dataSize = 10000
kind 0 -> 0.005500
kind 1 -> 0.106600
kind 2 -> 0.374800
kind 3 -> 0.391000
kind 4 -> 0.118300
kind 5 -> 0.003800
[[-0.93151419 -1.15264489 -1.09286506 -1.07529811 -1.0044742  -2.86144712]
 [-0.42948115 -2.20080396 -1.46486443 -1.97822921 -0.81411715  1.04242008]
 [-0.05671104 -1.7655673  -1.16561097 -1.21070352 -0.62971664  1.73063998]
 [ 0.06718988  1.69013898  1.22077026  1.1620007   0.39884712 -2.95602739]
 [ 0.17581079  1.70913602  1.03362367  1.50272166  0.70368526 -4.39604687]
 [-0.7302654  -0.79562672 -0.71042569 -0.69781866 -0.78849863 -3.11233474]]
[[-1.03026903 -1.35431366 -1.30340961 -1.27357011 -1.10123572 -2.8378497 ]
 [-0.81491271 -3.34293023 -2.30346758 -3.01864634 -1.40046136  2.63998911]
 [-0.31184047 -2.24459839 -1.59020725 -1.63260382 -0.97931141  2.64284953]
 [ 0.26039654  2.10258419  1.58891349  1.51831021  0.64976697 -3.7885964 ]
 [ 0.53260501  2.68629705  1.78230865  2.36776549  1.28603549 -6.49030701]
 [-0.68845101 -0.70282747 -0.60406175 -0.5841734  -0.73161019 -3.46754385]]
...
[[ -4.64430182  -9.88889103 -10.32735806  -9.71878585  -5.75540649  6.03823936]
 [ -4.85415656 -14.00558042 -10.59404533 -12.56228907  -7.56849661 16.2053294 ]
 [ -0.52624566  -2.60362638  -1.91668611  -1.96394575  -1.259864    3.34867158]
 [  0.45294325   2.43597606   1.89645053   1.82899407   0.87841299 -4.5130601 ]
 [  4.66907859  14.56942264  10.74890854  12.60299926   7.68186565-33.74429057]
 [  2.95492003   5.2915096    6.37632582   5.97786528   3.9272511 -21.7415685 ]]
[[ -4.65279681  -9.90877286 -10.34768486  -9.7390159   -5.7680122   6.05417403]
 [ -4.85537082 -14.00906197 -10.59674944 -12.56542149  -7.57040832 16.20950731]
 [ -0.52624566  -2.60362638  -1.91668611  -1.96394575  -1.259864    3.34867158]
 [  0.45294325   2.43597606   1.89645053   1.82899407   0.87841299 -4.5130601 ]
 [  4.67283876  14.58138225  10.75770271  12.61334672   7.68803566-33.77157783]
 [  2.96057027   5.30377548   6.39212882   5.9914551    3.93604998-21.78509134]]
errorRatioOnTrainData = 0.097667
dim = 5, errorRatio = 0.106429

● 画图(一维)

● 画图(二维)

● 画图(三维)

原文地址:https://www.cnblogs.com/cuancuancuanhao/p/11254950.html