package com.XX.udf; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.LongWritable; public class UDAFTest extends AbstractGenericUDAFResolver{ //判断 @Override public GenericUDAFEvaluator getEvaluator(TypeInfo[] info)//字段的描述信息参数parameters throws SemanticException { if(info.length !=2){ throw new UDFArgumentTypeException(info.length-1, "Exactly two argument is expected."); } //返回处理逻辑的类 return new GenericEvaluate(); } public static class GenericEvaluate extends GenericUDAFEvaluator{ private LongWritable result; private PrimitiveObjectInspector inputIO1; private PrimitiveObjectInspector inputIO2; //这个方法map与reduce阶段都需要执行 /** * map阶段:parameters长度与udaf输入的参数个数有关 * reduce阶段:parameters长度为1 */ //初始化 @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { super.init(m, parameters); //返回最终的结果 result = new LongWritable(0); inputIO1 = (PrimitiveObjectInspector) parameters[0]; if (parameters.length>1) { inputIO2 = (PrimitiveObjectInspector) parameters[1]; } return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector; } //map阶段 iterate函数处理读入的行数据 @Override public void iterate(AggregationBuffer agg, Object[] parameters)//agg缓存结果值 throws HiveException { assert(parameters.length==2); if(parameters==null || parameters[0]==null || parameters[1]==null){ return; } double base = PrimitiveObjectInspectorUtils.getDouble(parameters[0], inputIO1); double tmp = PrimitiveObjectInspectorUtils.getDouble(parameters[1], inputIO2); if(base > tmp){ ((CountAgg)agg).count++; } } //获得一个聚合的缓冲对象,每个map执行一次 @Override public AggregationBuffer getNewAggregationBuffer() throws HiveException { CountAgg agg = new CountAgg(); reset(agg); return agg; } //自定义类用于计数 public static class CountAgg implements AggregationBuffer{ long count;//计数,保存每次临时的结果 } //重置 @Override public void reset(AggregationBuffer countagg) throws HiveException { CountAgg agg = (CountAgg)countagg; agg.count=0; } //该方法当做iterate执行后,部分结果返回。 terminatePartial 返回iterate处理的中间结果 @Override public Object terminatePartial(AggregationBuffer agg) throws HiveException { result.set(((CountAgg)agg).count); return result; } @Override //合并处理结果 public void merge(AggregationBuffer agg, Object partial) throws HiveException { if(partial != null){ long p = PrimitiveObjectInspectorUtils.getLong(partial, inputIO1); ((CountAgg)agg).count += p; } } @Override //返回最终值 public Object terminate(AggregationBuffer agg) throws HiveException { result.set(((CountAgg)agg).count); return result; } } }