【决策树】— C4.5算法建立决策树JAVA练习

以下程序是我练习写的,不一定正确也没做存储优化。有问题请留言交流。转载请挂连接。

当前的属性为:age income student credit_rating

当前的数据集为(最后一列是TARGET_VALUE):

---------------------------------

youth     high   no   fair      no
youth     high   no   excellent   no
middle_aged   high   no   fair     yes
senior     low    yes  fair     yes
senior     low    yes  excellent   no
middle_aged   low    yes  excellent   yes
youth     medium  no   fair     no
youth     low     yes  fair     yes
senior     medium  yes    fair     yes
youth     medium  yes    excellent   yes
middle_aged   high   yes  fair        yes
senior     medium  no     excellent   no
---------------------------------

C4.5建立树类

package C45Test;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class DecisionTree {

    public TreeNode createDT(List<ArrayList<String>> data,List<String> attributeList){
        
        System.out.println("当前的DATA为");
        for(int i=0;i<data.size();i++){
            ArrayList<String> temp = data.get(i);
            for(int j=0;j<temp.size();j++){
                System.out.print(temp.get(j)+ " ");
            }
            System.out.println();
        }
        System.out.println("---------------------------------");
        System.out.println("当前的ATTR为");
        for(int i=0;i<attributeList.size();i++){
            System.out.print(attributeList.get(i)+ " ");
        }
        System.out.println();
        System.out.println("---------------------------------");
        TreeNode node = new TreeNode();
        String result = InfoGain.IsPure(InfoGain.getTarget(data));
        if(result != null){
            node.setNodeName("leafNode");
            node.setTargetFunValue(result);
            return node;
        }
        if(attributeList.size() == 0){
            node.setTargetFunValue(result);
            return node;
        }else{
            InfoGain gain = new InfoGain(data,attributeList);
            double maxGain = 0.0;
            int attrIndex = -1;
            for(int i=0;i<attributeList.size();i++){
                double tempGain = gain.getGainRatio(i);
                if(maxGain < tempGain){
                    maxGain = tempGain;
                    attrIndex = i;
                }
            }
            System.out.println("选择出的最大增益率属性为: " + attributeList.get(attrIndex));
            node.setAttributeValue(attributeList.get(attrIndex));
            List<ArrayList<String>> resultData = null;
            Map<String,Long> attrvalueMap = gain.getAttributeValue(attrIndex);
            for(Map.Entry<String, Long> entry : attrvalueMap.entrySet()){
                resultData = gain.getData4Value(entry.getKey(), attrIndex);
                TreeNode leafNode = null;
                System.out.println("当前为"+attributeList.get(attrIndex)+"的"+entry.getKey()+"分支。");
                if(resultData.size() == 0){
                    leafNode = new TreeNode();
                    leafNode.setNodeName(attributeList.get(attrIndex));
                    leafNode.setTargetFunValue(result);
                    leafNode.setAttributeValue(entry.getKey());
                }else{
                    for (int j = 0; j < resultData.size(); j++) {
                        resultData.get(j).remove(attrIndex);
                    }
                    ArrayList<String> resultAttr = new ArrayList<String>(attributeList);
                    resultAttr.remove(attrIndex);
                    leafNode = createDT(resultData,resultAttr);
                }
                node.getChildTreeNode().add(leafNode);
                node.getPathName().add(entry.getKey());
            }
        }
        return node;
    }
    
    class TreeNode{
        
        private String attributeValue;
        private List<TreeNode> childTreeNode;
        private List<String> pathName;
        private String targetFunValue;
        private String nodeName;
        
        public TreeNode(String nodeName){
            
            this.nodeName = nodeName;
            this.childTreeNode = new ArrayList<TreeNode>();
            this.pathName = new ArrayList<String>();
        }
        
        public TreeNode(){
            this.childTreeNode = new ArrayList<TreeNode>();
            this.pathName = new ArrayList<String>();
        }

        public String getAttributeValue() {
            return attributeValue;
        }

        public void setAttributeValue(String attributeValue) {
            this.attributeValue = attributeValue;
        }

        public List<TreeNode> getChildTreeNode() {
            return childTreeNode;
        }

        public void setChildTreeNode(List<TreeNode> childTreeNode) {
            this.childTreeNode = childTreeNode;
        }

        public String getTargetFunValue() {
            return targetFunValue;
        }

        public void setTargetFunValue(String targetFunValue) {
            this.targetFunValue = targetFunValue;
        }

        public String getNodeName() {
            return nodeName;
        }

        public void setNodeName(String nodeName) {
            this.nodeName = nodeName;
        }

        public List<String> getPathName() {
            return pathName;
        }

        public void setPathName(List<String> pathName) {
            this.pathName = pathName;
        }
        
    }
}

增益率计算类(取log的时候底用的是e,没用2

package C45Test;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

//C 4.5 实现
public class InfoGain {
    
    private List<ArrayList<String>> data;
    private List<String> attribute;
    
    public InfoGain(List<ArrayList<String>> data,List<String> attribute){
        
        this.data = new ArrayList<ArrayList<String>>();
        for(int i=0;i<data.size();i++){
            List<String> temp = data.get(i);
            ArrayList<String> t = new ArrayList<String>();
            for(int j=0;j<temp.size();j++){
                t.add(temp.get(j));
            }
            this.data.add(t);
        }
        
        this.attribute = new ArrayList<String>();
        for(int k=0;k<attribute.size();k++){
            this.attribute.add(attribute.get(k));
        }
        /*this.data = data;
        this.attribute = attribute;*/
    }
    
    //获得熵
    public double getEntropy(){
        
        Map<String,Long> targetValueMap = getTargetValue();
        Set<String> targetkey = targetValueMap.keySet();
        double entropy = 0.0;
        for(String key : targetkey){
            double p = MathUtils.div((double)targetValueMap.get(key), (double)data.size());
            entropy += (-1) * p * Math.log(p);
        }
        return entropy;
    }
    
    //获得InfoA
    public double getInfoAttribute(int attributeIndex){
        
        Map<String,Long> attributeValueMap = getAttributeValue(attributeIndex);
        double infoA = 0.0;
        for(Map.Entry<String, Long> entry : attributeValueMap.entrySet()){
            int size = data.size();
            double attributeP = MathUtils.div((double)entry.getValue() , (double) size);
            Map<String,Long> targetValueMap = getAttributeValueTargetValue(entry.getKey(),attributeIndex);
            long totalCount = 0L;
            for(Map.Entry<String, Long> entryValue :targetValueMap.entrySet()){
                totalCount += entryValue.getValue(); 
            }
            double valueSum = 0.0;
            for(Map.Entry<String, Long> entryTargetValue : targetValueMap.entrySet()){
                 double p = MathUtils.div((double)entryTargetValue.getValue(), (double)totalCount);
                 valueSum += Math.log(p) * p;
            }
            infoA += (-1) * attributeP * valueSum;
        }
        return infoA;
        
    }
    
    //得到属性值在决策空间的比例
    public Map<String,Long> getAttributeValueTargetValue(String attributeName,int attributeIndex){
        
        Map<String,Long> targetValueMap = new HashMap<String,Long>();
        Iterator<ArrayList<String>> iterator = data.iterator();
        while(iterator.hasNext()){
            List<String> tempList = iterator.next();
            if(attributeName.equalsIgnoreCase(tempList.get(attributeIndex))){
                int size = tempList.size();
                String key = tempList.get(size - 1);
                Long value = targetValueMap.get(key);
                targetValueMap.put(key, value != null ? ++value :1L);
            }
        }
        return targetValueMap;
    }
    
    //得到属性在决策空间上的数量
    public Map<String,Long> getAttributeValue(int attributeIndex){
        
        Map<String,Long> attributeValueMap = new HashMap<String,Long>();
        for(ArrayList<String> note : data){
            String key = note.get(attributeIndex);
            Long value = attributeValueMap.get(key);
            attributeValueMap.put(key, value != null ? ++value :1L);
        }
        return attributeValueMap;
        
    }
    
    public List<ArrayList<String>> getData4Value(String attrValue,int attrIndex){
        
        List<ArrayList<String>> resultData = new ArrayList<ArrayList<String>>();
        Iterator<ArrayList<String>> iterator = data.iterator();
        for(;iterator.hasNext();){
            ArrayList<String> templist = iterator.next();
            if(templist.get(attrIndex).equalsIgnoreCase(attrValue)){
                ArrayList<String> temp = (ArrayList<String>) templist.clone();
                resultData.add(temp);
            }
        }
        return resultData;
    }
    
    //获得增益率
    public double getGainRatio(int attributeIndex){
        return MathUtils.div(getGain(attributeIndex), getSplitInfo(attributeIndex));
    }
    
    //获得增益量
    public double getGain(int attributeIndex){
        return getEntropy() - getInfoAttribute(attributeIndex);
    }
    
    //得到惩罚因子
    public double getSplitInfo(int attributeIndex){
        
        Map<String,Long> attributeValueMap = getAttributeValue(attributeIndex);
        double splitA = 0.0;
        for(Map.Entry<String, Long> entry : attributeValueMap.entrySet()){
            int size = data.size();
            double attributeP = MathUtils.div((double)entry.getValue() , (double) size);
            splitA += attributeP * Math.log(attributeP) * (-1);
        }
        return splitA;
    }
    
    //得到目标函数在当前集合范围内的离散的值
    public Map<String,Long> getTargetValue(){
        
        Map<String,Long> targetValueMap = new HashMap<String,Long>();
        Iterator<ArrayList<String>> iterator = data.iterator();
        while(iterator.hasNext()){
            List<String> tempList = iterator.next();
            String key = tempList.get(tempList.size() - 1);
            Long value = targetValueMap.get(key);
            targetValueMap.put(key, value != null ? ++value : 1L);
        }
        return targetValueMap;
    }
    
    //获得TARGET值
    public static List<String> getTarget(List<ArrayList<String>> data){
        
        List<String> list = new ArrayList<String>();
        for(ArrayList<String> temp : data){
            int index = temp.size() -1;
            String value = temp.get(index);
            list.add(value);
        }
        return list;
    }
    
    //判断当前纯度是否100%
    public static String IsPure(List<String> list){
        
        Set<String> set = new HashSet<String>();
        for(String name :list){
            set.add(name);
        }
        if(set.size() > 1) return null;
        Iterator<String> iterator = set.iterator();
        return iterator.next();
    }
    

}

测试类,数据集读取以上的分别放到2个List中。

package C45Test;

import java.util.ArrayList;
import java.util.List;

import C45Test.DecisionTree.TreeNode;

public class MainC45 {

    private static final List<ArrayList<String>> dataList = new ArrayList<ArrayList<String>>();
    private static final List<String> attributeList = new ArrayList<String>();
    
    public static void main(String args[]){
        
        DecisionTree dt = new DecisionTree();
        TreeNode node = dt.createDT(configData(),configAttribute());
        System.out.println();
    }
}

大数运算工具类

package C45Test;
import java.math.BigDecimal;

public abstract class MathUtils {
    
    //默认余数长度
    private static final int DIV_SCALE = 10;
    
    //受限于DOUBLE长度
    public static double add(double value1,double value2){
        
        BigDecimal big1 = new BigDecimal(String.valueOf(value1));
        BigDecimal big2 = new BigDecimal(String.valueOf(value2));
        return big1.add(big2).doubleValue();
    }
    
    //大数加法
    public static double add(String value1,String value2){
        
        BigDecimal big1 = new BigDecimal(value1);
        BigDecimal big2 = new BigDecimal(value2);
        return big1.add(big2).doubleValue();
    }
    
    public static double div(double value1,double value2){
        
        BigDecimal big1 = new BigDecimal(String.valueOf(value1));
        BigDecimal big2 = new BigDecimal(String.valueOf(value2));
        return big1.divide(big2,DIV_SCALE,BigDecimal.ROUND_HALF_UP).doubleValue();
    }
    
    public static double mul(double value1,double value2){
        
        BigDecimal big1 = new BigDecimal(String.valueOf(value1));
        BigDecimal big2 = new BigDecimal(String.valueOf(value2));
        return big1.multiply(big2).doubleValue();
    }
    
    public static double sub(double value1,double value2){
        
        BigDecimal big1 = new BigDecimal(String.valueOf(value1));
        BigDecimal big2 = new BigDecimal(String.valueOf(value2));
        return big1.subtract(big2).doubleValue();
    }
    
    public static double returnMax(double value1, double value2) {
        
        BigDecimal big1 = new BigDecimal(value1);
        BigDecimal big2 = new BigDecimal(value2);
        return big1.max(big2).doubleValue();
    }
}
原文地址:https://www.cnblogs.com/lixusign/p/2548124.html