机器学习(周志华老师) 决策树习题4.3

建立决策树

import numpy as np 
import pandas as pd 
import operator
from PlotTree import createPlot #调用绘图子程序

#计算数据集dataSe香农熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    shannonEnt=0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt-=prob*np.log2(prob)
    return shannonEnt
    
#求离散属性取某一值时数据集
def splitDataSet(dataSet,axis,value):
    subDataSet=[]
    for example in dataSet:
        reducedExample = []
        if example[axis] == value:
            reducedExample = example[:axis]
            reducedExample.extend(example[axis+1:])
            subDataSet.append(reducedExample)
    return subDataSet

#分割连续数据集,根据direction参数确定方向
def splitContinuousDataSet(dataSet,axis,value,direction):
    reDataSet = []
    for example in dataSet:
        subDataSet=[]
        if direction == 1:
            if example[axis] <= value:
                subDataSet = example[:axis]
                subDataSet.extend(example[axis+1:])
                reDataSet.append(subDataSet)
        else:
            if example[axis] > value:
                subDataSet = example[:axis]
                subDataSet.extend(example[axis+1:])
                reDataSet.append(subDataSet)
    return reDataSet
            


#求最佳划分属性
bestSplitDict={}
def chooseBestFeatureToSplit(dataSet,labels):
    baseEntorpy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    
    for i in range(len(labels)):
        featureList = [example[i] for example in dataSet]
        if type(featureList[0]).__name__ == 'float' or type(featureList[0]).__name__ == 'int':
            sortedFeatureList = sorted(featureList)
            splitList = []
            for j in range(len(sortedFeatureList)-1):
                splitList.append((sortedFeatureList[j]+sortedFeatureList[j+1])/2.0)
            bestCotinuEnt = 10000
            bestSplit = 0
            #print(splitList,end='
')
            for j in range(len(splitList)):
                featureEnt = 0.0
                #print(dataSet,end='
')
                dataSet1=splitContinuousDataSet(dataSet,i,splitList[j],1)
                #print(dataSet1,end='
')
                prob1 = len(dataSet1)/float(len(dataSet))
                featureEnt += prob1*calcShannonEnt(dataSet1)
                dataSet0=splitContinuousDataSet(dataSet,i,splitList[j],0)
                prob0 = len(dataSet0)/float(len(dataSet))
                featureEnt += prob0*calcShannonEnt(dataSet0)
                if featureEnt < bestCotinuEnt:
                    bestCotinuEnt = featureEnt
                    bestSplit = splitList[j]
            bestSplitDict[labels[i]] = bestSplit
            infoGain = baseEntorpy - bestCotinuEnt
        #对于连续属性求信息熵
        else:
            uniqueValue = set(featureList)
            featureEnt = 0.0
            for value in uniqueValue:
                subDataSet = splitDataSet(dataSet,i,value)
                prob= len(subDataSet)/float(len(dataSet))
                featureEnt += prob*calcShannonEnt(subDataSet)
            infoGain = baseEntorpy - featureEnt
        #
        if infoGain > bestInfoGain:
            bestFeature = i
            bestInfoGain = infoGain

        #若最佳划分属性为连续值,将其以划分点为界,进行二至处理
        if type(dataSet[0][bestFeature]).__name__ == 'float' or 
            type(dataSet[0][bestFeature]).__name__ == 'int':
            bestSplitValue = bestSplitDict[labels[bestFeature]]
            #labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)
            for i in range(len(dataSet)):
                if dataSet[i][bestFeature] <= bestSplitValue:
                    dataSet[i][bestFeature] = 1
                else:
                    dataSet[i][bestFeature] = 0
    return bestFeature

#当不能划分时,投票选出分类
def majorityCnt(classList):
    classCount={}
    for i in classList:
        if i not in classCount.keys():
            classCount[i]=0
        classCount[i]+=1
    return max(classCount,key=classCount.get)

#生成决策树
def createTree(dataSet,labels,data_full,labels_full):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeature = chooseBestFeatureToSplit(dataSet,labels)
    
    bestFeatureLabel = labels[bestFeature]
    myTree = {bestFeatureLabel:{}}
    
    featureVals = [example[bestFeature] for example in dataSet]
    uniqueVals = set(featureVals)
    if type(dataSet[0][bestFeature]).__name__ == 'str':
        currentLabel = labels_full.index(labels[bestFeature])
        featureVals_full = [example[currentLabel] for example in data_full]
        uniqueVals_full = set(featureVals_full)
    else:
        bestSplitValue = bestSplitDict[labels[bestFeature]]
        bestFeatureLabel = labels[bestFeature] + '<=' + str(bestSplitValue)
        myTree = {bestFeatureLabel:{}}
    del(labels[bestFeature])
    for value in uniqueVals:
        sublabels = labels[:]
        if type(dataSet[0][bestFeature]).__name__ == 'str':
            uniqueVals_full.remove(value)
        myTree[bestFeatureLabel][value] = createTree(splitDataSet 
            (dataSet,bestFeature,value), sublabels,data_full,labels_full)
    if type(dataSet[0][bestFeature]) == 'str':
        for val in uniqueVals_full:
            myTree[bestFeatureLabel][val] = majorityCnt(classList)
    return myTree

#main function
df = pd.read_csv('WaterMelon_4_3.txt',sep='	')
data_full = df.values[:,1:].tolist()
dataSet = data_full[:]
labels_full = df.columns[1:-1].tolist()
labels = labels_full[:]
myTree = createTree(dataSet,labels,data_full,labels_full)

createPlot(myTree)

参考:

ID3决策树

绘图子程序

python绘制决策树

效果

坚持
原文地址:https://www.cnblogs.com/liudianfengmang/p/12900717.html