word2vec的Java源码【转】

一、核心代码 word2vec.java

  1 package com.ansj.vec;
  2 
  3 import java.io.*;
  4 import java.lang.reflect.Array;
  5 import java.util.ArrayList;
  6 import java.util.Arrays;
  7 import java.util.Collections;
  8 import java.util.HashMap;
  9 import java.util.List;
 10 import java.util.Map;
 11 import java.util.Map.Entry;
 12 import java.util.Set;
 13 import java.util.TreeSet;
 14 
 15 import com.ansj.vec.domain.WordEntry;
 16 import com.ansj.vec.util.WordKmeans;
 17 import com.ansj.vec.util.WordKmeans.Classes;
 18 
 19 public class Word2VEC {
 20 
 21     public static void main(String[] args) throws IOException {
 22 
 23          //Learn learn = new Learn();
 24         //learn.learnFile(new File("C:\Users\le\Desktop\0328-事件相关法律的算法进展\Result_Country.txt"));
 25         //learn.saveModel(new File("C:\Users\le\Desktop\0328-事件相关法律的算法进展\javaSkip1"));
 26         
 27         Word2VEC vec = new Word2VEC();
 28         vec.loadJavaModel("C:\Users\le\Desktop\0328-事件相关法律的算法进展\javaSkip1");
 29         System.out.println("中国" + "	" +Arrays.toString(vec.getWordVector("中国")));
 30         System.out.println("何润东" + "	" +Arrays.toString(vec.getWordVector("何润东")));
 31         System.out.println("足球" + "	" + Arrays.toString(vec.getWordVector("足球")));
 32 
 33         String str = "中国";
 34         System.out.println(vec.distance(str));
 35         WordKmeans wordKmeans = new WordKmeans(vec.getWordMap(), 50, 10);
 36         Classes[] explain = wordKmeans.explain();
 37         for (int i = 0; i < explain.length; i++) {
 38             System.out.println("--------" + i + "---------");
 39             System.out.println(explain[i].getTop(10));
 40         }
 41     }
 42 
 43     private HashMap<String, float[]> wordMap = new HashMap<String, float[]>();
 44 
 45     private int words;
 46     private int size;
 47     private int topNSize = 40;
 48 
 49     /**
 50      * 鍔犺浇妯″瀷
 51      * 
 52      * @param path
 53      *            妯″瀷鐨勮矾寰�
 54      * @throws IOException
 55      */
 56     public void loadGoogleModel(String path) throws IOException {
 57         DataInputStream dis = null;
 58         BufferedInputStream bis = null;
 59         double len = 0;
 60         float vector = 0;
 61         try {
 62             bis = new BufferedInputStream(new FileInputStream(path));
 63             dis = new DataInputStream(bis);
 64             // //璇诲彇璇嶆暟
 65             words = Integer.parseInt(readString(dis));
 66             // //澶у皬
 67             size = Integer.parseInt(readString(dis));
 68             String word;
 69             float[] vectors = null;
 70             for (int i = 0; i < words; i++) {
 71                 word = readString(dis);
 72                 vectors = new float[size];
 73                 len = 0;
 74                 for (int j = 0; j < size; j++) {
 75                     vector = readFloat(dis);
 76                     len += vector * vector;
 77                     vectors[j] = (float) vector;
 78                 }
 79                 len = Math.sqrt(len);
 80 
 81                 for (int j = 0; j < size; j++) {
 82                     vectors[j] /= len;
 83                 }
 84 
 85                 wordMap.put(word, vectors);
 86                 dis.read();
 87             }
 88         } finally {
 89             bis.close();
 90             dis.close();
 91         }
 92     }
 93 
 94     /**
 95      * 鍔犺浇妯″瀷
 96      * 
 97      * @param path
 98      *            妯″瀷鐨勮矾寰�
 99      * @throws IOException
100      */
101     public void loadJavaModel(String path) throws IOException {
102         try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(path)))) {
103             words = dis.readInt();
104             size = dis.readInt();
105 
106             float vector = 0;
107 
108             String key = null;
109             float[] value = null;
110             for (int i = 0; i < words; i++) {
111                 double len = 0;
112                 key = dis.readUTF();
113                 value = new float[size];
114                 for (int j = 0; j < size; j++) {
115                     vector = dis.readFloat();
116                     len += vector * vector;
117                     value[j] = vector;
118                 }
119 
120                 len = Math.sqrt(len);
121 
122                 for (int j = 0; j < size; j++) {
123                     value[j] /= len;
124                 }
125                 wordMap.put(key, value);
126             }
127 
128         }
129     }
130 
131     private static final int MAX_SIZE = 50;
132 
133     /**
134      * 杩戜箟璇�
135      * 
136      * @return
137      */
138     public TreeSet<WordEntry> analogy(String word0, String word1, String word2) {
139         float[] wv0 = getWordVector(word0);
140         float[] wv1 = getWordVector(word1);
141         float[] wv2 = getWordVector(word2);
142 
143         if (wv1 == null || wv2 == null || wv0 == null) {
144             return null;
145         }
146         float[] wordVector = new float[size];
147         for (int i = 0; i < size; i++) {
148             wordVector[i] = wv1[i] - wv0[i] + wv2[i];
149         }
150         float[] tempVector;
151         String name;
152         List<WordEntry> wordEntrys = new ArrayList<WordEntry>(topNSize);
153         for (Entry<String, float[]> entry : wordMap.entrySet()) {
154             name = entry.getKey();
155             if (name.equals(word0) || name.equals(word1) || name.equals(word2)) {
156                 continue;
157             }
158             float dist = 0;
159             tempVector = entry.getValue();
160             for (int i = 0; i < wordVector.length; i++) {
161                 dist += wordVector[i] * tempVector[i];
162             }
163             insertTopN(name, dist, wordEntrys);
164         }
165         return new TreeSet<WordEntry>(wordEntrys);
166     }
167 
168     private void insertTopN(String name, float score, List<WordEntry> wordsEntrys) {
169         // TODO Auto-generated method stub
170         if (wordsEntrys.size() < topNSize) {
171             wordsEntrys.add(new WordEntry(name, score));
172             return;
173         }
174         float min = Float.MAX_VALUE;
175         int minOffe = 0;
176         for (int i = 0; i < topNSize; i++) {
177             WordEntry wordEntry = wordsEntrys.get(i);
178             if (min > wordEntry.score) {
179                 min = wordEntry.score;
180                 minOffe = i;
181             }
182         }
183 
184         if (score > min) {
185             wordsEntrys.set(minOffe, new WordEntry(name, score));
186         }
187 
188     }
189 
190     public Set<WordEntry> distance(String queryWord) {
191 
192         float[] center = wordMap.get(queryWord);
193         if (center == null) {
194             return Collections.emptySet();
195         }
196 
197         int resultSize = wordMap.size() < topNSize ? wordMap.size() : topNSize;
198         TreeSet<WordEntry> result = new TreeSet<WordEntry>();
199 
200         double min = Float.MIN_VALUE;
201         for (Map.Entry<String, float[]> entry : wordMap.entrySet()) {
202             float[] vector = entry.getValue();
203             float dist = 0;
204             for (int i = 0; i < vector.length; i++) {
205                 dist += center[i] * vector[i];
206             }
207 
208             if (dist > min) {
209                 result.add(new WordEntry(entry.getKey(), dist));
210                 if (resultSize < result.size()) {
211                     result.pollLast();
212                 }
213                 min = result.last().score;
214             }
215         }
216         result.pollFirst();
217 
218         return result;
219     }
220 
221     public Set<WordEntry> distance(List<String> words) {
222 
223         float[] center = null;
224         for (String word : words) {
225             center = sum(center, wordMap.get(word));
226         }
227 
228         if (center == null) {
229             return Collections.emptySet();
230         }
231 
232         int resultSize = wordMap.size() < topNSize ? wordMap.size() : topNSize;
233         TreeSet<WordEntry> result = new TreeSet<WordEntry>();
234 
235         double min = Float.MIN_VALUE;
236         for (Map.Entry<String, float[]> entry : wordMap.entrySet()) {
237             float[] vector = entry.getValue();
238             float dist = 0;
239             for (int i = 0; i < vector.length; i++) {
240                 dist += center[i] * vector[i];
241             }
242 
243             if (dist > min) {
244                 result.add(new WordEntry(entry.getKey(), dist));
245                 if (resultSize < result.size()) {
246                     result.pollLast();
247                 }
248                 min = result.last().score;
249             }
250         }
251         result.pollFirst();
252 
253         return result;
254     }
255 
256     private float[] sum(float[] center, float[] fs) {
257         // TODO Auto-generated method stub
258 
259         if (center == null && fs == null) {
260             return null;
261         }
262 
263         if (fs == null) {
264             return center;
265         }
266 
267         if (center == null) {
268             return fs;
269         }
270 
271         for (int i = 0; i < fs.length; i++) {
272             center[i] += fs[i];
273         }
274 
275         return center;
276     }
277 
278     /**
279      * 寰楀埌璇嶅悜閲�
280      * 
281      * @param word
282      * @return
283      */
284     public float[] getWordVector(String word) {
285         return wordMap.get(word);
286     }
287 
288     public static float readFloat(InputStream is) throws IOException {
289         byte[] bytes = new byte[4];
290         is.read(bytes);
291         return getFloat(bytes);
292     }
293 
294     /**
295      * 璇诲彇涓�涓猣loat
296      * 
297      * @param b
298      * @return
299      */
300     public static float getFloat(byte[] b) {
301         int accum = 0;
302         accum = accum | (b[0] & 0xff) << 0;
303         accum = accum | (b[1] & 0xff) << 8;
304         accum = accum | (b[2] & 0xff) << 16;
305         accum = accum | (b[3] & 0xff) << 24;
306         return Float.intBitsToFloat(accum);
307     }
308 
309     /**
310      * 璇诲彇涓�涓�瓧绗︿覆
311      * 
312      * @param dis
313      * @return
314      * @throws IOException
315      */
316     private static String readString(DataInputStream dis) throws IOException {
317         // TODO Auto-generated method stub
318         byte[] bytes = new byte[MAX_SIZE];
319         byte b = dis.readByte();
320         int i = -1;
321         StringBuilder sb = new StringBuilder();
322         while (b != 32 && b != 10) {
323             i++;
324             bytes[i] = b;
325             b = dis.readByte();
326             if (i == 49) {
327                 sb.append(new String(bytes));
328                 i = -1;
329                 bytes = new byte[MAX_SIZE];
330             }
331         }
332         sb.append(new String(bytes, 0, i + 1));
333         return sb.toString();
334     }
335 
336     public int getTopNSize() {
337         return topNSize;
338     }
339 
340     public void setTopNSize(int topNSize) {
341         this.topNSize = topNSize;
342     }
343 
344     public HashMap<String, float[]> getWordMap() {
345         return wordMap;
346     }
347 
348     public int getWords() {
349         return words;
350     }
351 
352     public int getSize() {
353         return size;
354     }
355 
356 }

 二、词向量-模型学习代码learn.java

  1 package com.ansj.vec;
  2 
  3 import java.io.BufferedOutputStream;
  4 import java.io.BufferedReader;
  5 import java.io.DataOutputStream;
  6 import java.io.File;
  7 import java.io.FileInputStream;
  8 import java.io.FileNotFoundException;
  9 import java.io.FileOutputStream;
 10 import java.io.IOException;
 11 import java.io.InputStreamReader;
 12 import java.util.ArrayList;
 13 import java.util.HashMap;
 14 import java.util.List;
 15 import java.util.Map;
 16 import java.util.Map.Entry;
 17 
 18 import com.ansj.vec.util.MapCount;
 19 import com.ansj.vec.domain.HiddenNeuron;
 20 import com.ansj.vec.domain.Neuron;
 21 import com.ansj.vec.domain.WordNeuron;
 22 import com.ansj.vec.util.Haffman;
 23 
 24 public class Learn {
 25 
 26   private Map<String, Neuron> wordMap = new HashMap<>();
 27   /**
 28    * 训练多少个特征
 29    */
 30   private int layerSize = 200;
 31 
 32   /**
 33    * 上下文窗口大小
 34    */
 35   private int window = 5;
 36 
 37   private double sample = 1e-3;
 38   private double alpha = 0.025;
 39   private double startingAlpha = alpha;
 40 
 41   public int EXP_TABLE_SIZE = 1000;
 42 
 43   private Boolean isCbow = false;
 44 
 45   private double[] expTable = new double[EXP_TABLE_SIZE];
 46 
 47   private int trainWordsCount = 0;
 48 
 49   private int MAX_EXP = 6;
 50 
 51   public Learn(Boolean isCbow, Integer layerSize, Integer window, Double alpha,
 52       Double sample) {
 53     createExpTable();
 54     if (isCbow != null) {
 55       this.isCbow = isCbow;
 56     }
 57     if (layerSize != null)
 58       this.layerSize = layerSize;
 59     if (window != null)
 60       this.window = window;
 61     if (alpha != null)
 62       this.alpha = alpha;
 63     if (sample != null)
 64       this.sample = sample;
 65   }
 66 
 67   public Learn() {
 68     createExpTable();
 69   }
 70 
 71   /**
 72    * trainModel
 73    * 
 74    * @throws IOException
 75    */
 76   private void trainModel(File file) throws IOException {
 77     try (BufferedReader br = new BufferedReader(new InputStreamReader(
 78         new FileInputStream(file)))) {
 79       String temp = null;
 80       long nextRandom = 5;
 81       int wordCount = 0;
 82       int lastWordCount = 0;
 83       int wordCountActual = 0;
 84       while ((temp = br.readLine()) != null) {
 85         if (wordCount - lastWordCount > 10000) {
 86           System.out.println("alpha:" + alpha + "	Progress: "
 87               + (int) (wordCountActual / (double) (trainWordsCount + 1) * 100)
 88               + "%");
 89           wordCountActual += wordCount - lastWordCount;
 90           lastWordCount = wordCount;
 91           alpha = startingAlpha
 92               * (1 - wordCountActual / (double) (trainWordsCount + 1));
 93           if (alpha < startingAlpha * 0.0001) {
 94             alpha = startingAlpha * 0.0001;
 95           }
 96         }
 97         String[] strs = temp.split(" ");
 98         wordCount += strs.length;
 99         List<WordNeuron> sentence = new ArrayList<WordNeuron>();
100         for (int i = 0; i < strs.length; i++) {
101           Neuron entry = wordMap.get(strs[i]);
102           if (entry == null) {
103             continue;
104           }
105           // The subsampling randomly discards frequent words while keeping the
106           // ranking same
107           if (sample > 0) {
108             double ran = (Math.sqrt(entry.freq / (sample * trainWordsCount)) + 1)
109                 * (sample * trainWordsCount) / entry.freq;
110             nextRandom = nextRandom * 25214903917L + 11;
111             if (ran < (nextRandom & 0xFFFF) / (double) 65536) {
112               continue;
113             }
114           }
115           sentence.add((WordNeuron) entry);
116         }
117 
118         for (int index = 0; index < sentence.size(); index++) {
119           nextRandom = nextRandom * 25214903917L + 11;
120           if (isCbow) {
121             cbowGram(index, sentence, (int) nextRandom % window);
122           } else {
123             skipGram(index, sentence, (int) nextRandom % window);
124           }
125         }
126 
127       }
128       System.out.println("Vocab size: " + wordMap.size());
129       System.out.println("Words in train file: " + trainWordsCount);
130       System.out.println("sucess train over!");
131     }
132   }
133 
134   /**
135    * skip gram 模型训练
136    * 
137    * @param sentence
138    * @param neu1
139    */
140   private void skipGram(int index, List<WordNeuron> sentence, int b) {
141     // TODO Auto-generated method stub
142     WordNeuron word = sentence.get(index);
143     int a, c = 0;
144     for (a = b; a < window * 2 + 1 - b; a++) {
145       if (a == window) {
146         continue;
147       }
148       c = index - window + a;
149       if (c < 0 || c >= sentence.size()) {
150         continue;
151       }
152 
153       double[] neu1e = new double[layerSize];// 误差项
154       // HIERARCHICAL SOFTMAX
155       List<Neuron> neurons = word.neurons;
156       WordNeuron we = sentence.get(c);
157       for (int i = 0; i < neurons.size(); i++) {
158         HiddenNeuron out = (HiddenNeuron) neurons.get(i);
159         double f = 0;
160         // Propagate hidden -> output
161         for (int j = 0; j < layerSize; j++) {
162           f += we.syn0[j] * out.syn1[j];
163         }
164         if (f <= -MAX_EXP || f >= MAX_EXP) {
165           continue;
166         } else {
167           f = (f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2);
168           f = expTable[(int) f];
169         }
170         // 'g' is the gradient multiplied by the learning rate
171         double g = (1 - word.codeArr[i] - f) * alpha;
172         // Propagate errors output -> hidden
173         for (c = 0; c < layerSize; c++) {
174           neu1e[c] += g * out.syn1[c];
175         }
176         // Learn weights hidden -> output
177         for (c = 0; c < layerSize; c++) {
178           out.syn1[c] += g * we.syn0[c];
179         }
180       }
181 
182       // Learn weights input -> hidden
183       for (int j = 0; j < layerSize; j++) {
184         we.syn0[j] += neu1e[j];
185       }
186     }
187 
188   }
189 
190   /**
191    * 词袋模型
192    * 
193    * @param index
194    * @param sentence
195    * @param b
196    */
197   private void cbowGram(int index, List<WordNeuron> sentence, int b) {
198     WordNeuron word = sentence.get(index);
199     int a, c = 0;
200 
201     List<Neuron> neurons = word.neurons;
202     double[] neu1e = new double[layerSize];// 误差项
203     double[] neu1 = new double[layerSize];// 误差项
204     WordNeuron last_word;
205 
206     for (a = b; a < window * 2 + 1 - b; a++)
207       if (a != window) {
208         c = index - window + a;
209         if (c < 0)
210           continue;
211         if (c >= sentence.size())
212           continue;
213         last_word = sentence.get(c);
214         if (last_word == null)
215           continue;
216         for (c = 0; c < layerSize; c++)
217           neu1[c] += last_word.syn0[c];
218       }
219 
220     // HIERARCHICAL SOFTMAX
221     for (int d = 0; d < neurons.size(); d++) {
222       HiddenNeuron out = (HiddenNeuron) neurons.get(d);
223       double f = 0;
224       // Propagate hidden -> output
225       for (c = 0; c < layerSize; c++)
226         f += neu1[c] * out.syn1[c];
227       if (f <= -MAX_EXP)
228         continue;
229       else if (f >= MAX_EXP)
230         continue;
231       else
232         f = expTable[(int) ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];
233       // 'g' is the gradient multiplied by the learning rate
234       // double g = (1 - word.codeArr[d] - f) * alpha;
235       // double g = f*(1-f)*( word.codeArr[i] - f) * alpha;
236       double g = f * (1 - f) * (word.codeArr[d] - f) * alpha;
237       //
238       for (c = 0; c < layerSize; c++) {
239         neu1e[c] += g * out.syn1[c];
240       }
241       // Learn weights hidden -> output
242       for (c = 0; c < layerSize; c++) {
243         out.syn1[c] += g * neu1[c];
244       }
245     }
246     for (a = b; a < window * 2 + 1 - b; a++) {
247       if (a != window) {
248         c = index - window + a;
249         if (c < 0)
250           continue;
251         if (c >= sentence.size())
252           continue;
253         last_word = sentence.get(c);
254         if (last_word == null)
255           continue;
256         for (c = 0; c < layerSize; c++)
257           last_word.syn0[c] += neu1e[c];
258       }
259 
260     }
261   }
262 
263   /**
264    * 统计词频
265    * 
266    * @param file
267    * @throws IOException
268    */
269   private void readVocab(File file) throws IOException {
270     MapCount<String> mc = new MapCount<>();
271     try (BufferedReader br = new BufferedReader(new InputStreamReader(
272         new FileInputStream(file)))) {
273       String temp = null;
274       while ((temp = br.readLine()) != null) {
275         String[] split = temp.split(" ");
276         trainWordsCount += split.length;
277         for (String string : split) {
278           mc.add(string);
279         }
280       }
281     }
282     for (Entry<String, Integer> element : mc.get().entrySet()) {
283       wordMap.put(element.getKey(), new WordNeuron(element.getKey(),
284           (double) element.getValue() / mc.size(), layerSize));
285     }
286   }
287 
288   /**
289    * 对文本进行预分类
290    * 
291    * @param files
292    * @throws IOException
293    * @throws FileNotFoundException
294    */
295   private void readVocabWithSupervised(File[] files) throws IOException {
296     for (int category = 0; category < files.length; category++) {
297       // 对多个文件学习
298       MapCount<String> mc = new MapCount<>();
299       try (BufferedReader br = new BufferedReader(new InputStreamReader(
300           new FileInputStream(files[category])))) {
301         String temp = null;
302         while ((temp = br.readLine()) != null) {
303           String[] split = temp.split(" ");
304           trainWordsCount += split.length;
305           for (String string : split) {
306             mc.add(string);
307           }
308         }
309       }
310       for (Entry<String, Integer> element : mc.get().entrySet()) {
311         double tarFreq = (double) element.getValue() / mc.size();
312         if (wordMap.get(element.getKey()) != null) {
313           double srcFreq = wordMap.get(element.getKey()).freq;
314           if (srcFreq >= tarFreq) {
315             continue;
316           } else {
317             Neuron wordNeuron = wordMap.get(element.getKey());
318             wordNeuron.category = category;
319             wordNeuron.freq = tarFreq;
320           }
321         } else {
322           wordMap.put(element.getKey(), new WordNeuron(element.getKey(),
323               tarFreq, category, layerSize));
324         }
325       }
326     }
327   }
328 
329   /**
330    * Precompute the exp() table f(x) = x / (x + 1)
331    */
332   private void createExpTable() {
333     for (int i = 0; i < EXP_TABLE_SIZE; i++) {
334       expTable[i] = Math.exp(((i / (double) EXP_TABLE_SIZE * 2 - 1) * MAX_EXP));
335       expTable[i] = expTable[i] / (expTable[i] + 1);
336     }
337   }
338 
339   /**
340    * 根据文件学习
341    * 
342    * @param file
343    * @throws IOException
344    */
345   public void learnFile(File file) throws IOException {
346     readVocab(file);
347     new Haffman(layerSize).make(wordMap.values());
348 
349     // 查找每个神经元
350     for (Neuron neuron : wordMap.values()) {
351       ((WordNeuron) neuron).makeNeurons();
352     }
353 
354     trainModel(file);
355   }
356 
357   /**
358    * 根据预分类的文件学习
359    * 
360    * @param summaryFile
361    *          合并文件
362    * @param classifiedFiles
363    *          分类文件
364    * @throws IOException
365    */
366   public void learnFile(File summaryFile, File[] classifiedFiles)
367       throws IOException {
368     readVocabWithSupervised(classifiedFiles);
369     new Haffman(layerSize).make(wordMap.values());
370     // 查找每个神经元
371     for (Neuron neuron : wordMap.values()) {
372       ((WordNeuron) neuron).makeNeurons();
373     }
374     trainModel(summaryFile);
375   }
376 
377   /**
378    * 保存模型
379    */
380   public void saveModel(File file) {
381     // TODO Auto-generated method stub
382 
383     try (DataOutputStream dataOutputStream = new DataOutputStream(
384         new BufferedOutputStream(new FileOutputStream(file)))) {
385       dataOutputStream.writeInt(wordMap.size());
386       dataOutputStream.writeInt(layerSize);
387       double[] syn0 = null;
388       for (Entry<String, Neuron> element : wordMap.entrySet()) {
389         dataOutputStream.writeUTF(element.getKey());
390         syn0 = ((WordNeuron) element.getValue()).syn0;
391         for (double d : syn0) {
392           dataOutputStream.writeFloat(((Double) d).floatValue());
393         }
394       }
395     } catch (IOException e) {
396       // TODO Auto-generated catch block
397       e.printStackTrace();
398     }
399   }
400 
401   public int getLayerSize() {
402     return layerSize;
403   }
404 
405   public void setLayerSize(int layerSize) {
406     this.layerSize = layerSize;
407   }
408 
409   public int getWindow() {
410     return window;
411   }
412 
413   public void setWindow(int window) {
414     this.window = window;
415   }
416 
417   public double getSample() {
418     return sample;
419   }
420 
421   public void setSample(double sample) {
422     this.sample = sample;
423   }
424 
425   public double getAlpha() {
426     return alpha;
427   }
428 
429   public void setAlpha(double alpha) {
430     this.alpha = alpha;
431     this.startingAlpha = alpha;
432   }
433 
434   public Boolean getIsCbow() {
435     return isCbow;
436   }
437 
438   public void setIsCbow(Boolean isCbow) {
439     this.isCbow = isCbow;
440   }
441 
442   public static void main(String[] args) throws IOException {
443     Learn learn = new Learn();
444     long start = System.currentTimeMillis();
445     learn.learnFile(new File("library/xh.txt"));
446     System.out.println("use time " + (System.currentTimeMillis() - start));
447     learn.saveModel(new File("library/javaVector"));
448 
449   }
450 }

三、词向量的kmeans聚类 util-----wordKmeans.java

  1 package com.ansj.vec.util;
  2 
  3 import java.io.IOException;
  4 import java.util.ArrayList;
  5 import java.util.Arrays;
  6 import java.util.Collections;
  7 import java.util.Comparator;
  8 import java.util.HashMap;
  9 import java.util.Iterator;
 10 import java.util.List;
 11 import java.util.Map;
 12 import java.util.Map.Entry;
 13 
 14 import com.ansj.vec.Word2VEC;
 15 /*import com.ansj.vec.domain.WordEntry;
 16 import com.ansj.vec.util.WordKmeans.Classes;*/
 17 /**
 18  * keanmeans聚类
 19  * 
 20  * @author ansj
 21  * 
 22  */
 23 public class WordKmeans {
 24 
 25     public static void main(String[] args) {
 26         Word2VEC vec = new Word2VEC();
 27         try {
 28             
 29             vec.loadJavaModel("C:\Users\le\Desktop\0328-事件相关法律的算法进展\javaSkip1");
 30             System.out.println("中国" + "	" +Arrays.toString(vec.getWordVector("中国")));
 31             System.out.println("何润东" + "	" +Arrays.toString(vec.getWordVector("何润东")));
 32             System.out.println("足球" + "	" + Arrays.toString(vec.getWordVector("足球")));
 33         } catch (IOException e) {
 34             // TODO Auto-generated catch block
 35             e.printStackTrace();
 36         }
 37         System.out.println("load model ok!");
 38         WordKmeans wordKmeans = new WordKmeans(vec.getWordMap(), 50, 50);
 39         Classes[] explain = wordKmeans.explain();
 40 
 41         for (int i = 0; i < explain.length; i++) {
 42             System.out.println("--------" + i + "---------");
 43             System.out.println(explain[i].getTop(10));
 44         }
 45 
 46     }
 47 
 48     private HashMap<String, float[]> wordMap = null;
 49 
 50     private int iter;
 51 
 52     private Classes[] cArray = null;
 53 
 54     public WordKmeans(HashMap<String, float[]> wordMap, int clcn, int iter) {
 55         this.wordMap = wordMap;
 56         this.iter = iter;
 57         cArray = new Classes[clcn];
 58     }
 59 
 60     public Classes[] explain() {
 61         //first 取前clcn个点
 62         Iterator<Entry<String, float[]>> iterator = wordMap.entrySet().iterator();
 63         for (int i = 0; i < cArray.length; i++) {
 64             Entry<String, float[]> next = iterator.next();
 65             cArray[i] = new Classes(i, next.getValue());
 66         }
 67 
 68         for (int i = 0; i < iter; i++) {
 69             for (Classes classes : cArray) {
 70                 classes.clean();
 71             }
 72 
 73             iterator = wordMap.entrySet().iterator();
 74             while (iterator.hasNext()) {
 75                 Entry<String, float[]> next = iterator.next();
 76                 double miniScore = Double.MAX_VALUE;
 77                 double tempScore;
 78                 int classesId = 0;
 79                 for (Classes classes : cArray) {
 80                     tempScore = classes.distance(next.getValue());
 81                     if (miniScore > tempScore) {
 82                         miniScore = tempScore;
 83                         classesId = classes.id;
 84                     }
 85                 }
 86                 cArray[classesId].putValue(next.getKey(), miniScore);
 87             }
 88 
 89             for (Classes classes : cArray) {
 90                 classes.updateCenter(wordMap);
 91             }
 92             System.out.println("iter " + i + " ok!");
 93         }
 94 
 95         return cArray;
 96     }
 97 
 98     public static class Classes {
 99         private int id;
100 
101         private float[] center;
102 
103         public Classes(int id, float[] center) {
104             this.id = id;
105             this.center = center.clone();
106         }
107 
108         Map<String, Double> values = new HashMap<>();
109 
110         public double distance(float[] value) {
111             double sum = 0;
112             for (int i = 0; i < value.length; i++) {
113                 sum += (center[i] - value[i])*(center[i] - value[i]) ;
114             }
115             return sum ;
116         }
117 
118         public void putValue(String word, double score) {
119             values.put(word, score);
120         }
121 
122         /**
123          * 重新计算中心点
124          * @param wordMap
125          */
126         public void updateCenter(HashMap<String, float[]> wordMap) {
127             for (int i = 0; i < center.length; i++) {
128                 center[i] = 0;
129             }
130             float[] value = null;
131             for (String keyWord : values.keySet()) {
132                 value = wordMap.get(keyWord);
133                 for (int i = 0; i < value.length; i++) {
134                     center[i] += value[i];
135                 }
136             }
137             for (int i = 0; i < center.length; i++) {
138                 center[i] = center[i] / values.size();
139             }
140         }
141 
142         /**
143          * 清空历史结果
144          */
145         public void clean() {
146             // TODO Auto-generated method stub
147             values.clear();
148         }
149 
150         /**
151          * 取得每个类别的前n个结果
152          * @param n
153          * @return 
154          */
155         public List<Entry<String, Double>> getTop(int n) {
156             List<Map.Entry<String, Double>> arrayList = new ArrayList<Map.Entry<String, Double>>(
157                 values.entrySet());
158             Collections.sort(arrayList, new Comparator<Map.Entry<String, Double>>() {
159                 @Override
160                 public int compare(Entry<String, Double> o1, Entry<String, Double> o2) {
161                     // TODO Auto-generated method stub
162                     return o1.getValue() > o2.getValue() ? 1 : -1;
163                 }
164             });
165             int min = Math.min(n, arrayList.size() - 1);
166             if(min<=1)return Collections.emptyList() ;
167             return arrayList.subList(0, min);
168         }
169 
170     }
171 
172 }

四、词向量的 util-----huffman.java  mapcount.java

 1 package com.ansj.vec.util;
 2 
 3 import java.util.Collection;
 4 import java.util.List;
 5 import java.util.TreeSet;
 6 
 7 import com.ansj.vec.domain.HiddenNeuron;
 8 import com.ansj.vec.domain.Neuron;
 9 
10 /**
11  * 构建Haffman编码树
12  * 
13  * @author ansj
14  *
15  */
16 public class Haffman {
17   private int layerSize;
18 
19   public Haffman(int layerSize) {
20     this.layerSize = layerSize;
21   }
22 
23   private TreeSet<Neuron> set = new TreeSet<>();
24 
25   public void make(Collection<Neuron> neurons) {
26     set.addAll(neurons);
27     while (set.size() > 1) {
28       merger();
29     }
30   }
31 
32   private void merger() {
33     HiddenNeuron hn = new HiddenNeuron(layerSize);
34     Neuron min1 = set.pollFirst();
35     Neuron min2 = set.pollFirst();
36     hn.category = min2.category;
37     hn.freq = min1.freq + min2.freq;
38     min1.parent = hn;
39     min2.parent = hn;
40     min1.code = 0;
41     min2.code = 1;
42     set.add(hn);
43   }
44 
45 }
 1 //
 2 // Source code recreated from a .class file by IntelliJ IDEA
 3 // (powered by Fernflower decompiler)
 4 //
 5 
 6 package com.ansj.vec.util;
 7 
 8 import java.util.HashMap;
 9 import java.util.Iterator;
10 import java.util.Map.Entry;
11 
12 public class MapCount<T> {
13     private HashMap<T, Integer> hm = null;
14 
15     public MapCount() {
16         this.hm = new HashMap();
17     }
18 
19     public MapCount(int initialCapacity) {
20         this.hm = new HashMap(initialCapacity);
21     }
22 
23     public void add(T t, int n) {
24         Integer integer = null;
25         if((integer = (Integer)this.hm.get(t)) != null) {
26             this.hm.put(t, Integer.valueOf(integer.intValue() + n));
27         } else {
28             this.hm.put(t, Integer.valueOf(n));
29         }
30 
31     }
32 
33     public void add(T t) {
34         this.add(t, 1);
35     }
36 
37     public int size() {
38         return this.hm.size();
39     }
40 
41     public void remove(T t) {
42         this.hm.remove(t);
43     }
44 
45     public HashMap<T, Integer> get() {
46         return this.hm;
47     }
48 
49     public String getDic() {
50         Iterator iterator = this.hm.entrySet().iterator();
51         StringBuilder sb = new StringBuilder();
52         Entry next = null;
53 
54         while(iterator.hasNext()) {
55             next = (Entry)iterator.next();
56             sb.append(next.getKey());
57             sb.append("	");
58             sb.append(next.getValue());
59             sb.append("
");
60         }
61 
62         return sb.toString();
63     }
64 
65     public static void main(String[] args) {
66         System.out.println(9223372036854775807L);
67     }
68 }

五、词向量的domain包

 1 package com.ansj.vec.domain;
 2 
 3 public class HiddenNeuron extends Neuron{
 4     
 5     public double[] syn1 ; //hidden->out
 6     
 7     public HiddenNeuron(int layerSize){
 8         syn1 = new double[layerSize] ;
 9     }
10     
11 }
 1 package com.ansj.vec.domain;
 2 
 3 public abstract class Neuron implements Comparable<Neuron> {
 4   public double freq;
 5   public Neuron parent;
 6   public int code;
 7   // 语料预分类
 8   public int category = -1;
 9 
10   @Override
11   public int compareTo(Neuron neuron) {
12     if (this.category == neuron.category) {
13       if (this.freq > neuron.freq) {
14         return 1;
15       } else {
16         return -1;
17       }
18     } else if (this.category > neuron.category) {
19       return 1;
20     } else {
21       return -1;
22     }
23   }
24 }
 1 package com.ansj.vec.domain;
 2 
 3 
 4 public class WordEntry implements Comparable<WordEntry> {
 5     public String name;
 6     public float score;
 7 
 8     public WordEntry(String name, float score) {
 9         this.name = name;
10         this.score = score;
11     }
12 
13     @Override
14     public String toString() {
15         // TODO Auto-generated method stub
16         return this.name + "	" + score;
17     }
18 
19     @Override
20     public int compareTo(WordEntry o) {
21         // TODO Auto-generated method stub
22         if (this.score < o.score) {
23             return 1;
24         } else {
25             return -1;
26         }
27     }
28 
29 }
 1 package com.ansj.vec.domain;
 2 
 3 import java.util.Collections;
 4 import java.util.LinkedList;
 5 import java.util.List;
 6 import java.util.Random;
 7 
 8 public class WordNeuron extends Neuron {
 9   public String name;
10   public double[] syn0 = null; // input->hidden
11   public List<Neuron> neurons = null;// 路径神经元
12   public int[] codeArr = null;
13 
14   public List<Neuron> makeNeurons() {
15     if (neurons != null) {
16       return neurons;
17     }
18     Neuron neuron = this;
19     neurons = new LinkedList<>();
20     while ((neuron = neuron.parent) != null) {
21       neurons.add(neuron);
22     }
23     Collections.reverse(neurons);
24     codeArr = new int[neurons.size()];
25 
26     for (int i = 1; i < neurons.size(); i++) {
27       codeArr[i - 1] = neurons.get(i).code;
28     }
29     codeArr[codeArr.length - 1] = this.code;
30 
31     return neurons;
32   }
33 
34   public WordNeuron(String name, double freq, int layerSize) {
35     this.name = name;
36     this.freq = freq;
37     this.syn0 = new double[layerSize];
38     Random random = new Random();
39     for (int i = 0; i < syn0.length; i++) {
40       syn0[i] = (random.nextDouble() - 0.5) / layerSize;
41     }
42   }
43 
44   /**
45    * 用于有监督的创造hoffman tree
46    * 
47    * @param name
48    * @param freq
49    * @param layerSize
50    */
51   public WordNeuron(String name, double freq, int category, int layerSize) {
52     this.name = name;
53     this.freq = freq;
54     this.syn0 = new double[layerSize];
55     this.category = category;
56     Random random = new Random();
57     for (int i = 0; i < syn0.length; i++) {
58       syn0[i] = (random.nextDouble() - 0.5) / layerSize;
59     }
60   }
61 
62 }
原文地址:https://www.cnblogs.com/Lxiaojiang/p/6644699.html