《统计学习方法》学习笔记(2) K近邻

  给定一个训练数据集,其中的实例类别已定。要确定新的实例类别时,根据训练数据集中k个最邻近的实例的类别,通过多数表决等方式进行预测。k近邻实际是利用训练数据集对特征向量空间进行划分,不具有显式的学习过程。

  特征空间中两个实例点的距离是两个实例点相似程度的反映,k近邻模型一般使用欧式距离,但也可以是其他距离,如Lp距离、Minkowski距离。在应用中,k值一般去一个比较小的值,通常采用交叉验证法来选取最优的k值。

  最简单的实现方法是线性扫描,这时要计算输入实例与每一个训练实例的距离。当训练集很大时,这种方法是不可行的。为了提高搜索效率,可以采用kd树方法。

  kd树是k-dimension tree的缩写,是对数据点在k维空间中划分的一种数据结构,主要应用于多维空间关键数据的搜索(如:范围搜索和最近邻搜索)。

  k-d树的构造:

 1 # -*- coding: gbk -*-
 2 #《统计学习方法》,例3.2,构造kd树
 3 
 4 class Tree:
 5     left = None
 6     right= None
 7     value = []
 8     def __init__(self, newValue, parent=None):
 9         self.value = newValue
10         self.parent=parent
11 
12 def selectNewNode(dataSet, k):
13     dataSet.sort(key=lambda x:x[k-1])
14     middle = int(len(dataSet)/2)
15     leftSet = dataSet[0:middle]
16     if (len(dataSet) > 2):
17         rightSet = dataSet[middle+1:]
18     else:
19         rightSet = []        
20     return leftSet,rightSet,dataSet[middle]
21 
22 def createTree(dataSet, treeRoot = None, j=0):
23     k = j%2 + 1
24     leftSet, rightSet, point = selectNewNode(dataSet, k)
25 
26     if(j==0):
27         treeRoot.value = point
28         myTree = treeRoot
29     else:
30         myTree = Tree(point, treeRoot)
31         
32     if(len(leftSet) == 1):
33         myTree.left = Tree(leftSet[0], myTree)
34     if(len(rightSet) == 1):
35         myTree.right = Tree(rightSet[0], myTree)
36 
37     j=j+1
38     parent = myTree
39     if(len(leftSet) > 1):
40         myTree.left = createTree(leftSet, parent, j)
41     if(len(rightSet) > 1):
42         myTree.right = createTree(rightSet, parent, j)
43     return myTree
44 
45 def printTree(myTree):  #preorder traversal
46     print(myTree.value)
47     if(myTree.left != None):
48         printTree(myTree.left)
49     if(myTree.right != None):
50         printTree(myTree.right)
51 
52 def test():
53     dataSet = [[2,3],
54                [5,4],
55                [9,6],
56                [4,7],
57                [8,1],
58                [7,2]]
59     root = Tree([])
60     createTree(dataSet,root)
61     printTree(root)

  利用kd树的最近邻搜索:

                                                  

  

  以目标点(5,3)为例:

  (1)从根节点(7, 2)开始,5<7,移动到左子树节点(5, 4),3<4,移动到左子树节点(2, 3)。节点(2, 3)为当前最近点,到目标点(5, 3)的距离为3。

  (2) 回溯查询:回溯到(2, 3)的父节点(5, 4)。(5, 4)与目标点(5,3)的距离为1<3,当前最近点更新为(5, 4)。判断在该父节点的其他子节点空间中是否有距离目标点更近的点。以(5, 3)为圆心,1为半径画圆,该圆与y=4的超平面没有交割,不需要进入(5, 4)的右子树进行查询。

  (3)回溯到(5, 4)的父节点(7, 2)。(7, 2)与目标点(5,3)的距离为2.236>1。当前最近点仍然为(5, 4)。判断在根节点(7, 2)的其他子节点空间中是否有距离目标点更近的数据点。以(5, 3)为圆心,1为半径画圆。该圆和x = 7的超平面没有交割,所以不用进入(7, 2)的右子树搜索。

  (4)至此,搜索路径中的节点已全部回溯完,结束搜索。返回最近邻点(5, 4),最近距离为1。

原文地址:https://www.cnblogs.com/thelongroad/p/3137955.html