【机器学习】决策树-02

心得体会:

  1。使用字典树和matplotlib绘图

  2.决策树可以用二进制方法‘wb+’存储到文本文件,用‘rb+’从文本文件提取

#3.2Matplotlib注解绘制树形图
#使用文本注解绘制树节点
import matplotlib
import matplotlib.pyplot as plt

decisionNode=dict(boxstyle="sawtooth",fc="0.8") #设置点
leafNode=dict(boxstyle="round4",fc="0.8")   #设置点
arrow_args=dict(arrowstyle="<-")    #设置箭头

#在图中添加这些点
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    #annotate是在plt的subplot上标记的函数
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,
                            textcoords='axes fraction',va="center",bbox=nodeType,arrowprops=arrow_args)
# def createPlot():
#     fig=plt.figure(1,facecolor='white')#图像编号1,背景色白色
#     fig.clf() # Clear figure清除所有轴,但是窗口打开,这样它可以被重复使用
#     createPlot.ax1=plt.subplot(111,frameon=False)# 1行1列,位置是1的子图——createPlot.ax1是plt子图的索引,可以通过ax1设计plt子图
#     plotNode('决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
#     plotNode('叶节点',(0.8,0.1),(0.0,0.0),leafNode)
#     plt.show()

#注意:使用matplotlib时不要用qq输入法
# createPlot()

#构造注解树

#获取叶节点的数目
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key])==dict:
            numLeafs+=getNumLeafs(secondDict[key])
        else:numLeafs+=1
    return numLeafs

#获得树的层数
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key])==dict:
            thisDepth=1+getTreeDepth(secondDict[key])
        else:thisDepth=1
        if thisDepth>maxDepth:maxDepth=thisDepth
    return maxDepth

#获得一颗树的数据
def retrieveTree():
    myDat, labels = createDataSet()
    mytree = createTree(myDat, labels)
    return mytree

# mytree=retrieveTree()
# print(getNumLeafs(mytree))
# print(getTreeDepth(mytree))

#plotTree函数
def plotMidTest(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=list(myTree.keys())[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW , plotTree.yOff)
    plotMidTest(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key])==dict:
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,decisionNode)
            plotMidTest((plotTree.xOff,plotTree.yOff), cntrPt, str(key))
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD

def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')#建立背景色白色图
    fig.clf()#清除框架
    axprops=dict(xticks=[],yticks=[])
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)#生成子图
    plotTree.totalW=float(getNumLeafs(inTree))#创建变量
    plotTree.totalD=float(getTreeDepth(inTree))#创建变量
    plotTree.xOff=-0.5/plotTree.totalW
    plotTree.yOff=1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

# createPlot(retrieveTree())

# 3-3测试和存储分类器
def classify(inputTree,featLabels,testVec):#testVec存储着对每个featLabel的答案
    firstStr=list(inputTree.keys())[0]
    secondDict=inputTree[firstStr]
    featIndex=featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex]==key:
            if type(secondDict[key])==dict:
                classLabel=classify(secondDict[key],featLabels,testVec)
            else:
                classLabel=secondDict[key]
    return classLabel

#使用算法:决策树的存储
def storeTree(inputTree,filename):
    import pickle
    fw=open(filename,'wb')  #二进制存
    pickle.dump(inputTree,fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr=open(filename,'rb')  ##二进制取
    return pickle.load(fr)

# myTree=retrieveTree()
# storeTree(myTree,"E:/Python/PycharmProjects/机器学习实战/Include/第03章_决策树/s.txt")
# print(grabTree("E:/Python/PycharmProjects/机器学习实战/Include/第03章_决策树/s.txt"))


#示例:使用决策树预测隐形眼镜的类型
fr=open("E:/Python/《机器学习实战》代码/Ch03/lenses.txt")
lenses=[]
for data in fr.readlines():
lenses.append(data.strip().split(' '))
lensesLabels=['age','prescript','astigmatic','tearRate']
lensesTree=createTree(lenses,lensesLabels)
createPlot(lensesTree)
 
原文地址:https://www.cnblogs.com/LPworld/p/13272339.html