决策树-李航

1、所谓决策树模型,是通过重要性依次向下绘出的,越重要的越在上面。

决策树有节点和有向边组成。结点有两种类型,内部节点和叶结点,内部节点表示一个属性,叶子节点表示一个类。

决策树的数学意义在于 ,条件概率分布。举一个简单的例子:一个人去银行贷款,他的年龄、收入、房子、车子都能决定他是否能贷到款。我们需要判断到底哪个是最重要的,只需要判断出P(贷到款 | (年龄、收入、房车))哪个条件概率最大就可以了。

2、一些必须要懂得前提

A、信息增益,熵是表示变量不确定的因素,我们用一个公式来从概率上表示一个分类方式的熵。基于结果的熵我们称为经验熵,基于某个因素的熵我们称为条件熵。

B、信息增益比。在信息增益的基础上,除以关于特征值的熵,既得信息增益比。

3、决策树的形成。

基本思路是:输入数据集,输入特征集,输入阈值

输出决策树

if then 结构。自己看书吧。

4、一些算法

算法包括ID3算法,C4.5算法,CART算法。

5、决策树的剪枝

样本节点的熵越小越好,我们令新的参数由经验熵和C(t)共同决定,t是样本节点个数。

我们发现,前者是越小越好,但是前者的小会导致后者大。

6、CART明天再弄。

7、今天写的代码

from math import log
import numpy

def CalEnt(database):
#计算熵,这里的熵是经验熵还是条件熵实际上由database决定,所以后面专门针对不同的标签生成了不同的database LabelNumber
= len(database) LabelDic = {} for line in database: LineLabel = line[-1] if LineLabel not in LabelDic.keys(): LabelDic[LineLabel] = 0 LabelDic[LineLabel] += 1 #一定要记得初始化某些数据 ShanonEnt = 0.0
#对字典做遍历,可以用key做他的i for key in LabelDic: prob = LabelDic[key]/LabelNumber ShanonEnt -= prob*log(prob,2) return ShanonEnt def CreatDatabase(): database =[[1,0,0,0,0],[1,0,0,1,0],[1,1,0,1,1],[1,1,1,0,1],[1,0,0,0,0],[2,0,0,0,0],[2,0,0,1,0],[2,1,1,1,1],[2,0,1,2,1],[2,0,1,2,0], [3,0,1,2,1],[3,0,1,1,1],[3,1,0,1,1],[3,1,0,2,1],[3,0,0,0,0]] return database #axis表示维度,value表示区别
#这个用来由维度和值提取矩阵
def SplitDatabase(database,axis,value): retDatabase = [] for line in database: if line[axis] == value: retDatabase.append(line) return retDatabase #用来选择最好的熵
提取基础的东西
def ChooseBest(database): baseShanon = CalEnt(database) bestInformationGain = 0.0 ConShanon = 0.0 FeaNum = len(database[1])-1 BestChoose = -1 LabelNum = len(database[:][0]) for i in range(FeaNum):
#这句话很灵性,将列向量转化为行向量,并且提取出来 FeatList
= [temp[i] for temp in database] PureFeatList = set(FeatList) for value in PureFeatList: subdatabase = SplitDatabase(database,i,value) prob = len(subdatabase)/float(len(database)) ConShanon -= prob*CalEnt(subdatabase) InformationGain = baseShanon - ConShanon #这个判断很好 if InformationGain > bestInformationGain: bestInformationGain = InformationGain BestChoose = i def test(): mydatabase = CreatDatabase() ''' s = CalEnt(mydatabase) t=SplitDatabase(mydatabase,0,1) print(s) print(t) ''' ChooseBest(mydatabase) print(i) test()
原文地址:https://www.cnblogs.com/baochen/p/9043614.html