[Spark]-结构化数据查询之自定义UDAF

1.自定义弱类型UDAF

  1.1 弱类型UDAF定义

    弱类型UDAF继承实现 UserDefinedAggregateFunction 抽象类

    override def inputSchema: StructType = 输入schema

    override def bufferSchema: StructType = 聚合过程schema

    override def dataType: DataType = 返回值类型

    override def deterministic: Boolean = 是否固定返回值类型

    override def initialize(buffer: MutableAggregationBuffer): Unit = 初始化函数,用来初始化基准值

    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = 分区内元素如何聚合

    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = 分区之间如何聚合

    override def evaluate(buffer: Row): Any = 聚合结果计算

    整个UDAF处理过程,非常类似RDD的aggregate算子

      aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U

    一个自定义求平均数UDAF例子

            object UDAFApp extends App{
            
            val spark = SparkSession.builder().master("local[2]").appName("UDAP-App").getOrCreate();
            import  spark.implicits._;
            val df = spark.read.format("json").load("D:\data\employees.json")
            
            //UDAF函数注册 只有UserDefinedAggregateFunction才能为SQL注册函数
            spark.udf.register("cusAvg",MyAvgUDAF)
            //DF转临时视图
            df.createTempView("employees_view")
            spark.sql("select cusAvg(salary) as salary from employees_view").show();
            
            //df-api形式
            df.select(MyAvgUDAF.apply($"salary")).show()
            spark.close()
            
            }
            
            object MyAvgUDAF extends UserDefinedAggregateFunction
            {
            //输入schema
            override def inputSchema: StructType = StructType(StructField("input",DoubleType)::Nil);
            //聚合过程schema
            override def bufferSchema: StructType = StructType(StructField("Sum",DoubleType)::StructField("Count",LongType)::Nil)
            //返回值类型
            override def dataType: DataType = DoubleType
            
            //是否固定返回值类型
            override def deterministic: Boolean = true
            
            //初始化函数
            override def initialize(buffer: MutableAggregationBuffer): Unit = {
                //设定聚合基准初始值 aggregate算子((0,0))的部分
                buffer(0) = 0D; //总和0
                buffer(1) = 0L; //个数0
            }
            
            override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
                //行第一列(Row[0])是否为null
                if(!input.isNullAt(0)){
                //aggregate算子....(seqOp: (U, T) => U 部分
                buffer(0)= buffer.getDouble(0)+ input.getDouble(0);
                buffer(1) =buffer.getLong(1)+1;
                }
            }
            
            override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
                //aggregate算子....combOp: (U, U) => U 部分
                buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0);
                buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1);
            }
            
            override def evaluate(buffer: Row): Any = buffer.getDouble(0) /buffer.getLong(1) ;
            }

2.自定义强类型UDAF

  自定义强类型UDAF 基础实现类 Aggregator 

  所以这种定义方式不能在UDF中注册,也不能用在SQL中

  一个强类型UDAF定义如下:    

            object UDAFApp extends App{
            
            val spark = SparkSession.builder().master("local[2]").appName("UDAP-App").getOrCreate();
            import  spark.implicits._;
            val ds = spark.read.format("json").load("D:\data\employees.json").as[Employee]
            
            //ds-api形式
            ds.select(MyAverage.toColumn.name("salary")).show()
            spark.close()
            
            }
            
            //目标类型定义
            case class Employee(val name: String,val salary: Long)
            //聚合类型定义
            case class Average(var sum: Long, var count: Long)
            object MyAverage extends Aggregator[Employee, Average, Double]  {
            override def zero: Average = Average(0,0)
            
            override def reduce(b: Average, a: Employee): Average = {
                b.sum += a.salary;
                b.count +=  1
                b
            }
            
            override def merge(b1: Average, b2: Average): Average = {
                b1.sum += b2.sum;
                b1.count += b2.count;
                b1;
            }
            
            override def finish(reduction: Average): Double = {
                println(reduction.sum + "  "+ reduction.count)
                reduction.sum.toDouble/reduction.count
            }
            
            override def bufferEncoder: Encoder[Average] = Encoders.product
            
            override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
            }
原文地址:https://www.cnblogs.com/NightPxy/p/9269171.html