代码学习discriminative patch(一)

 matlab code,对应论文:Unsupervised Discovery of Mid-Level Discriminative Patches Saurabh Singh, Abhinav Gupta, and Alexei A. Efros]

其中VisualEntityDetectors:是matlab中定义的类 

train for Discriminative patches for pascal 2007 subset

  1. train:
    • getPascalData() 从voc2011 xml文件中获取图像annotation存入PASCAL_DATA.mat
    • warpCrossValClusteringUnsup(1, USR, true);第三个参数在real train中应该是false,supposed to run on a cluster with shared file system
      • getdataset() load PASCAL_DATA.mat
      • getTrainValSplitForCategory(data, category)根据类别划分出正样本和负样本(负样本不包含category给出的类别)
      • processCategory()
        • getParamsForCategory(category) 获取所有可能用到的参数
        • getTrainValSplitUnsupervised(posData, negData)把正负样本各分一半 D1D2 N1N2用作交叉验证
        • processImages()
          • getRandomPatchesFromPyramid( ) 371x1984  742x1984
            •  sampleRandomPatches() 371x1984  742x1984

              convertToCanonicalSize() 将load的图像大小标准化 

              getGradientImage()对标准化后的图像进行梯度计算

              constructFeaturePyramidForImg()计算19个不同scale下hog特征 保存在pyramid中 19cell

              unentanglePyramid()依次将每个scale下的fea转换为1984=8x8x31维特征 得到11220x1984

              通过去除重叠部分多 没有梯度信息的patch最后得到正负(5/10)样本图像的patch371x1984  742x1984

          • clusterPatches() 正样本聚类得到95clusters(381/4)
          • refineClustersOverlap( )对聚类中心进行筛选得到42 剩余有效patch 193
            • selectClustAboveThresh()聚类中心下的patch少于一定数目 去除该cluster 57clusters
            • 去除每个聚类中心下patch中重复度大的patch 然后判断是否仍有足够数量的patch 决定是否留该聚类中心 42clusters

          • doTheIterations() 依次对D1之前每个cluster进行svmtrain 然后用D2的数据进行svmpredict 计算D2数据集下的得到的cluster下的patch(有某些限制)  并依次交换D1D2数据集进行重复运算 最后得到42x1984的svmtrain model
            • iterateTraining() 

初次进行firstTrainDetectors() 如果存在直接load

constructVisDetFromModels(firstTrainDir, trainingData.selectedClusters, params, '_det');

hardNegMineTrainDetectors(instanceId, thisIterDir, hardNegOut,detectors, trainingData)

detectPresenceUsingDetectors(instanceId, thisIterDir, detectionOut,detectors, trainingData.selectedClusters)

getTopNDetsPerCluster(detectionResult, maxOverlap, trainingData.trainSetPos, numTopN)

prepareDetectedPatchClusters()

calculateClusterCenters()

Prepare the data for the next iterations

    • assimilateCrossValBatchResults()

detect discriminative patches

  1. load model: 全部训练数据用上得到的模型svm训练的模型总积  5481clusters 1984dimension features
  2.  load data and get data info: size
  3. construct detection params: 初始化参数 
  4. detector.detectPresenceUsingEntDet(): 用model对该图像中patch属于哪个clusters计算score
    displayPatchBox: 画出patch box 
    • constructFeaturePyramid() 19cell 不同尺度
    • getDetectionsForEntDets ()
      • unentanglePyramid() 13011x1984 hog特征
      • Mysvmpredicted()在5841x1984model上计算出每个patch在cluster的score  
      • SelectTopN()根据score判断对pathch的取舍

 

原文地址:https://www.cnblogs.com/oudan/p/4160082.html