Weka中EM算法详解

 1  private void EM_Init (Instances inst)
 2     throws Exception {
 3     int i, j, k;
 4 
 5     // 由于EM算法对初始值较敏感,故选择run k means 10 times and choose best solution
 6     SimpleKMeans bestK = null;
 7     double bestSqE = Double.MAX_VALUE;
 8     for (i = 0; i < 10; i++) {
 9       SimpleKMeans sk = new SimpleKMeans();
10       sk.setSeed(m_rr.nextInt());
11       sk.setNumClusters(m_num_clusters);
12       sk.setDisplayStdDevs(true);
13       sk.buildClusterer(inst);
14       //KMeans中各个cluster的平方误差
15       if (sk.getSquaredError() < bestSqE) {
16          
17           bestSqE = sk.getSquaredError();
18           bestK = sk;
19       }
20     }
21     
22     /*************** KMeans Finds the best cluster number *****************/
23     
24     
25     // initialize with best k-means solution
26     m_num_clusters = bestK.numberOfClusters();
27     // 每个样本所在各个集群的概率
28     m_weights = new double[inst.numInstances()][m_num_clusters];
29     // 评估每个集群所对应的离散型属性的相关取值
30
m_model = new DiscreteEstimator[m_num_clusters][m_num_attribs]; 31 // 每个集群所对应的连续性属性数所对应的相关取值(均值,标准偏差,样本权值(进行归一化)) 32 m_modelNormal = new double[m_num_clusters][m_num_attribs][3]; 33 // 每个集群所对应的先验概率 34 m_priors = new double[m_num_clusters]; 35 // 每个集群所对应的中心点 36 Instances centers = bestK.getClusterCentroids(); 37 // 每个集群所对应的标准差 38 Instances stdD = bestK.getClusterStandardDevs(); 39 // ??? Returns for each cluster the frequency counts for the values of each nominal attribute 40 int [][][] nominalCounts = bestK.getClusterNominalCounts(); 41 // 得到每个集群所对应的样本数 42 int [] clusterSizes = bestK.getClusterSizes(); 43 44 for (i = 0; i < m_num_clusters; i++) { 45 Instance center = centers.instance(i); 46 for (j = 0; j < m_num_attribs; j++) { 47 48 // 样本属性是离散型 49 if (inst.attribute(j).isNominal()) 50 { 51 m_model[i][j] = new DiscreteEstimator(m_theInstances.attribute(j).numValues() 52 , true); 53 for (k = 0; k < inst.attribute(j).numValues(); k++) { 54 m_model[i][j].addValue(k, nominalCounts[i][j][k]); 55 } 56 } 57 //// 样本属性是连续型 58 else 59 { 60 double minStdD = (m_minStdDevPerAtt != null)? m_minStdDevPerAtt[j]: m_minStdDev; 61 double mean = (center.isMissing(j))? inst.meanOrMode(j): center.value(j); 62 m_modelNormal[i][j][0] = mean; 63 double stdv = (stdD.instance(i).isMissing(j))? ((m_maxValues[j] - 64 m_minValues[j]) / (2 * m_num_clusters)): stdD.instance(i).value(j); 65 if (stdv < minStdD) 66 { 67 stdv = inst.attributeStats(j).numericStats.stdDev; 68 if (Double.isInfinite(stdv)) { 69 stdv = minStdD; 70 } 71 if (stdv < minStdD) { 72 stdv = minStdD; 73 } 74 } 75 if (stdv <= 0) { 76 stdv = m_minStdDev; 77 } 78 79 m_modelNormal[i][j][1] = stdv; 80 m_modelNormal[i][j][2] = 1.0; 81 } 82 } 83 } 84 85 86 for (j = 0; j < m_num_clusters; j++) { 87 // 计算每个集群的先验概率 88 m_priors[j] = clusterSizes[j]; 89 } 90 Utils.normalize(m_priors); 91 }
原文地址:https://www.cnblogs.com/likai198981/p/3170568.html