Spark线性回归实现优化

 1 import org.apache.log4j.{Level, Logger}
 2 import org.apache.spark.ml.feature.VectorAssembler
 3 import org.apache.spark.ml.regression.LinearRegression
 4 import org.apache.spark.sql.SparkSession
 5 
 6 /**
 7   * 线性回归
 8   * Created by zhen on 2018/11/12.
 9   */
10 object LinearRegression {
11   Logger.getLogger("org").setLevel(Level.WARN) // 设置日志级别
12   def main(args: Array[String]) {
13     val spark = SparkSession
14       .builder()
15       .appName("LinearRegression")
16       .master("local[2]")
17       .getOrCreate()
18     val train_data = spark.sparkContext.textFile("E:/BDS/newsparkml/src/train.txt") // 加载数据
19     val train_map_data = train_data.map{ row =>
20         val split = row.split(",")
21         (split(0).toDouble,split(1).toDouble,split(2).toDouble,split(3).toDouble,
22           split(4).toDouble,split(5).toDouble,split(6).toDouble,split(7).toDouble)
23       }
24     val df = spark.sqlContext.createDataFrame(train_map_data)
25     val colArray = Array("Population","Income","Illiteracy","LifeExp","HSGrad","Frost","Area")
26     val train_df = df.toDF(colArray(0),colArray(1),colArray(2),colArray(3),"Murder",colArray(4),colArray(5),colArray(6))
27     val assembler = new VectorAssembler()
28       .setInputCols(colArray)
29       .setOutputCol("features")
30     val vectDF = assembler.transform(train_df)
31     val weights = Array(0.8,0.2) //设置训练集和测试集的比例
32     val split_data = vectDF.randomSplit(weights) // 拆分训练集和测试集
33     // 创建模型对象
34     val linearRegression = new LinearRegression()
35       .setFeaturesCol("features")
36       .setLabelCol("Murder")
37       .setFitIntercept(true)
38       .setMaxIter(10)
39       .setRegParam(0.3)// 正则化
40       .setElasticNetParam(0.8)
41     // 训练模型
42     val lrModel = linearRegression.fit(split_data(0))
43     // 查看模型参数
44     //lrModel.extractParamMap()
45     println(s"Cofficients:${lrModel.coefficients} Intercept:${lrModel.intercept}")
46     //模型评估
47     val trainingSummary = lrModel.summary
48     println(s"objectiveHistoryList:${trainingSummary.objectiveHistory.toList}")
49     println(s"r2:${trainingSummary.r2}")
50     // 预测
51     val predictions = lrModel.transform(split_data(1))
52     val predict_result = predictions.selectExpr("features","Murder","round(prediction,1) as prediction") // 保存一位小数
53     println("训练集数据------------------------------真实值--预测值")
54     predict_result.foreach(println(_))
55   }
56 }

结果:

原文地址:https://www.cnblogs.com/yszd/p/9952268.html