SparkMllib分类问题的模板代码

  • 需求:对数据进行分类问题的处理

  • 开发步骤:

    • 1-准备SparkSession的环境
    • 2-准备大数据的数据
    • 3-读取数据并进行解析
    • 4-数据的基本信息的查看
    • 5-特征工程
    • 6-准备算法
    • 7-模型训练
    • 8-模型预测
    • 9-模型校验
    • 10-模型保存
    • 11-新数据预测
  • 代码模板:

import org.apache.spark.SparkConf
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}

/**
  * DESC: 对分类问题的模板的代码
  * Complete data processing and modeling process steps:
  *- 1-准备SparkSession的环境
  *- 2-准备大数据的数据
  *- 3-读取数据并进行解析
  *- 4-数据的基本信息的查看
  *- 5-特征工程
  *- 6-准备算法
  *- 7-模型训练
  *- 8-模型预测
  *- 9-模型校验
  *- 10-模型保存
  *- 11-新数据预测
  *
  */
object ClassficationModelTest {

  var datapath = "D:\BigData\Workspace\SparkMachineLearningTest\SparkMllib_BigData32\src\main\resources\iris.csv"

  def main(args: Array[String]): Unit = {
    //    - 1-准备SparkSession的环境
    val conf: SparkConf = new SparkConf().setAppName("ClassficationModelTest").setMaster("local[*]")
    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
    spark.sparkContext.setLogLevel("WARN")
    //    - 2-准备大数据的数据
    val irisDF: DataFrame = spark.read.format("csv")
      .option("header", true)
      .option("inferschema", true)
      .option("sep", ",")
      .load(datapath)
    //    - 3-读取数据并进行解析
    irisDF.show(10, false)
    //    +------------+-----------+------------+-----------+-----------+
    //    |sepal_length|sepal_width|petal_length|petal_width|class      |
    //    +------------+-----------+------------+-----------+-----------+
    //    |5.1         |3.5        |1.4         |0.2        |Iris-setosa|
    //      |4.9         |3.0        |1.4         |0.2        |Iris-setosa|
    //      |4.7         |3.2        |1.3         |0.2        |Iris-setosa|
    //      |4.6         |3.1        |1.5         |0.2        |Iris-setosa|
    //    - 4-数据的基本信息的查看
    irisDF.printSchema()
    // 因为在写各种string类型数据的时候可能会有一些单词拼写错误,可以实现定义
    val sepal_length_feeature = "sepal_length"
    val sepal_width_feeature = "sepal_width"
    val petal_length_feeature = "petal_length"
    val petal_width_feeature = "petal_width"
    val class_label = "class"
    //    root
    //    |-- sepal_length: double (nullable = true)
    //    |-- sepal_ double (nullable = true)
    //    |-- petal_length: double (nullable = true)
    //    |-- petal_ double (nullable = true)
    //    |-- class: string (nullable = true)
    //    - 5-特征工程
    //5-1处理类别型的数据class
    val stringIndexer: StringIndexer = new StringIndexer()
      .setInputCol(class_label)
      .setOutputCol("classlabel")
    val stringIndexerModel: StringIndexerModel = stringIndexer.fit(irisDF)
    val indexDF: DataFrame = stringIndexerModel.transform(irisDF)
    //5-2处理分散的特征整合为特征向量
    val vectorAssembler: VectorAssembler = new VectorAssembler()
      .setInputCols(Array(sepal_length_feeature, sepal_width_feeature, petal_length_feeature, petal_width_feeature))
      .setOutputCol("features")
    val vecDF: DataFrame = vectorAssembler.transform(indexDF)
    //5-3VectorIndexer对类别值的索引化,加速构建决策树
    val vectorIndexer: VectorIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("vecindexFeatures")
      .setMaxCategories(20)
    val vectorIndexerModel: VectorIndexerModel = vectorIndexer.fit(vecDF)
    val vecindexerDF: DataFrame = vectorIndexerModel.transform(vecDF)
    vecindexerDF.show(10, false)
    //    - 6-准备算法
    val classifier: DecisionTreeClassifier = new DecisionTreeClassifier()
      .setLabelCol("classlabel")
      .setPredictionCol("prces")
      .setFeaturesCol("vecindexFeatures")
      .setMaxDepth(5)
      .setImpurity("gini")
    val Array(trainingSet, testSet): Array[Dataset[Row]] = vecindexerDF.randomSplit(Array(0.8, 0.2), seed = 1234L)
    //    - 7-模型训练
    val model: DecisionTreeClassificationModel = classifier.fit(trainingSet)
    //    - 8-模型预测
    val y_pred_train: DataFrame = model.transform(trainingSet)
    val y_pred_test: DataFrame = model.transform(testSet)
    y_pred_train.show(10, false)
    //    - 9-模型校验
    val evaluator: MulticlassClassificationEvaluator = new MulticlassClassificationEvaluator()
      //"(f1|weightedPrecision|weightedRecall|accuracy)"
      .setMetricName("accuracy")
      .setPredictionCol("prces")
      .setLabelCol("classlabel")
    val acc_test: Double = evaluator.evaluate(y_pred_test)
    val acc_train: Double = evaluator.evaluate(y_pred_train)
    println("acc in trainset score is:", acc_train)
    println("acc in testset score is:", acc_test)
    //    (acc in trainset score is:,0.9920634920634921)
    //    (acc in testset score is:,0.9583333333333334)
    //    //    - 10-模型保存
    //    val datapath="D:\BigData\Workspace\SparkMachineLearningTest\SparkMllib_BigData32\src\main\resources\model1"
    //    model.save(datapath)
    //    //    - 11-新数据预测
    //    DecisionTreeClassificationModel.load(datapath)

  }
}
原文地址:https://www.cnblogs.com/haojia/p/12396975.html