[Java] 数据分析--分类

ID3算法

  • 思路:分类算法的输入为训练集,输出为对数据进行分类的函数。ID3算法为分类函数生成分类树
  • 需求:对水果训练集的一个维度(是否甜)进行预测
  • 实现:决策树,熵函数,ID3,weka库 J48类

ComputeGain.java

 1 public class ComputeGain {
 2     public static void main(String[] args) {
 3         System.out.printf("h(11,16) = %.4f%n", h(11,16));
 4         System.out.println("Gain(Size):");
 5         System.out.printf("	h(3,5) = %.4f%n", h(3,5));
 6         System.out.printf("	h(6,7) = %.4f%n", h(6,7));
 7         System.out.printf("	h(2,4) = %.4f%n", h(2,4));
 8         System.out.printf("	g({3,6,2},{5,7,4}) = %.4f%n", 
 9                     g(new int[]{3,6,2},new int[]{5,7,4}));
10         System.out.println("Gain(Color):");
11         System.out.printf("	h(3,4) = %.4f%n", h(3,4));
12         System.out.printf("	h(3,5) = %.4f%n", h(3,5));
13         System.out.printf("	h(2,3) = %.4f%n", h(2,3));
14         System.out.printf("	h(2,4) = %.4f%n", h(2,4));
15         System.out.printf("	g({3,3,2,2},{4,5,3,4}) = %.4f%n", 
16                     g(new int[]{3,3,2,2},new int[]{4,5,3,4}));
17         System.out.println("Gain(Surface):");
18         System.out.printf("	h(4,7) = %.4f%n", h(4,7));
19         System.out.printf("	h(4,6) = %.4f%n", h(4,6));
20         System.out.printf("	h(3,3) = %.4f%n", h(3,3));
21         System.out.printf("	g({4,4,3},{7,6,3}) = %.4f%n", 
22                     g(new int[]{4,4,3},new int[]{7,6,3}));
23         System.out.println("Gain(Size|SMOOTH):");
24         System.out.printf("	h(1,3) = %.4f%n", h(1,3));
25         System.out.printf("	h(3,3) = %.4f%n", h(3,3));
26         System.out.printf("	g({1,3,0},{3,3,1}) = %.4f%n", 
27                     g(new int[]{1,3,0},new int[]{3,3,1}));
28         System.out.println("Gain(Color|SMOOTH):");
29         System.out.printf("	h(2,3) = %.4f%n", h(2,3));
30         System.out.printf("	g({2,2,0},{3,2,2}) = %.4f%n", 
31                     g(new int[]{2,2,0},new int[]{3,2,2}));
32         System.out.println("Gain(Size|ROUGH):");
33         System.out.printf("	h(3,6) = %.4f%n", h(3,6));
34         System.out.printf("	h(1,2) = %.4f%n", h(1,2));
35         System.out.printf("	g({2,1,1},{2,2,2}) = %.4f%n", 
36                     g(new int[]{2,1,1},new int[]{2,2,2}));
37         System.out.println("Gain(Color|ROUGH):");
38         System.out.printf("	h(4,6) = %.4f%n", h(4,6));
39         System.out.printf("	g({1,1,1},{2,2,2}) = %.4f%n", 
40                     g(new int[]{1,0,2,1},new int[]{1,2,2,1}));
41     }
42     
43     /*  Gain for the splitting {A1, A2, ...}, where Ai 
44         has n[i] points, m[i] of which are favorable.
45     */
46     public static double g(int[] m, int[] n) {
47         int sm = 0, sn = 0;
48         double nsh = 0.0;
49         for (int i = 0; i < m.length; i++) {
50             sm += m[i];
51             sn += n[i];
52             nsh += n[i]*h(m[i],n[i]);
53         }
54         return h(sm, sn) - nsh/sn;
55     }
56     
57     /*  Entropy for m favorable items out of n.
58     */
59     public static double h(int m, int n) {
60         if (m == 0 || m == n) {
61             return 0;
62         }
63         double p = (double)m/n, q = 1 - p;
64         return -p*lg(p) - q*lg(q);
65     }
66     
67     /*  Returns the binary logarithm of x.
68     */
69     public static double lg(double x) {
70         return Math.log(x)/Math.log(2);
71     }
72 }
View Code

h(11,16) = 0.8960
Gain(Size):
h(3,5) = 0.9710
h(6,7) = 0.5917
h(2,4) = 1.0000
g({3,6,2},{5,7,4}) = 0.0838
Gain(Color):
h(3,4) = 0.8113
h(3,5) = 0.9710
h(2,3) = 0.9183
h(2,4) = 1.0000
g({3,3,2,2},{4,5,3,4}) = 0.0260
Gain(Surface):
h(4,7) = 0.9852
h(4,6) = 0.9183
h(3,3) = 0.0000
g({4,4,3},{7,6,3}) = 0.1206
Gain(Size|SMOOTH):
h(1,3) = 0.9183
h(3,3) = 0.0000
g({1,3,0},{3,3,1}) = 0.5917
Gain(Color|SMOOTH):
h(2,3) = 0.9183
g({2,2,0},{3,2,2}) = 0.5917
Gain(Size|ROUGH):
h(3,6) = 1.0000
h(1,2) = 1.0000
g({2,1,1},{2,2,2}) = 0.2516
Gain(Color|ROUGH):
h(4,6) = 0.9183
g({1,1,1},{2,2,2}) = 0.9183

 1 import weka.classifiers.trees.J48;
 2 import weka.core.Instances;
 3 import weka.core.Instance;
 4 import weka.core.converters.ConverterUtils.DataSource;
 5 
 6 public class TestWekaJ48 {
 7     public static void main(String[] args) throws Exception {
 8         DataSource source = new DataSource("data/AnonFruit.arff");
 9         Instances instances = source.getDataSet();
10         instances.setClassIndex(3);  // target attribute: (Sweet)
11         
12         J48 j48 = new J48();  // an extension of ID3
13         j48.setOptions(new String[]{"-U"});  // use unpruned tree
14         j48.buildClassifier(instances);
15 
16         for (Instance instance : instances) {
17             double prediction = j48.classifyInstance(instance);
18             System.out.printf("%4.0f%4.0f%n", 
19                     instance.classValue(), prediction);
20         }
21     }
22 }
View Code

1 1
1 1
1 1
1 0
1 1
0 0
1 1
0 0
0 0
0 0
1 1
1 1
1 1
1 1
0 0
1 1

贝叶斯分类

  • 思路:基于训练集计算的比率生成的函数进行分类

Fruit.java

 1 import java.io.File;
 2 import java.io.FileNotFoundException;
 3 import java.util.HashSet;
 4 import java.util.Scanner;
 5 import java.util.Set;
 6 
 7 public class Fruit {
 8     String name, size, color, surface;
 9     boolean sweet;
10 
11     public Fruit(String name, String size, String color, String surface, 
12             boolean sweet) {
13         this.name = name;
14         this.size = size;
15         this.color = color;
16         this.surface = surface;
17         this.sweet = sweet;
18     }
19 
20     @Override
21     public String toString() {
22         return String.format("%-12s%-8s%-8s%-8s%s", 
23                 name, size, color, surface, (sweet? "T": "F") );
24     }
25     
26     public static Set<Fruit> loadData(File file) {
27         Set<Fruit> fruits = new HashSet();
28         try {
29             Scanner input = new Scanner(file);
30             for (int i = 0; i < 7; i++) {  // read past metadata
31                 input.nextLine();
32             }
33             while (input.hasNextLine()) {
34                 String line = input.nextLine();
35                 Scanner lineScanner = new Scanner(line);
36                 String name = lineScanner.next();
37                 String size = lineScanner.next();
38                 String color = lineScanner.next();
39                 String surface = lineScanner.next();
40                 boolean sweet = (lineScanner.next().equals("T"));
41                 Fruit fruit = new Fruit(name, size, color, surface, sweet);
42                 fruits.add(fruit);
43             }
44         } catch (FileNotFoundException e) {
45             System.err.println(e);
46         }
47         return fruits;
48     }
49 
50     public static void print(Set<Fruit> fruits) {
51         int k=1;
52         for (Fruit fruit : fruits) {
53             System.out.printf("%2d. %s%n", k++, fruit);
54         }
55     }
56 }
View Code

BayesianTest.java

 1 import java.io.File;
 2 import java.util.Set;
 3 
 4 public class BayesianTest {
 5     private static Set<Fruit> fruits;
 6     
 7     public static void main(String[] args) {
 8         fruits = Fruit.loadData(new File("data/Fruit.arff"));
 9         Fruit fruit = new Fruit("cola", "SMALL", "RED", "SMOOTH", false);
10         double n = fruits.size();  // total number of fruits in training set
11         double sum1 = 0;           // number of sweet fruits
12         for (Fruit f : fruits) {
13             sum1 += (f.sweet? 1: 0);
14         }
15         double sum2 = n - sum1;    // number of sour fruits
16         double[][] p = new double[4][3];
17         for (Fruit f : fruits) {
18             if (f.sweet) {
19                 p[1][1] += (f.size.equals(fruit.size)? 1: 0)/sum1;
20                 p[2][1] += (f.color.equals(fruit.color)? 1: 0)/sum1;
21                 p[3][1] += (f.surface.equals(fruit.surface)? 1: 0)/sum1;
22             } else {
23                 p[1][2] += (f.size.equals(fruit.size)? 1: 0)/sum2;
24                 p[2][2] += (f.color.equals(fruit.color)? 1: 0)/sum2;
25                 p[3][2] += (f.surface.equals(fruit.surface)? 1: 0)/sum2;
26             }
27         }
28         double pc1 = p[1][1]*p[2][1]*p[3][1]*sum1/n;
29         double pc2 = p[1][2]*p[2][2]*p[3][2]*sum2/n;
30         System.out.printf("pc1 = %.4f, pc2 = %.4f%n", pc1, pc2);
31         System.out.printf("Predict %s is %s.%n", 
32                 fruit.name, (pc1 > pc2? "sweet": "sour"));
33     }
34 }
View Code

pc1 = 0.0186, pc2 = 0.0150
Predict cola is sweet.

TestWekaBayes.java

 1 import java.util.List;
 2 import weka.classifiers.Evaluation;
 3 import weka.classifiers.bayes.NaiveBayes;
 4 import weka.classifiers.evaluation.Prediction;
 5 import weka.core.Instance;
 6 import weka.core.Instances;
 7 import weka.core.converters.ConverterUtils;
 8 import weka.core.converters.ConverterUtils.DataSource;
 9 
10 public class TestWekaBayes {
11     public static void main(String[] args) throws Exception {
12 //        ConverterUtils.DataSource source = new ConverterUtils.DataSource("data/AnonFruit.arff");
13         DataSource source = new DataSource("data/AnonFruit.arff");
14         Instances train = source.getDataSet();
15         train.setClassIndex(3);  // target attribute: (Sweet)
16         //build model
17         NaiveBayes model=new NaiveBayes();
18         model.buildClassifier(train);
19 
20         //use
21         Instances test = train;
22         Evaluation eval = new Evaluation(test);
23         eval.evaluateModel(model,test);
24         List <Prediction> predictions = eval.predictions();
25         int k = 0;
26         for (Instance instance : test) {
27             double actual = instance.classValue();
28             double prediction = eval.evaluateModelOnce(model, instance);
29             System.out.printf("%2d.%4.0f%4.0f", ++k, actual, prediction);
30             System.out.println(prediction != actual? " *": "");
31         }
32     }
33 }
View Code

1. 1 1
2. 1 1
3. 1 1
4. 1 1
5. 1 1
6. 0 1 *
7. 1 1
8. 0 0
9. 0 0
10. 0 1 *
11. 1 1
12. 1 1
13. 1 1
14. 1 1
15. 0 0
16. 1 1

SVM算法

  • 思路:生成超平面方程,计算数据点位于哪一边

逻辑回归

  • 思路:将目标值属性为布尔值的问题转化成一个数值变量,在转化后的问题上进行线性回归
  • 需求:某政党候选人想知道选举获胜的花费
  • 实现
 1 import org.apache.commons.math3.analysis.function.*;
 2 import org.apache.commons.math3.stat.regression.SimpleRegression;
 3 
 4 public class LogisticRegression {
 5     static int n = 6;
 6     static double[] x = {5, 15, 25, 35, 45, 55};
 7     static double[] p = {2./6,2./5, 4./8, 5./9, 3./5, 4./5};
 8     static double[] y = new double[n];    // y = logit(p)
 9 
10     public static void main(String[] args) {
11         
12         //  Transform p-values into y-values:
13         Logit logit = new Logit();
14         for (int i = 0; i < n; i++) {
15             y[i] = logit.value(p[i]);
16         }
17         
18         //  Set up input array for linear regression:
19         double[][] data = new double[n][n];
20         for (int i = 0; i < n; i++) {
21             data[i][0] = x[i];
22             data[i][1] = y[i];
23         }
24         
25         //  Run linear regression of y on x:
26         SimpleRegression sr = new SimpleRegression();
27         sr.addData(data);
28         
29         //  Print results:
30         for (int i = 0; i < n; i++) {
31             System.out.printf("x = %2.0f, y = %7.4f%n", x[i], sr.predict(x[i]));
32         }
33         System.out.println();
34         
35         //  Convert y-values back to p-values:
36         Sigmoid sigmoid = new Sigmoid();
37         for (int i = 0; i < n; i++) {
38             double p = sr.predict(x[i]);
39             System.out.printf("x = %2.0f, p = %6.4f%n", x[i], sigmoid.value(p));
40         }
41     }
42 }
View Code

x = 5, y = -0.7797
x = 15, y = -0.4067
x = 25, y = -0.0338
x = 35, y = 0.3392
x = 45, y = 0.7121
x = 55, y = 1.0851

x = 5, p = 0.3144
x = 15, p = 0.3997
x = 25, p = 0.4916
x = 35, p = 0.5840
x = 45, p = 0.6709
x = 55, p = 0.7475

k临近

  • 思路:根据临近范围内的样本进行分类
 1 import weka.classifiers.lazy.IBk;  // K-Nearest Neighbors
 2 import weka.core.Instances;
 3 import weka.core.Instance;
 4 import weka.core.converters.ConverterUtils.DataSource;
 5 
 6 public class TestIBk {
 7     public static void main(String[] args) throws Exception {
 8         DataSource source = new DataSource("data/AnonFruit.arff");
 9         Instances instances = source.getDataSet();
10         instances.setClassIndex(3);  // target attribute: (Sweet)
11         
12         IBk ibk = new IBk();
13         ibk.buildClassifier(instances);
14 
15         for (Instance instance : instances) {
16             double prediction = ibk.classifyInstance(instance);
17             System.out.printf("%4.0f%4.0f%n", 
18                     instance.classValue(), prediction);
19         }
20     }
21 }
View Code

1 1
1 1
1 1
1 0
1 1
0 0
1 1
0 0
0 0
0 0
1 1
1 1
1 1
1 1
0 0
1 1

原文地址:https://www.cnblogs.com/cxc1357/p/14692228.html