Spark开发-Spark中类型安全UDAF开发示例

Spark开发UDAF

 通过对源码中的示例代码进行实际演练,对各个功能进行了解,以及排除开发中的错误
  System.out.println(); 在UDAF中可以用来辅助一些判断

开发示例代码

`
import org.apache.spark.sql.*;
import org.apache.spark.sql.expressions.Aggregator;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
public class MeanTypeUDAF  implements Serializable{
/**
 * 输入数据类型 IN:输入数据类型
 */
public static class MyEmployee implements Serializable {
    private String name;
    private long salary;
    /**
     * 类中添加了一个无参数的构造函数,问题解决
     * 数据类型 long 和 Long
     */
    public   MyEmployee(){}

    private   MyEmployee(String name, long salary){
        this.name = name;
        this.salary = salary;
    }
    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public long getSalary() {
        return salary;
    }

    public void setSalary(long salary) {
        this.salary = salary;
    }

}

/**
 * 输出数据类型  OUT:输出数据类型
 */
public static class AverageBuffer implements Serializable {
    private long sum;
    private long count;
    /**
     * 类中添加了一个无参数的构造函数,问题解决
     * 数据类型 long 和 Long
     */
    public  AverageBuffer(){ }
    private AverageBuffer(long sum, long count){
        this.sum = sum;
        this.count = count;
    }

    public long getSum() {
        return sum;
    }
    public long getCount() {
        return count;
    }
    public void setSum(long sum) {
        this.sum = sum;
    }
    public void setCount(long count) {
        this.count = count;
    }
}

/**
 * abstract class Aggregator[-IN, BUF, OUT] extends Serializable
 *     IN:输入数据类型
 *    BUF:缓冲区数据类型
 *    OUT:输出数据类型
 */
public static class MyAverage extends Aggregator<MyEmployee, AverageBuffer , Double>  {
    /**
     * 中间结构的输入数据结构 Encoder.bean bufferEncoder: Encoder[BUF]
     */
    @Override
    public Encoder<AverageBuffer> bufferEncoder() {
        return Encoders.bean(AverageBuffer.class);
    }

    /**
     * 聚合函数的输出数据结构 Encoders.DOUBLE()
     */
    @Override
    public Encoder<Double> outputEncoder() {
        return Encoders.DOUBLE();
    }

    /**
     * aggregation 初始化  b + zero = b
     * 初始化缓冲区
     * zero: BUF
     */
    @Override
    public AverageBuffer zero() {
        return new AverageBuffer(0L, 0L);
    }

    /**
     *  给聚合函数传入一条新数据进行处理
     *  buffer里面存放着累计的执行结果,input是当前的执行结果
     *  reduce(b: BUF, a: IN): BUF
     */
    @Override
    public AverageBuffer reduce(AverageBuffer buffer, MyEmployee employee) {
        long newSum = buffer.getSum() + employee.getSalary();
        long newCount = buffer.getCount() + 1;
        buffer.setSum(newSum);
        buffer.setCount(newCount);
        return buffer;
    }

    /**
     *  合并聚合函数缓冲区-全局聚合 merge(b1: BUF, b2: BUF): BUF
     */
    @Override
    public AverageBuffer merge(AverageBuffer b1, AverageBuffer b2) {
        long mergedSum = b1.getSum() + b2.getSum();
        long mergedCount = b1.getCount() + b2.getCount();
        b1.setSum(mergedSum);
        b1.setCount(mergedCount);
        return b1;
    }

    /**
     * 计算最终结果 finish(reduction: BUF): OUT
     */
    @Override
    public Double finish(AverageBuffer reduction) {
        return ((double) reduction.getSum()) / reduction.getCount();
    }
}
public static void main(String[] args) {
    SparkSession spark = SparkSession
            .builder()
            .appName("Java Spark SQL data sources example")
            .config("spark.some.config.option", "some-value")
            .master("local[2]")
            .getOrCreate();
    // Create an instance of a Bean class
    List<MyEmployee> Da = Arrays.asList(
            new MyEmployee("CFF",30L),
            new MyEmployee("CFAF",50L),
            new MyEmployee("ADD",10L)
    );
    Encoder<MyEmployee> personEncoder = Encoders.bean(MyEmployee.class);
    Dataset<MyEmployee> itemsDataset = spark.createDataset( Da, personEncoder);
    itemsDataset.printSchema();
    itemsDataset.show();
    System.out.println(itemsDataset.head().getName());
    System.out.println(itemsDataset.head().getSalary());
    MyAverage myAverage = new MyAverage();
    System.out.println("############");
  // Convert the function to a `TypedColumn` and give it a name
    //使用TypedColumn,目的是为了能在Dataset中使用
    TypedColumn<MyEmployee, Double> averageSalary = myAverage.toColumn().name("average_salary");
    itemsDataset.printSchema();
    Dataset<Double> result = itemsDataset.select(averageSalary);
    result.show();
}
}`

说明

使用UDAF的类型安全的示例,同时也是对Dataset中Bean的数据来源做个简单的使用
报错: Caused by: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 24, Column 87: 
 failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 24, Column 87:
  No applicable constructor/method found for zero actual parameters;  candidates are: 

参考

  http://spark.apache.org/docs/latest/sql-ref-functions-udf-aggregate.html
原文地址:https://www.cnblogs.com/ytwang/p/14007331.html