LibSvm流程及java代码测试

使用libSvm实现文本分类的基本过程,此文参考 使用libsvm实现文本分类 对前期数据准备及后续的分类测试进行了验证,同时对文中作者的分词组件修改成hanLP分词,对数字进行过滤,仅保留长度大于1的词进行处理。

转上文作者写的分类流程:

  1. 选择文本训练数据集和测试数据集:训练集和测试集都是类标签已知的;
  2. 训练集文本预处理:这里主要包括分词、去停用词、建立词袋模型(倒排表);
  3. 选择文本分类使用的特征向量(词向量):最终的目标是使得最终选出的特征向量在多个类别之间具有一定的类别区分度,可以使用相关有效的技术去实现特征向量的选择,由于分词后得到大量的词,通过选择降维技术能很好地减少计算量,还能维持分类的精度;
  4. 输出libsvm支持的量化的训练样本集文件:类别名称、特征向量中每个词元素分别到数字编号的映射转换,以及基于类别和特征向量来量化文本训练集,能够满足使用libsvm训练所需要的数据格式;
  5. 测试数据集预处理:同样包括分词(需要和训练过程中使用的分词器一致)、去停用词、建立词袋模型(倒排表),但是这时需要加载训练过程中生成的特征向量,用特征向量去排除多余的不在特征向量中的词(也称为降维);
  6. 输出libsvm支持的量化的测试样本集文件:格式和训练数据集的预处理阶段的输出相同;
  7. 使用libsvm训练文本分类器:使用训练集预处理阶段输出的量化的数据集文件,这个阶段也需要做很多工作(后面会详细说明),最终输出分类模型文件;
  8. 使用libsvm验证分类模型的精度:使用测试集预处理阶段输出的量化的数据集文件,和分类模型文件来验证分类的精度;
  9. 分类模型参数寻优:如果经过libsvm训练出来的分类模型精度很差,可以通过libsvm自带的交叉验证(Cross Validation)功能来实现参数的寻优,通过搜索参数取值空间来获取最佳的参数值,使分类模型的精度满足实际分类需要。

文本预处理阶段,增加了基于hanLP的分词,代码如下:

/**
 * 使用hanlp进行分词
 * Created by zhouyh on 2018/5/30.
 */
public class HanLPDocumentAnalyzer extends AbstractDocumentAnalyzer implements DocumentAnalyzer {

    private static final Log LOG = LogFactory.getLog(HanLPDocumentAnalyzer.class);

    public HanLPDocumentAnalyzer(ConfigReadable configuration) {
        super(configuration);
    }

    @Override
    public Map<String, Term> analyze(File file) {
        String doc = file.getAbsolutePath();
        LOG.debug("Process document: file=" + doc);
        Map<String, Term> terms = Maps.newHashMap();
        BufferedReader br = null;
        try {
            br = new BufferedReader(new InputStreamReader(new FileInputStream(file), charSet));
            String line = null;
            while((line = br.readLine()) != null) {
                LOG.debug("Process line: " + line);
                List<com.hankcs.hanlp.seg.common.Term> termList = HanLP.segment(line);
                if (termList!=null && termList.size()>0){
                    for (com.hankcs.hanlp.seg.common.Term hanLPTerm : termList){
                        String word = hanLPTerm.word;
                        if (!word.isEmpty() && !super.isStopword(word)){
                            if (word.trim().length()>1){
                                Pattern compile = Pattern.compile("(\d+\.\d+)|(\d+)|([\uFF10-\uFF19]+)");
                                Matcher matcher = compile.matcher(word);
                                if (!matcher.find()){
                                    Term term = terms.get(word);
                                    if (term == null){
                                        term = new TermImpl(word);
                                        terms.put(word, term);
                                    }
                                    term.incrFreq();
                                }
                            }
                        } else {
                            LOG.debug("Filter out stop word: file=" + file + ", word=" + word);
                        }
                    }
                }
            }
        } catch (IOException e) {
            throw new RuntimeException("", e);
        } finally {
            try {
                if(br != null) {
                    br.close();
                }
            } catch (IOException e) {
                LOG.warn(e);
            }
            LOG.debug("Done: file=" + file + ", termCount=" + terms.size());
        }
        return terms;
    }

    public static void main(String[] args){
        String filePath = "/Users/zhouyh/work/yanfa/xunlianji/UTF8/train/ClassFile/C000008/0.txt";
        HanLPDocumentAnalyzer hanLPDocumentAnalyzer = new HanLPDocumentAnalyzer(new Configuration());
        hanLPDocumentAnalyzer.analyze(new File(filePath));
        String str = "测试hanLP分词";
        System.out.println(str);
//        Pattern compile = Pattern.compile("(\d+\.\d+)|(\d+)|([\uFF10-\uFF19]+)");
//        Matcher matcher = compile.matcher("9402");
//        if (matcher.find()){
//            System.out.println(matcher.group());
//        }
    }
}
View Code

这里对原作者提供的训练集资源做了合并,将训练集扩大到10个类别,每个类别的8000文本中,前6000文本作为训练集,后2000文本作为测试集,文本结构如下图所示:

 测试集中是同样的结构。

生成的特征向量与libsvm需要的训练集格式如下面所示:

libsvm训练集格式文档:

针对测试集也通过上述方式处理。

使用libSvm训练分类文本

文本转换:

./svm-scale -l 0 -u 1 /Users/zhouyh/work/yanfa/xunlianji/UTF8/heji/train.txt > /Users/zhouyh/work/yanfa/xunlianji/UTF8/heji/train-scale.txt

测试集也做同样转换:

./svm-scale -l 0 -u 1 /Users/zhouyh/work/yanfa/xunlianji/UTF8/heji/test.txt > /Users/zhouyh/work/yanfa/xunlianji/UTF8/heji/test-scale.txt

进行模型训练,此部分耗时较长:

./svm-train -h 0 -t 0 /Users/zhouyh/work/yanfa/xunlianji/UTF8/heji/train-scale.txt /Users/zhouyh/work/yanfa/xunlianji/UTF8/heji/model.txt

训练过程如下图所示:

训练完成会生成model文件

采用预先处理好的测试文本进行分类测试:

./svm-predict /Users/zhouyh/work/yanfa/xunlianji/UTF8/heji/test-scale.txt /Users/zhouyh/work/yanfa/xunlianji/UTF8/heji/model.txt /Users/zhouyh/work/yanfa/xunlianji/UTF8/heji/predict.txt

得到结果为:Accuracy = 81.6568% (16333/20002) (classification) 

整体流程做完,得到文件如下图所列:

至此,仿照原作者的思路,对libsvm的分类流程做了一次实践。

JAVA代码测试

建立相关java项目,引入libsvm的jar包,我这里采用maven搭建,引入jar包:

<!-- https://mvnrepository.com/artifact/tw.edu.ntu.csie/libsvm -->
      <!-- libsvm jar包 -->
      <dependency>
          <groupId>tw.edu.ntu.csie</groupId>
          <artifactId>libsvm</artifactId>
          <version>3.17</version>
      </dependency>

同时要把libsvm包中的svm_predict.java及svm_train.java引入,并对svm_predict.java的类做简单改动,将预测的结果值返回,测试代码如下:

public class LibSvmAlgorithm {

    public static void main(String[] args){
        String[] testArgs = {"/Users/zhouyh/work/yanfa/xunlianji/UTF8/heji/test-scale.txt", "/Users/zhouyh/work/yanfa/xunlianji/UTF8/heji/model.txt", "/Users/zhouyh/work/yanfa/xunlianji/UTF8/heji/predict1.txt"};
        try {
            Double accuracy = svm_predict.main(testArgs);
            System.out.println(accuracy);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
原文地址:https://www.cnblogs.com/yhzhou/p/9114958.html