数据挖掘实践(13):基础理论(十三)KNN算法(二)KNN算法的实现

算法实现

def classify0(inX, dataSet, labels, k):
 dataSetSize = dataSet.shape[0]
 #距离度量 度量公式为欧⽒距离
 diffMat = np.tile(inX, (dataSetSize,1))-dataSet
 sqDiffMat = diffMat**2
 sqDistances = sqDiffMat.sum(axis=1)
 distances = sqDistances**0.5
 
 #将距离排序:从⼩到⼤
 sortedDistIndicies = distances.argsort()
 #选取前K个最短距离, 选取这K个中最多的分类类别
 classCount={}
 for i in range(k):
 voteIlabel = labels[sortedDistIndicies[i]]
 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1),reverse=True)
 return sortedClassCount[0][0]
def img2vector(filename):
 returnVect = np.zeros((1,1024))
 fr = open(filename)
 for i in range(32):
 lineStr = fr.readline()
 for j in range(32):
 returnVect[0,32*i+j] = int(lineStr[j])
 return returnVect
def handwritingClassTest():
 # 1. 导⼊训练数据
 hwLabels = []
 trainingFileList = os.listdir('2.KNN/trainingDigits') # load the training
set
 m = len(trainingFileList)
 trainingMat = np.zeros((m, 1024))
 # hwLabels存储0~9对应的index位置, trainingMat存放的每个位置对应的图⽚向量
 for i in range(m):
 fileNameStr = trainingFileList[i]
 fileStr = fileNameStr.split('.')[0] # take off .txt
 classNumStr = int(fileStr.split('_')[0])
 hwLabels.append(classNumStr)
 # 将 32*32的矩阵->1*1024的矩阵
 trainingMat[i, :] = img2vector('2.KNN/trainingDigits/%s' %
fileNameStr)
 # 2. 导⼊测试数据
 testFileList = os.listdir('2.KNN/testDigits') # iterate through the test
set
 errorCount = 0.0
 mTest = len(testFileList)
 for i in range(mTest):
 fileNameStr = testFileList[i]
 fileStr = fileNameStr.split('.')[0] # take off .txt
 classNumStr = int(fileStr.split('_')[0])
 vectorUnderTest = img2vector('2.KNN/testDigits/%s' % fileNameStr)
 classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
 print ("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
 if (classifierResult != classNumStr): errorCount += 1.0
 print ("
the total number of errors is: %d" % errorCount)
 print ("
the total error rate is: %f" % (errorCount / float(mTest)))
import os
import numpy as np
import operator
handwritingClassTest()
优点:⾼的准确率,对于异常值不敏感
缺点:空间。时间复杂度太⾼了
原文地址:https://www.cnblogs.com/qiu-hua/p/14322144.html