DecisionTree_python

#coding:utf-8
from numpy import *
from math import *
import operator
def file2matrix(filename):
    fr=open(filename)
    lines=fr.readlines()
    lenth=len(lines)
    rematrix=zeros((lenth,7))
    label=["seze","gendi","qiaoshen","wenli","qibu","chugan"]#西瓜特征集
    index=0
    for line in lines:
        line=line.strip()
        lin=line.split(" ")
        rematrix[index:]=lin
        index=index+1
    return rematrix,label
def singlesplit(data,axis,value):
    newlistt=[]
    for feat in data:
        if feat[axis]==value:
            newlist=list([feat[axis]])
            newlist.extend([feat[-1]])
            newlistt.append(newlist)
    return newlistt
def allsplit(data):
    alldata=[]
    baseEntry=calcshannon(data)
    ordermax=0.0
    bestfuture=-1
    lenth=len(data[0])
    for i in range(lenth-1):
        b=[example[i] for example in data]#取得特征的所有取值
        newEntry=0.0
        uniq=set(b)#特征的可能取值
        for j in uniq:
            cooldata=singlesplit(data,i,j)
            prob=len(cooldata)/float(len(data))
            newEntry+=prob*calcshannon(cooldata)
        info=baseEntry-newEntry
        if(info>ordermax):
            ordermax=info
            bestfuture=i
    return bestfuture
def calcshannon(data):
    simplenum=len(data)
    tempdict={}
    for line in data:
        tail=line[-1]
        if tail not in tempdict.keys():
            tempdict[tail]=0
        tempdict[tail]+=1
    shannonEntry=0.0
    for k in tempdict.keys():
        prob=tempdict[k]/float(simplenum)
        shannonEntry-=prob*log(prob,2)
    return shannonEntry
def selectbigger(label):
    calcdict={}
    for line in label:
        if line not in calcdict.keys():
            calcdict[line]=0
        calcdict+=1
    Getsorted=sorted(calcdict.iteritems(),key=operator.itemgetter(1),reverse=True)
    return Getsorted[0][0]
def createTree(data,label):
    labellist=[tt[-1] for tt in data]
    if labellist.count(labellist[0])==len(labellist):#所有样本均为同类
        return labellist[0]
    if len(data[0])==1:#特征集为空
        return selectbigger(labellist)
    bestfuture=allsplit(data)
    bestlabel=label[bestfuture]
    tree={bestlabel:{}}#用字典递归建立树
    del(label[bestfuture])
    bestval=[tt[bestfuture] for tt in data]
    uniq=set(bestval)
    for value in uniq:
        sublabel=label
        tree[bestlabel][value]=createTree(singlesplit(data,bestfuture,value),sublabel)
    return tree
def classifier(inputree,featurelabel,clsdata):
    firststr=inputree.keys()[0]
    secondict=inputree[firststr]
    classlabel=''
    featindex=featurelabel.index(firststr)
    for key in secondict.keys():
        if clsdata[featindex]==key:
            if type(secondict[key]).__name__=='dict':#当节点为字典是,继续递归,否则返回当前的节点值
                classlabel=classifier(secondict[key],featurelabel,clsdata)
            else:
                classlabel=secondict[key]
    return classlabel
dataset,label=file2matrix("out.txt")
mytree=createTree(dataset,label)
dataset,label=file2matrix("out.txt")#createTree中label元素已被全部删除,而classifier要用label
print classifier(mytree,label,[3,1,1,3,3,1])
原文地址:https://www.cnblogs.com/semen/p/6959056.html