《机器学习实战》程序清单3-4 创建树的函数代码

有点乱,等我彻底想明白时再来整理清楚。

from math import log
import operator

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    #print("样本总数:" + str(numEntries))

    labelCounts = {} #记录每一类标签的数量

    #定义特征向量featVec
    for featVec in dataSet:
        
        currentLabel = featVec[-1] #最后一列是类别标签

        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0;

        labelCounts[currentLabel] += 1 #标签currentLabel出现的次数
        #print("当前labelCounts状态:" + str(labelCounts))

    shannonEnt = 0.0

    for key in labelCounts:
        
        prob = float(labelCounts[key]) / numEntries #每一个类别标签出现的概率

        #print(str(key) + "类别的概率:" + str(prob))
        #print(prob * log(prob, 2) )
        shannonEnt -= prob * log(prob, 2) 
        #print("熵值:" + str(shannonEnt))

    return shannonEnt

def createDataSet():
    dataSet = [
        [1, 1, 'yes'],
        [1, 1, 'yes'],
        [1, 0, 'no'],
        [0, 1, 'no'],
        [0, 1, 'no'],
        #以下随意添加,用于测试熵的变化,越混乱越冲突,熵越大
        # [1, 1, 'no'],
        # [1, 1, 'no'],
        # [1, 1, 'no'],
        # [1, 1, 'no'],
        #[1, 1, 'maybe'],
        # [1, 1, 'maybe1']
        # 用下面的8个比较极端的例子看得会更清楚。如果按照这个规则继续增加下去,熵会继续增大。
        # [1,1,'1'],
        # [1,1,'2'],
        # [1,1,'3'],
        # [1,1,'4'],
        # [1,1,'5'],
        # [1,1,'6'],
        # [1,1,'7'],
        # [1,1,'8'],

        # 这是另一个极端的例子,所有样本的类别是一样的,有序,不混乱,此时熵为0
        # [1,1,'2'],
        # [1,1,'1'],
        # [1,1,'1'],
        # [1,1,'1'],
        # [1,1,'1'],
        # [1,1,'1'],
        # [1,1,'1'],
        # [1,1,'1'],        
    ]

    #print("dataSet[0]:" + str(dataSet[0]))
    #print(dataSet)

    labels = ['no surfacing', 'flippers']

    return dataSet, labels

def testCalcShannonEnt():

    myDat, labels = createDataSet()
    #print(calcShannonEnt(myDat))

def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        #print("featVec:" + str(featVec))
        #print("featVec[axis]:" + str(featVec[axis]))
        if featVec[axis] == value:
            reduceFeatVec = featVec[:axis]
            #print(reduceFeatVec)
            reduceFeatVec.extend(featVec[axis + 1:])
            #print('reduceFeatVec:' + str(reduceFeatVec))
            retDataSet.append(reduceFeatVec)
    #print("retDataSet:" + str(retDataSet))
    return retDataSet

def testSplitDataSet():
    myDat,labels = createDataSet()
    #print(myDat)
    a = splitDataSet(myDat, 0, 0)
    #print(a)


def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1 #减掉类别列,剩2列
    #print("特征数量:" + str(numFeatures))

    baseEntropy = calcShannonEnt(dataSet)
    #print("基础熵:" + str(baseEntropy))

    bestInfoGain = 0.0;
    bestFeature  = -1

    #numFeatures==2
    for i in range(numFeatures):
        #print("i的值" + str(i))
        featList = [example[i] for example in dataSet];
        #print("featList:" + str(featList))

        #在列表中创建集合是Python语言得到列表中唯一元素值的最快方法
        #集合对象是一组无序排列的可哈希的值。集合化,收缩
        #[1, 0, 1, 1, 1, 1]创建集合后,变为{0,1}
        uniqueVals = set(featList) 
        #print("uniqueVals" + str(uniqueVals))

        newEntropy = 0.0
        #uniqueVals=={0,1}
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            #print("subDataSet:" + str(subDataSet))
            prob = len(subDataSet) / float(len(dataSet))
            
            #print("subDataSet:" + str(subDataSet))
            #print("subDataSet的长度:" + str(len(subDataSet)))
            newEntropy += prob * calcShannonEnt(subDataSet)
            #print("newEntropy:" + str(newEntropy))

        #信息增益,新序列熵越小,增益越大,最终目标是把最大的增益找出来
        infoGain = baseEntropy - newEntropy 
        #print("infoGain:" + str(infoGain))
        #print("bestInfoGain:" + str(bestInfoGain))


        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i

    #print("bestFeature:" + str(bestFeature))
    return bestFeature
            
    
def testChooseBestFeatureToSplit():
    myDat, labels = createDataSet()
    chooseBestFeatureToSplit(myDat)

'''
输入:类别列表     
输出:类别列表中多数的类,即多数表决
这个函数的作用是返回字典中出现次数最多的value对应的key,也就是输入list中出现最多的那个值
'''
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): 
            classCount[vote] = 0
        classCount[vote] += 1
 
     #key=operator.itemgetter(0)或key=operator.itemgetter(1),决定以字典的键排序还是以字典的值排序
     #0以键排序,1以值排序
     #reverse(是否反转)默认是False,reverse == true 则反转由大到小排列

    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)

    print(sortedClassCount)

    return sortedClassCount[0][0]
def testMajorityCnt():
     list1 = ['a','b','a','a','b','c','d','d','d','e','a','a','a','a','c','c','c','c','c','c','c','c']
    
     print(majorityCnt(list1))

global n
n=0

def createTree(dataSet, labels):
    
    global n
    print("=================createTree"+str(n)+" begin=============")
    n += 1
    print(n)

    classList = [example[-1] for example in dataSet]

    print("" + str(n) + "次classList:" + str(classList))
    print("此时列表中的第1个元素为" + str(classList[0]) + ",数量为:" + str(classList.count(classList[0])) + ",列表总长度为:" + str(len(classList)))
    
    print("列表中"+str(classList[0])+"的数量:",classList.count(classList[0]))
    print("列表的长度:", len(classList))

    if classList.count(classList[0])== len(classList):
        print("判断结果为:所有类别相同,停止本组划分")
    else:
        print("判断结果为:类别不相同")

     #列表中有n个元素,并且n个都一致,则停止递归 
    if classList.count(classList[0]) == len(classList):
         return classList[0]

    print("dataSet[0]:" + str(dataSet[0]))

    if len(dataSet[0]) == 1:
        print("启动多数表决")  #书中的示例样本集合没有触发
        return majorityCnt(classList)

    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    print("bestFeat:" +str(bestFeat))
    print("bestFeatLabel:" + str(bestFeatLabel))

    myTree = {bestFeatLabel:{}}
    print("当前树状态:" + str(myTree))

    print("当前标签集合:" + str(labels))
    print("准备删除" + labels[bestFeat])
    del(labels[bestFeat])
    print("已删除")
    print("删除元素后的标签集合:" + str(labels))

    featValues = [example[bestFeat] for example in dataSet]
    print("featValues:",featValues)

    uniqueVals = set(featValues)
    print("uniqueVals:", uniqueVals) #{0,1}

    k = 0
    print("********开始循环******")
    for value in uniqueVals:

        k += 1
        print("",k,"次循环")
        subLabels = labels[:]
        print("传入参数:")
        print("        --待划分的数据集:",dataSet)
        print("        --划分数据集的特征:", bestFeat)
        print("        --需要返回的符合特征值:", value)
        splited = splitDataSet(dataSet, bestFeat, value)
        print("splited:", str(splited))
        myTree[bestFeatLabel][value] = createTree(splited, subLabels)  #递归调用
    print("*******结束循环*****")

    print("=================createTree"+str(n)+" end=============")
    return myTree
     
def testCreateTree():
     
     myDat,labels = createDataSet();
     myTree = createTree(myDat, labels);
     print("============testCreateTree=============")
     print(myTree)

if __name__ == '__main__':
    #测试输出信息熵
    #testCalcShannonEnt()

    #测试拆分结果集
    #testSplitDataSet()

    #选择最好的特征值
    #testChooseBestFeatureToSplit()
 
    #testMajorityCnt()

    testCreateTree()


    
原文地址:https://www.cnblogs.com/Sabre/p/8415124.html