k-近邻算法学习笔记

刚开始研读《machine learning in action》这本书,介绍的第一个算法就是k-近邻算法。

机器学习算法可分为监督学习和无监督学习,其中监督学习解决的是问题是分类和回归这两类问题,而无监督学习由于没有目标值和类别信息,将数据集合进行聚类。无监督学习尚不了解,以前课题用到了神经网络也并未对无监督学习方法涉及。

分类问题即给定一个实例数据,将其划分至合适的类别中;回归问题解决的是预测值,最简单的回归问题应该是物理实验课上用的一元二次回归了。

对于分类问题,k-近邻算法是一种简单有效的算法,其思路特别简单。假定存在一个已知样本集S,S中每个样本si对应有一个类别cj,其中类别集合C是有限的。那么给定一个待分类数据d,可由如下方法给出:

  1. 计算d与S中每个si之间的欧氏距离;
  2. 对所有的距离进行升序排列;
  3. 取距离最近的k个样本集s1~sk,其对应的类别为c1~ck;
  4. c1~ck中出现频率最高的类别就是d的类别。

算法实现起来也很简单,python版本如下:

#-*-coding: utf-8 -*-

from numpy import * 
import operator
import matplotlib
import matplotlib.pyplot as pyplot
from dircache import listdir

def create_data_set():
    group   = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
    labels  = ['A', 'A', 'B', 'B']
    return group, labels

def classify0(inX, data_set, labels, k):
    data_set_size = data_set.shape[0]   # number of samples
    diff_mat = tile(inX, (data_set_size, 1)) - data_set # tile: expand to 1 cols and data_set_size rows
    diff_mat2 = diff_mat ** 2
    distances2 = diff_mat2.sum(axis = 1)
    distances = distances2 ** 0.5
    sorted_dist_index = distances.argsort()
    class_count = {}
    for i in xrange(k):
        vote_ilabel = labels[sorted_dist_index[i]]
        class_count[vote_ilabel] = class_count.get(vote_ilabel, 0) + 1
    sorted_class_count = sorted(class_count.iteritems(), key = operator.itemgetter(1), reverse = True)
    return sorted_class_count[0][0]

def test_classify0():
    group, labels = create_data_set()
    res1 = classify0([0, 0], group, labels, 3)
    print '分类结果', res1

if __name__ == '__main__':
    test_classify0()

其中classify0就是k-近邻算法的实现。这里用到了numpy包。

测试结果

image

对于《machine learning in action》中给的几个例子我也重新做了一遍,其实大同小异,大部分工作都是如何将外部的数据导入:)

约会网站示例

#-*-coding: utf-8 –*- 
 
from numpy import * 
import operator
import matplotlib
import matplotlib.pyplot as pyplot
from dircache import listdir

def file2matrix(filename):
    fr = open(filename)
    read_lines = fr.readlines()
    sample_count = len(read_lines)
    print '%d lines in "%s"' % (sample_count, filename)
    sample_matrix = zeros((sample_count, 3))    # 3个特征
    label_vector = []
    isample = 0
    for line in read_lines:
        line = line.strip()
        one_sample_list = line.split('	')
        sample_matrix[isample, :] = [double(item) for item in one_sample_list[0 : 3]]    
        label_vector.append(int(one_sample_list[-1]))   # 每行最后一个值为类别
        isample += 1
    return sample_matrix, label_vector
 
def auto_normalize(data_set):
    min_val = data_set.min(0)
    max_val = data_set.max(0)
    ranges = max_val - min_val
    m = data_set.shape[0]
    norm_set = data_set - tile(min_val, (m, 1))
    norm_set = norm_set / tile(ranges, (m, 1))
    return norm_set, ranges, min_val
 
def test_dating_classify():
    dating_matrix, dating_label = file2matrix('datingTestSet2.txt')
#     fig = pyplot.figure()
#     ax = fig.add_subplot(111)
#     ax.scatter(dating_matrix[:, 0], dating_matrix[:, 1], 15.0 * array(dating_label), 15 * array(dating_label))
#     pyplot.show()
    norm_matrix, _, _ = auto_normalize(dating_matrix)
    verify_ratio = 0.1
    samples_count = norm_matrix.shape[0]
    verify_count = int(verify_ratio * samples_count)
    error_count = 0.0
    for i in xrange(verify_count):
        classify_result = classify0(norm_matrix[i, :], norm_matrix[verify_count : samples_count, :], 
                                    dating_label[verify_count : samples_count], 9)
        print '分类器识别为%d,真实类别为%d' % (classify_result, dating_label[i])
        if (classify_result != dating_label[i]):
            error_count += 1
    print '分类错误率为:%.2f' % (error_count / float(verify_count))

if __name__ == '__main__':
    test_dating_classify()

测试结果

image

手写识别实例

#-*-coding: utf-8 -*-

from numpy import * 
import operator
import matplotlib
import matplotlib.pyplot as pyplot
from dircache import listdir

def img2vector(filename):
    img_vector = zeros((1, 1024))
    fr = open(filename)
    for i in xrange(32):
        line_str = fr.readline()
        line_str = line_str.strip()
        for j in xrange(32):
            img_vector[0, 32 * i + j] = int(line_str[j])
    return img_vector
    
def test_handwritting_classify():
    handwritting_labels = []
    training_files = listdir('trainingDigits')
    samples_count = len(training_files)
    training_matrix = zeros((samples_count, 1024))
    # construct training matrix
    for i in xrange(samples_count):
        file_name_str = training_files[i]
        file_str = file_name_str.split('.')[0]
        label_str = int(file_str.split('_')[0])
        handwritting_labels.append(label_str)
        training_matrix[i, :] = img2vector('trainingDigits/%s' % file_name_str)
    # test
    test_files = listdir('testDigits')
    error_count = 0
    tests_count = len(test_files)
    for i in xrange(tests_count):
        file_name_str = test_files[i]
        file_str = file_name_str.split('.')[0]
        label_str = int(file_str.split('_')[0])
        vector_under_test = img2vector('testDigits/%s' % file_name_str)
        classify_result = classify0(vector_under_test, training_matrix, handwritting_labels, 3)
        print '手写识别为%d, 实际为%d' % (classify_result, label_str)
        if (classify_result != label_str):
            error_count += 1
    print '手写识别错误共计%d, 错误率%.2f' % (error_count, error_count / float(tests_count))

if __name__ == '__main__':
    test_handwritting_classify()

 

测试结果

image

总结

从原理上讲,k-近邻算法是精确有效的,也符合人的分类习惯,说白了,离谁最近就是谁。

但从使用的情况看,k-近邻算法运行速度非常慢(计算复杂度高),存储空间要求有很大(空间复杂度高)。

原文地址:https://www.cnblogs.com/robert-cai/p/3466131.html