spark 数据分析

//使用kmeans算法进行微博聚类分析

//scala版本

package com.swust.machine.line.kmeans

import org.apache.lucene.analysis.TokenStream
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.feature.{HashingTF, IDF, IDFModel}
import org.apache.spark.mllib.linalg
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
import org.wltea.analyzer.lucene.IKAnalyzer

import scala.collection.mutable.{ArrayBuffer, ListBuffer}

object ScalaKmeans {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local").setAppName("kmeans")
    val sc = new SparkContext(conf)
    //sc.setLogLevel("error")
    val input: RDD[String] = sc.textFile("./data/original.txt")
    //获取切词后的内容 k-v格式 k为微博id v为分词内容
    val wordRDD: RDD[(String, ArrayBuffer[String])] = input.mapPartitions(iterator => {
      val list = new ListBuffer[(String, ArrayBuffer[String])]
      while (iterator.hasNext) {
        //创建分词对象
        val analyzer = new IKAnalyzer(true)
        val line = iterator.next()
        val textData: Array[String] = line.split("	")
        val id: String = textData(0)
        val text: String = textData(1)
        //创建分词对象
        val ts: TokenStream = analyzer.tokenStream("", text)
        val term: CharTermAttribute = ts.getAttribute(classOf[CharTermAttribute])
        ts.reset()
        val arr = new ArrayBuffer[String]
        //遍历分词数据
        while (ts.incrementToken()) {
          arr.+=:(term.toString)
        }
        list.append((id, arr))
      }
      list.iterator
    })
    wordRDD.cache()
    //计算词频
    val hashingTF:HashingTF = new HashingTF(1000)
    val TFRdd: RDD[(String, linalg.Vector)] = wordRDD.map(one => {
      val value: ArrayBuffer[String] = one._2
      (one._1, hashingTF.transform(value))
    })
    //计算逆文本频率
    val idf: IDFModel = new IDF().fit(TFRdd.map(one => {
      one._2
    }))
    //计算每一篇微博的If-Idf值
    val tf_Idfs: RDD[(String, linalg.Vector)] = TFRdd.mapValues(one =>{
       idf.transform(one)
    })
    //根据排序映射tf_Idfs里面的每个位置到底是哪一个分词
    //按照每个词由hashingTF 映射的分区号由小到大排序,得到的每个词组对应以上得到的tfIdfs 值的顺序
    val wordTfIdf: RDD[(String, ArrayBuffer[String])] = wordRDD.mapValues(one => {
      one.distinct.sortBy(item => {
        hashingTF.indexOf(item)
      })
    })
    //使用kmeans聚类算法
    //创建Kmeans聚类对象
    val kmeans = new KMeans()
    //设置聚类中心个数
    val  cluster = 20
    kmeans.setK(cluster)
    //使用kmeans++算法
    kmeans.setInitializationMode("k-means||")
    //设置最大迭代次数
    kmeans.setMaxIterations(1000)
    //进行模型训练
    val model: KMeansModel = kmeans.run(tf_Idfs.map(one => {
      one._2
    }))

    //输出模型的20个中心点
    println(model.clusterCenters)

    //使用训练出来的kmeans模型 进行数据预测
    // 使用广播变量将模型广播
    val modelBroadcast: Broadcast[KMeansModel] = sc.broadcast(model)
    //进行训练模型预测
    val predicetion: RDD[(String, Int)] = tf_Idfs.mapValues(vector => {
      //从广播变量中获取model
      val kmeansModel: KMeansModel = modelBroadcast.value
      kmeansModel.predict(vector)
    })
    //总结预测结果
    val result: RDD[(String, (linalg.Vector, ArrayBuffer[String]))] = tf_Idfs.join(wordTfIdf)
    val res: RDD[(String, (Int, (linalg.Vector, ArrayBuffer[String])))] = predicetion.join(result)
    //查看0号类别中tf-idf较高的词汇 代表这个类的主题
    val firstRes: RDD[(String, (Int, (linalg.Vector, ArrayBuffer[String])))] = res.filter(one => {
      one._2._1 == 1
    })
    val flatRes: RDD[(Double, String)] = firstRes.flatMap(line => {
      val tf: linalg.Vector = line._2._2._1
      val words: ArrayBuffer[String] = line._2._2._2
      val list = new ListBuffer[(Double, String)]
      for (i <- 0 until (words.length)) {
        //追加词频 和 词的内容
        //value 表示当前单词在所指定的1000个向量中的位置
        //每一个位置对应一个词和一个词频
        val value: Int = hashingTF.indexOf(words(i))
        list.append((tf(value), words(i)))
      }
      list
    })
    flatRes.sortBy(_._1,false)//根据词频降序排序
      .map(_._2)//拿到所对应的词
      .filter(_.length>1)//过滤
      .distinct()//去除重复数据
      .take(30)
      .foreach(println)

    sc.stop()
  }
}

  

原文地址:https://www.cnblogs.com/walxt/p/12814863.html