【机器学习实战】k-近邻算法2.2约会网站预测函数

《机器学习实战》学习

书中使用Python2进行代码演示,我这里将其转换为Python3,并做了一些注释。要学会使用断点调试,方便很多

下面的代码是书中2.2节使用k-近邻算法改进约会网站的配对效果的完整测试代码:

  1 from numpy import *
  2 import operator
  3 import matplotlib
  4 import matplotlib.pyplot as plt
  5 
  6 
  7 def createDataSet():
  8     group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
  9     labels = ['A', 'A', 'B', 'B']
 10     return group, labels
 11 
 12 
 13 def classify0(inX, dataSet, labels, k):
 14     '''
 15     k-近邻算法
 16     :param inX:用于分类的输入向量
 17     :param dataSet: 输入的训练样本集
 18     :param labels: 标签向量
 19     :param k: 用于选择最近邻居的数目
 20     :return: 返回k个邻居中距离最近且数量最多的类别作为预测类别
 21     '''
 22     dataSetSize = dataSet.shape[0]
 23     diffMat = tile(inX, (dataSetSize, 1)) - dataSet
 24     sqDiffMat = diffMat ** 2
 25     sqDistances = sqDiffMat.sum(axis=1)
 26     distances = sqDistances ** 0.5
 27     # 以上为计算输入向量与已有标签样本的欧式距离
 28     sortedDistIndicies = distances.argsort()  # argsort函数返回的是数组值从小到大的索引值,距离需要从小到大排序
 29     classCount = {}
 30     for i in range(k):
 31         voteIlabel = labels[sortedDistIndicies[i]]
 32         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
 33         # Python 字典(Dictionary) get() 函数返回指定键的值,如果值不在字典中返回默认值。
 34         # get(voteIlabel,0)表示当能查询到相匹配的字典时,就会显示相应key对应的value,如果不能的话,就会显示后面的这个参数。
 35     sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
 36     # 按照元祖中第2个值的大小降序排序
 37     # python2中的iteritems()方法需改为items()
 38     return sortedClassCount[0][0]
 39 
 40 
 41 def file2matrix(filename):
 42     # 将文本记录转换为NumPy的解析程序
 43     fr = open(filename)
 44     arrayOLines = fr.readlines()
 45     numberOfLines = len(arrayOLines)
 46     print(numberOfLines)
 47     returnMat = zeros((numberOfLines,3))  # 存放3种特征
 48     classLabelVector = []  # 存放标签
 49     index = 0
 50     for line in arrayOLines:
 51         line = line.strip()  # strip() 方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列。
 52         listFromLine = line.split('	')  # split() 通过指定分隔符对字符串进行切片,组成列表
 53         returnMat[index, :] = listFromLine[0:3]  # 将当前列表的前3个值赋予returnMat的当前行
 54         classLabelVector.append(int(listFromLine[-1]))  # 将标签添加到classLabelVector中
 55         index += 1
 56     return returnMat, classLabelVector
 57 
 58 
 59 def autoNum(dataSet):
 60     minVals = dataSet.min(0)  # A.min(0) : 返回A每一列最小值组成的一维数组;
 61     maxVals = dataSet.max(0)  # A.max(0):返回A每一列最大值组成的一维数组;
 62     # https://blog.csdn.net/qq_41800366/article/details/86313052
 63     ranges = maxVals - minVals
 64     normDataSet = zeros(shape(dataSet))
 65     m = dataSet.shape[0]
 66     normDataSet = dataSet - tile(minVals, (m,1))
 67     # tile将minVals的行数乘以m次重复,列数乘以1次重复,每一行都减掉minVals
 68     normDataSet = normDataSet/tile(ranges,(m,1))
 69     # 每一行都除以ranges以是实现数据归一化
 70     return normDataSet,ranges, minVals
 71 
 72 
 73 def datingClassTest():
 74     hoRatio = 0.10  # 测试集比重
 75     m = normMat.shape[0]
 76     numTestVecs = int(m*hoRatio)  # 测试集数量
 77     errorCount = 0.0
 78     for i in range(numTestVecs):
 79         classfierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
 80         print("the classifierResult came back with: %d,the real answer is: %d"%(classfierResult,datingLabels[i]))
 81         if(classfierResult != datingLabels[i]):errorCount += 1.0
 82         print("the total error rate is: %f"%(errorCount/float(numTestVecs)))
 83 
 84 def classifyPerson():
 85     resultList = ['not at all', 'in small doses', 'in large doses']
 86     percentTats = float(input("percentage of time spent playing video games?"))
 87     # 在 Python3.x 中 raw_input( ) 和 input( ) 进行了整合,去除了 raw_input( ),仅保留了 input( ) 函数,
 88     # 其接收任意任性输入,将所有输入默认为字符串处理,并返回字符串类型。
 89     ffMiles = float(input("frequent flier miles earned per year?"))
 90     iceCream = float(input("liters of ice cream consumed per year?"))
 91     inArr = array([ffMiles, percentTats, iceCream])  # 输入测试向量
 92     classifierResult = classify0((inArr-minVals)/ranges,normMat,datingLabels,3)  # 得到分类结果
 93     print("You will probably like this person:",resultList[classifierResult-1])
 94 
 95 
 96 if __name__ == "__main__":
 97     '''
 98     group,labels = createDataSet()
 99     result = classify0([0,0],group,labels,3)
100     print(result)
101     '''
102     datingDataMat, datingLabels = file2matrix("./datingTestSet2.txt")  # 数据转换
103     # print(datingDataMat)
104     # print(datingLabels)
105     fig = plt.figure()
106     ax = fig.add_subplot()
107     ax.scatter(datingDataMat[:,0],datingDataMat[:,1],15.0*array(datingLabels),15.0*array(datingLabels))
108     plt.show()
109     normMat, ranges, minVals = autoNum(datingDataMat)  # 输入数据归一化
110     # print(normMat)
111     # print(ranges)
112     # print(minVals)
113     # datingClassTest()
114     classifyPerson()  # 分类
运行结果:
1
percentage of time spent playing video games?10 2 frequent flier miles earned per year?4000 3 liters of ice cream consumed per year?1 4 You will probably like this person: in small doses
# python2中的iteritems()方法需改为items()
原文地址:https://www.cnblogs.com/DJames23/p/13053974.html