OpenCV SVM

 1 #include <opencv2/core/core.hpp>  
 2 #include <opencv2/highgui/highgui.hpp>  
 3 #include <opencv2/ml/ml.hpp>  
 4 
 5 using namespace cv;
 6 
 7 int main()
 8 {
 9     // Data for visual representation  
10     int width = 512, height = 512;
11     Mat image = Mat::zeros(height, width, CV_8UC3);
12 
13     // Set up training data  
14     float labels[5] = { 1.0, -1.0, -1.0, -1.0, 1.0 };
15     Mat labelsMat(5, 1, CV_32FC1, labels);
16 
17 
18     float trainingData[5][2] = { { 501, 10 }, { 255, 10 }, { 501, 255 }, { 10, 501 }, { 501, 128 } };
19     Mat trainingDataMat(5, 2, CV_32FC1, trainingData);
20 
21     //设置支持向量机的参数  
22     CvSVMParams params;
23     params.svm_type = CvSVM::C_SVC;//SVM类型:使用C支持向量机  
24     params.kernel_type = CvSVM::LINEAR;//核函数类型:线性  
25     params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6);//终止准则函数:当迭代次数达到最大值时终止  
26 
27     //训练SVM  
28     //建立一个SVM类的实例  
29     CvSVM SVM;
30     //训练模型,参数为:输入数据、响应、XX、XX、参数(前面设置过)  
31     SVM.train(trainingDataMat, labelsMat, Mat(), Mat(), params);
32 
33     Vec3b green(0, 255, 0), blue(255, 0, 0);
34     //显示判决域  
35     for (int i = 0; i < image.rows; ++i)
36     for (int j = 0; j < image.cols; ++j)
37     {
38         Mat sampleMat = (Mat_<float>(1, 2) << i, j);
39         //predict是用来预测的,参数为:样本、返回值类型(如果值为ture而且是一个2类问题则返回判决函数值,否则返回类标签)、  
40         float response = SVM.predict(sampleMat);
41 
42         if (response == 1)
43             image.at<Vec3b>(j, i) = green;
44         else if (response == -1)
45             image.at<Vec3b>(j, i) = blue;
46     }
47 
48     //画出训练数据  
49     int thickness = -1;
50     int lineType = 8;
51     circle(image, Point(501, 10), 5, Scalar(0, 0, 0), thickness, lineType);//画圆  
52     circle(image, Point(255, 10), 5, Scalar(255, 255, 255), thickness, lineType);
53     circle(image, Point(501, 255), 5, Scalar(255, 255, 255), thickness, lineType);
54     circle(image, Point(10, 501), 5, Scalar(255, 255, 255), thickness, lineType);
55     circle(image, Point(501, 128), 5, Scalar(0, 0, 0), thickness, lineType);
56 
57     //显示支持向量  
58     thickness = 2;
59     lineType = 8;
60     //获取支持向量的个数  
61     int c = SVM.get_support_vector_count();
62 
63     for (int i = 0; i < c; ++i)
64     {
65         //获取第i个支持向量  
66         const float* v = SVM.get_support_vector(i);
67         //支持向量用到的样本点,用灰色进行标注  
68         circle(image, Point((int)v[0], (int)v[1]), 6, Scalar(128, 128, 128), thickness, lineType);
69     }
70 
71     imwrite("result.png", image);        // save the image   
72 
73     imshow("SVM Simple Example", image); // show it to the user  
74     waitKey(0);
75 
76 }
原文地址:https://www.cnblogs.com/hsy1941/p/8260648.html