libsvm代码阅读:关于Kernel类分析(转)

这一篇博文来分析下Kernel类,代码上很简单,一般都能看懂。Kernel类主要是为SVM的核函数服务的,里面实现了SVM常用的核函数,通过函数指针来使用这些核函数。

其中几个常用核函数如下所示:(一般情况下,使用RBF核函数能取得很好的效果)

关于基类QMatrix在Kernel中的作用并不明显,只是定义了一些纯虚函数,Kernel继承这些函数,Kernel只对swap_index进行了定义。其余的get_Q和get_QD在Kernel并没有用到。

[cpp]   view plain copy 在CODE上查看代码片 派生到我的代码片
<EMBED id=ZeroClipboardMovie_1 height=18 name=ZeroClipboardMovie_1 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=1&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. class QMatrix {  
  2. public:  
  3.     virtual Qfloat *get_Q(int column, int len) const = 0;//纯虚函数,在子类中实现,important!  
  4.     virtual double *get_QD() const = 0;  
  5.     virtual void swap_index(int i, int j) const = 0;  
  6.     virtual ~QMatrix() {}  
  7. };  

Kernel类的定义函数,比较简单就不细说。

[cpp]   view plain copy 在CODE上查看代码片 派生到我的代码片
<EMBED id=ZeroClipboardMovie_2 height=18 name=ZeroClipboardMovie_2 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=2&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. class Kernel: public QMatrix {  
  2. public:  
  3.     Kernel(int l, svm_node * const * x, const svm_parameter& param);  
  4.     virtual ~Kernel();  
  5.   
  6.     static double k_function(const svm_node *x, const svm_node *y,  
  7.                  const svm_parameter& param);  
  8.     virtual Qfloat *get_Q(int column, int len) const = 0;  
  9.     virtual double *get_QD() const = 0;  
  10.     virtual void swap_index(int i, int j) const // no so const...  
  11.     {  
  12.         swap(x[i],x[j]);  
  13.         if(x_square) swap(x_square[i],x_square[j]);  
  14.     }  
  15. protected:  
  16.   
  17.     double (Kernel::*kernel_function)(int i, int j) const;  
  18.   
  19. private:  
  20.     const svm_node **x;//用来指向样本数据,每次数据传入时通过克隆函数来实现,完全重新分配内存,主要是为处理多类着想  
  21.     double *x_square;//使用RBF 核才使用  
  22.   
  23.     // svm_parameter  
  24.     const int kernel_type;  
  25.     const int degree;  
  26.     const double gamma;  
  27.     const double coef0;  
  28.   
  29.     static double dot(const svm_node *px, const svm_node *py);  
  30.   
  31.     double kernel_linear(int i, int j) const  
  32.     {  
  33.         return dot(x[i],x[j]);  
  34.     }  
  35.     double kernel_poly(int i, int j) const  
  36.     {  
  37.         return powi(gamma*dot(x[i],x[j])+coef0,degree);  
  38.     }  
  39.     double kernel_rbf(int i, int j) const  
  40.     {  
  41.         return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));  
  42.     }  
  43.   
  44.     double kernel_sigmoid(int i, int j) const  
  45.     {  
  46.         return tanh(gamma*dot(x[i],x[j])+coef0);  
  47.     }  
  48.     double kernel_precomputed(int i, int j) const  
  49.     {  
  50.         return x[i][(int)(x[j][0].value)].value;  
  51.     }  
  52. };  

这个Kernel类的函数比较清晰,我直接把它的全部代码贴出。

全部代码如下:

[cpp]   view plain copy 在CODE上查看代码片 派生到我的代码片
<EMBED id=ZeroClipboardMovie_3 height=18 name=ZeroClipboardMovie_3 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=3&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
    1. //  
    2. // Kernel evaluation  
    3. //  
    4. // the static method k_function is for doing single kernel evaluation  
    5. // the constructor of Kernel prepares to calculate the l*l kernel matrix  
    6. // the member function get_Q is for getting one column from the Q Matrix  
    7. //  
    8. class QMatrix {  
    9. public:  
    10.     virtual Qfloat *get_Q(int column, int len) const = 0;  
    11.     virtual double *get_QD() const = 0;  
    12.     virtual void swap_index(int i, int j) const = 0;  
    13.     virtual ~QMatrix() {}  
    14. };  
    15.   
    16. class Kernel: public QMatrix {  
    17. public:  
    18.     Kernel(int l, svm_node * const * x, const svm_parameter& param);//构造函数  
    19.     virtual ~Kernel();  
    20.   
    21.     static double k_function(const svm_node *x, const svm_node *y,  
    22.                  const svm_parameter& param);  
    23.     virtual Qfloat *get_Q(int column, int len) const = 0;  
    24.     virtual double *get_QD() const = 0;  
    25.     virtual void swap_index(int i, int j) const // no so const...  
    26.     {  
    27.         swap(x[i],x[j]);  
    28.         if(x_square) swap(x_square[i],x_square[j]);  
    29.     }  
    30. protected:  
    31.   
    32.     double (Kernel::*kernel_function)(int i, int j) const;  
    33.   
    34. private:  
    35.     const svm_node **x;//用来指向样本数据,每次数据传入时通过克隆函数来实现,完全重新分配内存,主要是为处理多类着想  
    36.     double *x_square;//使用RBF 核才使用  
    37.   
    38.     // svm_parameter  
    39.     const int kernel_type;  
    40.     const int degree;  
    41.     const double gamma;  
    42.     const double coef0;  
    43.   
    44.     static double dot(const svm_node *px, const svm_node *py);  
    45.   
    46.     double kernel_linear(int i, int j) const  
    47.     {  
    48.         return dot(x[i],x[j]);  
    49.     }  
    50.     double kernel_poly(int i, int j) const  
    51.     {  
    52.         return powi(gamma*dot(x[i],x[j])+coef0,degree);  
    53.     }  
    54.     double kernel_rbf(int i, int j) const  
    55.     {  
    56.         return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));  
    57.     }  
    58.   
    59.     double kernel_sigmoid(int i, int j) const  
    60.     {  
    61.         return tanh(gamma*dot(x[i],x[j])+coef0);  
    62.     }  
    63.     double kernel_precomputed(int i, int j) const  
    64.     {  
    65.         return x[i][(int)(x[j][0].value)].value;  
    66.     }  
    67. };  
    68.   
    69. //构造函数,初始化类中的部分常量,指定核函数,克隆样本数据。如果使用RBF核函数,则计算x_square[i]  
    70. Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param)  
    71. :kernel_type(param.kernel_type), degree(param.degree),  
    72.  gamma(param.gamma), coef0(param.coef0)  
    73. {  
    74.     switch(kernel_type)  
    75.     {  
    76.         case LINEAR:  
    77.             kernel_function = &Kernel::kernel_linear;  
    78.             break;  
    79.         case POLY:  
    80.             kernel_function = &Kernel::kernel_poly;  
    81.             break;  
    82.         case RBF:  
    83.             kernel_function = &Kernel::kernel_rbf;  
    84.             break;  
    85.         case SIGMOID:  
    86.             kernel_function = &Kernel::kernel_sigmoid;  
    87.             break;  
    88.         case PRECOMPUTED:  
    89.             kernel_function = &Kernel::kernel_precomputed;  
    90.             break;  
    91.     }  
    92.   
    93.     clone(x,x_,l);//void clone(T*& dst, S* src, int n)  
    94.   
    95.     if(kernel_type == RBF)  
    96.     {  
    97.         x_square = new double[l];  
    98.         for(int i=0;i<l;i++)  
    99.             x_square[i] = dot(x[i],x[i]);  
    100.     }  
    101.     else  
    102.         x_square = 0;  
    103. }  
    104.   
    105. Kernel::~Kernel()  
    106. {  
    107.     delete[] x;  
    108.     delete[] x_square;  
    109. }  
    110.   
    111. double Kernel::dot(const svm_node *px, const svm_node *py)  
    112. {  
    113.     double sum = 0;  
    114.     while(px->index != -1 && py->index != -1)  
    115.     {  
    116.         if(px->index == py->index)  
    117.         {  
    118.             sum += px->value * py->value;  
    119.             ++px;  
    120.             ++py;  
    121.         }  
    122.         else  
    123.         {  
    124.             if(px->index > py->index)  
    125.                 ++py;  
    126.             else  
    127.                 ++px;  
    128.         }             
    129.     }  
    130.     return sum;  
    131. }  
    132.   
    133. double Kernel::k_function(const svm_node *x, const svm_node *y,  
    134.               const svm_parameter& param)  
    135. {  
    136.     switch(param.kernel_type)  
    137.     {  
    138.         case LINEAR:  
    139.             return dot(x,y);  
    140.         case POLY:  
    141.             return powi(param.gamma*dot(x,y)+param.coef0,param.degree);  
    142.         case RBF:  
    143.         {  
    144.             double sum = 0;  
    145.             while(x->index != -1 && y->index !=-1)  
    146.             {  
    147.                 if(x->index == y->index)  
    148.                 {  
    149.                     double d = x->value - y->value;  
    150.                     sum += d*d;  
    151.                     ++x;  
    152.                     ++y;  
    153.                 }  
    154.                 else  
    155.                 {  
    156.                     if(x->index > y->index)  
    157.                     {     
    158.                         sum += y->value * y->value;  
    159.                         ++y;  
    160.                     }  
    161.                     else  
    162.                     {  
    163.                         sum += x->value * x->value;  
    164.                         ++x;  
    165.                     }  
    166.                 }  
    167.             }  
    168.   
    169.             while(x->index != -1)  
    170.             {  
    171.                 sum += x->value * x->value;  
    172.                 ++x;  
    173.             }  
    174.   
    175.             while(y->index != -1)  
    176.             {  
    177.                 sum += y->value * y->value;  
    178.                 ++y;  
    179.             }  
    180.               
    181.             return exp(-param.gamma*sum);  
    182.         }  
    183.         case SIGMOID:  
    184.             return tanh(param.gamma*dot(x,y)+param.coef0);  
    185.         case PRECOMPUTED:  //x: test (validation), y: SV  
    186.             return x[(int)(y->value)].value;  
    187.         default:  
    188.             return 0;  // Unreachable   
    189.     }  
    190. }  
原文地址:https://www.cnblogs.com/Miliery/p/4394138.html