朴素贝叶斯原理和应用

上次去深圳招行面试。被问到了这个。中间讨论了几个关于贝叶斯的问题。可能我并不偏向知识图谱。然后就没有下文了。

结合李航的《统计学》和几篇博客,还有在凤凰网某位仁兄贡献新闻分类的源码。给自己复习一下。

为什么叫朴素贝叶斯和大学课本里的贝叶斯有什么不同?

朴素一词来源于==>假设各特征之间相互独立。这一假设使得朴素贝叶斯算法变得简单,但有时会牺牲一定的分类准确率。

招行的那位小姐姐有先验。说的就是这个。

大学里面的贝叶斯

算法使用的朴素贝叶斯(怎么我感觉叫条件特征独立贝叶斯更好呢):

条件独立假设:

就是说分类特征在类确定的条件下都是独立的。

朴素贝叶斯分类时,对于给定输出的x,通过学习得到的模型计算后验概率分布p(Y=ck|X=x),将后验概率最大的类作为x的类输出,后验概率计算根据贝叶斯定理进行:

把特征独立条件带入上面公式:

 所以贝叶斯分类器可以表示为:

因为分母对于所有的K都是相同的,公式可以简化为

朴素贝叶斯法的参数估计

学习就意味着估计,使用极大似然估计法估计相应的概率。

先验概率的极大似然估计是

 

条件概率的极大似然估计是

 

朴素贝叶斯的优缺点

优点:

  (1) 算法逻辑简单,易于实现(算法思路很简单,只要使用贝叶斯公式转化即可!)
(2)分类过程中时空开销小(假设特征相互独立,只会涉及到二维存储)
缺点:
      朴素贝叶斯假设属性之间相互独立,这种假设在实际过程中往往是不成立的。在属性之间相关性越大,分类误差也就越大。

朴素贝叶斯实战

    sklearn中有3种不同类型的朴素贝叶斯:

  高斯分布型:用于classification问题,假定属性/特征服从正态分布的。
  多项式型:用于离散值模型里。比如文本分类问题里面我们提到过,我们不光看词语是否在文本中出现,也得看出现次数。如果总词数为n,出现词数为m的话,有点像掷骰子n次出现m次这个词的场景。
  伯努利型:最后得到的特征只有0(没出现)和1(出现过)。


莺尾花Demo

 https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_score.html 

from sklearn.naive_bayes import GaussianNB
from sklearn.model_selection import cross_val_score
from sklearn import datasets
iris = datasets.load_iris()
gnb = GaussianNB()
scores=cross_val_score(gnb, iris.data, iris.target, cv=10)

print(scores)
[ 0.93333333  0.93333333  1.          0.93333333  0.93333333  0.93333333
  0.86666667  1.          1.          1.        ]
kaggle比赛中旧金山犯罪

1.数据观察
import pandas as pd  
import numpy as np  
from sklearn import preprocessing  
from sklearn.metrics import log_loss  
from sklearn.cross_validation import train_test_split
train = pd.read_csv('train.csv', parse_dates = ['Dates'])  
test = pd.read_csv('test.csv', parse_dates = ['Dates'])
train

特征为

Date: 日期
Category: 犯罪类型,比如 Larceny/盗窃罪 等.
Descript: 对于犯罪更详细的描述
DayOfWeek: 星期几
PdDistrict: 所属警区
Resolution: 处理结果『逮捕』『逃了』
Address: 发生街区位置
X and Y: GPS坐标

2.特征处理

sklearn.preprocessing模块中的 LabelEncoder函数可以对类别做编号,我们用它对犯罪类型做编号;

pandas中的get_dummies( )可以将变量进行二值化01向量,我们用它对”街区“、”星期几“、”时间点“进行因子化。

#对犯罪类别:Category; 用LabelEncoder进行编号  
leCrime = preprocessing.LabelEncoder()  
crime = leCrime.fit_transform(train.Category)   #39种犯罪类型  
#用get_dummies因子化星期几、街区、小时等特征  
days=pd.get_dummies(train.DayOfWeek)  
district = pd.get_dummies(train.PdDistrict)  
hour = train.Dates.dt.hour  
hour = pd.get_dummies(hour)  
#组合特征  
trainData = pd.concat([hour, days, district], axis = 1)  #将特征进行横向组合  
trainData['crime'] = crime   #追加'crime'列  
days = pd.get_dummies(test.DayOfWeek)  
district = pd.get_dummies(test.PdDistrict)  
hour = test.Dates.dt.hour  
hour = pd.get_dummies(hour)  
testData = pd.concat([hour, days, district], axis=1)  
trainData

 

3.建立贝叶斯模型

from sklearn.naive_bayes import BernoulliNB
import time
features=['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday', 'BAYVIEW', 'CENTRAL', 'INGLESIDE', 'MISSION',  
 'NORTHERN', 'PARK', 'RICHMOND', 'SOUTHERN', 'TARAVAL', 'TENDERLOIN']  
X_train, X_test, y_train, y_test = train_test_split(trainData[features], trainData['crime'], train_size=0.6)  
NB = BernoulliNB()  
nbStart = time.time()  
NB.fit(X_train, y_train)  
nbCostTime = time.time() - nbStart  
print(X_test.shape)  
propa = NB.predict_proba(X_test)   #X_test为263415*17; 那么该行就是将263415分到39种犯罪类型中,每个样本被分到每一种的概率  
print("朴素贝叶斯建模%.2f秒"%(nbCostTime))  
predicted = np.array(propa)  
logLoss=log_loss(y_test, predicted)  
print("朴素贝叶斯的log损失为:%.6f"%logLoss) 
输出:
(351220, 17)
朴素贝叶斯建模0.87秒
朴素贝叶斯的log损失为:2.615733




凤凰新闻的文章


package com.ifeng.classify.Util;

import java.io.Serializable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class NativeBayes implements Serializable {
    
   /**
    * 序列化ID
    */
    private static final long serialVersionUID = -5809782578272943999L;
    
     /**
     * 默认频率
     */
    private double defaultFreq = 0.1;

    /**
     * 训练数据的比例
     */
    private Double trainingPercent = 0.8;

    private Map<String, List<String>> files_all = new HashMap<String, List<String>>();

    private Map<String, List<String>> files_train = new HashMap<String, List<String>>();

    private Map<String, List<String>> files_test = new HashMap<String, List<String>>();

    public NativeBayes() {

    }

    /**
     * 每个分类的频数
     */
    private Map<String, Integer> classFreq = new HashMap<String, Integer>();

    /**
     * 每个分类所占的百分比     先验概率 p(yi)
     */
    private Map<String, Double> ClassProb = new HashMap<String, Double>();

    /**
     * 特征总数
     */
    private Set<String> WordDict = new HashSet<String>();

    /**
     * 每个分类中每个特征的频数
     */
    private Map<String, Map<String, Integer>> classFeaFreq = new HashMap<String, Map<String, Integer>>();

    /**
     * 每个分类中每个特征的概率    p(xi/yi)
     */
    private Map<String, Map<String, Double>> ClassFeaProb = new HashMap<String, Map<String, Double>>();

    /**
     * 每个分类默认的概率
     */
    private Map<String, Double> ClassDefaultProb = new HashMap<String, Double>();
    
    
    public double getDefaultFreq() {
        return defaultFreq;
    }

    public void setDefaultFreq(double defaultFreq) {
        this.defaultFreq = defaultFreq;
    }

    public Double getTrainingPercent() {
        return trainingPercent;
    }

    public void setTrainingPercent(Double trainingPercent) {
        this.trainingPercent = trainingPercent;
    }

    public Map<String, List<String>> getFiles_all() {
        return files_all;
    }

    public void setFiles_all(Map<String, List<String>> files_all) {
        this.files_all = files_all;
    }

    public Map<String, List<String>> getFiles_train() {
        return files_train;
    }

    public void setFiles_train(Map<String, List<String>> files_train) {
        this.files_train = files_train;
    }

    public Map<String, List<String>> getFiles_test() {
        return files_test;
    }

    public void setFiles_test(Map<String, List<String>> files_test) {
        this.files_test = files_test;
    }

    public Map<String, Integer> getClassFreq() {
        return classFreq;
    }

    public void setClassFreq(Map<String, Integer> classFreq) {
        this.classFreq = classFreq;
    }

    public Map<String, Double> getClassProb() {
        return ClassProb;
    }

    public void setClassProb(Map<String, Double> classProb) {
        ClassProb = classProb;
    }

    public Set<String> getWordDict() {
        return WordDict;
    }

    public void setWordDict(Set<String> wordDict) {
        WordDict = wordDict;
    }

    public Map<String, Map<String, Integer>> getClassFeaFreq() {
        return classFeaFreq;
    }

    public void setClassFeaFreq(Map<String, Map<String, Integer>> classFeaFreq) {
        this.classFeaFreq = classFeaFreq;
    }

    public Map<String, Map<String, Double>> getClassFeaProb() {
        return ClassFeaProb;
    }

    public void setClassFeaProb(Map<String, Map<String, Double>> classFeaProb) {
        ClassFeaProb = classFeaProb;
    }

    public Map<String, Double> getClassDefaultProb() {
        return ClassDefaultProb;
    }

    public void setClassDefaultProb(Map<String, Double> classDefaultProb) {
        ClassDefaultProb = classDefaultProb;
    }
}
util

package com.ifeng.classify.trainModel;

import com.ifeng.classify.Util.NativeBayes;

import java.io.File;
import java.io.FileNotFoundException;
import java.util.*;
import java.util.Map.Entry;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class TrainModel {
    
    private static String dataDir = "E:/data/data";
    
    /**
     * 将数据分为训练数据和测试数据
     * 
     * @param 
     */
    public static void splitData(NativeBayes nativeBayes) {
        // 用文件名区分类别
        File f = new File(dataDir);
        File[] files = f.listFiles();
        assert files != null;
        for (File file : files) {
            String fname = file.getName().replaceAll(".txt", "");
            ArrayList<String> list = new ArrayList<String>();
            Scanner scanner = null;
            try {
                scanner = new Scanner(file);
                while(scanner.hasNext()){
                    String line = scanner.nextLine().trim();
                    list.add(line);
                }
            } catch (FileNotFoundException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
                if (nativeBayes.getFiles_all().containsKey(fname)) {
                    nativeBayes.getFiles_all().get(fname).addAll(list);
                } else {
                    nativeBayes.getFiles_all().put(fname, list);
                }
        }

        System.out.println("统计数据:");
        for (Entry<String, List<String>> entry : nativeBayes.getFiles_all().entrySet()) {
            String cname = entry.getKey();
            List<String> value = entry.getValue();
            // System.out.println(cname + " : " + value.size());

            // 训练集
            List<String> train = new ArrayList<String>();
            // 测试集
            List<String> test = new ArrayList<String>();

            for (String str : value) {
                if (Math.random() <= nativeBayes.getTrainingPercent()) {// 80%用来训练 , 20%测试
                    train.add(str);
                } else {
                    test.add(str);
                }
            }

            nativeBayes.getFiles_train().put(cname, train);
            nativeBayes.getFiles_test().put(cname, test);
        }

        System.out.println("所有文件数:");
        printStatistics(nativeBayes.getFiles_all());
        System.out.println("训练文件数:");
        printStatistics(nativeBayes.getFiles_train());
        System.out.println("测试文件数:");
        printStatistics(nativeBayes.getFiles_test());

    }
    
    
    
    /**
     * 将数据分为训练数据和测试数据
     * 
     * @param dataDir
     */
    public static void splitDataTwo(NativeBayes nativeBayes, String dataDir) {
        // 用文件名区分类别
        Pattern pat = Pattern.compile("\d+([a-z]+?)\.");
        dataDir = "testdata/allfiles";
        File f = new File(dataDir);
        File[] files = f.listFiles();
        assert files != null;
        for (File file : files) {
            String fname = file.getName();
            Matcher m = pat.matcher(fname);
            if (m.find()) {
                String cname = m.group(1);
                if (nativeBayes.getFiles_all().containsKey(cname)) {
                    nativeBayes.getFiles_all().get(cname).add(file.toString());
                } else {
                    List<String> tmp = new ArrayList<String>();
                    tmp.add(file.toString());
                    nativeBayes.getFiles_all().put(cname, tmp);
                }
            } else {
                System.out.println("err: " + file);
            }
        }

        System.out.println("统计数据:");
        for (Entry<String, List<String>> entry : nativeBayes.getFiles_all().entrySet()) {
            String cname = entry.getKey();
            List<String> value = entry.getValue();
            // System.out.println(cname + " : " + value.size());

            List<String> train = new ArrayList<String>();
            List<String> test = new ArrayList<String>();

            for (String str : value) {
                if (Math.random() <= nativeBayes.getTrainingPercent()) {// 80%用来训练 , 20%测试
                    train.add(str);
                } else {
                    test.add(str);
                }
            }

            nativeBayes.getFiles_train().put(cname, train);
            nativeBayes.getFiles_test().put(cname, test);
        }

        System.out.println("所有文件数:");
        printStatistics(nativeBayes.getFiles_all());
        System.out.println("训练文件数:");
        printStatistics(nativeBayes.getFiles_train());
        System.out.println("测试文件数:");
        printStatistics(nativeBayes.getFiles_test());

    }
    
    
    /**
     * 加载训练数据
     */
    public static void loadTrainData(NativeBayes nativeBayes){
        for (Entry<String, List<String>> entry : nativeBayes.getFiles_train().entrySet()) {
            //{体育:[11,,22,33]}
            String classname = entry.getKey();
            List<String> docs = entry.getValue();

            nativeBayes.getClassFreq().put(classname, docs.size());

            Map<String, Integer> feaFreq = new HashMap<String, Integer>();
            nativeBayes.getClassFeaFreq().put(classname, feaFreq);           //ClassFeaFreq 每个分类中每个特征的频数

            for (String doc : docs) {
                String[] words = doc.split(" ");
//                String[] words = null;
                for (String word : words) {
                    nativeBayes.getWordDict().add(word);
                    if(feaFreq.containsKey(word)){
                        int num = feaFreq.get(word) + 1;
                        feaFreq.put(word, num);
                    }else{
                        feaFreq.put(word, 1);
                    }
                }
            }    
        }
        System.out.println(nativeBayes.getClassFreq().size()+" 分类, " + nativeBayes.getWordDict().size()+" 特征词");
    }

    
    /**
     * 模型训练
     */
    public static void createModel(NativeBayes nativeBayes) {
        double sum = 0.0;

        //每个分类的频数相加
        for (Entry<String, Integer> entry : (nativeBayes.getClassFreq().entrySet())) {
            sum+=entry.getValue();
        }

        //每个分类的频率
        for (Entry<String, Integer> entry : nativeBayes.getClassFreq().entrySet()) {
            nativeBayes.getClassProb().put(entry.getKey(), entry.getValue()/sum);
        }

        //循环类--->Map<String, Map<String, Integer>> ClassFeaFreq
        for (Entry<String, Map<String, Integer>> entry : nativeBayes.getClassFeaFreq().entrySet()) {
            //sum是一个类下所有的特征总和数
            sum = 0.0;
            //
            String classname = entry.getKey();

            //循环一个类下的所有 特征map
            for (Entry<String, Integer> entry_1 : entry.getValue().entrySet()){
                sum += entry_1.getValue();
            }

            //不做平滑处理
            double newsum = sum ;

            // 用于做平滑处理,防止分母为零
//            double newsum = sum + nativeBayes.getWordDict().size()*nativeBayes.getDefaultFreq();

            // 在训练集中每个分类中每个特征词出现的概率值          p(xi/yi)
            Map<String, Double> feaProb = new HashMap<String, Double>();

            nativeBayes.getClassFeaProb().put(classname, feaProb);
            
            for (Entry<String, Integer> entry_1 : entry.getValue().entrySet()){
                String word = entry_1.getKey();
                //不做平滑处理
                feaProb.put(word, entry_1.getValue()/newsum);

                //做平滑处理
//                feaProb.put(word, (entry_1.getValue() + nativeBayes.getDefaultFreq()) /newsum);
            }

            nativeBayes.getClassDefaultProb().put(classname, nativeBayes.getDefaultFreq()/newsum);
        }
    }
    
    /**
     * 打印统计信息
     * 
     * @param m
     */
    public static void printStatistics(Map<String, List<String>> m) {
        for (Entry<String, List<String>> entry : m.entrySet()) {
            String cname = entry.getKey();
            List<String> value = entry.getValue();
            System.out.println(cname + " : " + value.size());
        }
        System.out.println("--------------------------------");
    }
    
    
    
}
trainModel
package com.ifeng.classify.trainModel;

import com.ifeng.classify.Util.NativeBayes;

import java.io.*;

public class NBModel {
    
    private static String path = "E:/data/NB.model.bin";
    
    /**
    * MethodName: SerializePerson 
    * Description: 序列化Person对象
    * @author 
    * @throws FileNotFoundException
    * @throws IOException
    */
    public static void SerializeNativeBayes(NativeBayes nativeBayes){
     // ObjectOutputStream 对象输出流,将Person对象存储到E盘的Person.txt文件中,完成对Person对象的序列化操作
     ObjectOutputStream oo = null;
    try {
        oo = new ObjectOutputStream(new FileOutputStream(
                 new File(path)));
        oo.writeObject(nativeBayes);
        System.out.println("NativeBayes对象序列化成功!");
        oo.close();
    } catch (FileNotFoundException e) {
        // TODO Auto-generated catch block
        e.printStackTrace();
    } catch (IOException e) {
        // TODO Auto-generated catch block
        e.printStackTrace();
    }
    }
    
    /**
    * MethodName: DeserializePerson 
    * Description: 反序列Perons对象
    * @author 
    * @return
    * @throws Exception
    * @throws IOException
    */
    public static NativeBayes DeserializeNativeBayes(){
        ObjectInputStream ois = null;
        NativeBayes nativeBayes = null;
        try{
            ois = new ObjectInputStream(new FileInputStream(
                    new File(path)));
            nativeBayes = (NativeBayes) ois.readObject();
            System.out.println("NativeBayes对象反序列化成功!");
        }catch(Exception e){
            e.printStackTrace();
        }
        
     return nativeBayes;
    }
}
trainModel

package com.ifeng.classify.evaluate;

import java.util.List;

public class CheckUp {
    /**
     * 计算准确率
     * @param reallist 真实类别
     * @param pridlist 预测类别
     */
    public static void Evaluate(List<String> reallist, List<String> pridlist){
        double correctNum = 0.0;
        for (int i = 0; i < reallist.size(); i++) {
            if(reallist.get(i).equals(pridlist.get(i))){
                correctNum += 1;
            }
        }
        double accuracy = correctNum / reallist.size();
        System.out.println("准确率为:" + accuracy);
    }

    /**
     * 计算精确率和召回率
     * @param reallist
     * @param pridlist
     * @param classname
     */
    public static void CalPreRec(List<String> reallist, List<String> pridlist, String classname){
        double correctNum = 0.0;
        double allNum = 0.0;//测试数据中,某个分类的文章总数
        double preNum = 0.0;//测试数据中,预测为该分类的文章总数

        for (int i = 0; i < reallist.size(); i++) {
            if(reallist.get(i).equals(classname)){
                allNum += 1;
                if(reallist.get(i).equals(pridlist.get(i))){
                    correctNum += 1;
                }
            }
            if(pridlist.get(i).equals(classname)){
                preNum += 1;
            }
        }
        System.out.println(classname + " 精确率(跟预测分类比较):" + correctNum / preNum + " 召回率(跟真实分类比较):" + correctNum / allNum);
    }
}
evaluate

package com.ifeng.classify.evaluate;

import com.ifeng.classify.Util.NativeBayes;
import com.ifeng.classify.trainModel.TrainModel;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map.Entry;

public class predict {
    /**
     * 用模型进行预测
     * 用于训练测试样本
     */
    public static void PredictTestData(NativeBayes nativeBayes) {
        List<String> reallist=new ArrayList<String>();
        List<String> pridlist=new ArrayList<String>();

        for (Entry<String, List<String>> entry : nativeBayes.getFiles_test().entrySet()) {
            String realclassname = entry.getKey();
            List<String> files = entry.getValue();

            for (String file : files) {
                reallist.add(realclassname);


                List<String> classnamelist=new ArrayList<String>();
                List<Double> scorelist=new ArrayList<Double>();
                for (Entry<String, Double> entry_1 : nativeBayes.getClassProb().entrySet()) {
                    String classname = entry_1.getKey();
                  //先验概率
                    Double score = Math.log(entry_1.getValue());

                    String[] words = file.split(" ");
//                    String[] words = null;
                    for (String word : words) {
                        //在全集则计算该Word权重
                        if(!nativeBayes.getWordDict().contains(word)){
                            continue;
                        }

                        if(nativeBayes.getClassFeaProb().get(classname).containsKey(word)){
                            score += Math.log(nativeBayes.getClassFeaProb().get(classname).get(word));
                        }else{
                            score += Math.log(nativeBayes.getClassDefaultProb().get(classname));
                        }
                    }

                    classnamelist.add(classname);
                    scorelist.add(score);
                    
                }

                Double maxProb = Collections.max(scorelist);
                int idx = scorelist.indexOf(maxProb);
                pridlist.add(classnamelist.get(idx));
            }
        }

        CheckUp.Evaluate(reallist, pridlist);

        for (String cname : nativeBayes.getFiles_test().keySet()) {
            CheckUp.CalPreRec(reallist, pridlist, cname);
        }

    }
    
    
    public static void main(String[] args) {
        NativeBayes bayes = new NativeBayes();
        TrainModel.splitData(bayes);
        TrainModel.loadTrainData(bayes);
        TrainModel.createModel(bayes);
        predict.PredictTestData(bayes);
//        NBModel.SerializeNativeBayes(bayes);
//        NBModel.DeserializeNativeBayes();

    }
    
    
}
evaluate


统计数据:
所有文件数:
科技 : 10000
社会 : 10000
娱乐 : 10000
汽车 : 10000
体育 : 10000
教育 : 10000
时政 : 10000
时尚 : 10000
游戏 : 10000
财经 : 10000
--------------------------------
训练文件数:
科技 : 7946
社会 : 8016
娱乐 : 8062
汽车 : 8041
体育 : 7995
教育 : 7962
时政 : 8004
时尚 : 7906
游戏 : 7922
财经 : 7955
--------------------------------
测试文件数:
科技 : 2054
社会 : 1984
娱乐 : 1938
汽车 : 1959
体育 : 2005
教育 : 2038
时政 : 1996
时尚 : 2094
游戏 : 2078
财经 : 2045
--------------------------------
10 分类, 325496 特征词
准确率为:0.9202119756327076
科技 精确率(跟预测分类比较):0.8898305084745762 召回率(跟真实分类比较):0.9201557935735151
社会 精确率(跟预测分类比较):0.8351111111111111 召回率(跟真实分类比较):0.9470766129032258
娱乐 精确率(跟预测分类比较):0.8614547253834736 召回率(跟真实分类比较):0.8983488132094943
汽车 精确率(跟预测分类比较):0.9768177028451 召回率(跟真实分类比较):0.9464012251148545
体育 精确率(跟预测分类比较):0.9811512990320937 召回率(跟真实分类比较):0.9605985037406484
教育 精确率(跟预测分类比较):0.945646703573226 召回率(跟真实分类比较):0.92198233562316
时政 精确率(跟预测分类比较):0.9029850746268657 召回率(跟真实分类比较):0.8486973947895792
时尚 精确率(跟预测分类比较):0.9110588235294118 召回率(跟真实分类比较):0.9245463228271251
游戏 精确率(跟预测分类比较):0.9805970149253731 召回率(跟真实分类比较):0.9485081809432147
财经 精确率(跟预测分类比较):0.9344346928239545 召回率(跟真实分类比较):0.8850855745721271

 
原文地址:https://www.cnblogs.com/wqbin/p/10235974.html