决策树

ID3算法构建决策树

  1 # Author Qian Chenglong
  2 #label 特征的名字        dataSet  n个特征+目标
  3 
  4 
  5 from math import log
  6 import operator
  7 
  8 
  9 
 10 '''计算香农熵'''
 11 def calcShannonEnt(dataSet):
 12     numEntries=len(dataSet)
 13     labelCounts={}
 14     for featVec in dataSet:#将数据放入字典中,并计算字典中label出现的次数
 15         currentLabel=featVec[-1]
 16         if currentLabel not in labelCounts.keys():
 17             labelCounts[currentLabel]=0
 18         labelCounts[currentLabel]+=1
 19     shannonEnt=0.0
 20     for key in labelCounts:
 21         porb=float(labelCounts[key])/numEntries #每一个label出现的概率
 22         shannonEnt-=porb*log(porb,2)
 23     return shannonEnt
 24 '''熵越高数据越混乱'''
 25 
 26 '''按照指定特征划分数据集'''
 27 def splitDataSet(dataSet,axis,value):#待划分数据集,划分数据集的特征的下标,特征的值
 28     retDataSet=[]
 29     for featVec in dataSet:
 30         if featVec[axis]==value:
 31             reducedFeatVec=featVec[:axis]           #取出除划分依据用的特征以外的值
 32             reducedFeatVec.extend(featVec[axis+1:])
 33             retDataSet.append(reducedFeatVec)
 34     return retDataSet
 35 '''把指定特征的数据取出来'''
 36 
 37 '''遍历所有特征,选择熵最小的划分方式'''
 38 def chooseBestFeatureToSplit(dataSet):
 39     numFeatures=len(dataSet[0])-1   #获取属性个数,最后一列为label所以-1
 40     baseEntropy=calcShannonEnt(dataSet)  #数据集的原始熵
 41     bestInfoGain=0.0;bestFeature=-1
 42     for i in range(numFeatures):
 43         featList=[example[i] for example in dataSet] #遍历当前特征的所有属性生成一个列表 i为特征下标
 44         uniqueVals=set(featList)                        #创建一个集合,集合会删除重复的内容
 45         newEntropy=0.0
 46         for value in uniqueVals:            #遍历当前特征的所有值
 47             subDataSet=splitDataSet(dataSet,i,value)
 48             prob=len(subDataSet)/float(len(dataSet))
 49             newEntropy+=prob*calcShannonEnt(subDataSet)  #计算新的熵
 50         infoGain=baseEntropy-newEntropy        #baseEntropy-newEntropy求熵减,即信息增益
 51         if(infoGain>bestInfoGain):
 52             bestInfoGain=infoGain
 53             bestFeature=i
 54     return bestFeature
 55 
 56 '''出现最多的目标及其次数'''
 57 def majorityCnt(classList):
 58     classCount={}
 59     for vote in classList:
 60         if vote not in classCount.keys():
 61             classCount[vote]=0
 62         classCount[vote]+=1
 63     sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)#reverse = True 降序 , reverse = False 升序(默认)
 64     return sortedClassCount[0][0]
 65 
 66 def createTree(dataSet,labels):
 67     classList = [example[-1] for example in dataSet]        #目标的列表
 68     if classList.count(classList[0]) == len(classList):      #所有类别都相同,即只有1个目标
 69         return classList[0]                                   #停止继续划分
 70     if len(dataSet[0]) == 1:                                 # 用完了所有特征,即只剩最后一个“目标”的时候,遍历完所有实例返回出现次数最多的类别
 71         return majorityCnt(classList)
 72     bestFeat = chooseBestFeatureToSplit(dataSet)
 73     bestFeatLabel = labels[bestFeat]
 74     myTree = {bestFeatLabel:{}}                             #以标签作为关键字创建树
 75     del(labels[bestFeat])                                   #删除已使用的标签
 76     featValues = [example[bestFeat] for example in dataSet]
 77     uniqueVals = set(featValues)
 78     for value in uniqueVals:
 79         subLabels = labels[:]                                 #copy all of labels, so trees don't mess up existing labels
 80         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
 81     return myTree
 82 
 83 '''获取叶节点数目'''
 84 def getNumLeafs(myTree):
 85     numLeafs=0
 86     firstStr=myTree.keys()[0]
 87     secondDict=myTree[firstStr]
 88     for key in secondDict.keys():
 89         if type(secondDict[key]).__name__=='dict':
 90             numLeafs+=getNumLeafs(secondDict[key])
 91         else:   numLeafs+=1
 92     return numLeafs
 93 
 94 '''获取树的层数'''
 95 def getTreeDepth(myTree):
 96     maxDepth=0
 97     firstStr=myTree.key()[0]
 98     secondDict=myTree[firstStr]
 99     for key in secondDict.keys():
100         if type(secondDict[key]).__name__=='dict':
101             thisDepth=1+getTreeDepth(secondDict[key])
102         else:   thisDepth=1
103         if thisDepth>maxDepth:
104             maxDepth=thisDepth
105     return maxDepth
106 
107 '''使用决策树的分类函数'''
108 def classify(inputTree,featLabels,testVec):
109     firstStr = inputTree.keys()[0]    #字典中的第一个key
110     secondDict = inputTree[firstStr]        #第二层字典
111     featIndex = featLabels.index(firstStr)
112     key = testVec[featIndex]
113     valueOfFeat = secondDict[key]
114     if isinstance(valueOfFeat, dict):
115         classLabel = classify(valueOfFeat, featLabels, testVec)
116     else: classLabel = valueOfFeat
117     return classLabel
118 
119 '''存储树'''
120 def storeTree(inputTree,filename):
121     import pickle
122     fw = open(filename,'w')
123     pickle.dump(inputTree,fw)
124     fw.close()
125 
126 '''加载树'''
127 def grabTree(filename):
128     import pickle
129     fr = open(filename)
130     return pickle.load(fr)
原文地址:https://www.cnblogs.com/long5683/p/9340083.html