Kafka:ZK+Kafka+Spark Streaming集群环境搭建(十五)Spark编写UDF、UDAF、Agg函数

Spark Sql提供了丰富的内置函数让开发者来使用,但实际开发业务场景可能很复杂,内置函数不能够满足业务需求,因此spark sql提供了可扩展的内置函数。

UDF:是普通函数,输入一个或多个参数,返回一个值。比如:len(),isnull()

UDAF:是聚合函数,输入一组值,返回一个聚合结果。比如:max(),avg(),sum()

Spark编写UDF函数

下边的例子是在spark2.0之前的示例:例子中展示只有一个参数输入,和一个参数输出的UDF。

package com.dx.streaming.producer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class TestUDF1 {
    public static void main(String[] args) {        
        SparkConf sparkConf = new SparkConf();
        sparkConf.setMaster("local[2]");
        sparkConf.setAppName("spark udf test");
        JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf);
        @SuppressWarnings("deprecation")
        SQLContext sqlContext=new SQLContext(javaSparkContext);
        JavaRDD<String> javaRDD = javaSparkContext.parallelize(Arrays.asList("1,zhangsan", "2,lisi", "3,wangwu", "4,zhaoliu"));
        JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
            private static final long serialVersionUID = -4769584490875182711L;

            @Override
            public Row call(String line) throws Exception {
                String[] fields = line.split(",");
                return RowFactory.create(fields);
            }
        });

        List<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("id", DataTypes.StringType, true));
        fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));

        StructType schema = DataTypes.createStructType(fields);
        Dataset<Row> ds = sqlContext.createDataFrame(rowRDD, schema);
        ds.createOrReplaceTempView("user");

        // 根据UDF函数参数的个数来决定是实现哪一个UDF  UDF1,UDF2。。。。UDF1xxx
        sqlContext.udf().register("strLength", new UDF1<String, Integer>() {
            private static final long serialVersionUID = -8172995965965931129L;

            @Override
            public Integer call(String t1) throws Exception {
                return t1.length();
            }
        }, DataTypes.IntegerType);

        Dataset<Row> rows = sqlContext.sql("select id,name,strLength(name) as length from user");
        rows.show();

        javaSparkContext.stop();
    }
}

输出效果:

+---+--------+------+
| id|    name|length|
+---+--------+------+
|  1|zhangsan|     8|
|  2|    lisi|     4|
|  3|  wangwu|     6|
|  4| zhaoliu|     7|
+---+--------+------+

上边使用UDF展示了:单个输入,单个输出的函数。那么下边将会展示使用spark2.0实现三个输入,一个输出的UDF函数。

package com.dx.streaming.producer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.api.java.UDF3;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class TestUDF2 {
    public static void main(String[] args) {
        SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
        Dataset<String> row = sparkSession.createDataset(Arrays.asList("1,zhangsan", "2,lisi", "3,wangwu", "4,zhaoliu"), Encoders.STRING());

        // 根据UDF函数参数的个数来决定是实现哪一个UDF UDF1,UDF2。。。。UDF1xxx
        sparkSession.udf().register("strLength", new UDF1<String, Integer>() {
            private static final long serialVersionUID = -8172995965965931129L;

            @Override
            public Integer call(String t1) throws Exception {
                return t1.length();
            }
        }, DataTypes.IntegerType);
        sparkSession.udf().register("strConcat", new UDF3<String, String, String, String>() {
            private static final long serialVersionUID = -8172995965965931129L;

            @Override
            public String call(String combChar, String t1, String t2) throws Exception {
                return t1 + combChar + t2;
            }
        }, DataTypes.StringType);

        showByStruct(sparkSession, row);
        System.out.println("==========================================");
        showBySchema(sparkSession, row);

        sparkSession.stop();
    }

    private static void showBySchema(SparkSession sparkSession, Dataset<String> row) {
        JavaRDD<String> javaRDD = row.javaRDD();
        JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
            private static final long serialVersionUID = -4769584490875182711L;

            @Override
            public Row call(String line) throws Exception {
                String[] fields = line.split(",");
                return RowFactory.create(fields);
            }
        });

        List<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("id", DataTypes.StringType, true));
        fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));

        StructType schema = DataTypes.createStructType(fields);
        Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
        ds.show();
        ds.createOrReplaceTempView("user");

        Dataset<Row> rows = sparkSession.sql("select id,name,strLength(name) as length,strConcat('+',id,name) as str from user");
        rows.show();
    }

    private static void showByStruct(SparkSession sparkSession, Dataset<String> row) {
        JavaRDD<Person> map = row.javaRDD().map(Person::parsePerson);
        Dataset<Row> persons = sparkSession.createDataFrame(map, Person.class);
        persons.show();

        persons.createOrReplaceTempView("user");

        Dataset<Row> rows = sparkSession.sql("select id,name,strLength(name) as length,strConcat('-',id,name) as str from user");
        rows.show();
    }
}

Person.java

package com.dx.streaming.producer;

import java.io.Serializable;

public class Person implements Serializable{
    private String id;
    private String name;

    public Person(String id, String name) {
        this.id = id;
        this.name = name;
    }

    public String getId() {
        return id;
    }

    public void setId(String id) {
        this.id = id;
    }

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }
    
    public static Person parsePerson(String line)  {
        String[] fields = line.split(",");
        Person person = new Person(fields[0], fields[1]);
        return person;
    }
}
View Code

需要注意的地方,我们全局udf函数只需要注册一次,就允许多次调用。

输出效果:

+---+--------+
| id|    name|
+---+--------+
|  1|zhangsan|
|  2|    lisi|
|  3|  wangwu|
|  4| zhaoliu|
+---+--------+

+---+--------+------+----------+
| id|    name|length|       str|
+---+--------+------+----------+
|  1|zhangsan|     8|1-zhangsan|
|  2|    lisi|     4|    2-lisi|
|  3|  wangwu|     6|  3-wangwu|
|  4| zhaoliu|     7| 4-zhaoliu|
+---+--------+------+----------+

==========================================

+---+--------+
| id|    name|
+---+--------+
|  1|zhangsan|
|  2|    lisi|
|  3|  wangwu|
|  4| zhaoliu|
+---+--------+

+---+--------+------+----------+
| id|    name|length|       str|
+---+--------+------+----------+
|  1|zhangsan|     8|1+zhangsan|
|  2|    lisi|     4|    2+lisi|
|  3|  wangwu|     6|  3+wangwu|
|  4| zhaoliu|     7| 4+zhaoliu|
+---+--------+------+----------+

相信认真阅读的话,通过上边的两个示例,就可以掌握其用法。

Spark编写UDAF函数

自定义聚合函数需要实现UserDefinedAggregateFunction,以下是该抽象类的定义:

package org.apache.spark.sql.expressions
 
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.types._
import org.apache.spark.annotation.Experimental
 
/**
 * :: Experimental ::
 * The base class for implementing user-defined aggregate functions (UDAF).
 */
@Experimental
abstract class UserDefinedAggregateFunction extends Serializable {
 
  /**
   * A [[StructType]] represents data types of input arguments of this aggregate function.
   * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
   * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like
   *
   * ```
   *   new StructType()
   *    .add("doubleInput", DoubleType)
   *    .add("longInput", LongType)
   * ```
   *
   * The name of a field of this [[StructType]] is only used to identify the corresponding
   * input argument. Users can choose names to identify the input arguments.
   */
   //输入参数的数据类型定义
  def inputSchema: StructType
 
  /**
   * A [[StructType]] represents data types of values in the aggregation buffer.
   * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
   * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]],
   * the returned [[StructType]] will look like
   *
   * ```
   *   new StructType()
   *    .add("doubleInput", DoubleType)
   *    .add("longInput", LongType)
   * ```
   *
   * The name of a field of this [[StructType]] is only used to identify the corresponding
   * buffer value. Users can choose names to identify the input arguments.
   */
   //聚合的中间过程中产生的数据的数据类型定义
  def bufferSchema: StructType
 
  /**
   * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]].
   */
   //聚合结果的数据类型定义
  def dataType: DataType
 
  /**
   * Returns true if this function is deterministic, i.e. given the same input,
   * always return the same output.
   */
   //一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的。
  def deterministic: Boolean
 
  /**
   * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer.
   *
   * The contract should be that applying the merge function on two initial buffers should just
   * return the initial buffer itself, i.e.
   * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`.
   */
   //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
   //即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
  def initialize(buffer: MutableAggregationBuffer): Unit
  /**
   * Updates the given aggregation buffer `buffer` with new input data from `input`.
   *
   * This is called once per input row.
   */
   //用输入数据input更新buffer值,类似于combineByKey
  def update(buffer: MutableAggregationBuffer, input: Row): Unit
  /**
   * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`.
   *
   * This is called when we merge two partially aggregated data together.
   */
   //合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
   //这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
  /**
   * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
   * aggregation buffer.
   */
   //计算并返回最终的聚合结果
  def evaluate(buffer: Row): Any
  /**
   * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments.
   */
   //所有输入数据进行聚合
  @scala.annotation.varargs
  def apply(exprs: Column*): Column = {
    val aggregateExpression =
      AggregateExpression2(
        ScalaUDAF(exprs.map(_.expr), this),
        Complete,
        isDistinct = false)
    Column(aggregateExpression)
  }
 
  /**
   * Creates a [[Column]] for this UDAF using the distinct values of the given
   * [[Column]]s as input arguments.
   */
   //所有输入数据去重后进行聚合
  @scala.annotation.varargs
  def distinct(exprs: Column*): Column = {
    val aggregateExpression =
      AggregateExpression2(
        ScalaUDAF(exprs.map(_.expr), this),
        Complete,
        isDistinct = true)
    Column(aggregateExpression)
  }
}
 
/**
 * :: Experimental ::
 * A [[Row]] representing an mutable aggregation buffer.
 *
 * This is not meant to be extended outside of Spark.
 */
@Experimental
abstract class MutableAggregationBuffer extends Row {
 
  /** Update the ith value of this buffer. */
  def update(i: Int, value: Any): Unit
}

实现单列求平均数的聚合函数:

package com.dx.streaming.producer;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;

public class SimpleAvg extends UserDefinedAggregateFunction {
    private static final long serialVersionUID = 3924913264741215131L;

    @Override
    public StructType inputSchema() {
        StructType structType=     new StructType().add("myinput",DataTypes.DoubleType);
        return structType;
    }
    
    
    @Override
    public StructType bufferSchema() {
        StructType structType= new StructType().add("mycnt", DataTypes.LongType).add("mysum", DataTypes.DoubleType);
        return structType;
    }

    @Override
    public DataType dataType() {        
        return DataTypes.DoubleType;
    }

    @Override
    public boolean deterministic() {
        return true;
    }

   //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
   //即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, 0l); // 用于存储mysum 0L=>是一个长整型Long类型的0
        buffer.update(1, 0d); // 用于存储mycnt 0D=>是一个长整型Double类型的0
    }

    /**
     * partitions内部combine
     * */
    //用输入数据input更新buffer值,类似于combineByKey
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        buffer.update(0, buffer.getLong(0)+1);                     // 條目數+1
        buffer.update(1, buffer.getDouble(1)+input.getDouble(0)); // 输入汇总
    }

    /**
     * partitions间合并:MutableAggregationBuffer继承自Row。
     * */   
    //合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
    //这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。    
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        buffer1.update(0, buffer1.getLong(0)+buffer2.getLong(0));     // 條目數合併
        buffer1.update(1, buffer1.getDouble(1)+buffer2.getDouble(1)); // 输入汇总合併
    }
    
    //计算并返回最终的聚合结果
    @Override
    public Object evaluate(Row buffer) {
        // 计算平均值
        Double avg = buffer.getDouble(1) / buffer.getLong(0);
        Double avgFormat = Double.parseDouble(String.format("%.2f", avg));

        return avgFormat;
    }
}

下边展示下如何使用自定义的UDAF函数:

package com.dx.streaming.producer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class TestUDAF1 {

    public static void main(String[] args) {
        SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
        Dataset<String> row = sparkSession.createDataset(Arrays.asList(
                "1,zhangsan,English,80",
                "2,zhangsan,History,87",
                "3,zhangsan,Chinese,88",
                "4,zhangsan,Chemistry,96",
                "5,lisi,English,70",
                "6,lisi,Chinese,74",
                "7,lisi,History,75",
                "8,lisi,Chemistry,77",
                "9,lisi,Physics,79",
                "10,lisi,Biology,82",
                "11,wangwu,English,96",
                "12,wangwu,Chinese,98",
                "13,wangwu,History,91",
                "14,zhaoliu,English,68",
                "15,zhaoliu,Chinese,66"), Encoders.STRING());
        JavaRDD<String> javaRDD = row.javaRDD();
        JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
            private static final long serialVersionUID = -4769584490875182711L;

            @Override
            public Row call(String line) throws Exception {
                String[] fields = line.split(",");
                Integer id=Integer.parseInt(fields[0]);
                String name=fields[1];
                String subject=fields[2];
                Double achieve=Double.parseDouble(fields[3]);
                return RowFactory.create(id,name,subject,achieve);
            }
        });

        List<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true));
        fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
        fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true));
        fields.add(DataTypes.createStructField("achieve", DataTypes.DoubleType, false));

        StructType schema = DataTypes.createStructType(fields);
        Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
        ds.show();
        ds.createOrReplaceTempView("user");

        UserDefinedAggregateFunction udaf=new SimpleAvg();
        sparkSession.udf().register("avg_format", udaf);
        
        Dataset<Row> rows1 = sparkSession.sql("select name,avg(achieve) avg_achieve from user group by name");
        rows1.show();

        Dataset<Row> rows2 = sparkSession.sql("select name,avg_format(achieve) avg_achieve from user group by name");
        rows2.show();
    }

}

输出结果:

+---+--------+---------+-------+
| id|    name|  subject|achieve|
+---+--------+---------+-------+
|  1|zhangsan|  English|   80.0|
|  2|zhangsan|  History|   87.0|
|  3|zhangsan|  Chinese|   88.0|
|  4|zhangsan|Chemistry|   96.0|
|  5|    lisi|  English|   70.0|
|  6|    lisi|  Chinese|   74.0|
|  7|    lisi|  History|   75.0|
|  8|    lisi|Chemistry|   77.0|
|  9|    lisi|  Physics|   79.0|
| 10|    lisi|  Biology|   82.0|
| 11|  wangwu|  English|   96.0|
| 12|  wangwu|  Chinese|   98.0|
| 13|  wangwu|  History|   91.0|
| 14| zhaoliu|  English|   68.0|
| 15| zhaoliu|  Chinese|   66.0|
+---+--------+---------+-------+

+--------+-----------------+
|    name|      avg_achieve|
+--------+-----------------+
|  wangwu|             95.0|
| zhaoliu|             67.0|
|zhangsan|            87.75|
|    lisi|76.16666666666667|
+--------+-----------------+

+--------+-----------+
|    name|avg_achieve|
+--------+-----------+
|  wangwu|       95.0|
| zhaoliu|       67.0|
|zhangsan|      87.75|
|    lisi|      76.17|
+--------+-----------+

实现多列之和,再求平均数的UDAF聚合函数:

package com.dx.streaming.producer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class TestUDAF1 {

    public static void main(String[] args) {
        SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
        Dataset<String> row = sparkSession.createDataset(Arrays.asList(
                "1,zhangsan,English,80,89",
                "2,zhangsan,History,87,88",
                "3,zhangsan,Chinese,88,87",
                "4,zhangsan,Chemistry,96,95",
                "5,lisi,English,70,75",
                "6,lisi,Chinese,74,67",
                "7,lisi,History,75,80",
                "8,lisi,Chemistry,77,70",
                "9,lisi,Physics,79,80",
                "10,lisi,Biology,82,83",
                "11,wangwu,English,96,84",
                "12,wangwu,Chinese,98,64",
                "13,wangwu,History,91,92",
                "14,zhaoliu,English,68,80",
                "15,zhaoliu,Chinese,66,69"), Encoders.STRING());
        JavaRDD<String> javaRDD = row.javaRDD();
        JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
            private static final long serialVersionUID = -4769584490875182711L;

            @Override
            public Row call(String line) throws Exception {
                String[] fields = line.split(",");
                Integer id=Integer.parseInt(fields[0]);
                String name=fields[1];
                String subject=fields[2];
                Double achieve1=Double.parseDouble(fields[3]);
                Double achieve2=Double.parseDouble(fields[4]);
                return RowFactory.create(id,name,subject,achieve1,achieve2);
            }
        });

        List<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true));
        fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
        fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true));
        fields.add(DataTypes.createStructField("achieve1", DataTypes.DoubleType, false));
        fields.add(DataTypes.createStructField("achieve2", DataTypes.DoubleType, false));

        StructType schema = DataTypes.createStructType(fields);
        Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
        ds.show();
        ds.createOrReplaceTempView("user");

        UserDefinedAggregateFunction udaf=new MutilAvg(2);
        sparkSession.udf().register("avg_format", udaf);

        Dataset<Row> rows1 = sparkSession.sql("select name,avg(achieve1+achieve2) avg_achieve from user group by name");
        rows1.show();

        Dataset<Row> rows2 = sparkSession.sql("select name,avg_format(achieve1,achieve2) avg_achieve from user group by name");
        rows2.show();
    }
}

上边创建了一个DataSet,包含列:id,name,achieve1,achieve2,使用其中MutilAvg实现的就是一个多列求和之后在进行求平均的使用。

MutilAvg.java(udaf函数):

package com.dx.streaming.producer;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;

public class MutilAvg extends UserDefinedAggregateFunction {
    private static final long serialVersionUID = 3924913264741215131L;
    private int columnSize=1;
    
    public MutilAvg(int columnSize){
        this.columnSize=columnSize;
    }
    
    @Override
    public StructType inputSchema() {
        StructType structType=     new StructType();
        for(int i=0;i<columnSize;i++){
            structType.add("myinput"+i,DataTypes.DoubleType);
        }
        return structType;
    }
        
    @Override
    public StructType bufferSchema() {
        StructType structType= new StructType().add("mycnt", DataTypes.LongType).add("mysum", DataTypes.DoubleType);
        return structType;
    }

    @Override
    public DataType dataType() {        
        return DataTypes.DoubleType;
    }

    @Override
    public boolean deterministic() {
        return true;
    }

   //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
   //即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, 0l); // 用于存储mysum 0L=>是一个长整型Long类型的0
        buffer.update(1, 0d); // 用于存储mycnt 0D=>是一个长整型Double类型的0
    }

    /**
     * partitions内部combine
     * */
    //用输入数据input更新buffer值,类似于combineByKey
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        buffer.update(0, buffer.getLong(0)+1);                     // 條目數+1
        
        // 输入一行包含多列,因此需要把铜一行的多列合并。
        Double currentLineSumValue= 0d;
        for(int i=0;i<columnSize;i++){
            currentLineSumValue+=input.getDouble(i);
        }
        
        buffer.update(1, buffer.getDouble(1)+currentLineSumValue); // 输入汇总
    }

    /**
     * partitions间合并:MutableAggregationBuffer继承自Row。
     * */   
    //合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
    //这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。    
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        buffer1.update(0, buffer1.getLong(0)+buffer2.getLong(0));     // 條目數合併
        buffer1.update(1, buffer1.getDouble(1)+buffer2.getDouble(1)); // 输入汇总合併
    }
    
    //计算并返回最终的聚合结果
    @Override
    public Object evaluate(Row buffer) {
        // 计算平均值
        Double avg = buffer.getDouble(1) / buffer.getLong(0);
        Double avgFormat = Double.parseDouble(String.format("%.2f", avg));

        return avgFormat;
    }
}
View Code

测试输出:

        +---+--------+---------+--------+--------+
        | id|    name|  subject|achieve1|achieve2|
        +---+--------+---------+--------+--------+
        |  1|zhangsan|  English|    80.0|    89.0|
        |  2|zhangsan|  History|    87.0|    88.0|
        |  3|zhangsan|  Chinese|    88.0|    87.0|
        |  4|zhangsan|Chemistry|    96.0|    95.0|
        |  5|    lisi|  English|    70.0|    75.0|
        |  6|    lisi|  Chinese|    74.0|    67.0|
        |  7|    lisi|  History|    75.0|    80.0|
        |  8|    lisi|Chemistry|    77.0|    70.0|
        |  9|    lisi|  Physics|    79.0|    80.0|
        | 10|    lisi|  Biology|    82.0|    83.0|
        | 11|  wangwu|  English|    96.0|    84.0|
        | 12|  wangwu|  Chinese|    98.0|    64.0|
        | 13|  wangwu|  History|    91.0|    92.0|
        | 14| zhaoliu|  English|    68.0|    80.0|
        | 15| zhaoliu|  Chinese|    66.0|    69.0|
        +---+--------+---------+--------+--------+

        +--------+-----------+
        |    name|avg_achieve|
        +--------+-----------+
        |  wangwu|      175.0|
        | zhaoliu|      141.5|
        |zhangsan|      177.5|
        |    lisi|      152.0|
        +--------+-----------+

        +--------+-----------+
        |    name|avg_achieve|
        +--------+-----------+
        |  wangwu|      175.0|
        | zhaoliu|      141.5|
        |zhangsan|      177.5|
        |    lisi|      152.0|
        +--------+-----------+

实现多列分别求最大值,之后再从多列中最大值中找出一个最大的值的UDAF聚合函数:

package com.dx.streaming.producer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class TestUDAF2 {

    public static void main(String[] args) {
        SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
        Dataset<String> row = sparkSession.createDataset(Arrays.asList(
                "1,zhangsan,English,80,89",
                "2,zhangsan,History,87,88",
                "3,zhangsan,Chinese,88,87",
                "4,zhangsan,Chemistry,96,95",
                "5,lisi,English,70,75",
                "6,lisi,Chinese,74,67",
                "7,lisi,History,75,80",
                "8,lisi,Chemistry,77,70",
                "9,lisi,Physics,79,80",
                "10,lisi,Biology,82,83",
                "11,wangwu,English,96,84",
                "12,wangwu,Chinese,98,64",
                "13,wangwu,History,91,92",
                "14,zhaoliu,English,68,80",
                "15,zhaoliu,Chinese,66,69"), Encoders.STRING());
        JavaRDD<String> javaRDD = row.javaRDD();
        JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
            private static final long serialVersionUID = -4769584490875182711L;

            @Override
            public Row call(String line) throws Exception {
                String[] fields = line.split(",");
                Integer id=Integer.parseInt(fields[0]);
                String name=fields[1];
                String subject=fields[2];
                Double achieve1=Double.parseDouble(fields[3]);
                Double achieve2=Double.parseDouble(fields[4]);
                return RowFactory.create(id,name,subject,achieve1,achieve2);
            }
        });

        List<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true));
        fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
        fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true));
        fields.add(DataTypes.createStructField("achieve1", DataTypes.DoubleType, false));
        fields.add(DataTypes.createStructField("achieve2", DataTypes.DoubleType, false));

        StructType schema = DataTypes.createStructType(fields);
        Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
        ds.show();

        ds.createOrReplaceTempView("user");

        UserDefinedAggregateFunction udaf=new MutilMax(2,0);
        sparkSession.udf().register("max_vals", udaf);

        Dataset<Row> rows1 = sparkSession.sql(""
                + "select name,max(achieve) as max_achieve "
                + "from "
                + "("
                + "select name,max(achieve1) achieve from user group by name "
                + "union all "
                + "select name,max(achieve2) achieve from user group by name "
                + ") t10 "
                + "group by name");
        rows1.show();

        Dataset<Row> rows2 = sparkSession.sql("select name,max_vals(achieve1,achieve2) as max_achieve from user group by name");
        rows2.show();
    }
}

上边创建了一个DataSet,包含列:id,name,achieve1,achieve2,使用其中MutilMax实现的就是一个多列分别求出各自列的最大值,再从这些列的最大值中找出最大的一个值作为返回的最大值。

MutilMax.java(udaf函数):

package com.dx.streaming.producer;

import java.util.ArrayList;
import java.util.List;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class MutilMax extends UserDefinedAggregateFunction {
    private static final long serialVersionUID = 3924913264741215131L;
    private int columnSize = 1;
    private Double defaultValue;

    public MutilMax(int columnSize, double defaultValue) {
        this.columnSize = columnSize;
        this.defaultValue = defaultValue;
    }

    @Override
    public StructType inputSchema() {
        List<StructField> inputFields = new ArrayList<StructField>();
        for (int i = 0; i < this.columnSize; i++) {
            inputFields.add(DataTypes.createStructField("myinput" + i, DataTypes.DoubleType, true));
        }
        StructType inputSchema = DataTypes.createStructType(inputFields);
        return inputSchema;
    }

    @Override
    public StructType bufferSchema() {
        List<StructField> bufferFields = new ArrayList<StructField>();
        for (int i = 0; i < this.columnSize; i++) {
            bufferFields.add(DataTypes.createStructField("mymax" + i, DataTypes.DoubleType, true));
        }
        StructType bufferSchema = DataTypes.createStructType(bufferFields);
        return bufferSchema;
    }

    @Override
    public DataType dataType() {
        return DataTypes.DoubleType;
    }

    @Override
    public boolean deterministic() {
        return false;
    }

    // 设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
    // 即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        for (int i = 0; i < this.columnSize; i++) {
            buffer.update(i, 0d);
        }
    }

    /**
     * partitions内部combine
     */
    // 用输入数据input更新buffer值,类似于combineByKey
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        for (int i = 0; i < this.columnSize; i++) {
            if( buffer.getDouble(i) >input.getDouble(i)){
                buffer.update(i, buffer.getDouble(i));
            }else{
                buffer.update(i, input.getDouble(i));
            }
        }
    }

    /**
     * partitions间合并:MutableAggregationBuffer继承自Row。
     */
    // 合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
    // 这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        for (int i = 0; i < this.columnSize; i++) {
            if( buffer1.getDouble(i) >buffer2.getDouble(i)){
                buffer1.update(i, buffer1.getDouble(i));
            }else{
                buffer1.update(i, buffer2.getDouble(i));
            }
        }
    }

    // 计算并返回最终的聚合结果
    @Override
    public Object evaluate(Row buffer) {
        // 计算平均值
        Double max = Double.MIN_VALUE;
        for (int i = 0; i < this.columnSize; i++) {
            if (buffer.getDouble(i) > max) {
                max = buffer.getDouble(i);
            }
        }

        if (max == Double.MIN_VALUE) {
            max = this.defaultValue;
        }

        return max;
    }

}
View Code

打印结果:

        +---+--------+---------+--------+--------+
        | id|    name|  subject|achieve1|achieve2|
        +---+--------+---------+--------+--------+
        |  1|zhangsan|  English|    80.0|    89.0|
        |  2|zhangsan|  History|    87.0|    88.0|
        |  3|zhangsan|  Chinese|    88.0|    87.0|
        |  4|zhangsan|Chemistry|    96.0|    95.0|
        |  5|    lisi|  English|    70.0|    75.0|
        |  6|    lisi|  Chinese|    74.0|    67.0|
        |  7|    lisi|  History|    75.0|    80.0|
        |  8|    lisi|Chemistry|    77.0|    70.0|
        |  9|    lisi|  Physics|    79.0|    80.0|
        | 10|    lisi|  Biology|    82.0|    83.0|
        | 11|  wangwu|  English|    96.0|    84.0|
        | 12|  wangwu|  Chinese|    98.0|    64.0|
        | 13|  wangwu|  History|    91.0|    92.0|
        | 14| zhaoliu|  English|    68.0|    80.0|
        | 15| zhaoliu|  Chinese|    66.0|    69.0|
        +---+--------+---------+--------+--------+

        +--------+-----------+
        |    name|max_achieve|
        +--------+-----------+
        |  wangwu|       98.0|
        | zhaoliu|       80.0|
        |zhangsan|       96.0|
        |    lisi|       83.0|
        +--------+-----------+

        +--------+-----------+
        |    name|max_achieve|
        +--------+-----------+
        |  wangwu|       98.0|
        | zhaoliu|       80.0|
        |zhangsan|       96.0|
        |    lisi|       83.0|
        +--------+-----------+

Spark编写Agg函数

实现一个avg函数:

第一步:定义一个Average,用来存储count,sum;

import java.io.Serializable;

public class Average implements Serializable {
    private long sum;
    private long count;

    // Constructors, getters, setters...
    public long getSum() {
        return sum;
    }

    public void setSum(long sum) {
        this.sum = sum;
    }

    public long getCount() {
        return count;
    }

    public void setCount(long count) {
        this.count = count;
    }

    public Average() {

    }

    public Average(long sum, long count) {
        this.sum = sum;
        this.count = count;
    }
}
View Code

第二步:定义一个Employee,存储员工信息:员工名称、员工薪资;

import java.io.Serializable;

public class Employee implements Serializable {
    private String name;
    private long salary;

    // Constructors, getters, setters...
    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public long getSalary() {
        return salary;
    }

    public void setSalary(long salary) {
        this.salary = salary;
    }

    public Employee() {
    }

    public Employee(String name, long salary) {
        this.name = name;
        this.salary = salary;
    }
}
View Code

第三步:定义一个Agg,实现对员工的薪资avg功能;

import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;

public class MyAverage extends Aggregator<Employee, Average, Double> {
    // A zero value for this aggregation. Should satisfy the property that any b + zero = b
    @Override
    public Average zero() {
        return new Average(0L, 0L);
    }

    // Combine two values to produce a new value. For performance, the function may modify `buffer`
    // and return it instead of constructing a new object
    @Override
    public Average reduce(Average buffer, Employee employee) {
        long newSum = buffer.getSum() + employee.getSalary();
        long newCount = buffer.getCount() + 1;
        buffer.setSum(newSum);
        buffer.setCount(newCount);
        return buffer;
    }

    // Merge two intermediate values
    @Override
    public Average merge(Average b1, Average b2) {
        long mergedSum = b1.getSum() + b2.getSum();
        long mergedCount = b1.getCount() + b2.getCount();
        b1.setSum(mergedSum);
        b1.setCount(mergedCount);
        return b1;
    }

    // Transform the output of the reduction
    @Override
    public Double finish(Average reduction) {
        return ((double) reduction.getSum()) / reduction.getCount();
    }

    // Specifies the Encoder for the intermediate value type
    @Override
    public Encoder<Average> bufferEncoder() {
        return Encoders.bean(Average.class);
    }

    // Specifies the Encoder for the final output value type
    @Override
    public Encoder<Double> outputEncoder() {
        return Encoders.DOUBLE();
    }
}

第四步:spark调用agg,验证。

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.*;;

import java.util.ArrayList;
import java.util.List;

public class SparkClient {
    public static void main(String[] args) {
        final SparkSession spark = SparkSession.builder().master("local[*]").appName("test_agg").getOrCreate();
        final JavaSparkContext ctx = JavaSparkContext.fromSparkContext(spark.sparkContext());

        List<Employee> employeeList = new ArrayList<Employee>();
        employeeList.add(new Employee("Michael", 3000L));
        employeeList.add(new Employee("Andy", 4500L));
        employeeList.add(new Employee("Justin", 3500L));
        employeeList.add(new Employee("Berta", 4000L));

        JavaRDD<Employee> rows = ctx.parallelize(employeeList);
        Dataset<Employee> ds = spark.createDataFrame(rows, Employee.class).map(new MapFunction<Row, Employee>() {
            @Override
            public Employee call(Row row) throws Exception {
                return new Employee(row.getString(0), row.getLong(1));
            }
        }, Encoders.bean(Employee.class));

        ds.show();
        // +-------+------+
        // |   name|salary|
        // +-------+------+
        // |Michael|  3000|
        // |   Andy|  4500|
        // | Justin|  3500|
        // |  Berta|  4000|
        // +-------+------+

        MyAverage myAverage = new MyAverage();
        // Convert the function to a `TypedColumn` and give it a name
        TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
        Dataset<Double> result = ds.select(averageSalary);
        result.show();
        // +--------------+
        // |average_salary|
        // +--------------+
        // |        3750.0|
        // +--------------+
    }
}

输出:

+-------+------+
|   name|salary|
+-------+------+
|Michael|  3000|
|   Andy|  4500|
| Justin|  3500|
|  Berta|  4000|
+-------+------+

+--------------+
|average_salary|
+--------------+
|        3750.0|
+--------------+

参考:

https://www.cnblogs.com/LHWorldBlog/p/8432210.html

https://blog.csdn.net/kwu_ganymede/article/details/50462020

https://my.oschina.net/cloudcoder/blog/640009

https://blog.csdn.net/xgjianstart/article/details/54956413

原文地址:https://www.cnblogs.com/yy3b2007com/p/9294345.html