关于回归树的创建和剪枝

  之前对于树剪枝一直感到很神奇;最近参考介绍手工写了一下剪枝代码,才算理解到底什么是剪枝。

  首先要明白回归树作为预测的模式(剪枝是针对回归树而言),其实是叶子节点进行预测;所以在使用回归树进行预测的时候,本质都是在通过每层(每个层代表一个属性)的值的大于和小于来作为分值,进行二叉树的遍历。最后预测值其实叶子节点中左值或者右值;注意这里的叶子结点也是一个结构体,对于非叶子节点而言,他的左右值是一棵树,但是对于叶子结点而言,左右值则是一个单一的数值。

  那么剪枝的原始就是找到叶子节点,如上图所示的特征C和特征E,然后取左右值的均值,合并(merge)为一个节点。比如低于特征C,就是取值5.5,作为B树的左节点,这样特征C这个节点就被减掉了。

  但是在剪枝的时候注意一定要和原始场景进行比较,未剪枝前的偏差和剪枝后偏差做一个比较,看看到底哪个更优秀;如果剪枝后MSE值反而更加大了,就不要价值了。这里偏差的计算值是sum(power(yHat- y, 2))来进行比较即可。

  下面的就是剪枝的python实现:

 1 # 所谓剪枝即使遍历到叶子结点,然后看一下作为预测值的叶子结点,合并左右节点(即取左右子树平均数)为一个点
 2 # 但是需要比较一下合并之后的偏差和合并前的偏差,如果合并之后的方差变小了,则剪枝(取合并值),反之则保持原状
 3 def prune(tree, testData):
 4     m, n = shape(testData)
 5     # 如果测试在分类(分割)过程,某一类数据为0
 6     if(m == 0): return getMean(tree)
 7     # 下面一大段其实都是在做这一件事情:深入都叶子结点
 8     # 1. 只要左右子树中有一颗不是叶子结点,那么就以当前节点的spIndex以及spValue为分割(分类)点,对testData进行二元分类
 9     # 获得的是二元分类的数据集left set和right set
10     if(isTree(tree["left"]) or isTree(tree["right"])):
11         lset, rset = binSplitDataset(testData, tree["spIndex"], tree["spValue"])
12     # 2. 继续处理不是叶子结点左右子树,对其进行递归prune(本质就是要深入到叶子结点为止)
13     if(isTree(tree["left"])): 
14         tree["left"] = prune(tree["left"], lset)
15     if(isTree(tree["right"])): 
16         tree["right"] = prune(tree["right"], rset)
17     
18     # 左右子树都是叶子节点了
19     if(not isTree(tree["left"]) and not isTree(tree["right"])):
20         # 那么就以当前叶子节点的spIndex以及spValue为分割(分类)点,对testData进行二元分类
21         lset, rset = binSplitDataset(testData, tree["spIndex"], tree["spValue"])
22         # 计算测试数据集和预测值(叶子结点)之间的方差,剪枝前的偏差
23         errorNotMerge = sum(power(lset[:, -1] - tree["left"], 2)) + sum(power(rset[:, -1] - tree["right"],2))
24         treeMean = (tree["left"] + tree["right"]) / 2.0
25         # 测试数据全集和树均值(预测值)之间的方差,剪枝后偏差
26         errorMerge = sum(power(testData[:, -1] - treeMean, 2))
27         # 看看谁的方差小,如果测试数据全集和树均值的方差小,返回的是树均值(叶子结点)的均值
28         if(errorMerge < errorNotMerge):
29             #print("errorMerge < errorNotMerge, treeMean is: ")
30             #print(treeMean)
31             return treeMean
32         # 如果叶子节点(预测值)的和真实值之间的方差比较小,则返回的树,不需要剪枝
33         else:
34             #print("errorMerge > errorNotMerge, [tree] is: ")
35             #print(tree)
36             return tree
37     # 说明叶子结点剪枝效果不明显,不需要剪枝
38     else:
39         return tree
40             

  那么再汇过来,如何构建一个回归树呢?

  构建回归树有几个条件,首先要有样本数据,叶子节点的计算方式(regLeaf),以及计算一个数据集的偏差的公式(regErr);

1 from numpy import mean
2 
3 # 数据集中y值的均值
4 def regLeaf(dataset):
5     return mean(dataset[:, -1])
6 
7 # 数据集中y值的方差和
8 def regErr(dataset):
9     return var(dataset[:, -1]) * shape(dataset)[0]

  有了这三者之后,就可以进行构建树了。构建树的时候,首先将会选择一个区分度最好的特征以及特征值,做样本分割,然后基于分割后的样本分别构建左子树和右子树,这是一个递归的过程,发生变化的样本,以及基于变化的样本产生的新的分割特征以及特征值,这个递归过程一直到样本数据不再可分为止,此时获得就是一个value,这个就是叶子结点的left/right值(非叶子节点left/right仍然是一棵树)。

 1 # 获取最好的分割信息,这里包括分割的特征以及特征值,然后对数据进行分割,在以分割后数据为基础继续进行继续创建树,一直到数据无法再分割
 2 # (feature)为none为止。
 3 def createTree(dataset, leafType=regLeaf, errorType=regErr, ops=(1, 4)):
 4     feature, value = chooseBsetSplit(dataset, leafType, errorType, ops)
 5     # left/right值直接就是数字(不再是树了)
 6     if(feature == None):
 7         return value
 8     retTree = {}
 9     retTree["spIndex"] = feature
10     retTree["spValue"] = value
11     # chooseBsetSplit其实应该一并把mat0和mat1返回,这样这里就不需要再计算了。
12     # 但是后来看了一下代码,返现该函数里面有的返回分支里面是没有mat0和mat1,所以这里再计算一下也是说的通的。
13     lset, rset = bindSplitDataset(dataset, feature, value)
14     retTree["left"] = createTree(lset, leafType, errorType, ops)
15     retTree["right"] = createTree(rset, leafType, errorType, ops)
16     
17     return retTree

  下面的代码就是获取最佳区分特征和特征值的实现

 1 # 寻找最好的区分特征;为了能够找到需要遍历所有的特征,以及所有的特征值,然后以该特征值做分割,获取两个矩阵
 2 # 计算两个矩阵的方差,不断选出方差小的作为bestIndex以及bestValue;最后将bestIndex对应的方差和原始矩阵
 3 # 方差进行比较,如果发现最佳区分特征对应的两分割矩阵方差明显小,并且两个矩阵的样本数量都不是十分小;
 4 # 则说明该特征是OK的
 5 
 6 # 返回的feature信息可能是None,代表该节点就是叶子结点中left/right的值,该函数
 7 def chooseBsetSplit(dataset, leafType=regLeaf, errorType=regErr, ops=(1, 4)):
 8     # 可容忍的偏差,在程序开始的时候,通过errorType来计算一下dataset的y值的方差和;然后用dataset的方差
 9     # 和最好区分度的方差和做减法,如果发现差值比这个tolS要小,那么说明这次指定特征是失败的;理想的差值是要大于tols
10     # 方差一定要比原始数据小到一定程度,这次属性指定才有意义。
11     tolS = ops[0]
12     tolN = ops[1] # 特征划分的样本的阈值,如果一分为二后,任何一个分类样本数少于这个阈值,这次划分就取消
13     # 为什么==1就要退出?
14     if(len(set(dataset[:, -1].T.tolist()[0])) == 1):
15         #print("len(set(dataset[:, -1].T.tolist()[0])) == 1, return None feature")
16         return None, leafType(dataset)
17     m, n = shape(dataset)
18     # 注意这里errorType其实就是参数,这里参数就是一个函数,默认是regErr
19     S =errorType(dataset)
20     # 初始化best*
21     bestS = inf
22     bestIndex = 0
23     bestValue = 0
24     iterate_num = n-1
25     #print("iterate_num: " + str(iterate_num))
26     # 遍历所有的特征(最后一列是结果,跳过)
27     for featureIndex in range(iterate_num):
28         #print("++++++++++++++++++++++ %d turns +++++++++++++++++++++++" % (featureIndex))
29         # 遍历该特征的所有特征值
30         for splitValue in set(dataset[:, featureIndex].A.flatten().tolist()):
31             # 在所有训练样本上面(dataset)对于该特征,大于该特征值,小于特征值分别做数据分割,获得两个矩阵
32             mat0, mat1 = bindSplitDataset(dataset, featureIndex, splitValue)
33             # 如果分割的特征矩阵任意一个的样本数<tolN,那么将会跳过该特征的处理,经过分割一定要达到一定的样本数才有意义
34             # 任意一个矩阵的样本数少说明该特征的区分度不高
35             if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
36                 #print("shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN; splitValue: %f, shape(mat0)[0]: %d, (shape(mat1)[0]: %d, tolN: %d" % (splitValue, shape(mat0)[0], shape(mat1)[0], tolN))
37                 continue
38             #print("*************** one ok **********************")
39             # 和leafType一样,都是参数类型为函数,计算方差和
40             newS = errorType(mat0) + errorType(mat1)
41             # 如果方差小于bestS,则用当前的方差以及特征信息做替换;到此可以看到目标就是找到区分度高并且方差小的特征,作为最好
42             # 区分特征
43             if(newS < bestS):
44                 bestIndex = featureIndex
45                 bestS = newS
46                 bestValue = splitValue
47     # 如果S值和bestS值之差小于tolS;参见tolS的注释。
48     if(S -bestS) < tolS:
49         #print("(S -bestS) < tolS, return feature NULL, S: %s, bestS: %s, tolS: %s" % ( str(S), str(bestS), str(tolS)))
50         return None, leafType(dataset)
51     mat0, mat1 = bindSplitDataset(dataset, bestIndex, bestValue)
52     # 这里的判断有意义吗?在循环体中其实已经做了这个判断,如果不满足也不会成为bestIndex,bestvalue;
53     if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
54         print("shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN")
55         return None, leafType(dataset)
56     
57     return bestIndex, bestValue

 

原文地址:https://www.cnblogs.com/xiashiwendao/p/10507098.html