ML(4.2): R CART

     CART模型 :即Classification And Regression Trees。它和一般回归分析类似,是用来对变量进行解释和预测的工具,也是数据挖掘中的一种常用算法。如果因变量是连续数据,相对应的分析称为回归树,如果因变量是分类数据,则相应的分析称为分类树。决策树是一种倒立的树结构,它由内部节点、叶子节点和边组成。其中最上面的一个节点叫根节点。 构造一棵决策树需要一个训练集,一些例子组成,每个例子用一些属性(或特征)和一个类别标记来描述。构造决策树的目的是找出属性和类别间的关系,一旦这种关系找出,就能用它来预测将来未知类别的记录的类别。这种具有预测功能的系统叫决策树分类器。

    CART算法是一种二分递归分割技术,把当前样本划分为两个子样本,使得生成的每个非叶子结点都有两个分支,因此CART算法生成的决策树是结构简洁的二叉树。由于CART算法构成的是一个二叉树,它在每一步的决策时只能 是“是”或者“否”,即使一个feature有多个取值,也是把数据分为两部分。在CART算法中主要分为两个步骤

  • 将样本递归划分进行建树过程
  • 用验证数据进行剪枝

  在R包中,有如下的算法包可完成CART 分类计算,如下,分别以鸢尾花数据集为例进行验证

  •  rpart::rpart
  •  tree::tree

rpart::rpart


  • rpart包中有针对CART决策树算法提供的函数,比如rpart函数,以及用于剪枝的prune函数
  • rpart函数的基本形式:rpart(formula,data,subset,na.action=na.rpart,method.parms,control,...)
  • 安装所需R包
    install.packages("mboost")
    install.packages("rpart")
    install.packages("maptree")
  • 数据集划分训练集和测试,比例是2:1
    set.seed(1234)
    index <-sample(1:nrow(iris),100)
    iris.train <-iris[index,]
    iris.test <-iris[-index,]
  •  构建CART模型,查看模型结构,在结构中能看到很多有意思的内容

    library(rpart)
    model.CART <-rpart(Species~.,data=iris.train)
    str(model.CART)
  •  

  •  control:对树进行一些设置 

    1. minsplit是最小分支节点数,这里指大于等于20,那么该节点会继续分划下去,否则停止
    2. minbucket:树中叶节点包含的最小样本数 
    3. maxdepth:决策树最大深度
    4. xval:交叉验证的次数
    5. cp (complexity parameter),指某个点的复杂度,对每一步拆分,模型的拟合优度必须提高的程度。(即是每次分割对应的复杂度系数)
  • variable.importance:变量的重要性
    > model.CART$variable.importance
     Petal.Width Petal.Length Sepal.Length  Sepal.Width 
        60.58917     56.38914     39.79006     26.00328 
  • 预测数据: vector: 预测数值   class: 预测类别  prob: 预测类别的概率
    > p.rpart <- predict(model.CART, iris.test,type="class") 
    > table(p.rpart,iris.test$Species)
           
    p.rpart setosa versicolor virginica
          1     12          0         0
          2      0         21         3
          3      0          0        14
  •  可视化,需要rpart.plot包
#可视化决策树
#install.packages("rpart.plot")
library(rpart.plot)
rpart.plot(model.CART) 
  • 效果如下图:
  • CART剪枝:
    1. prune函数可以实现最小代价复杂度剪枝法,对于CART的结果,每个节点均输出一个对应的cp
    2. prune函数通过设置cp参数来对决策树进行修剪,cp为复杂度系数
    3. 可以用下面的办法选择具有最小xerror的cp的办法:
      model.CART.pru<-prune(model.CART, cp= model.CART$cptable[which.min(model.CART$cptable[,"xerror"]),"CP"]) 
      model.CART.pru$cp
  • CART剪枝后的模型进行预测 

    p.rpart1<-predict(model.CART.pru,iris.test,type="class")
    table(p.rpart1,iris.test$Species)
  •  

tree::tree


  • 数据集划分训练集和测试见上节
  • 构建模型,查看生成模型结构,如下图,错误率为:0.02667
    > #install.packages("tree")
    > library(tree)  
    > ir.tr <- tree(Species~., iris)  
    > summary(ir.tr)
    
    Classification tree:
    tree(formula = Species ~ ., data = iris)
    Variables actually used in tree construction:
    [1] "Petal.Length" "Petal.Width"  "Sepal.Length"
    Number of terminal nodes:  6 
    Residual mean deviance:  0.1253 = 18.05 / 144 
    Misclassification error rate: 0.02667 = 4 / 150
  • 查看生成决策树及图例
    plot(ir.tr)
    text(ir.tr,pretty = 0) 
  • 结果验证
    > tree_predict <- predict(ir.tr,iris.test,type="class")
    > table(tree_predict,iris.test$Species)
                
    tree_predict setosa versicolor virginica
      setosa         12          0         0
      versicolor      0         20         1
      virginica       0          1        16
  • 用误分类率来剪枝,做交叉验证,代码如下:
    > cv.carseats=cv.tree(ir.tr, FUN=prune.misclass)
    > str(cv.carseats)
    List of 4
     $ size  : int [1:5] 6 4 3 2 1
     $ dev   : num [1:5] 11 11 10 96 121
     $ k     : num [1:5] -Inf 0 2 44 50
     $ method: chr "misclass"
     - attr(*, "class")= chr [1:2] "prune" "tree.sequence"
  •  可视化模型

    par(mfrow=c(1, 2))
    plot(cv.carseats$size, cv.carseats$dev, type="b")
    plot(cv.carseats$k, cv.carseats$dev, type="b")
  •  图表示例

  • 随着树的节点越来越多(树越来越复杂),deviance逐渐减小,然后又开始增大
  • 随着对模型复杂程度的惩罚越来越重(k越来越大),deviance逐渐减小,然后又开始增大 (此图暂看不起来)
  • 从左边的图可以看出,当树的节点个数为 3 时,deviance达到最小,画出3个叶子节点的树
    #画出3个叶子节点的树
    par(new = TRUE) 
    prune.carseats <- prune.misclass(ir.tr, best=3)
    plot(prune.carseats)
    text(prune.carseats, pretty=0)
  • 图示例
  • 测试及结果
    > tree.pred  <- predict(prune.carseats, iris.test, type="class")
    > summary(tree.pred)
        setosa versicolor  virginica 
            12         24         14 
    > table(tree.pred,iris.test$Species)
                
    tree.pred    setosa versicolor virginica
      setosa         12          0         0
      versicolor      0         21         3
      virginica       0          0        14
  •  

原文地址:https://www.cnblogs.com/tgzhu/p/6697564.html