hive自定义分段函数(分箱)

分段函数常用于分箱中,统计分组在指定的区间中的占比。

比如有如下例子:统计某个班级中考试分数在各个阶段的占比。

准备的数据如下:

使用如下文件在hive中建表。

class1,1,100
class1,2,88
class1,3,90
class1,4,23
class1,5,30
class1,6,55
class1,7,66
class1,8,99
class1,9,56
class1,10,34

这时候使用case when来计算每行记录分别在哪个区间如下:

with tmp_a as(
select 
clazz,name,
case when score <30 then '[0,30)'
when score <60 then '[30,60)'
when score < 80 then '[60,80)'
when score <= 100 then '[80,100]'
else 'none' end bins
from dt_dwd.score
)
select clazz,bins,count(1)/sum(count(1)) over (partition by clazz) as rate,count(1)
from tmp_a group by clazz,bins; 

最后是统计结果如下:

上述就是通常的分箱占比操作例子。


现在我有多组标签需要监控,每次写case when的,这里面的分段非常多,于是想到用hive udf来简化写法。

先看已经完成的自定义函数default.piecewise的sql写法如下:

select clazz,name,default.piecewise('[0,30)|[30,60)|[60,80)|[80,100]',score) as bins
from dt_dwd.score

  

完整的sql如下:

with tmp_a as(
	select 
	clazz,name,default.piecewise('[0,30)|[30,60)|[60,80)|[80,100]',score) as bins
	from dt_dwd.score
)
select clazz,bins,count(1)/sum(count(1)) over (partition by clazz) as rate,count(1)
from tmp_a group by clazz,bins;

这样我们可以将分箱抽象到变量中,在当做参数传入,就不要每次写很大段的case when了。

 default.piecewise的完整写法如下:

package com.demo.udf;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaConstantStringObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.Text;

import java.util.Objects;

/**
 * @Author: KingWang
 * @Date: 2021/9/21
 * @Desc: 自定义分段函数
 *   传入参数1:none|[0,30]|(30,60]|(60,90]|(90,+]
 *   传入参数2:值
 *   返回:根据参数2的值,判断在参数1的区间,返回参数1的区间值
 *   如: 参数2:45, 则返回(30,60]
 **/
@Description(name = "default.piecewise", value = "_FUNC_(piecewise, value) - Returns piecewise if the value mapped.", extended = "Example:
  > SELECT _FUNC_('[0,30)|[30,60)|(60,100]', 88) FROM table limit 1;
  '(60,100]'")
public class Piecewise extends GenericUDF {

    private transient StringObjectInspector piecewiseOI;
    private transient StringObjectInspector valOI;



    @Override
    public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
        if (arguments.length != 2) {
            throw new UDFArgumentException("The function piecewiseUDF accepts 2 arguments.");
        }
        if(null == arguments[0]){
            throw new UDFArgumentException("first arguments can not be null.");
        }

        this.piecewiseOI = (StringObjectInspector) arguments[0];
        this.valOI = (StringObjectInspector) (null == arguments[1] ? new JavaConstantStringObjectInspector("")  : arguments[1]);
        return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
    }

    @Override
    public Object evaluate(DeferredObject[] deferredObjects) throws HiveException {

        String piecewise = piecewiseOI.getPrimitiveJavaObject(deferredObjects[0].get());
        String val = valOI.getPrimitiveJavaObject(null != deferredObjects[1] ? deferredObjects[1].get():"");

        String[] list = piecewise.split("\|");
        if(Objects.isNull(val)||"".equals(val)){
            return new Text("none");
        }
        boolean match = false;
        for(String str:list){
            if(str.indexOf(",")>0){
                try{
                    double value = Double.valueOf(val);
                    if(str.startsWith("(-,") || str.startsWith("[-,")){
                        match = minusAndValue(str,value);
                    }else if(str.endsWith(",+)") || str.endsWith(",+]")){
                        match = valueAndPlus(str,value);
                    }else{
                        match = valueAndValue(str,value);
                    }
                }catch (NumberFormatException e){

                }
            } else {
                if(str.equalsIgnoreCase(val)){
                    match = true;
                }
            }
            if(match) return new Text(str);
        }

        //未匹配上的返回到ERROR分组中
        return new Text("ERROR");
    }

    @Override
    public String getDisplayString(String[] strings) {
        return strings[0];
    }



    /**
     * 表达式类似于(-,60)或者[-,60)或者(-,60]或者[-,60]
     * @param express
     * @return
     */
    public static boolean minusAndValue(String express,Double value){
        boolean is_match = false;
        String endStr = express.split(",")[1];
        double end = Double.valueOf(endStr.substring(0,endStr.length()-1));
        if(express.endsWith(")")){
            if( value < end ){
                is_match = true;
            }
        }else if(express.endsWith("]")){
            if( value <= end){
                is_match = true;
            }
        }
        return is_match;
    }

    /**
     * 表达式类似于(80,+)或者[80,+)或者(80,+]或者[80,+]
     * @param express
     * @return
     */
    public static boolean valueAndPlus(String express,Double value){
        boolean is_match = false;
        String beginStr = express.split(",")[0];
        double begin = Double.valueOf(beginStr.substring(1));
        if(express.startsWith("(")){
            if( value > begin ){
                is_match = true;
            }
        }else if(express.startsWith("[")){
            if(value >= begin){
                is_match = true;
            }
        }
        return is_match;
    }

    /**
     * 表达式类似于(60,80)或者(60,80]或者[60,80)或者[60,80]
     * @param express
     * @return
     */
    public static boolean valueAndValue(String express,Double value){
        boolean is_match = false;
        String beginStr = express.split(",")[0];
        String endStr = express.split(",")[1];

        double begin = Double.valueOf(beginStr.substring(1));
        double end = Double.valueOf(endStr.substring(0,endStr.length()-1));
        if(express.startsWith("(") && express.endsWith(")")){
            if( value> begin && value < end){
                is_match = true;
            }
        }else if(express.startsWith("[") && express.endsWith("]")){
            if(value >= begin && value <= end){
                is_match = true;
            }
        }else if(express.startsWith("(") && express.endsWith("]")){
            if(value > begin && value <= end){
                is_match = true;
            }
        }else if(express.startsWith("[") && express.endsWith(")")){
            if(value >= begin && value < end){
                is_match = true;
            }
        }
        return is_match;
    }
}

实际使用中,shell脚本中定义分箱变量,通过参数传递给scrip.sql脚本

使用注意事项: 

当传入的值为null时,会报异常,需要使用nvl(nullfileld,'') 或者nvl(nulllfiled,'none')来处理,其结果将默认分配在none分段中。 

 然后在script.sql脚本中可以接收传入的分箱变量来灵活使用。


2021.09.26 优化:

将未匹配的行,使用ERROR分箱表示,因为执行程序过程中如果发现有不在给定的分箱里面的,会报错,到时候排错特别困难。

这样直接给个带ERROR的分箱组,可以很直接的在结果数据中可以观察到。

  

  

作者:尤灯塔
本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利.
原文地址:https://www.cnblogs.com/30go/p/15321847.html