机器学习实战源码决策树的构造

 1 from math import log
 2 import operator
 3 
 4 def createDataSet():
 5     dataSet = [[1,1,"yes"],
 6                [1,1,"yes"],
 7                [1,0,"no"],
 8                [0,1,"no"],
 9                [0,1,"no"]]
10     labels = ["no surfacing","flippers"]
11     return dataSet,labels
12 def calcShannonEnt(dataSet):
13     numEntries = len(dataSet)
14     labelCounts = {}
15     for featVec in dataSet:
16         currentLabel = featVec[-1]
17         if currentLabel not in labelCounts.keys():
18             labelCounts[currentLabel] = 0
19         labelCounts[currentLabel] += 1
20     shannonEnt = 0.0
21     for key in labelCounts:
22         prob = float(labelCounts[key]) / numEntries
23         shannonEnt -= prob * log(prob,2)
24     return shannonEnt
25 def splitdataSet(dataSet,axis,value):
26     retDataSet = []
27     for featVec in dataSet:
28         if featVec[axis] == value:
29             reducedFeatVec = featVec[:axis]
30             reducedFeatVec.extend(featVec[axis + 1:])
31             retDataSet.append(reducedFeatVec)
32     return retDataSet
33 def chooseBestFeatureToSplit(dataSet):
34     numFeatures = len(dataSet[0]) - 1
35     baseEntropy = calcShannonEnt(dataSet)
36     bestInfoGain = 0.0;bestFeature = -1
37     for i in range(numFeatures):
38         featList = [example[i] for example in dataSet]
39         uniqueVals = set(featList)
40         newEntropy = 0.0
41         for value in uniqueVals:
42             subDataSet = splitdataSet(dataSet,i,value)
43             prob = len(subDataSet) / float(len(dataSet))
44             newEntropy += prob * calcShannonEnt(subDataSet)
45         infoGain = baseEntropy - newEntropy
46         if (infoGain > bestInfoGain):
47             bestInfoGain = infoGain
48             bestFeature = i
49     return bestFeature
50 def majorityCnt(classList):
51     classCount = {}
52     for vote in classList:
53         if vote not in classCount.keys():
54             classCount[vote] = 0
55         classCount[vote] += 1
56     sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
57     return sortedClassCount[0][0]
58 def createTree(dataSet,labels):
59     classList = [example[-1] for example in dataSet]
60     if classList.count(classList[0]) == len(classList):
61         return classList[0]
62     if len(dataSet[0]) == 1:
63         return majorityCnt(classList)
64     bestFeat = chooseBestFeatureToSplit(dataSet)
65     bestFeatLabel = labels[bestFeat]
66     myTree = {bestFeatLabel:{}}
67     del(labels[bestFeat])
68     featValues = [example[bestFeat] for example in dataSet]
69     uniqueVals = set(featValues)
70     for value in uniqueVals:
71         subLabels = labels[:]
72         myTree[bestFeatLabel][value] = createTree(splitdataSet(dataSet,bestFeat,value),subLabels)
73     return myTree
74 if __name__ == "__main__":
75     myDat,labels = createDataSet()
76     #print calcShannonEnt(myDat)
77     #print splitdataSet(myDat,0,1)
78     #print chooseBestFeatureToSplit(myDat)
79     myTree = createTree(myDat,labels)
80     print myTree
原文地址:https://www.cnblogs.com/guochangyu/p/7718230.html