【LDA】lda模型和java代码

几个问题:

1、停用次应该去到什么程度??

2、比如我选了参数topicNumber=100,结果中,其中有80个topic,每个的前几个words很好地描述了一个topic。另外的20个topic的前几个words没有描述好。这样是否说明了topicNumber=100已经足够了?

3、LDA考虑了多少文件之间的关系?

4、参数 alpha,beta怎么取?? alpha=K/50  ??  b=0.1(0.01) ??

========================================

看了几篇LDA的文档,实在写的太好了,我只能贴点代码,表示我做过lda了

public class LdaModel {

    int[][] doc;// word index array,每个文本中每个词在字典indexToTermMap中的序号
    int V, K, M;// vocabulary size, topic number, document number
    int[][] z;// topic label array,每个文本的每个词对应的topic的编号
    float alpha; // doc-topic dirichlet prior parameter
    float beta; // topic-word dirichlet prior parameter
    int[][] nmk;// given document m, count times of topic k. M*K
    int[][] nkt;// given topic k, count times of term t. K*V
    int[] nmkSum;// Sum for each row in nmk,nukSum[m]=n:也就是文档m中word的个数为n
    int[] nktSum;// Sum for each row in nkt,nkt[k]=n:被指定给topic k的term/word的个数为n
    double[][] phi;// Parameters for topic-word distribution K*V
    double[][] theta;// Parameters for doc-topic distribution M*K
    int iterations;// Times of iterations
    int saveStep;// The number of iterations between two saving
    int beginSaveIters;// Begin save model at this iteration

    public LdaModel(LdaGibbsSampling.modelparameters modelparam) {
        // TODO Auto-generated constructor stub
        alpha = modelparam.alpha;
        beta = modelparam.beta;
        iterations = modelparam.iteration;
        K = modelparam.topicNum;
        saveStep = modelparam.saveStep;
        beginSaveIters = modelparam.beginSaveIters;
    }

    public void initializeModel(Documents docSet) {
        // TODO Auto-generated method stub
        M = docSet.docs.size();
        V = docSet.termToIndexMap.size();
        nmk = new int[M][K];
        nkt = new int[K][V];
        nmkSum = new int[M];
        nktSum = new int[K];
        phi = new double[K][V];
        theta = new double[M][K];

        // initialize documents index array
        doc = new int[M][];
        for (int m = 0; m < M; m++) {
            // Notice the limit of memory
            int N = docSet.docs.get(m).docWords.length;
            doc[m] = new int[N];
            for (int n = 0; n < N; n++) {
                doc[m][n] = docSet.docs.get(m).docWords[n];
            }
        }

        // initialize topic label z for each word
        z = new int[M][];
        for (int m = 0; m < M; m++) {
            int N = docSet.docs.get(m).docWords.length;
            z[m] = new int[N];
            for (int n = 0; n < N; n++) {
                int initTopic = (int) (Math.random() * K);// From 0 to K - 1
                z[m][n] = initTopic;
                // number of words in doc m assigned to topic initTopic add 1
                nmk[m][initTopic]++;
                // number of terms doc[m][n] assigned to topic initTopic add 1
                nkt[initTopic][doc[m][n]]++;
                // total number of words assigned to topic initTopic add 1
                nktSum[initTopic]++;
            }
            // total number of words in document m is N
            nmkSum[m] = N;
        }
    }

    public void inferenceModel(Documents docSet) throws IOException {
        // TODO Auto-generated method stub
        if (iterations < saveStep + beginSaveIters) {
            System.err
                    .println("Error: the number of iterations should be larger than "
                            + (saveStep + beginSaveIters));
            System.exit(0);
        }
        for (int i = 0; i < iterations; i++) {
            System.out.println("Iteration " + i);
            if ((i >= beginSaveIters)
                    && (((i - beginSaveIters) % saveStep) == 0)) {
                // Saving the model
                System.out.println("Saving model at iteration " + i + " ... ");
                // Firstly update parameters
                updateEstimatedParameters();
                // Secondly print model variables
                saveIteratedModel(i, docSet);
            }

            // Use Gibbs Sampling to update z[][]
            for (int m = 0; m < M; m++) {
                int N = docSet.docs.get(m).docWords.length;
                for (int n = 0; n < N; n++) {
                    // Sample from p(z_i|z_-i, w)
                    int newTopic = sampleTopicZ(m, n);
                    z[m][n] = newTopic;
                }
            }
        }
    }

    private void updateEstimatedParameters() {
        // TODO Auto-generated method stub
        for (int k = 0; k < K; k++) {
            for (int t = 0; t < V; t++) {
                phi[k][t] = (nkt[k][t] + beta) / (nktSum[k] + V * beta);
            }
        }

        for (int m = 0; m < M; m++) {
            for (int k = 0; k < K; k++) {
                theta[m][k] = (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);
            }
        }
    }

    private int sampleTopicZ(int m, int n) {
        // TODO Auto-generated method stub
        // Sample from p(z_i|z_-i, w) using Gibbs upde rule

        // Remove topic label for w_{m,n}
        int oldTopic = z[m][n];
        nmk[m][oldTopic]--;
        nkt[oldTopic][doc[m][n]]--;
        nmkSum[m]--;
        nktSum[oldTopic]--;

        // Compute p(z_i = k|z_-i, w)
        double[] p = new double[K];
        for (int k = 0; k < K; k++) {
            p[k] = (nkt[k][doc[m][n]] + beta) / (nktSum[k] + V * beta)
                    * (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);
        }

        // Sample a new topic label for w_{m, n} like roulette
        // Compute cumulated probability for p
        for (int k = 1; k < K; k++) {
            p[k] += p[k - 1];
        }
        double u = Math.random() * p[K - 1]; // p[] is unnormalised
        int newTopic;
        for (newTopic = 0; newTopic < K; newTopic++) {
            if (u < p[newTopic]) {
                break;
            }
        }

        // Add new topic label for w_{m, n}
        nmk[m][newTopic]++;
        nkt[newTopic][doc[m][n]]++;
        nmkSum[m]++;
        nktSum[newTopic]++;
        return newTopic;
    }

    public void saveIteratedModel(int iters, Documents docSet)
            throws IOException {
        // TODO Auto-generated method stub
        // lda.params lda.phi lda.theta lda.tassign lda.twords
        // lda.params
        String resPath = LdaConfig.OUTPUTFILE_PATH;
        String modelName = "lda_" + iters;
        ArrayList<String> lines = new ArrayList<String>();
        lines.add("alpha = " + alpha);
        lines.add("beta = " + beta);
        lines.add("topicNum = " + K);
        lines.add("docNum = " + M);
        lines.add("termNum = " + V);
        lines.add("iterations = " + iterations);
        lines.add("saveStep = " + saveStep);
        lines.add("beginSaveIters = " + beginSaveIters);
        FileUtil.writeLines(resPath + modelName + ".params", lines);

        // lda.phi K*V
        BufferedWriter writer = new BufferedWriter(new FileWriter(resPath
                + modelName + ".phi"));
        for (int i = 0; i < K; i++) {
            for (int j = 0; j < V; j++) {
                writer.write(phi[i][j] + "\t");
            }
            writer.write("\n");
        }
        writer.close();

        // lda.theta M*K
        writer = new BufferedWriter(new FileWriter(resPath + modelName
                + ".theta"));
        for (int i = 0; i < M; i++) {
            for (int j = 0; j < K; j++) {
                writer.write(theta[i][j] + "\t");
            }
            writer.write("\n");
        }
        writer.close();

        // lda.tassign
        writer = new BufferedWriter(new FileWriter(resPath + modelName
                + ".tassign"));
        for (int m = 0; m < M; m++) {
            for (int n = 0; n < doc[m].length; n++) {
                writer.write(doc[m][n] + ":" + z[m][n] + "\t");
            }
            writer.write("\n");
        }
        writer.close();

        // lda.twords phi[][] K*V
        writer = new BufferedWriter(new FileWriter(resPath + modelName
                + ".twords"));
        int topNum = 15; // Find the top 20 topic words in each topic
        for (int i = 0; i < K; i++) {
            List<Integer> tWordsIndexArray = new ArrayList<Integer>();
            for (int j = 0; j < V; j++) {
                tWordsIndexArray.add(new Integer(j));
            }
            Collections.sort(tWordsIndexArray,
                    new LdaModel.ArrayDoubleComparator(phi[i]));
            writer.write("topic " + i + "\t:\t");
            for (int t = 0; t < topNum; t++) {
                writer.write(docSet.indexToTermMap.get(tWordsIndexArray.get(t))
                        + " " + phi[i][tWordsIndexArray.get(t)] + "\t");
            }
            writer.write("\n");
        }
        writer.close();
    }

    // save topic "word1:f1;word2:f2"
    public void saveTopic(Documents docSet) {
        int topNum = 15;
        for (int i = 0; i < K; i++) {
            List<Integer> tWordsIndexArray = new ArrayList<Integer>();
            for (int j = 0; j < V; j++) {
                tWordsIndexArray.add(new Integer(j));
            }
            Collections.sort(tWordsIndexArray,
                    new LdaModel.ArrayDoubleComparator(phi[i]));
            TbTopic tbTopic = new TbTopic();
            tbTopic.setId(i);
            StringBuffer bf = new StringBuffer();
            for (int t = 0; t < topNum; t++) {
                bf.append(docSet.indexToTermMap.get(tWordsIndexArray.get(t)));
                bf.append(":");
                bf.append(phi[i][tWordsIndexArray.get(t)]);
                bf.append(";");
            }
            tbTopic.setWords(bf.toString());
            DocDBUtil.saveTbTopic(tbTopic);
        }
    }

    // save TbDistopic "topicId1:f1;TopicId2:f1"
    public void saveDisTopic(Documents docSet) {
        int topicNum = 3;
        for (int i = 0; i < M; i++) {
            int disId = Integer.parseInt(docSet.docs.get(i).docName);
            TbDisTopic tbDisTopic = new TbDisTopic();
            tbDisTopic.setId(disId);
            List<Integer> topicIndexArray = new ArrayList<Integer>();
            for (int j = 0; j < K; j++) {
                topicIndexArray.add(new Integer(j));
            }
            Collections.sort(topicIndexArray,
                    new LdaModel.ArrayDoubleComparator(theta[i]));
            String topicIds = "";
            for (int t = 0; t < topicNum; t++) {
                topicIds += topicIndexArray.get(t);
                topicIds += ":";
                topicIds += theta[i][topicIndexArray.get(t)];
                topicIds += ";";
            }
            tbDisTopic.setTopic(topicIds);
            DocDBUtil.saveTbDisTopic(tbDisTopic);
        }
    }

    public class ArrayDoubleComparator implements Comparator<Integer> {
        private double[] sortProb; // Store probability of each word in topic k

        public ArrayDoubleComparator(double[] sortProb) {
            this.sortProb = sortProb;
        }

        @Override
        public int compare(Integer o1, Integer o2) {// Sort topic word index according to the probability of each word
            // in topic k
            if (sortProb[o1] > sortProb[o2])
                return -1;
            else if (sortProb[o1] < sortProb[o2])
                return 1;
            else
                return 0;
        }
    }
}

核心代码还是写的不错的

public class Documents {

    ArrayList<Document> docs;//.size()=M
    Map<String, Integer> termToIndexMap;//.size()=V,所有的词,没有重复,词--序号
    ArrayList<String> indexToTermMap;//序号--词,序号就是数组号。相当于词典,给定序号,就能找到词
    Map<String, Integer> termCountMap;//词频

    public Documents() {
        docs = new ArrayList<Document>();
        termToIndexMap = new HashMap<String, Integer>();
        indexToTermMap = new ArrayList<String>();
        termCountMap = new HashMap<String, Integer>();
    }

    public void getDocsFromDB() {
        List<TbDiseases> diseases = DocDBUtil.getTbDiseases();
        for (TbDiseases disea : diseases) {
            String content = disea.getSymptomDetail();
            String docName = disea.getId() + "";
            Document doc = new Document(content, termToIndexMap, indexToTermMap, termCountMap, docName);
            docs.add(doc);
        }
    }

    public static class Document {
        String docName;
        //文档的词,去除了停用词和干扰词,保存的int是indexToTermMap中对应词的序号
        int[] docWords;

        public Document(String content, Map<String, Integer> termToIndexMap,
                ArrayList<String> indexToTermMap,
                Map<String, Integer> termCountMap, String docName) {
            this.setDocName(docName);
            ArrayList<String> words = DocDBUtil.getWordsFromSentence(content);
            // Transfer word to index
            this.docWords = new int[words.size()];
            for (int i = 0; i < words.size(); i++) {
                String word = words.get(i);
                if (!termToIndexMap.containsKey(word)) {
                    int newIndex = termToIndexMap.size();
                    termToIndexMap.put(word, newIndex);
                    indexToTermMap.add(word);
                    termCountMap.put(word, new Integer(1));
                    docWords[i] = newIndex;
                } else {
                    docWords[i] = termToIndexMap.get(word);
                    termCountMap.put(word, termCountMap.get(word) + 1);
                }
            }
            words.clear();
        }

        public void setDocName(String docName) {
            this.docName = docName;
        }

        public String getDocName() {
            return docName;
        }
    }//Document
}

文档来自数据库的

原文地址:https://www.cnblogs.com/549294286/p/3019473.html