OpenCV KNN 之 使用方法

http://blog.csdn.net/WL2002200/article/details/43149229

OpenCV 中KNN构造函数如下。

[cpp] view plain copy
 
  1. C++: CvKNearest::CvKNearest()  
  2. C++: CvKNearest::CvKNearest(const Mat& trainData, const Mat& responses, const Mat& sam-  
  3. pleIdx=Mat(), bool isRegression=false, int max_k=32 )  

训练函数为:

[cpp] view plain copy
 
  1. C++: bool CvKNearest::train(  
  2.     const Mat& trainData, //训练数据  
  3.     const Mat& responses,//对应的响应值  
  4.     const Mat& sampleIdx=Mat(),//样本索引  
  5.     bool isRegression=false,//是否是回归,否则是分类问题  
  6.     int maxK=32, //最大K值  
  7.     bool updateBase=false//是否更新数据,是,则maxK需要小于原数据大小 )  

查找函数:

[cpp] view plain copy
 
  1. C++: float CvKNearest::find_nearest(  
  2. const Mat& samples,//按行存储的测试数据  
  3.  int k, //K 值  
  4. Mat* results=0,//预测结果  
  5. const float** neighbors=0, //近邻指针向量  
  6. Mat* neighborResponses=0, //近邻值  
  7. Mat* dist=0 //距离矩阵) const  
  8.   
  9. C++: float CvKNearest::find_nearest(  
  10. const Mat& samples,  
  11. int k,  
  12. Mat& results,  
  13. Mat& neighborResponses,  
  14. Mat& dists) const  

还有一些其他辅助函数,无关紧要,略去了。


opencv 有KNN 的示例,改写成C++ 版本如下:

[html] view plain copy
 
    1. #include <opencv2/core/core.hpp>  
    2. #include <opencv2/highgui/highgui.hpp>  
    3. #include <opencv2/ml/ml.hpp>  
    4.   
    5. int main( )  
    6. {  
    7.     const int K = 10;  
    8.     int i, j, k, accuracy;  
    9.     float response;  
    10.     int train_sample_count = 100;  
    11.     cv::RNG rng_state(-1);  
    12.     cv::Mat trainData(train_sample_count,2,CV_32FC1);  
    13.     cv::Mat trainClasses(train_sample_count,1,CV_32FC1);  
    14.     cv::Mat img(cv::Size(500,500),CV_8UC3,cv::Scalar::all (0));  
    15.     float _sample[2];  
    16.     cv::Mat sample(1,2,CV_32FC1,_sample);  
    17.   
    18.     cv::Mat trainData1, trainData2, trainClasses1, trainClasses2;  
    19.   
    20.     // form the training samples  
    21.     trainData1 = trainData.rowRange (0,train_sample_count/2);  
    22.     rng_state.fill (trainData1,CV_RAND_NORMAL,cv::Scalar(200,200),cv::Scalar(50,50));  
    23.   
    24.     trainData2 = trainData.rowRange (train_sample_count/2,train_sample_count);  
    25.     rng_state.fill (trainData2,CV_RAND_NORMAL,cv::Scalar(300,300),cv::Scalar(50,50));  
    26.   
    27.     trainClasses1 = trainClasses.rowRange (0,train_sample_count/2);  
    28.     trainClasses1.setTo (1);  
    29.   
    30.     trainClasses2 = trainClasses.rowRange (train_sample_count/2,train_sample_count);  
    31.     trainClasses2.setTo (2);  
    32.   
    33.     // learn classifier  
    34.     CvKNearest knn( trainData, trainClasses, cv::Mat(), false, K );  
    35.     cv::Mat nearests( 1, K, CV_32FC1);  
    36.   
    37.     for( i = 0; i img.rows; i++ )  
    38.     {  
    39.         for( j = 0; j img.cols; j++ )  
    40.         {  
    41.             sample.at<float>(0,0) = (float)j;  
    42.             sample.at<float>(0,1) = (float)i;  
    43.   
    44.             // estimate the response and get the neighbors' labels  
    45.             response = knn.find_nearest(sample,K,0,0,&nearests,0);  
    46.   
    47.             // compute the number of neighbors representing the majority  
    48.             for( k = 0, accuracy = 0; k K; k++ )  
    49.             {  
    50.                 if( nearests.at<float>(0,k) == response)  
    51.                     accuracy++;  
    52.             }  
    53.             // highlight the pixel depending on the accuracy (or confidence)  
    54.             img.at<cv::Vec3b>(i,j) = response == 1 ?  
    55.                         (accuracy > 5 ? cv::Vec3b(0,0,180) : cv::Vec3b(0,120,180)) :  
    56.                         (accuracy > 5 ? cv::Vec3b(0,180,0) : cv::Vec3b(0,120,120));  
    57.         }  
    58.     }  
    59.   
    60.     // display the original training samples  
    61.     for( i = 0; i train_sample_count/2; i++ )  
    62.     {  
    63.         cv::Point pt;  
    64.         pt.x = cvRound(trainData1.at<float>(i,0));  
    65.         pt.y = cvRound(trainData1.at<float>(i,1));  
    66.         cv::circle (img,pt,2,cv::Scalar(0,0,255),1,CV_FILLED);  
    67.         pt.x = cvRound(trainData2.at<float>(i,0));  
    68.         pt.y = cvRound(trainData2.at<float>(i,1));  
    69.         cv::circle (img,pt,2,cv::Scalar(0,255,0),1,CV_FILLED);  
    70.     }  
    71.   
    72.     cv::namedWindow( "classifier result", 1 );  
    73.     cv::imshow( "classifier result", img );  
    74.     cv::waitKey(0);  
    75.   
    76.     return 0;  
    77. }  
原文地址:https://www.cnblogs.com/jukan/p/7279521.html