每日一题 为了工作 2020 0429 第五十八题

//Java版本的线性回归的预测代码

package com.swust.machine;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import org.apache.spark.rdd.RDD;
import scala.Tuple2;

import java.util.List;


/**
 *
 * @author 雪瞳
 * @Slogan 时钟尚且前行,人怎能再此止步!
 * @Function 线性回归算法实现
 *
 */
public class LinearRegression {
    public static void main(String[] args) {
        SparkConf conf = new SparkConf();
        conf.setMaster("local").setAppName("line");
        JavaSparkContext jsc = new JavaSparkContext(conf);
        jsc.setLogLevel("Error");

        // 读取样本数据
        JavaRDD<String> data = jsc.textFile("./data/lpsa.data");

        JavaRDD<LabeledPoint> examples = data.map(new Function<String, LabeledPoint>() {
            @Override
            public LabeledPoint call(String line) throws Exception {
                String[] splits = line.split(",");
                String y = splits[0];
                String xs = splits[1];
                String[] words = xs.split(" ");
                double[] wd = new double[words.length];
                for (int i = 0; i < words.length; i++) {
                    wd[i] = Double.parseDouble(words[i]);
                }
                return new LabeledPoint(Double.parseDouble(y),
                        Vectors.dense(wd));
            }
        });
        //将数据集按比例切分为训练集和测试集
        double[] doubles = new double[]{0.8,0.2};
        RDD<LabeledPoint> rdd = examples.rdd();
        RDD<LabeledPoint>[] TestData = rdd.randomSplit(doubles, 1L);

        //设置迭代次数
        int numIterations = 100;
        //设置迭代过程中 梯度下降算法的下降步长大小
        // 0.1 0.2 0.3 0.4
        int stepSize = 1;
        int miniBatchFraction = 1;
        LinearRegressionWithSGD lrs = new LinearRegressionWithSGD();
        //设置训练模型是否存在截距
        lrs.setIntercept(true);
        //设置步长
        lrs.optimizer().setStepSize(stepSize);
        //设置迭代次数
        lrs.optimizer().setNumIterations(numIterations);
        //计算所有样本的误差值,1代表所有样本,默认1.0
        lrs.optimizer().setMiniBatchFraction(miniBatchFraction);
        //GeneralizedLinearAlgorithm
        LinearRegressionModel model = lrs.run(TestData[0]);
        System.err.println(model.weights());
        System.err.println(model.intercept());

        //对样本的测试
        JavaRDD<Double> prediction = model.predict(TestData[1].toJavaRDD().map(new Function<LabeledPoint, Vector>() {
            @Override
            public Vector call(LabeledPoint labeledPoint) throws Exception {
                return labeledPoint.features();
            }
        }));
        //压缩样本
        JavaPairRDD<Double, Double> predictionAndLabel = prediction.zip(TestData[1].toJavaRDD().map(new Function<LabeledPoint, Double>() {
            @Override
            public Double call(LabeledPoint labeledPoint) throws Exception {
                return labeledPoint.label();
            }
        }));
        //数据分析 取其中20条
        List<Tuple2<Double, Double>> take = predictionAndLabel.take(20);
        //预测 标签
        System.err.println("prediction"+"	"+"label");
        for (Tuple2<Double, Double> elem:take){
            System.out.println(elem._1()+"	"+elem._2());
        }
        //计算数据的平均误差
        JavaRDD<Double> dataLoss = predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
            @Override
            public Double call(Tuple2<Double, Double> one) throws Exception {
                double err = one._1() - one._2();
                return Math.abs(err);
            }
        });
        Double lossResult = dataLoss.reduce(new Function2<Double, Double, Double>() {
            @Override
            public Double call(Double aDouble, Double aDouble2) throws Exception {
                return aDouble + aDouble2;
            }
        });
        double err = lossResult / TestData[1].count();
        System.err.println("Test RMSE"+err);
        jsc.stop();


    }
}

  

//由于数据量本身只有100条 所以预测效果相对较差 但是重要的是思路不是嘛

// 有道无术术可求 有术无道止于术 学会一个思想更为重要

原文地址:https://www.cnblogs.com/walxt/p/12803045.html