Java实现LSH(Locality Sensitive Hash )

  在对大批量数据进行图像处理的时候,比如说我提取SIFT特征,数据集为10W张图片,一个SIFT特征点是128维,一张图片提取出500个特征点,这样我们在处理的时候就是对5000万个128维的数据进行处理,这样处理所需要的耗时太长了,不符合实际生产的需要。我们需要用一种方法降低运算量,比如说降维。

  看了一些论文,提到的较多的方法是LSH(Locality Sensitive Hash),就是局部敏感哈希。我们利用LSH方法在5000万个特征点中筛选出极少量的我们需要的特征点,在对这些极少量的数据进行计算,就可以得到我们想要的结果啦。

  1 package com.demo.lsh;
  2 
  3 import com.demo.config.Constant;
  4 import com.demo.dao.FeatureDao;
  5 import com.demo.dao.FeatureTableDao;
  6 import com.demo.dao.HashTableDao;
  7 import com.demo.entity.HashTable;
  8 import com.demo.utils.MD5Util;
  9 import com.demo.utils.MathUtil;
 10 import org.opencv.core.Mat;
 11 import org.springframework.util.StringUtils;
 12 
 13 import java.io.*;
 14 import java.security.MessageDigest;
 15 import java.security.NoSuchAlgorithmException;
 16 import java.util.*;
 17 
 18 public class LSH {
 19     //维度大小,例如对于sift特征来说就是128
 20     private int dimention = Constant.DIMENTION;
 21     //所需向量中元素可能的上限,譬如对于RGB来说,就是255
 22     private int max = Constant.MAX;
 23     //哈希表的数量,用于更大程度地削减false positive
 24     private int hashCount = Constant.HASHCOUNT;
 25     //LSH随机选取的采样位数,该值越小,则近似查找能力越大,但相应的false positive也越大;若该值等于size,则为由近似查找退化为精确匹配
 26     private int bitCount = Constant.BITCOUNT;
 27     //转化为01字符串之后的位数,等于max乘以dimensions
 28     private int size = dimention * max;
 29     //LSH哈希族,保存了随机采样点的INDEX
 30     private int[][] hashFamily;
 31     private HashTableDao hashTableDao;
 32     /**
 33      * 构造函数
 34      */
 35     public LSH(HashTableDao hashTableDao) {
 36         this.hashTableDao = hashTableDao;
 37         dimention = Constant.DIMENTION;
 38         max = Constant.MAX;
 39         hashCount = Constant.HASHCOUNT;
 40         bitCount = Constant.BITCOUNT;
 41         size = dimention * max;
 42         hashFamily = new int[hashCount][bitCount];
 43         generataHashFamily();
 44     }
 45 
 46     /**
 47      * 生成随机的投影点 ,在程序第一次执行时生成。投影点可以理解为后面去数组的索引值
 48      */
 49     private void generataHashFamily() {
 50         if (new File("/home/fanxuan/data/1.txt").exists()) {
 51             try {
 52                 InputStream in = new FileInputStream("/home/fanxuan/data/1.txt");
 53                 ObjectInputStream oin = new ObjectInputStream(in);
 54                 hashFamily = (int[][]) (oin.readObject());
 55             } catch (IOException e) {
 56                 e.printStackTrace();
 57             } catch (ClassNotFoundException e) {
 58                 e.printStackTrace();
 59             }
 60         }else {
 61             Random rd = new Random();
 62             for (int i = 0; i < hashCount; i++) {
 63                 for (int j = 0; j < bitCount; j++) {
 64                     hashFamily[i][j] = rd.nextInt(size);
 65                 }
 66             }
 67             try {
 68                 OutputStream out = new FileOutputStream("/home/fanxuan/data/1.txt");
 69                 ObjectOutputStream oout = new ObjectOutputStream(out);
 70                 oout.writeObject(hashFamily);
 71             } catch (FileNotFoundException e) {
 72                 e.printStackTrace();
 73             } catch (IOException e) {
 74                 e.printStackTrace();
 75             }
 76         }
 77     }
 78 
 79     //将向量转化为二进制字符串,比如元素的最大范围255,则元素65就被转化为65个1以及190个0
 80     private int[] unAray(int[] data) {
 81         int unArayData[] = new int[size];
 82         for (int i = 0; i < data.length; i++) {
 83             for (int j = 0; j < data[i]; j++) {
 84                 unArayData[i * max + j] = 1;
 85             }
 86         }
 87         return unArayData;
 88     }
 89 
 90     /**
 91      * 将向量映射为LSH中的key
 92      */
 93     private String generateHashKey(int[] list, int hashNum) {
 94         StringBuilder sb = new StringBuilder();
 95         int[] tempData = unAray(list);
 96         int[] hashedData = new int[bitCount];
 97         //首先将向量转为二进制字符串
 98         for (int i = 0; i < bitCount; i++) {
 99             hashedData[i] = tempData[hashFamily[hashNum][i]];
100             sb.append(hashedData[i]);
101         }
102         //再用常规hash函数比如MD5对key进行压缩
103         MessageDigest messageDigest = null;
104         try{
105             messageDigest = MessageDigest.getInstance("MD5");
106         }catch (NoSuchAlgorithmException e) {
107 
108         }
109         byte[] binary = sb.toString().getBytes();
110         byte[] hash = messageDigest.digest(binary);
111         String hashV = MD5Util.bufferToHex(hash);
112         return hashV;
113     }
114 
115     /**
116      * 将Sift特征点转换为Hash存表
117      */
118     public void generateHashMap(String id, int[] vercotr, int featureId) {
119         for (int j = 0; j < hashCount; j++) {
120             String key = generateHashKey(vercotr, j);
121             HashTable hashTableUpdateOrAdd = new HashTable();
122             HashTable hashTable = hashTableDao.findHashTableByBucketId(key);
123             if (hashTable != null) {
124                 String featureIdValue = hashTable.getFeatureId() + "," + featureId;
125                 hashTableUpdateOrAdd.setFeatureId(featureIdValue);
126                 hashTableUpdateOrAdd.setBucketId(key);
127                 hashTableDao.updateHashTableFeatureId(hashTableUpdateOrAdd);
128             } else {
129                 hashTableUpdateOrAdd.setBucketId(key);
130                 hashTableUpdateOrAdd.setFeatureId(String.valueOf(featureId));
131                 hashTableDao.insertHashTable(hashTableUpdateOrAdd);
132             }
133         }
134     }
135 
136     // 查询与输入向量最接近(海明空间)的向量
137     public List<String> queryList(int[] data) {
138         List<String> result = new ArrayList<>();
139         for (int j = 0; j < hashCount; j++) {
140             String key = generateHashKey(data, j);
141             result.add(key);
142             HashTable hashTable = hashTableDao.findHashTableByBucketId(key);
143             if (!StringUtils.isEmpty(hashTable.getFeatureId())) {
144                 String[] str = hashTable.getFeatureId().split(",");
145                 for (String string : str) {
146                     result.add(string);
147                 }
148             }
149         }
150         return result;
151     }
152 
153 }

   

 1 package com.demo.config;
 2 
 3 public class Constant {
 4     //维度大小,例如对于sift特征来说就是128
 5     public static final int DIMENTION = 128;
 6     //所需向量中元素可能的上限,譬如对于RGB来说,就是255
 7     public static final int MAX = 255;
 8     //哈希表的数量,用于更大程度地削减false positive
 9     public static final int HASHCOUNT = 12;
10     //LSH随机选取的采样位数,该值越小,则近似查找能力越大,但相应的false positive也越大;若该值等于size,则为由近似查找退化为精确匹配
11     public static final int BITCOUNT = 32;
12 }

  简单的介绍下代码,构造函数LSH()用来建立LSH对象,hashTableDao为数据表操作对象,不多说;因为局部敏感哈希依赖与一套随机数,每次产生的结果都不一致,所以我们需要在程序第一次运行的时候将随机数生成并固定下来,我采用的方法是存放在本地磁盘中,也可以存放在数据库中。generateHashMap()方法为数据训练函数,int[] vercotr为特征向量,其他两个参数为我需要的标志位。queryList()方法是筛选方法。

  感谢http://grunt1223.iteye.com/blog/944894的文章。

原文地址:https://www.cnblogs.com/fx-blog/p/8227988.html