Gibbs LDA java实现

1.偏文、偏理的故事


    某学校高一年级有6个班级,每个班级各有一定数量的学生,3班有几个同学数学成绩很好,拿过省奥赛奖。现在教育局要来该校听数学课,学校应该安排听课老师听哪个班的课?显然是3班,因为3班有几个数学特别厉害的同学,所以3班数学强一点,至少看起来数学强一点.这里,我们把"偏理"称为3班的特点。同样,2班和4班有很多同学的语文成绩很好,他们的作文都曾被文学报刊发表过,我们可以说”偏文“是2班和4班的特点。又假如5班和6班的同学在校篮球赛上进了决赛,我们可以说5班和6班”偏体育“。如果教育局来该校听某种课程,我们就可以安排他们去有该课程"特点"的班级里听。
    在这里,原来的班级结构是只有两层,即学生层,和班级层,每个学生都有指定的班级。我们为了区分每个班级的特点,在学生和班级之间又加了一层,特点层,即”偏文“,”偏理“,”偏体育“。这个特点层就是对LDA最直观的理解。接着上面偏文偏理的故事,3班除了几个同学数学好,另外还有一部分同学思想品德很好,因多次扶老奶奶过马路而上新闻,我们同样也可以说3班同学思想品德很好。这样3班的特点就不只1个了,这里我们提出分布的概念,即每个班级可能有多个特点,只是有的特点对应的学生多,有的特点对应的学生少,我们选对应学生多的特点作为这个班级的主要特点。每个特点同样也对应多个同学也是一种分布,比如,”偏理“包含拿过奥赛奖的同学,也包含期末考试数学考满分的同学。到这里,班级包含多个特点,每个特点又包含多个学生,LDA的主要结构就是这样。
    对于一片文档,我们怎么区分这篇文档是属于那个类别?参照偏文偏理的例子, 我们可以把文档想象成班级,word想象成学生。例如某篇文档的单词中,银行,汇率,股票,下跌等次大量重复出现,那么该篇文档很有可能就是写经济的,我们可以把这篇文档归为经济类。如果某篇文档里面含有,詹姆斯,科比,扣篮,犯规等词,那么这篇文章很有可能是体育类。当然这种分类不一定是单一的,有可能一个文章有多个主题。
三层结构如下:

doct: |           doc1                     doc2                     doc3 ...

topic:|      t1     t2      t3...      t1   t2    t3            t1  t2  t3...

word: |w1 w5 w8... w6 w2 w3....      w3,w5... 

最终要求的,doc下面的topics分布和topic下的words分布.LDA原理见原论文,不赘述.
输出文件中有各种分布:
topic ~ words
doc ~ topics
topic ~ docs
详见JGibbLDA的输出文件

2.Gibbs LDA代码结构


    第一次读代码时把lda分成了两部分,即训练部分和推测部分,训练部分训练出来模型,即topic下面的words分布等,推测部分是用训练出的模型推测新的文章。后来发现推测部分也是一种训练,只是参考了已训练好的结果再训练.如果推测的文件数据量大于参考的数据量,那么这个推测集推测出来的结果,可以当成新的模型,更为准确。训练过程和推测过程的结果类型是完全相同的,包含各个完整的分布,详见JGibbLDA的输出文件

代码除了读入,保存之类的,核心代码不到200行.LDA 结构与代码如下:

  • 1.预处理:
    去停词表,去noise词,低频词等等.
  • 2.Estimate:推测过程
package jgibblda;

import java.io.File;
import java.util.Vector;

public class Estimator {
	
	// output model
	protected Model trnModel;
	LDACmdOption option;
	
	public boolean init(LDACmdOption option){
		this.option = option;
		trnModel = new Model();
		
		if (option.est){
			if (!trnModel.initNewModel(option))
				return false;
			trnModel.data.localDict.writeWordMap(option.dir + File.separator + option.wordMapFileName);
		}
		else if (option.estc){
			if (!trnModel.initEstimatedModel(option))
				return false;
		}
		
		return true;
	}
	
	public void estimate(){
		System.out.println("Sampling " + trnModel.niters + " iteration!");
		
		int lastIter = trnModel.liter;
		for (trnModel.liter = lastIter + 1; trnModel.liter < trnModel.niters + lastIter; trnModel.liter++){
			System.out.println("Iteration " + trnModel.liter + " ...");
			
			// for all z_i
			for (int m = 0; m < trnModel.M; m++){				
				for (int n = 0; n < trnModel.data.docs[m].length; n++){
					// z_i = z[m][n]
					// sample from p(z_i|z_-i, w)
					int topic = sampling(m, n);
					trnModel.z[m].set(n, topic);
				}// end for each word
			}// end for each document
			
			if (option.savestep > 0){
				if (trnModel.liter % option.savestep == 0){
					System.out.println("Saving the model at iteration " + trnModel.liter + " ...");
					computeTheta();
					computePhi();
					trnModel.saveModel("model-" + Conversion.ZeroPad(trnModel.liter, 5));
				}
			}
		}// end iterations		
		
		System.out.println("Gibbs sampling completed!
");
		System.out.println("Saving the final model!
");
		computeTheta();
		computePhi();
		trnModel.liter--;
		trnModel.saveModel("model-final");
	}
	
	/**
	 * Do sampling
	 * @param m document number
	 * @param n word number
	 * @return topic id
	 */
	public int sampling(int m, int n){
		// remove z_i from the count variable
		int topic = trnModel.z[m].get(n);
		int w = trnModel.data.docs[m].words[n];
		
		trnModel.nw[w][topic] -= 1;
		trnModel.nd[m][topic] -= 1;
		trnModel.nwsum[topic] -= 1;
		trnModel.ndsum[m] -= 1;
		
		double Vbeta = trnModel.V * trnModel.beta;
		double Kalpha = trnModel.K * trnModel.alpha;
		
		//do multinominal sampling via cumulative method
		for (int k = 0; k < trnModel.K; k++){
			trnModel.p[k] = (trnModel.nw[w][k] + trnModel.beta)/(trnModel.nwsum[k] + Vbeta) *
					(trnModel.nd[m][k] + trnModel.alpha)/(trnModel.ndsum[m] + Kalpha);
		}
		
		// cumulate multinomial parameters
		for (int k = 1; k < trnModel.K; k++){
			trnModel.p[k] += trnModel.p[k - 1];
		}
		
		// scaled sample because of unnormalized p[]
		double u = Math.random() * trnModel.p[trnModel.K - 1];              // 这一段没懂
		
		for (topic = 0; topic < trnModel.K; topic++){
			if (trnModel.p[topic] > u) //sample topic w.r.t distribution p
				break;
		}
		
		// add newly estimated z_i to count variables
		
		trnModel.nw[w][topic] += 1;
		trnModel.nd[m][topic] += 1;
		trnModel.nwsum[topic] += 1;
		trnModel.ndsum[m] += 1;
 		return topic;
	}
	
	public void computeTheta(){
		for (int m = 0; m < trnModel.M; m++){
			for (int k = 0; k < trnModel.K; k++){
				trnModel.theta[m][k] = (trnModel.nd[m][k] + trnModel.alpha) / (trnModel.ndsum[m] + trnModel.K * trnModel.alpha);
			}
		}
	}
	
	public void computePhi(){
		for (int k = 0; k < trnModel.K; k++){
			for (int w = 0; w < trnModel.V; w++){
				trnModel.phi[k][w] = (trnModel.nw[w][k] + trnModel.beta) / (trnModel.nwsum[k] + trnModel.V * trnModel.beta);
			}
		}
	}
}

Sampling()部分里面,以下代码没懂。每个word所属的topic初始化时是随机分配的,中间迭代的时候,为什么还是随机的?
p[k]在这是所有topic分布之和,然后随机一个数乘以这个和,得到u。这里u可以理解成word可以取到topic的范围。
然后返回第一个比u大的p[k]的下标k,这里k代表第k个topic,还是前k个topics?
最终要求的不是word只对应某个topic,而是word下的topic分布,和topic下的分布,下一遍看代码要参考分布理解这一段。

// cumulate multinomial parameters
		for (int k = 1; k < trnModel.K; k++){
			trnModel.p[k] += trnModel.p[k - 1];
		}
		
		// scaled sample because of unnormalized p[]
		double u = Math.random() * trnModel.p[trnModel.K - 1];              // 这一段没懂
		
		for (topic = 0; topic < trnModel.K; topic++){
			if (trnModel.p[topic] > u) //sample topic w.r.t distribution p
				break;
		}

3.Inference: 推测过程


package jgibblda;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.StringTokenizer;
import java.util.Vector;

public class Inferencer {	
	// Train model
	public Model trnModel;
	public Dictionary globalDict;
	private LDACmdOption option;
	
	private Model newModel;
	public int niters = 100;
	
	//-----------------------------------------------------
	// Init method
	//-----------------------------------------------------
	public boolean init(LDACmdOption option){
		this.option = option;
		trnModel = new Model();
		
		if (!trnModel.initEstimatedModel(option))
			return false;		
		
		globalDict = trnModel.data.localDict;
		computeTrnTheta();
		computeTrnPhi();
		
		return true;
	}
	
	//inference new model ~ getting data from a specified dataset
	public Model inference( LDADataset newData){
		System.out.println("init new model");
		Model newModel = new Model();		
		
		newModel.initNewModel(option, newData, trnModel);		
		this.newModel = newModel;		
		
		System.out.println("Sampling " + niters + " iteration for inference!");		
		for (newModel.liter = 1; newModel.liter <= niters; newModel.liter++){
			//System.out.println("Iteration " + newModel.liter + " ...");
			
			// for all newz_i
			for (int m = 0; m < newModel.M; ++m){
				for (int n = 0; n < newModel.data.docs[m].length; n++){
					// (newz_i = newz[m][n]
					// sample from p(z_i|z_-1,w)
					int topic = infSampling(m, n);
					newModel.z[m].set(n, topic);
				}
			}//end foreach new doc
			
		}// end iterations
		
		System.out.println("Gibbs sampling for inference completed!");
		
		computeNewTheta();
		computeNewPhi();
		newModel.liter--;
		return this.newModel;
	}
	
	public Model inference(String [] strs){
		//System.out.println("inference");
		Model newModel = new Model();
		
		//System.out.println("read dataset");
		LDADataset dataset = LDADataset.readDataSet(strs, globalDict);
		
		return inference(dataset);
	}
	
	//inference new model ~ getting dataset from file specified in option
	public Model inference(){	
		//System.out.println("inference");
		
		newModel = new Model();
		if (!newModel.initNewModel(option, trnModel)) return null;
		
		System.out.println("Sampling " + niters + " iteration for inference!");
		
		for (newModel.liter = 1; newModel.liter <= niters; newModel.liter++){
			//System.out.println("Iteration " + newModel.liter + " ...");
			
			// for all newz_i
			for (int m = 0; m < newModel.M; ++m){
				for (int n = 0; n < newModel.data.docs[m].length; n++){
					// (newz_i = newz[m][n]
					// sample from p(z_i|z_-1,w)
					int topic = infSampling(m, n);
					newModel.z[m].set(n, topic);
				}
			}//end foreach new doc
			
		}// end iterations
		
		System.out.println("Gibbs sampling for inference completed!");		
		System.out.println("Saving the inference outputs!");
		
		computeNewTheta();
		computeNewPhi();
		newModel.liter--;
		newModel.saveModel(newModel.dfile + "." + newModel.modelName);		
		
		return newModel;
	}
	
	/**
	 * do sampling for inference
	 * m: document number
	 * n: word number?
	 */
	protected int infSampling(int m, int n){
		// remove z_i from the count variables
		int topic = newModel.z[m].get(n);
		int _w = newModel.data.docs[m].words[n];
		int w = newModel.data.lid2gid.get(_w);
		newModel.nw[_w][topic] -= 1;
		newModel.nd[m][topic] -= 1;
		newModel.nwsum[topic] -= 1;
		newModel.ndsum[m] -= 1;
		
		double Vbeta = trnModel.V * newModel.beta;
		double Kalpha = trnModel.K * newModel.alpha;
		
		// do multinomial sampling via cummulative method		
		for (int k = 0; k < newModel.K; k++){			
			newModel.p[k] = (trnModel.nw[w][k] + newModel.nw[_w][k] + newModel.beta)/(trnModel.nwsum[k] +  newModel.nwsum[k] + Vbeta) *
					(newModel.nd[m][k] + newModel.alpha)/(newModel.ndsum[m] + Kalpha);
		}
		
		// cummulate multinomial parameters
		for (int k = 1; k < newModel.K; k++){
			newModel.p[k] += newModel.p[k - 1];
		}
		
		// scaled sample because of unnormalized p[]
		double u = Math.random() * newModel.p[newModel.K - 1];     
		
		for (topic = 0; topic < newModel.K; topic++){
			if (newModel.p[topic] > u)
				break;
		}
		
		// add newly estimated z_i to count variables
		newModel.nw[_w][topic] += 1;
		newModel.nd[m][topic] += 1;
		newModel.nwsum[topic] += 1;
		newModel.ndsum[m] += 1;
		
		return topic;
	}
	
	protected void computeNewTheta(){
		for (int m = 0; m < newModel.M; m++){
			for (int k = 0; k < newModel.K; k++){
				newModel.theta[m][k] = (newModel.nd[m][k] + newModel.alpha) / (newModel.ndsum[m] + newModel.K * newModel.alpha);
			}//end foreach topic
		}//end foreach new document
	}
	
	protected void computeNewPhi(){
		for (int k = 0; k < newModel.K; k++){
			for (int _w = 0; _w < newModel.V; _w++){
				Integer id = newModel.data.lid2gid.get(_w);
				
				if (id != null){
					newModel.phi[k][_w] = (trnModel.nw[id][k] + newModel.nw[_w][k] + newModel.beta) / (newModel.nwsum[k] + newModel.nwsum[k] + trnModel.V * newModel.beta);
				}
			}//end foreach word
		}// end foreach topic
	}
	
	protected void computeTrnTheta(){
		for (int m = 0; m < trnModel.M; m++){
			for (int k = 0; k < trnModel.K; k++){
				trnModel.theta[m][k] = (trnModel.nd[m][k] + trnModel.alpha) / (trnModel.ndsum[m] + trnModel.K * trnModel.alpha);
			}
		}
	}
	
	protected void computeTrnPhi(){
		for (int k = 0; k < trnModel.K; k++){
			for (int w = 0; w < trnModel.V; w++){
				trnModel.phi[k][w] = (trnModel.nw[w][k] + trnModel.beta) / (trnModel.nwsum[k] + trnModel.V * trnModel.beta);
			}
		}
	}
}

4.数据可视化和输出
完整代码可参考原版JGibblda

原文地址:https://www.cnblogs.com/cyno/p/4451804.html