java 朴素贝叶斯

由于在网上找的bayes的源码都是基于应用的,本人才疏学浅,看不太懂,自己花了2天时间写了个粗糙的代码(基于李航那本书的例子),由于只是初学,若有错误,请指出,大家一起学习!

  1 import java.io.BufferedReader;
  2 import java.io.File;
  3 import java.io.FileNotFoundException;
  4 import java.io.FileReader;
  5 import java.io.IOException;
  6 import java.util.ArrayList;
  7 import java.util.HashMap;
  8 import java.util.List;
  9 import java.util.Map;
 10 
 11 public class Bayes {
 12     public static void main(String[] args){
 13         List<List<String>> filelist = new ArrayList<List<String>>();
 14         Map<String,Double> prioriP = new HashMap<String,Double>();
 15         Map<String,Integer> prioriNo = new HashMap<String,Integer>();
 16         Map<String,Double> result = new HashMap<String,Double>();
 17         String s1 = "D://1.txt";
 18         String s2 = "D://2.txt";
 19         filelist = read(filelist,s1);
 20         prioriP = computepirior(filelist,prioriP,prioriNo);
 21         List<List<String>> testlist = new ArrayList<List<String>>();
 22         testlist = read(testlist,s2);
 23         result = decide(prioriP,filelist,testlist,prioriNo);
 24         print(result,testlist);
 25     }
 26     //第4步、打印结构
 27     private static void print(Map<String, Double> result,List<List<String>> testlist) {
 28         System.out.print("测试数据:" + "   ");
 29         for(int i=0;i<testlist.size();i++){
 30             System.out.print("特征" + (i+1) +" :");
 31             for(int j=0;j<testlist.get(i).size();j++){
 32                 System.out.print(testlist.get(i).get(j) + "   ");
 33             }
 34         }
 35         System.out.print("所属类别:" + result.keySet().iterator().next());
 36     }
 37     //第3.1步、把元数据根据所属类别分开处理
 38     private static Map<String, Double> decide(Map<String, Double> prioriP, List<List<String>> filelist, List<List<String>> testlist, Map<String, Integer> prioriNo) {
 39         List<Map<String,Integer>> map = new ArrayList<Map<String,Integer>>();
 40         List<List<List<String>>> fc = new ArrayList<List<List<String>>>();
 41         
 42         for(Map.Entry<String, Integer> entry : prioriNo.entrySet()){
 43             List<List<String>> filecopy = new ArrayList<List<String>>();
 44             for(int i=0;i<filelist.size();i++){
 45                 List<String> list = new ArrayList<String>();
 46                 for(int j=0;j<filelist.get(i).size();j++){
 47                     if(filelist.get(filelist.size()-1).get(j).equals(entry.getKey())){
 48                         list.add(filelist.get(i).get(j));
 49                     }
 50                 }
 51                 filecopy.add(list);
 52             }
 53             fc.add(filecopy);
 54         }
 55 
 56         //有几组测试数据,本来想实现的是测试数据是多对,自己写不出来,这段代码有待改进
 57         //第3.2步、测试数据在条件下出现的次数
 58         List<Map<String,Integer>> l = new ArrayList<Map<String,Integer>>();
 59         for(int i=0;i<fc.size();i++){
 60             Map<String,Integer> mapdecide = new HashMap<String,Integer>();
 61             for(int k=0;k<fc.get(i).size()-1;k++){
 62                 for(int j=0;j<fc.get(i).get(k).size();j++){                                //需要和元数据比较的次数
 63                     if(testlist.get(k).get(0).equals(fc.get(i).get(k).get(j))){
 64                         if(mapdecide.containsKey(testlist.get(k).get(0))){
 65                             mapdecide.put(testlist.get(k).get(0), mapdecide.get(testlist.get(k).get(0)) + 1);
 66                         }
 67                         else{
 68                             mapdecide.put(testlist.get(k).get(0), 1);
 69                         }
 70                     }
 71                 }
 72             }
 73             l.add(mapdecide);
 74         }
 75         
 76         //第3.3步、求后验概率,并比较哪个类别的概率大即所属类别
 77         Map<String,Double> m = new HashMap<String,Double>();
 78         for(int i=0;i<l.size();i++){
 79             double d = 1.0;
 80             for(Map.Entry<String, Integer> entry : l.get(i).entrySet()){
 81                 d *= (entry.getValue()/(double)fc.get(i).get(fc.get(i).size()-1).size());
 82             }
 83             m.put(fc.get(i).get(fc.get(i).size()-1).get(0), prioriP.get(fc.get(i).get(fc.get(i).size()-1).get(0)) * d);
 84         }
 85         
 86         Double max = 0.0;
 87         for(Map.Entry<String, Double> e : m.entrySet()){
 88             if(max <= e.getValue()){
 89                 max = e.getValue();
 90             }
 91         }
 92         
 93         Map<String,Double> result = new HashMap<String,Double>();
 94         for(Map.Entry<String, Double> e:m.entrySet()){
 95             if(max == e.getValue()){
 96                 result.put(e.getKey(), e.getValue());
 97             }
 98         }
 99         return result;
100     }
101     
102     //第2步、求先验概率
103     private static Map<String, Double> computepirior(List<List<String>> list, Map<String, Double> prioriP, Map<String, Integer> m) {
104         
105         for(int i=0;i<list.get(list.size()-1).size();i++){
106             if(m.containsKey(list.get(list.size()-1).get(i))){
107                 m.put(list.get(list.size()-1).get(i),m.get(list.get(list.size()-1).get(i)) + 1);
108             }
109             else{
110                 m.put(list.get(list.size()-1).get(i),1);
111             }
112         }
113         for (Map.Entry<String,Integer> entry : m.entrySet()) {
114             prioriP.put(entry.getKey(),(entry.getValue()/(double)list.get(list.size()-1).size()));
115         }
116         return prioriP;
117     }
118     //第1步、读取测试数据和训练数据
119     private static List<List<String>> read(List<List<String>> list, String sread) {
120         try {
121             FileReader fr = new FileReader(new File(sread));
122             BufferedReader br = new BufferedReader(fr);
123             String string = br.readLine();
124             while(string != null){
125                 List<String> l = new ArrayList<String>();
126                 String[] str = string.split(" ");
127                 for (String s : str) {
128                     l.add(s);
129                 }
130                 list.add(l);
131                 string = br.readLine();
132             }
133         } catch (FileNotFoundException e) {
134             e.printStackTrace();
135         } catch (IOException e) {
136             e.printStackTrace();
137         }
138         return list;
139     }
140 }

 训练数据:

1 1 1 1 1 2 2 2 2 2 3 3 3 3 3
S M M S S S M M L L L M M L L
-1 -1 1 1 -1 -1 -1 1 1 1 1 1 1 1 -1

测试数据

2

S

实现结果:

测试数据:   特征1 :2   特征2 :S   所属类别:-1

原文地址:https://www.cnblogs.com/wn19910213/p/3329590.html