Spark SQL UDAF示例

UDAF:用户自定义聚合函数

Scala 2.10.7,spark 2.0.0

package UDF_UDAF

import java.util

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, RowFactory, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}

class UDAF extends UserDefinedAggregateFunction {
  /**
    * 指定输入字段的字段及类型
    */
  override def inputSchema: StructType =
    DataTypes.createStructType(Array(DataTypes.createStructField("namexxx",DataTypes.StringType,true)))

  /**
    * 在进行聚合操作的时候所要处理的数据的结果的类型
    * */
  override def bufferSchema: StructType =
    DataTypes.createStructType(Array(DataTypes.createStructField("buffer",DataTypes.IntegerType,true)))

  /**
    * 指定UDAF计算后返回的结果类型
    * @return
    */
  override def dataType: DataType = DataTypes.IntegerType

  /**
    * 确保一致性 一般用true,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果。
    */
  override def deterministic: Boolean = true

  /**
    * 初始化一个内部的自己定义的值,在Aggregate之前每组数据的初始化结果
    */
  override def initialize(buffer: MutableAggregationBuffer): Unit = buffer.update(0,0)

  /**
    * 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
    * buffer.getInt(0)获取的是上一次聚合后的值
    * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
    * 大聚和发生在reduce端.
    * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
    */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = buffer.update(0, buffer.getInt(0)+1)

  /**
    * 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
    * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
    * buffer1.getInt(0) : 大聚合的时候 上一次聚合后的值
    * buffer2.getInt(0) : 这次计算传入进来的update的结果
    * 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
    */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = buffer1.update(0, buffer1.getInt(0)+buffer2.getInt(0))

  /**
    * 最后返回一个和dataType方法的类型要一致的类型,返回UDAF最后的计算结果
    */
  override def evaluate(buffer: Row): Any = buffer.getInt(0)
}

object UDAF{
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local").setAppName("udaf")
    val sparkSession = SparkSession.builder().config(conf).config("spark.sql.warehouse.dir","/test/warehouse").getOrCreate()
    val sc = sparkSession.sparkContext

    val parallelize = sc.parallelize(Array("zhangsan","lisi","wanger","zhaosi","zhangsan","lisi"))
    val rowRDD = parallelize.map(s=>RowFactory.create(s))

    val fields = new util.ArrayList[StructField]()
    fields.add(DataTypes.createStructField("name",DataTypes.StringType,true))
    val schema = DataTypes.createStructType(fields)

    val df = sparkSession.createDataFrame(rowRDD, schema)
    df.createOrReplaceTempView("user")

    sparkSession.udf.register("StringCount",new UDAF())

    sparkSession.sql("select name, StringCount(name) as StrCount from user group by name").show()

    sparkSession.stop()

  }
}

原文地址:https://www.cnblogs.com/144823836yj/p/10769222.html