1 import org.apache.log4j.{Level, Logger} 2 import org.apache.spark.ml.feature.VectorAssembler 3 import org.apache.spark.ml.regression.LinearRegression 4 import org.apache.spark.sql.SparkSession 5 6 /** 7 * 线性回归 8 * Created by zhen on 2018/11/12. 9 */ 10 object LinearRegression { 11 Logger.getLogger("org").setLevel(Level.WARN) // 设置日志级别 12 def main(args: Array[String]) { 13 val spark = SparkSession 14 .builder() 15 .appName("LinearRegression") 16 .master("local[2]") 17 .getOrCreate() 18 val train_data = spark.sparkContext.textFile("E:/BDS/newsparkml/src/train.txt") // 加载数据 19 val train_map_data = train_data.map{ row => 20 val split = row.split(",") 21 (split(0).toDouble,split(1).toDouble,split(2).toDouble,split(3).toDouble, 22 split(4).toDouble,split(5).toDouble,split(6).toDouble,split(7).toDouble) 23 } 24 val df = spark.sqlContext.createDataFrame(train_map_data) 25 val colArray = Array("Population","Income","Illiteracy","LifeExp","HSGrad","Frost","Area") 26 val train_df = df.toDF(colArray(0),colArray(1),colArray(2),colArray(3),"Murder",colArray(4),colArray(5),colArray(6)) 27 val assembler = new VectorAssembler() 28 .setInputCols(colArray) 29 .setOutputCol("features") 30 val vectDF = assembler.transform(train_df) 31 val weights = Array(0.8,0.2) //设置训练集和测试集的比例 32 val split_data = vectDF.randomSplit(weights) // 拆分训练集和测试集 33 // 创建模型对象 34 val linearRegression = new LinearRegression() 35 .setFeaturesCol("features") 36 .setLabelCol("Murder") 37 .setFitIntercept(true) 38 .setMaxIter(10) 39 .setRegParam(0.3)// 正则化 40 .setElasticNetParam(0.8) 41 // 训练模型 42 val lrModel = linearRegression.fit(split_data(0)) 43 // 查看模型参数 44 //lrModel.extractParamMap() 45 println(s"Cofficients:${lrModel.coefficients} Intercept:${lrModel.intercept}") 46 //模型评估 47 val trainingSummary = lrModel.summary 48 println(s"objectiveHistoryList:${trainingSummary.objectiveHistory.toList}") 49 println(s"r2:${trainingSummary.r2}") 50 // 预测 51 val predictions = lrModel.transform(split_data(1)) 52 val predict_result = predictions.selectExpr("features","Murder","round(prediction,1) as prediction") // 保存一位小数 53 println("训练集数据------------------------------真实值--预测值") 54 predict_result.foreach(println(_)) 55 } 56 }
结果: