机器学习(四):决策树

七、代码实现(python)

以下代码来自Peter Harrington《Machine Learing in Action》
本例代码实现算法5,生成最小二乘回归树。
代码如下(保存为CART.py):

# -- coding: utf-8 --
from numpy import *

def loadDataSet(fileName):
    # 获取训练集
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('	')
        fltLine = map(float,curLine)
        dataMat.append(fltLine)
    return dataMat

def binSplitDataSet(dataSet, feature, value):
    # 该函数接收3个参数,数据集、第几个特征(切分变量)、划分条件(切分点),根据选择的特征和划分条件将数据分成两个区域
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
    return mat0,mat1

def regLeaf(dataSet):
    # 获取数据集dataSet最后一列的平均值
    return mean(dataSet[:,-1])

def regErr(dataSet):
    # 根据式(4)计算数据集dataSet的平方误差
    # var用于计算方差
    return var(dataSet[:,-1]) * shape(dataSet)[0]

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    # 该函数用于寻找对于数据集dataSet的最好切分变量及切分点(即使得平方误差最小),ops用于控制函数停止机制
    tolS = ops[0]                              # 容许的误差下降值
    tolN = ops[1]                              # 切分的最小样本数
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)         # 若所有类别值相等,退出,此时无最好切分量
    m,n = shape(dataSet)
    S = errType(dataSet)                       # 存储数据集的平方误差
    bestS = inf
    bestIndex = 0                              # 初始化切分变量
    bestValue = 0                              # 初始化切分点
    for featIndex in range(n-1):
        # 循环特征数目,featIndex此时为切分变量
        for splitVal in set(dataSet[:,featIndex]):
            # 循环数据集行数,splitVal此时为切分点
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)      # 根据循环到的切分变量与切分点将数据分成两个区域
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue # 若切分后的样本点小于最小样本数,退出此次循环,继续下一个循环
            newS = errType(mat0) + errType(mat1)# 计算划分后数据集的平方误差
            if newS < bestS:                    # 若新的平方误差更小,更新各个数据
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if (S - bestS) < tolS:                      # 若误差下降值小于容许的误差下降值,退出,此时无最好切分量
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)              # 用新的切分变量与切分点将数据分成两个区域
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):                   # 若切分后的样本点小于最小样本数,退出,此时无最好切分量
        return None, leafType(dataSet)
    return bestIndex,bestValue

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    # 该函数根据接收的数据集创建决策树(子树)
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)             # 寻找对于数据集dataSet的最好切分变量及切分点
    if feat == None: return val                 # 若无最好的切分点,则返回数据集均值作为叶节点
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)                         # 用最好的切分变量与切分点将数据分成两个区域,作为左右子树
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

 

以上全部内容参考书籍如下:
李航《统计学习方法》

原文地址:https://www.cnblogs.com/pengfeiz/p/11392684.html