ID3决策树---Java

1)熵与信息增益:

2)以下是实现代码:

  1 //import java.awt.color.ICC_ColorSpace;
  2 import java.io.*;
  3 import java.util.ArrayList;
  4 import java.util.Collections;
  5 import java.util.Comparator;
  6 import java.util.HashMap;
  7 import java.util.HashSet;
  8 import java.util.Iterator;
  9 //import java.util.Iterator;
 10 import java.util.List;
 11 //import java.util.Locale.Category;
 12 import java.util.Map;
 13 import java.util.Map.Entry;
 14 import java.util.Set;
 15 class decisionTree{
 16 
 17     private static Map<String, Map<String, Integer>> featureValuesAndCounts=new HashMap<String, Map<String,Integer>>();
 18     private static ArrayList<String[]> dataSet=new ArrayList<String[]>();
 19     private static ArrayList<String> features=new ArrayList<String>();
 20     private static Set<String> category=new HashSet<String>();
 21     //public static DecisionNode root=new DecisionNode();
 22     //private static  ArrayList<ArrayList<String>> featureValue=new ArrayList<ArrayList<String>>();
 23     public static void GetDataSet()
 24     {
 25         File file = new File("C:\Users\hfz\workspace\decisionTree\src\loan.txt");
 26         try{
 27             BufferedReader br = new BufferedReader(new FileReader(file));//
 28             String s = null;
 29             s=br.readLine();//读取第一行的内容,即是各特征的名称
 30             String[] tempFeatures=s.split(",");
 31             for (String string1 : tempFeatures) {
 32                 features.add(string1);
 33             }
 34             s=br.readLine();        //开始读取特征值
 35             String[] tt=null;
 36             int flag=s.length();
 37             while(flag!=0){//英文文档读到结尾得到的值是null,而中文文档读到结尾得到的值却是""
 38                 tt=s.split(",");
 39                 dataSet.add(tt);    //将特征值存入
 40                 category.add(tt[tt.length-1]);//category为集合类型,用于存储类型值
 41 
 42                 s=br.readLine();
 43                 if (s!=null) {
 44                     flag = s.length();
 45                 }
 46                 else{
 47                     flag=0;
 48                 }
 49 
 50             }
 51 
 52             for (int j = 0; j < features.size(); j++) {//逻辑上模拟列优先的方式读取二维数组形式的数据集,就是首先读取一个特征名称,再遍历数据集
 53                 Map<String, Integer> ttt=new HashMap<String, Integer>();//将某特征的各个特征值存入Map中,然后再度第二个特征,再遍历数据集。。。
 54                 for (int i = 0; i < dataSet.size(); i++) {
 55                     String currentFeatureValue=dataSet.get(i)[j];
 56                     if(!(ttt.containsKey(currentFeatureValue)))
 57                         ttt.put(currentFeatureValue, 1);
 58                     else {
 59                         ttt.replace(currentFeatureValue, ttt.get(currentFeatureValue)+1);
 60                     }
 61 
 62                 }
 63                 featureValuesAndCounts.put(features.get(j), ttt);//嵌套形式的Map,第一层的key是特征名称,value是一个新的Map
 64                 // 新Map中key是特征的各个值,value是特征值在数据集中出现的次数。
 65 
 66             }
 67 
 68             br.close();
 69         }
 70 
 71         catch(Exception e){
 72             e.printStackTrace();
 73         }
 74     }
 75     public static DecisionNode treeGrowth(ArrayList<String[]> dataset,String currentFeatureName,
 76                                           String currentFeatureValue,ArrayList<String> current_features,
 77                                           Map<String,Map<String,Integer>> current_featureValuesCounts){
 78         /*
 79         dataset:用于split方法,从dataset数据集中去除掉具有某个特征值的对应的若干实例,生成一个新的新的数据集
 80         currentFeatureName:当前的特征名称,用于叶子节点,赋值给叶子节点的featureName字段
 81         currentFeatureValue:当前特征名称对应的特征值,也用于叶子节点,赋值给featureValue字段
 82         current_features:当前数据集中包含的所有特征名称,用于findBestAttribute方法,找到信息增益最大的的属性
 83         current_featureValuesCounts:当前数据集中所有特征的各个特征值出现的次数,用于findBestAttribute方法,用于计算条件熵,进而计算信息增益。
 84          */
 85         ArrayList<String> classList=new ArrayList<String>();
 86         int flag=0;
 87         for (String[] string : dataset) {
 88             //测试数据集中类型值的数量,flag表示数据集中的类型数量
 89             if (classList.contains(string[string.length-1])) {
 90 
 91             }
 92             else {
 93                 classList.add(string[string.length-1]);
 94                 flag++;//如果flag>1表示数据集
 95             }
 96 
 97         }
 98         if(1==flag){//如果只有一个类结果,则返回此叶子节点
 99             DecisionNode d=new DecisionNode();
100             d.init(currentFeatureName,classList.get(0),currentFeatureValue);
101             return d;
102         }
103         if (dataset.get(0).length==1) {//如果数据集已经没有属性了只剩下类结果,则返回占比最大的类结果,也是叶子节点
104             DecisionNode d=new DecisionNode();
105             d.init(currentFeatureName,classify(classList),currentFeatureValue);
106             return d;
107         }
108 
109         /*
110         DecisionNode是一个自定义的递归型的数据类型,类中一个children字段是DecisionNode类型的数组,
111         正好用这种类型来存储递归算法产生的结果(决策树),也就是用这种结构来存储一棵树。
112         */
113         //程序运行到这里就说明此节点不是叶子节点
114         DecisionNode root2=new DecisionNode();//那么root2就是一个决策属性节点(非叶子节点)了,非叶子节点就有孩子节点,下面就是计算它的孩子节点
115 
116         int bestFeatureIndex=findBestAttribute(dataset,current_features,current_featureValuesCounts);
117         String bestFeatureLabel=current_features.get(bestFeatureIndex);
118         //root.testCondition=bestFeatureLabel;
119         ArrayList<String> feature_values=new ArrayList<String>();
120         for (Entry<String, Integer> featureEntry : current_featureValuesCounts.get(bestFeatureLabel).entrySet()) {
121             feature_values.add(featureEntry.getKey());
122 
123         }
124         //给非叶子节点,也就是特征节点仅仅赋特征名称值
125         root2.init(currentFeatureName,currentFeatureValue);//java中不能是使用像C++中默认参数的函数,只能通过重载来实现同样的目的。
126         for (String values : feature_values) {
127             //DecisionNode tempRoot=new DecisionNode();
128 
129             ArrayList<String[]> subDataSet = splitDataSet(dataset, bestFeatureIndex, values);//生成子数据集,即去除了包含values的实例,
130                                                                     // 接下来就是计算对此数据集利用决策树进行决策,又需要调用treeGrow方法
131                                                                     //所以,接下来需要得到对应这个子数据集的特征名称以及每个特征值在数据集中出现的次数
132             ArrayList<String> currentAttibutes=new ArrayList<>();
133             Iterator item1=current_features.iterator();
134             while(item1.hasNext()){
135                 currentAttibutes.add(item1.next().toString());//这个子数据集的特征名称
136             }
137 
138             Map<String,Map<String,Integer>> currentAttributeValuesCounts=new HashMap<String, Map<String, Integer>>();
139             //ArrayList<String[]> subDataSet = splitDataSet(dataset, bestFeatureIndex, values);
140             currentAttibutes.remove(bestFeatureLabel);
141             for (int j = 0; j < currentAttibutes.size(); j++) {
142                 Map<String, Integer> ttt=new HashMap<String, Integer>();
143                 for (int i = 0; i <subDataSet.size(); i++) {
144                     String currentFeatureValueXX=subDataSet.get(i)[j];
145                     if(!(ttt.containsKey(currentFeatureValueXX)))
146                         ttt.put(currentFeatureValueXX, 1);
147                     else {
148                         ttt.replace(currentFeatureValueXX, ttt.get(currentFeatureValueXX)+1);
149                     }
150 
151                 }
152                 currentAttributeValuesCounts.put(currentAttibutes.get(j), ttt);//每个特征值在数据集中出现的次数
153 
154             }
155 
156             root2.add(treeGrowth(subDataSet, bestFeatureLabel, values, currentAttibutes, currentAttributeValuesCounts));
157 
158         }
159 
160 
161         return root2;
162 
163     }
164 
165     public static void main(String[] agrs){
166         decisionTree.GetDataSet();
167         DecisionNode dd=decisionTree.treeGrowth(dataSet,"oo","xx",features,featureValuesAndCounts);
168         System.out.print(dd);
169 
170 
171 
172     }
173 
174     public static double calEntropy(ArrayList<String[]> dataset){//熵表示随机变量X不确定性的度量,在决策树中计算的熵就是决策结果这个变量的熵。
175         int sampleCounts=dataset.size();
176         Map<String, Integer> categoryCounts=new HashMap<String, Integer>();
177         for (String[] strings : dataset) {
178 
179             if(categoryCounts.containsKey(strings[strings.length-1]))
180                 categoryCounts.replace(strings[strings.length-1], categoryCounts.get(strings[strings.length-1])+1);
181             else {
182                 categoryCounts.put(strings[strings.length-1],1);
183             }
184 
185         }
186         double shannonEnt=0.0;
187         for (Integer value: categoryCounts.values()) {
188             double probability=value.doubleValue()/sampleCounts;
189             shannonEnt-=probability*(Math.log10(probability)/Math.log10(2));
190 
191         }
192         return shannonEnt;
193     }
194 
195     public static  int findBestAttribute(ArrayList<String[]> dataset,ArrayList<String> currentFeatures,
196                                          Map<String,Map<String,Integer>> currentFeatureValuesCounts){
197         double baseEntroy=calEntropy(dataset);//计算基础熵,就是在不划分出某个特征的情况下。
198         double bestInfoGain=0.0;
199         int bestFeatureIndex=-1;
200 
201         for (int i = 0; i <currentFeatures.size(); i++) {//遍历当前数据集的每个特征,计算每个特征的信息增益
202             double conditionalEntroy=0.0;
203             Map<String,Integer> tempFeatureCounts=currentFeatureValuesCounts.get(currentFeatures.get(i));
204             //Map类型有一个entrySet方法,此方法返回一个Map.Entry类型的集合,其中集合中的每个元素就是一个键值对,利用增强型的for循环可以遍历Map中
205             //key(entry.getkey)和value(entry.getValue)
206             for (Entry<String, Integer> entry : tempFeatureCounts.entrySet()) {
207                 //计算条件熵,就是根据某个具体特征值划分出新的数据集,计算新的数据集的基础熵,再乘以权值,累加得到某个特征的条件熵。
208                 conditionalEntroy+=(entry.getValue().doubleValue()/dataset.size())*calEntropy(splitDataSet(dataset, i, entry.getKey()));
209             }
210             if (baseEntroy-conditionalEntroy>bestInfoGain) {
211                 bestInfoGain=baseEntroy-conditionalEntroy;
212                 bestFeatureIndex=i;
213 
214             }
215         }
216         if (-1==bestFeatureIndex){
217             System.out.print("cannot find best attribute!");
218             return -1;
219         }
220         else {
221             return bestFeatureIndex;//返回信息增益最大的特征的索引,在当前特征(currentFeatures)中的索引。
222         }
223     }
224     public static String classify(ArrayList<String> dataset) {
225 
226         Map<String, Integer> categoryCount = new HashMap<String, Integer>();
227         for (String s1 : dataset) {
228             if (categoryCount.containsKey(s1)) {
229                 categoryCount.replace(s1, categoryCount.get(s1) + 1);
230             } else {
231                 categoryCount.put(s1, 1);
232             }
233         }
234         int maxCounts=-1;
235         String maxCountsCategory=null;
236         for (Entry<String,Integer> entry:categoryCount.entrySet()){//利用Map.Entry得到Map中的Value最大的键值对。
237             if (entry.getValue()>maxCounts){
238                 maxCounts=entry.getValue();
239                 maxCountsCategory=entry.getKey();
240             }
241         }
242         return  maxCountsCategory;
243 
244     }
245 
246     public static ArrayList<String[]> splitDataSet(ArrayList<String[]> dataset,int featureIndex,String featureValue
247     ){
248         ArrayList<String[]> tempDataSet=new ArrayList<String[]>();
249         for (String[] strings : dataset) {
250             if (strings[featureIndex].equals(featureValue)) {
251 
252                 String[] xx=strings.clone();//数组的clone方法实现的是浅拷贝,实质就是以下的过程
253                 /*
254                 for (int i = featureIndex; i < strings.length-1; i++) {
255                     xx[i]=strings[i];//就是把引用的值(地址)复制了一份,指向了同一个对象。
256                 }
257 
258                 */
259                 for (int i = featureIndex; i < strings.length-1; i++) {//xx中各个元素的值与strings中各个元素的值完全相等。
260                     xx[i]=xx[i+1];//只是复制了引用的值而已,跟引用指向的对象没一点关系。Java将基本类型和引用类型变量都看成是值而已·
261                 }
262                 //最最最需要注意的一点,以上代码不能以下面这种形式实现
263                 /*
264                 for (int i = featureIndex; i < strings.length-1; i++) {//
265                     strings[i]=strings[i+1];//这样会改变strings指向的对象,进而影响到dataset,改变了函数的参数dataset,
266                     这样就在函数内“无意间”修改了dataset的值,集合类型,其实所有引用类型都是,以参数形式传入函数的话,可能会“无意间”就被修改了
267                 }
268                  */
269                 String[] tempStrings=new String[xx.length-1];
270                 for (int i = 0; i < tempStrings.length; i++) {
271                     tempStrings[i]=xx[i];
272 
273                 }
274                 tempDataSet.add(tempStrings);
275             }
276 
277 
278         }
279         return tempDataSet;
280     }
281 
282 }
283 class DecisionNode{
284     public String featureName;
285     public String result;
286     public String featureValue;
287     public List<DecisionNode> children=new ArrayList<DecisionNode>();
288     public void add(DecisionNode node){
289         children.add(node);
290     }
291     public void init(String featureName,String result,String featureValue){
292         this.featureName=featureName;
293         this.result=result;
294         this.featureValue=featureValue;
295     }
296     public void init(String featureName,String featureValue){
297         this.featureName=featureName;
298         this.featureValue=featureValue;
299     }
300 }

 参考:

http://www.blogjava.net/zhenandaci/archive/2009/03/24/261701.html

http://my.oschina.net/xinyi/blog/116014

http://www.cnblogs.com/zhangchaoyang/articles/2196631.html

 http://blog.csdn.net/u011067360/article/details/21861989?utm_source=tuicool

原文地址:https://www.cnblogs.com/lz3018/p/4820144.html