Spark SQL UDF 函数(四)

Spark 中,也支持Hive中的自定义函数。自定义函数大致可以分为三种:

  • UDF(User-Defined-Function):即最基本的自定义函数,类似 to_char,to_date
  • UDAF(User- Defined Aggregation Funcation):用户自定义聚合函数,类似在group by之后使用的sum,avg
  • UDTF(User-Defined Table-Generating Functions):用户自定义生成函数,有点像stream里面的flatMap

1. 初步使用 UDF 函数


scala> val df = spark.read.json("hdfs://hadoop1:9000/people.json")
df: org.apache.spark.sql.DataFrame = [age: bigint, name: string]

// 注册使用,toUpper 为函数名称
scala> spark.udf.register("toUpper", (s: String) => s.toUpperCase)
res15: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))

scala> df.createOrReplaceTempView("people")

scala> spark.sql("select toUpper(name), age from people").show
+-----------------+----+
|UDF:toUpper(name)| age|
+-----------------+----+
|          MICHAEL|null|
|             ANDY|  30|
|           JUSTIN|  19|
+-----------------+----+

2. 自定义UDAF 聚合函数

package top.midworld.spark1031.create_df

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}

// 样例类
case class UserInfo(name: String, age: Double)

object UDF1 {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder.appName("udf").master("local[2]").getOrCreate()
    val sc = spark.sparkContext

    import spark.implicits._

    val rdd = sc.textFile("hdfs://hadoop1:9000/people.txt").
      map(_.split(",")).
      map(x => UserInfo(x(0), x(1).trim.toDouble))

    val df = rdd.toDF()
    df.createOrReplaceTempView("user")
	
      // 注册 udf 函数
    spark.udf.register("mySum", new MySum)

    spark.sql("select mySum(age) as age_sum from user").show()

    df.show()
    sc.stop()
    spark.stop()


  }
}

class MySum extends UserDefinedAggregateFunction {
  // 输入的数据类型:29/30/19
  override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)

  // 缓冲区的类型
  override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: Nil)

  // 最终聚合解结果的类型
  override def dataType: DataType = DoubleType

  // 相同的输入是否返回相同的输出,始终为 true
  override def deterministic: Boolean = true

  // 对缓冲区初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    println("initialize===>" + buffer)    // initialize===>[null]
    // 对缓冲区集合初始化和
    buffer(0) = 0D
  }

  // 分区内聚合
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    println("update===>" + buffer)
    println("input===>" + input)
    /*
    update===>[0.0]
    update===>[0.0]
    input===>[19.0]
    input===>[29.0]
    update===>[29.0]
    input===>[30.0]
     */

    // 模式匹配输入数据类型
    input match {
      // double 类型
      case Row(age: Double) =>
        buffer(0) = buffer.getDouble(0) + age

      // 其他类型
      case _ =>
    }
  }

  // 分区间聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    println("merge buffer1 ===>" + buffer1)
    println("merge buffer2 ===>" + buffer2)
    /*
    merge buffer1 ===>[0.0]
    merge buffer2 ===>[59.0]
    merge buffer1 ===>[59.0]
    merge buffer2 ===>[19.0]
     */

    // buffer1 + buffer2
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
  }

  // 返回最终的输出值
  override def evaluate(buffer: Row): Any = buffer.getDouble(0)
}

运行结果:

+-------+
|age_sum|
+-------+
|   78.0|
+-------+

+-------+----+
|   name| age|
+-------+----+
|Michael|29.0|
|   Andy|30.0|
| Justin|19.0|
+-------+----+

求平均值

class MyAvg extends UserDefinedAggregateFunction {
  // 输入的数据类型:29/30/19
  override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)

  // 缓冲区的类型
  override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: StructField("count", LongType) :: Nil)

  // 最终聚合解结果的类型
  override def dataType: DataType = DoubleType

  // 相同的输入是否返回相同的输出,始终为 true
  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0D
    buffer(1) = 0L
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    input match {
      case Row(age: Double) =>
        buffer(0) = buffer.getDouble(0) + age
        buffer(1) = buffer.getLong(1) + 1L
      case _ =>
    }
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer2 match {
      case Row(sum: Double, count: Long) =>
        buffer1(0) = buffer1.getDouble(0) + sum
        buffer1(1) = buffer2.getLong(1) + count
    }
  }

  override def evaluate(buffer: Row): Any = buffer.getDouble(0) / buffer.getLong(1)
}

3. 开窗函数

https://blog.csdn.net/sunxiaoju/article/details/103800028

https://blog.csdn.net/liangzelei/article/details/80608302?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_baidulandingword~default-4.no_search_link&spm=1001.2101.3001.4242.3

原文地址:https://www.cnblogs.com/midworld/p/15647008.html