libsvm使用简介

libsvm是support vector machine的一种开源实现,采用了smo算法。源代码编写有独到之处,值得一睹。

常用结构

svm_node结构

定义了构成输入特征向量的元素,index为索引(= -1为最后一个元素),value为值,

public class svm_node implements java.io.Serializable
{
    public int index;
    public double value;
}

 借鉴了稀疏矩阵的表示方法。对于一个输入向量,定义为svm_node构成的一维数组

svm_node[] pa = {pa0, pa1};

 所有输入序列有一个二维数组表示

svm_node[][] datas = {pa, pb};

标记序列

就是一个double数组,对应于输入序列datas的每一维。

double[] labels = {1.0, -1.0};

svm_problem结构

定义了(X, Y)的训练样本结构

public class svm_problem implements java.io.Serializable
{
    public int l;
    public double[] y;
    public svm_node[][] x;
}

其中l是样本数量。

svm_parameter结构

定义了训练时的重要参数

public class svm_parameter implements Cloneable,java.io.Serializable
{
    /* svm_type */
    public static final int C_SVC = 0;
    public static final int NU_SVC = 1;
    public static final int ONE_CLASS = 2;
    public static final int EPSILON_SVR = 3;
    public static final int NU_SVR = 4;

    /* kernel_type */
    public static final int LINEAR = 0;
    public static final int POLY = 1;
    public static final int RBF = 2;
    public static final int SIGMOID = 3;
    public static final int PRECOMPUTED = 4;

    public int svm_type;
    public int kernel_type;
    public int degree;    // for poly
    public double gamma;    // for poly/rbf/sigmoid
    public double coef0;    // for poly/sigmoid

    // these are for training only
    public double cache_size; // in MB
    public double eps;    // stopping criteria
    public double C;    // for C_SVC, EPSILON_SVR and NU_SVR
    public int nr_weight;        // for C_SVC
    public int[] weight_label;    // for C_SVC
    public double[] weight;        // for C_SVC
    public double nu;    // for NU_SVC, ONE_CLASS, and NU_SVR
    public double p;    // for EPSILON_SVR
    public int shrinking;    // use the shrinking heuristics
    public int probability; // do probability estimates

    public Object clone() 
    {
        try 
        {
            return super.clone();
        } catch (CloneNotSupportedException e) 
        {
            return null;
        }
    }

}

主要分为两大类参数:分类器的核函数性质和训练算法SMO的一些参数,包括精度啊等等

训练

通过调用svm.svm_train()训练模型

public static svm_model svm_train(svm_problem prob, svm_parameter param)

返回svm_model类对象表示训练得到的分类器

预测

通过svm.svm_predict()利用分类器进行预测

public static double svm_predict(svm_model model, svm_node[] x)

返回类别标记

实例代码如下,输入点pa = (10.0 10.0) ya = 1.0 pb = (-10.0, -10.0) yb = -1.0

测试点 (-0.1, 0)

 1 import libsvm.svm;
 2 import libsvm.svm_model;
 3 import libsvm.svm_node;
 4 import libsvm.svm_parameter;
 5 import libsvm.svm_problem;
 6 
 7 public class SvmTest {
 8     public static void main(String[] args) {
 9         
10         svm_node pa0 = new svm_node();
11         pa0.index = 0;
12         pa0.value = 10.0;
13         
14         svm_node pa1 = new svm_node();
15         pa1.index = -1;
16         pa1.value = 10.0;
17         
18         svm_node pb0 = new svm_node();
19         pb0.index = 0;
20         pb0.value = -10.0;
21         
22         svm_node pb1 = new svm_node();
23         pb1.index = -1;
24         pb1.value = -10.0;
25         
26         svm_node[] pa = {pa0, pa1};
27         svm_node[] pb = {pb0, pb1};
28         
29         svm_node[][] datas = {pa, pb};
30         
31         double[] labels = {1.0, -1.0};
32         
33         svm_problem problem = new svm_problem();
34         problem.l = 2;
35         problem.x = datas;
36         problem.y = labels;
37         
38         svm_parameter param = new svm_parameter();
39         param.svm_type = svm_parameter.C_SVC;
40         param.kernel_type = svm_parameter.LINEAR;
41         param.cache_size = 100;
42         param.eps = 0.00001;
43         param.C = 1;
44         
45         
46         System.out.println(svm.svm_check_parameter(problem, param));
47         svm_model model = svm.svm_train(problem, param);
48         
49         svm_node pc0 = new svm_node();
50         pc0.index = 0;
51         pc0.value = -0.1;
52         svm_node pc1 = new svm_node();
53         pc1.index = -1;
54         pc1.value = 0;
55         
56         svm_node[] pc = {pc0, pc1};
57         
58         System.out.println(svm.svm_predict(model, pc));
59     }
60 }
原文地址:https://www.cnblogs.com/zjgtan/p/3305720.html