贝叶斯算法Java实现

前言:朴素贝叶斯分类算法是一种基于贝叶斯定理的简单概率分类算法。贝叶斯分类的基础是概率推理,就是在各种条件的存在不确定,仅知其出现概率的情况下,如何完成推理和决策任务。概率推理是与确定性推理相对应的。而朴素贝叶斯分类器是基于独立假设的,即假设样本每个特征与其他特征都不相关。

朴素贝叶斯分类器依靠精确的自然概率模型,在有监督学习的样本集中能获取得非常好的分类效果。在许多实际应用中,朴素贝叶斯模型参数估计使用最大似然估计方法,换言之朴素贝叶斯模型能工作并没有用到贝叶斯概率或者任何贝叶斯模型。

尽管是带着这些朴素思想和过于简单化的假设,但朴素贝叶斯分类器在很多复杂的现实情形中仍能够取得相当好的效果。

贝叶斯算法基础讲解:http://www.cnblogs.com/skyme/p/3564391.html

package Bayes;

import java.util.ArrayList;  
import java.util.HashMap;  
import java.util.Map;  
import java.math.BigDecimal;
public class Bayes {  

    //将训练集按巡逻集合的最后一个值进行分类  
    Map<String, ArrayList<ArrayList<String>>> datasOfClass(ArrayList<ArrayList<String>> datas){  
        Map<String, ArrayList<ArrayList<String>>> map = new HashMap<String, ArrayList<ArrayList<String>>>();  
        ArrayList<String> t = null;  
        String c = "";  
        for (int i = 0; i < datas.size(); i++) {  
            t = datas.get(i);  
            c = t.get(t.size() - 1);  
            if (map.containsKey(c)) {  
                map.get(c).add(t);  
            } else {  
                ArrayList<ArrayList<String>> nt = new ArrayList<ArrayList<String>>();  
                nt.add(t);  
                map.put(c, nt);  
            }  
        }  
        return map;  
    }  

    //在训练数据的基础上预测测试元组的类别 ,testT的各个属性在结果集里面出现的概率相乘最高的,即是结果
    public String predictClass(ArrayList<ArrayList<String>> datas, ArrayList<String> testT) {  
        Map<String, ArrayList<ArrayList<String>>> doc = this.datasOfClass(datas);  
        //将训练集元素划分保存在数据里
        Object classes[] = doc.keySet().toArray();  
        double maxP = 0.00;  
        int maxPIndex = -1;  
      //testT的各个属性在结果集里面出现的概率相乘最高的,即使结果集
        for (int i = 0; i < doc.size(); i++) {  
            String c = classes[i].toString();   
            ArrayList<ArrayList<String>> d = doc.get(c);  
            BigDecimal b1 = new BigDecimal(Double.toString(d.size()));
            BigDecimal b2 = new BigDecimal(Double.toString(datas.size()));
            //b1除以b2得到一个精度为3的双浮点数
            double pOfC = b1.divide(b2,3,BigDecimal.ROUND_HALF_UP).doubleValue(); 
            for (int j = 0; j < testT.size(); j++) {  
                double pv = this.pOfV(d, testT.get(j), j);
                BigDecimal b3 = new BigDecimal(Double.toString(pOfC));   
                BigDecimal b4 = new BigDecimal(Double.toString(pv));
                //b3乘以b4得到一个浮点数
                pOfC=b3.multiply(b4).doubleValue(); 
            }  
            if(pOfC > maxP){  
                maxP = pOfC;  
                maxPIndex = i;  
            }  
        }  
        return classes[maxPIndex].toString();  
    } 

    // 计算指定属性到训练集出现的频率  
    private double pOfV(ArrayList<ArrayList<String>> d, String value, int index) {  
        double p = 0.00;  
        int count = 0;  
        int total = d.size();  
        for (int i = 0; i < total; i++) {  
            if(d.get(i).get(index).equals(value)){  
                count++;  
            }  
        }  
        BigDecimal b1 = new BigDecimal(Double.toString(count));
        BigDecimal b2 = new BigDecimal(Double.toString(total));
        //b1除以b2得到一个精度为3的双浮点数
        p = b1.divide(b2,3,BigDecimal.ROUND_HALF_UP).doubleValue(); 
        return p;  
    }  
}  
package Bayes;

import java.io.BufferedReader;  
import java.io.IOException;  
import java.io.InputStreamReader;  
import java.util.ArrayList;  

public class TestBayes {  

    //读取测试元组
    public ArrayList<String> readTestData() throws IOException{  
        ArrayList<String> candAttr = new ArrayList<String>();  
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));  
        String str = "";  
        while (!(str = reader.readLine()).equals("")) {
            //string分析器
            String[] tokenizer = str.split(" ");
            for(int i=0;i<tokenizer.length;i++){
                candAttr.add(tokenizer[i]);
            } 
        }  
        return candAttr;  
    }  

    //读取训练集
    public ArrayList<ArrayList<String>> readData() throws IOException {  
        ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>();  
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));  
        String str = "";  
        while (!(str = reader.readLine()).equals("")) {  
            String[] tokenizer = str.split(" ");  
            ArrayList<String> s = new ArrayList<String>();  
            for(int i=0;i<tokenizer.length;i++){
                s.add(tokenizer[i]);
            } 
            datas.add(s);  
        }  
        return datas;  
    }  

    public static void main(String[] args) {  
        TestBayes tb = new TestBayes();  
        ArrayList<ArrayList<String>> datas = null;  
        ArrayList<String> testT = null;  
        Bayes bayes = new Bayes();  
        try {  
            System.out.println("请输入训练数据");  
            datas = tb.readData();  
            while (true) {  
                System.out.println("请输入测试元组");  
                testT = tb.readTestData();  
                String c = bayes.predictClass(datas, testT);  
                System.out.println("The class is: " + c);  
            }  
        } catch (IOException e) {  
            e.printStackTrace();  
        }  
    }  
}  

测试结果:

请输入训练数据
youth high no fair no  
youth high no excellent no  
middle_aged high no fair yes  
senior medium no fair yes  
senior low yes fair yes  
senior low yes excellent no  
middle_aged low yes excellent yes  
youth medium no fair no  
youth low yes fair yes  
senior medium yes fair yes  
youth medium yes excellent yes  
middle_aged medium no excellent yes  
middle_aged high yes fair yes  
senior medium no excellent no  

贝叶斯扩展:
《数学之美》贝叶斯网络http://www.cricode.com/1078.html
《数学之美》贝叶斯分类方法http://www.cricode.com/1098.html

原文地址:https://www.cnblogs.com/yankang/p/6399032.html