逻辑回归(推荐系统)

import java.io.PrintWriter

import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.linalg.SparseVector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.classification.{LogisticRegressionModel, LogisticRegressionWithSGD}

import scala.collection.Map

/**
  * Created by root on 2016/9/17 0017.
  */
object Recommonder {
  def main(args: Array[String]) {
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
    val sc = new SparkContext(new SparkConf().setAppName("rcmd").setMaster("local"))
//    读取文件,按	分隔开,下标为0的是标签,下标为1的是特征,以;隔开的
    val data: RDD[Array[String]] = sc.textFile("D:\Program Files\feiq\Recv Files\spark20160827\推荐系统\DataGenerator\000000_0").map(_.split("	"))
//    构建一个大的向量,将所有特征构建成一个向量,转成map,为了将每个样本映射成稀疏向量
//    先压扁拿到所有特征,去重,加上下标
    val dict: Map[String, Long] = data.flatMap(_(1).split(";")).map(_.split(":")(0)).distinct().zipWithIndex().collectAsMap()
//    构建训练数据集,这里的sample表示一个样本,包含标签和特征
    val traindata: RDD[LabeledPoint] = data.map(sample=>{
//  因为MLlib只接收1.0和0.0做分类,这里我们模式匹配,转成1.0和0.0
      val label = sample(0) match {
        case "1" => 1.0
        case _ => 0.0
      }
//  找到非零元素下标,用当前样本的特征在字典map中查找下标,为非零下标
      val indexs = sample(1).split(";").map(_.split(":")(0))map(feature=>{
       val index: Long =  dict.getOrElse(feature,-1L)
        index.toInt
      })
//  转布尔值,这里非零值就是1.0
      val value = Array.fill(indexs.length)(1.0)
      new LabeledPoint(label,new SparseVector(dict.size,indexs,value));
    })
//  逻辑回归算法训练,10表示迭代次数,0.9表示步长,都是可以调整的参数
   val model: LogisticRegressionModel = LogisticRegressionWithSGD.train(traindata,10,0.9)
//   拿到特征权重
   val weights: Array[Double] =  model.weights.toArray
//    字典翻转
    val map: Map[Long, String] = dict.map(x=>(x._2,x._1))
    val pt = new PrintWriter("D:\output\a.txt")
//    遍历,权重的下标和字典map的下标一一对应,这里得到结果
    for(i<- 0 until weights.length){
      val str = map.getOrElse(i,"")+"	"+weights(i)
      println(str)
      pt.write(str)
      pt.println()
    }
    pt.flush()
    pt.close()
  }
}

原文地址:https://www.cnblogs.com/TendToBigData/p/10501372.html