SVM+HOG特征训练分类器

#1,概念

    在机器学习领域,支持向量机SVM(Support Vector Machine)是一个有监督的学习模型,通常用来进行模式识别、分类、以及回归分析。

    SVM的主要思想可以概括为两点:⑴它是针对线性可分情况进行分析,对于线性不可分的情况,通过使用非线性映射算法将低维输入空间线性不可分的样本转化为高维特征空间使其线性可分,从而 使得高维特征空间采用线性算法对样本的非线性特征进行线性分析成为可能;

  方向梯度直方图(Histogram of Oriented Gradient, HOG)特征是一种在计算机视觉和图像处理中用来进行物体检测的特征描述子。HOG特征通过计算和统计图像局部区域的梯度方向直方图来构成特征。

#2,代码和输入文件截图

  1 #include <iostream>    
  2 #include <fstream>    
  3 #include <string>    
  4 #include <vector> 
  5 #include <stdlib.h>
  6 #include <tchar.h>
  7 #include <windows.h>
  8    
  9 #include <opencv2/core/core.hpp>
 10 #include <opencv2/ml/ml.hpp>
 11 #include <opencv2/highgui/highgui.hpp>
 12 #include <opencv2/objdetect/objdetect.hpp>
 13 #include <opencv2/imgproc/imgproc.hpp>
 14 
 15 using namespace cv;    
 16 using namespace std;    
 17    
 18 int getfilepath(string txtpath,vector<string>& img_path, vector<int>& roi_sample_class, vector<vector<Rect > >& roi_sample_rect);  
 19   
 20 int main(int argc, char** argv)  {  
 21     //winSize窗口大小
 22     int ImgWidht = 48;  
 23     int ImgHeight = 48; 
 24     
 25     //Sample总数
 26     int num_sample_roi=0;
 27 
 28     //图片路径,每张图片中的sample类别和每个sample的rect参数
 29     vector<string> img_path;
 30     vector<int> roi_sample_class;
 31     vector<vector<Rect > > roi_sample_rect;
 32     
 33     //标记文件txt名称和路径
 34     string filepath = "E:/svm/";
 35     string txtpath = string(filepath) + "TrainingData.txt";
 36     
 37     //获得图片路径,sample类别,每个sample的rect参数,返回Sample总数(ROI总数)
 38     num_sample_roi=getfilepath(txtpath, img_path, roi_sample_class, roi_sample_rect);
 39     
 40     //测试img_path, roi_sample_class, roi_sample_rect
 41     //cout << "img_path[0]= " << img_path[0] << "
 img_path[50]= " << img_path[50] << endl;
 42     //cout << "roi_sample_class[0]= " << roi_sample_class[0] << "
 roi_sample_class[150]= " << roi_sample_class[150] << endl;
 43     //cout << "roi_sample_rect[0][0]= " << roi_sample_rect[0][0] << endl;
 44     //system("Pause");
 45 
 46     //HOG特征矩阵,sample类别矩阵
 47     //cout << "num_sample_roi= " << num_sample_roi << endl;
 48     //system("Pause");
 49     Mat sample_feature_mat(num_sample_roi, 900, CV_32FC1);//900=(win_height/8-1)*(win_width/8-1)*(2*2)*9;
 50     Mat sample_class_mat(num_sample_roi, 1, CV_32SC1);    //样本类别       
 51 
 52     //原图片和训练图片
 53     Mat orig_img;    
 54     Mat train_img; //= Mat::zeros(ImgWidht, ImgHeight, CV_8UC3);//需要分析的图片 
 55 
 56     //sample指示子
 57     unsigned long n_sample = 0;
 58     
 59     //对图片循环 
 60     for( string::size_type i = 0; i != img_path.size(); i++ )  
 61     {    
 62         orig_img = imread(img_path[i].c_str(), 1); 
 63         if(orig_img.empty()){
 64             cout<<"Can not load the image: "<<img_path[i]<<endl;
 65             continue;
 66         }
 67         
 68         //端口的提示信息
 69         cout<<"***processing***"<<img_path[i].c_str()<<endl;        
 70         
 71         //每个sample都要计算hog特征
 72         for (size_t j = 0; j != roi_sample_rect[i].size(); j++){
 73             //取ROI,归一化
 74             Mat handle_src=orig_img(roi_sample_rect[i][j]);
 75             resize(handle_src, train_img, Size(ImgWidht, ImgHeight));
 76             
 77             //申明描述子,每个参数的含义见笔记
 78             HOGDescriptor hog(Size(ImgWidht,ImgHeight),Size(16,16),Size(8,8),Size(8,8), 9);
 79             
 80             //描述子申请内存并计算
 81             vector<float> descriptors; 
 82             hog.compute(train_img, descriptors);
 83             
 84             //为当前sample的所有hog descriptor申请内存
 85             //sample_feature_mat[n_sample].resize(descriptors.size(),CV_32FC1);
 86             
 87             //输出hog特征个数
 88             //cout<<"HOG dims: "<<descriptors.size()<<endl; 
 89             
 90             //每个sample的hog特征个数    
 91             for (vector<float>::size_type k = 0; k != descriptors.size(); k++)
 92                 sample_feature_mat.at<float>(n_sample, k) = descriptors[k]; 
 93             
 94             //int num_class=i/100;
 95             sample_class_mat.at<int>(n_sample, 0) =  roi_sample_class[i];    
 96             cout<<"***end processing***"<<img_path[i].c_str()<<" "<<roi_sample_class[i]<<endl;
 97             
 98             n_sample++;
 99         }
100         
101     }    
102     
103     //SVM参数
104     Ptr<ml::SVM> svm = ml::SVM::create();
105     svm->setType(ml::SVM::C_SVC);
106     svm->setKernel(ml::SVM::RBF);
107     svm->setTermCriteria(TermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON));
108     
109     //SVM训练    
110     svm->train(sample_feature_mat, ml::ROW_SAMPLE, sample_class_mat);
111     svm->save( filepath+"SVM_DATA.xml" );      
112     
113     system("Pause");
114     return 0;    
115 }  
116 
117 
118 int getfilepath(string txtpath,vector<string>& img_path, vector<int>& roi_sample_class, vector<vector<Rect > >& roi_sample_rect){
119     int nLine = 0;    //图片计数器,每个类别100张图片。
120     int numroi=0;            //sample计数器,总的roi个数
121     
122     ifstream svm_data( txtpath );
123     char output[1000];
124     
125     while( !svm_data.eof() ) {    
126         svm_data >> output;//是不是>>遇见空格会自动截断?
127         string s0= string(output);
128         if( s0.length()>10 ){
129             if( nLine < 100 ){    
130                 roi_sample_class.push_back(0); 
131                 
132                 size_t bufname=s0.find(" ");
133                 string bufname1 = s0.substr(0, bufname);
134                 img_path.push_back( bufname1 ); 
135                 
136                 vector<Rect> obj_list;
137                 svm_data >> output;
138                 int obj_num=atoi(output);
139                 
140                 numroi=numroi+obj_num;
141                 int obj[4];
142                 for(int i=0; i<obj_num;i++)
143                 {
144                     for(int j=0;j<4;j++)
145                     {
146                         svm_data >> output;
147                         int p=atoi(output);
148                         obj[j]=p;
149                     }
150                     Rect tmp_obj=Rect(obj[0],obj[1],obj[2],obj[3]);
151                     obj_list.push_back(tmp_obj);
152                 }
153                 roi_sample_rect.push_back(obj_list);
154             }    
155             else if( nLine < 200 ){    
156                 roi_sample_class.push_back(1); 
157                 size_t bufname=s0.find(" ");
158                 string bufname1 = s0.substr(0, bufname);
159                 img_path.push_back( bufname1 ); 
160                 
161                 vector<Rect> obj_list;
162                 svm_data>>output;
163                 int obj_num=atoi(output);
164                 
165                 numroi=numroi+obj_num;
166                 int obj[4];
167                 for(int i=0; i<obj_num;i++)
168                 {
169                     for(int j=0;j<4;j++)
170                     {
171                         svm_data >> output;
172                         int p=atoi(output);
173                         obj[j]=p;
174                     }
175                     Rect tmp_obj=Rect(obj[0],obj[1],obj[2],obj[3]);
176                     obj_list.push_back(tmp_obj);
177                 }
178                 roi_sample_rect.push_back(obj_list);
179             }
180             else if( nLine < 300 ){    
181                 roi_sample_class.push_back(2); 
182                 size_t bufname=s0.find(" ");
183                 string bufname1 = s0.substr(0, bufname);
184                 img_path.push_back( bufname1 ); 
185                 
186                 vector<Rect> obj_list;
187                 svm_data>>output;
188                 int obj_num=atoi(output);
189                 
190                 numroi=numroi+obj_num;
191                 int obj[4];
192                 for(int i=0; i<obj_num;i++)
193                 {
194                     for(int j=0;j<4;j++)
195                     {
196                         svm_data >> output;
197                         int p=atoi(output);
198                         obj[j]=p;
199                     }
200                     Rect tmp_obj=Rect(obj[0],obj[1],obj[2],obj[3]);
201                     obj_list.push_back(tmp_obj);
202                 }
203                 roi_sample_rect.push_back(obj_list);
204             }
205             else{//(nLine < 400)
206                 roi_sample_class.push_back(3); 
207                 size_t bufname=s0.find(" ");
208                 string bufname1 = s0.substr(0, bufname);
209                 img_path.push_back( bufname1 ); 
210                 
211                 vector<Rect> obj_list;
212                 svm_data>>output;
213                 int obj_num=atoi(output);
214                 
215                 numroi=numroi+obj_num;
216                 int obj[4];
217                 for(int i=0; i<obj_num;i++)
218                 {
219                     for(int j=0;j<4;j++)
220                     {
221                         svm_data >> output;
222                         int p=atoi(output);
223                         obj[j]=p;
224                     }
225                     Rect tmp_obj=Rect(obj[0],obj[1],obj[2],obj[3]);
226                     obj_list.push_back(tmp_obj);
227                 }
228                 roi_sample_rect.push_back(obj_list);
229             }
230             nLine ++;        //计数
231         }    
232     } 
233     
234     svm_data.close();
235     //cout << "numroi= " << numroi<< endl;
236     return numroi;
237 }
SVM+HOG

我的输入文件格式:

得到的分类器xml文件和输入的数据文件TrainingData.txt是放在同一个文件夹下:

图片源文件是给的绝对目录,看代码就知道了。

原文地址:https://www.cnblogs.com/sophia-hxw/p/5686538.html