使用AOP统一验签和校参

一、需求背景

  对外提供服务的接口需要统一做验签和参数合法性校验。每个接口的加签算法相同,不同的是参数的不为空的要求不同。

  要求,在controller层外做校验,校验不通过直接返回,不进入controller层。

二、需求实现前代码

 在这之前已经对每个请求做了AOP拦截,对每个请求植入了线程号。以及统计每个接口的执行耗时,打印每个接口的返回结果,捕获接口的未检查异常并打印和封装返回结果。

 如:

/**
 * 为每一个的HTTP请求添加线程号
 *
 * @author yangyongjie
 * @date 2019/9/2
 * @desc
 */
@Order(1)
@Aspect
@Component
public class LogAspect {

    private static final Logger LOGGER = LoggerFactory.getLogger(LogAspect.class);

    @Pointcut(value = "@annotation(org.springframework.web.bind.annotation.RequestMapping)")
    private void webPointcut() {
        // doNothing
    }

    /**
     * 为所有的HTTP请求添加线程号
     *
     * @param joinPoint
     * @throws Throwable
     */
    @Around(value = "webPointcut()")
    public Object around(ProceedingJoinPoint joinPoint) {
        // 执行开始的时间
        Long beginTime = System.currentTimeMillis();
        // 方法执行前加上线程号,并将线程号放到线程本地变量中
        MDCUtil.init();
        // 获取切点的方法名
        String methodName = joinPoint.getSignature().getName();
        // 执行拦截的方法
        Object result = null;
        try {
            result = joinPoint.proceed();
        } catch (Throwable throwable) {
            LOGGER.error("{}方法执行异常:" + throwable.getMessage(), methodName, throwable);
            LogUtil.sendErrorLogMail("系统异常", throwable);
            result = new CommonResult(ResponseEnum.ERROR_SYSTEM.getCode(), ResponseEnum.ERROR_SYSTEM.getMsg());
        } finally {
            LOGGER.info("{}方法返回结果:{}", methodName, JacksonJsonUtil.toString(result));
            Long endTime = System.currentTimeMillis();
            LOGGER.info("{}方法耗时{}毫秒", methodName, endTime - beginTime);
            // 方法执行结束移除线程号,并移除线程本地变量,防止内存泄漏
            MDCUtil.remove();
        }
        return result;
    }
}

@Order(1) :为多个AOP切面排序,数字越小,先执行谁。

MDCUtil:

/**
 * 日志相关工具类
 *
 * @author yangyongjie
 * @date 2019/9/17
 * @desc
 */
public class MDCUtil {
    private MDCUtil() {
    }

    private static final String STR_THREAD_ID = "threadId";

    /**
     * 初始化日志参数并保存在线程副本中
     */
    public static void init() {
        String uuid = UUID.randomUUID().toString().replaceAll("-", "");
        MDC.put(STR_THREAD_ID, uuid);
        ThreadContext.currentThreadContext().setThreadId(uuid);
    }

    /**
     * 初始化日志参数
     */
    public static void initWithOutContext() {
        String uuid = UUID.randomUUID().toString().replaceAll("-", "");
        MDC.put(STR_THREAD_ID, uuid);
    }

    /**
     * 移除线程号和线程副本
     */
    public static void remove() {
        MDC.remove(STR_THREAD_ID);
        ThreadContext.remove();
    }

    /**
     * 移除线程号
     */
    public static void removeWithOutContext() {
        MDC.remove(STR_THREAD_ID);
    }
}

线程上下文ThreadContext:

/**
 * 线程上下文,一个线程内所需的上下文变量参数,使用ThreadLocal保存副本
 *
 * @author yangyongjie
 * @date 2019/9/12
 * @desc
 */
public class ThreadContext {
    /**
     * 每个线程的私有变量,每个线程都有独立的变量副本,所以使用private static final修饰,因为都需要复制进入本地线程
     */
    private static final ThreadLocal<ThreadContext> THREAD_LOCAL = new ThreadLocal<ThreadContext>() {
        @Override
        protected ThreadContext initialValue() {
            return new ThreadContext();
        }
    };

    public static ThreadContext currentThreadContext() {
        /*ThreadContext threadContext = THREAD_LOCAL.get();
        if (threadContext == null) {
            THREAD_LOCAL.set(new ThreadContext());
            threadContext = THREAD_LOCAL.get();
        }
        return threadContext;*/
        return THREAD_LOCAL.get();
    }

    public static void remove() {
        THREAD_LOCAL.remove();
    }

    /**
     * 线程号
     */
    private String threadId;

    /**
     * 请求参数
     */
    private Object requestParam;

    public String getThreadId() {
        return threadId;
    }

    public void setThreadId(String threadId) {
        this.threadId = threadId;
    }

    public Object getRequestParam() {
        return requestParam;
    }

    public void setRequestParam(Object requestParam) {
        this.requestParam = requestParam;
    }

    @Override
    public String toString() {
        return JacksonJsonUtil.toString(this);
    }
}

 公共返回结果类:

/**
 * 用于返回给调用方执行结果的公共结果类
 * 自定义返回结果继承此类即可
 *
 * @author yangyongjie
 * @date 2019/9/25
 * @desc
 */
public class CommonResult {
    /**
     * 返回码,0000表示成功,其余都是失败,9998表示入参不符合要求,9999表示系统异常
     */
    private String code = "0000";
    /**
     * 返回信息
     */
    private String msg = "success";

    public CommonResult() {
    }


    public CommonResult(String code, String msg) {
        this.code = code;
        this.msg = msg;
    }

    /**
     * 失败情况
     */
    public void fail(String code, String msg) {
        this.code = code;
        this.msg = msg;
    }

    /**
     * 判断是否成功
     */
    @JsonIgnore
    public boolean isSuccess() {
        return StringUtils.equals("0000", code);
    }

    public String getCode() {
        return code;
    }

    public void setCode(String code) {
        this.code = code;
    }

    public String getMsg() {
        return msg;
    }

    public void setMsg(String msg) {
        this.msg = msg;
    }
}

三、需求具体实现

  1、现在需要再增加一个切面,对需要做验签和参数校验的接口拦截并校验

    1)自定义注解,作用在controller层的方法上,标识此接口需要验签和验参,其有两个属性,一个是方法返回类型,一个是接收参数的实体类。

    方法返回类型用来切面校验不通过封装返回数据,接收参数的实体类对需要验不为空的方法标志了注解,需在切面中进行校验。

/**
 * 对外请求参数校验注解
 *
 * @author yangyongjie
 * @date 2019/11/5
 * @desc
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Check {

    /**
     * 方法的返回值类型,继承了CommonResult
     */
    Class<? extends CommonResult> value();

    /**
     * 校验的目标实体类
     */
    Class<?> paramBean();

}

如接收参数的实体类定义:

public class AuthTokenRequest extends BaseRequest {

    /**
     * 值为authorization_code
     */
    @ParamVerify(nullable = CheckEnum.NOTNULL)
    private String grant_type;
}

public class BaseRequest {
    /**
     * 签名
     */
    @ParamVerify(nullable = CheckEnum.NOTNULL)
    private String sign;

    /**
     * 分配的接入id
     */
    @ParamVerify(nullable = CheckEnum.NOTNULL)
    private String partnerId;
}

属性校验注解:

/**
 * 字段校验注解,目前只进行非空校验,可扩展
 */
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface ParamVerify {
    /**
     * 是否允许为空
     */
    CheckEnum nullable() default CheckEnum.NULL;
}

验签验参切面:

/**
 * 对外同步接口参数校验切面
 *
 * @author yangyongjie
 * @date 2019/11/5
 * @desc
 */
@Order(2)
@Aspect
@Component
public class CheckAspect {

    private static final Logger LOGGER = LoggerFactory.getLogger(CheckAspect.class);

    /**
     * 验签公钥
     */
    @Value("${fx.publicKey}")
    private String fxPublicKey;

    @Autowired
    private OutgoingPartnerInfoDao outgoingPartnerInfoDao;

    @Pointcut("@annotation(com.xiaomi.mitv.outgoing.common.annotation.Check)")
    private void webPointcut() {
        // donothing
    }

    @Around(value = "webPointcut()")
    public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
        // 获取被增强的方法的相关信息
        MethodSignature ms = (MethodSignature) joinPoint.getSignature();
        // 获取被增强的方法
        Method pointcutMethod = ms.getMethod();
        String methodName = pointcutMethod.getName();
        // 对于对外接口,统一进行参数校验
        CommonResult commonResult = null;
        // 判断方法上有没有@Check注解
        if (pointcutMethod.isAnnotationPresent(Check.class)) {
            // 获取到拦截方法的HttpServletRequest
            // 获取当前方法执行的上下文的request
            HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
            // 获取body请求参数
            String bodyString = HttpUtil.getRequestBody(request);
//            Map<String, Object> originMap = JacksonJsonUtil.toObject(bodyString, Map.class);
            Map<String, Object> originMap = HttpUtil.fromJsonToObject(bodyString, Map.class);
            // 将请求参数放到线程本地拷贝中
            ThreadContext.currentThreadContext().setRequestParam(originMap);

            // 得到方法上的Check注解
            Check check = pointcutMethod.getAnnotation(Check.class);
            // 获取切点方法的返回类型
            Class<?> returnType = check.value();
            // 创建对象
            commonResult = (CommonResult) returnType.newInstance();
            // 获取参数签名
            String sign = request.getParameter("sign");
            LOGGER.info("{}-sign={}", methodName, sign);

            // 参数校验
            Class<?> beanType = check.paramBean();
            originMap.put("sign", sign);
            if(!HttpUtil.paramCheck(originMap, beanType)){
                commonResult.fail(ResponseEnum.ERROR_PARAM.getCode(), ResponseEnum.ERROR_PARAM.getMsg());
                return commonResult;
            }

            String partnerId = String.valueOf(originMap.get("partnerId"));
            if (!StringUtil.areNotEmpty(partnerId, sign)) {
                commonResult.fail(ResponseEnum.ERROR_PARAM_NULL.getCode(), ResponseEnum.ERROR_PARAM_NULL.getMsg());
                return commonResult;
            }
            // 校验partnerId的有效性
            if (!checkPartnerId(partnerId)) {
                commonResult.fail(ResponseEnum.ERROR_APP_INVALID.getCode(), ResponseEnum.ERROR_APP_INVALID.getMsg());
                return commonResult;
            }
            // 组装加签串
            String paramBody = HttpUtil.getAssembleParam(originMap);
            // 验签
            boolean pass;
            try {
                pass = RSAUtil.rsa256CheckContent(paramBody, sign, fxPublicKey);
            } catch (BssException e) {
                LogUtil.LogAndMail("验签异常", e);
                commonResult.fail(ResponseEnum.ERROR_SYSTEM.getCode(), ResponseEnum.ERROR_SYSTEM.getMsg());
                return commonResult;
            }
            if (!pass) {
                commonResult.fail(ResponseEnum.ERROR_CHECK_SIGN_FAIL.getCode(), ResponseEnum.ERROR_CHECK_SIGN_FAIL.getMsg());
                return commonResult;
            }
        }
        // 执行增强方法
        Object result = joinPoint.proceed();
        return result;
    }

    /**
     * 校验partnerId的有效性,先查缓存,缓存中没有的话再查询数据库,使用互斥锁
     *
     * @param partnerId
     * @return
     */
    private boolean checkPartnerId(String partnerId) {
        // 先查询缓存,值为1表示存在且有效,值为0表示存在但无效,值为null表示不存在
        String val = RedisUtil.get(CommonConstants.PARTNER_ID + partnerId);
        if (StringUtils.isEmpty(val)) {
            // 缓存中不存在,先拿到互斥锁,再查询数据库,并放进缓存中
            // 获取互斥锁
            String mutexKey = CommonConstants.NX_PARTNER_ID + partnerId;
            boolean flag = RedisUtil.setex(mutexKey, CommonConstants.STR_ONE, 60);
            // 拿到锁
            if (flag) {
                // 查询数据库
                OutgoingPartnerInfoDto partnerInfoDto = outgoingPartnerInfoDao.getByPartnerId(partnerId);
                if (partnerInfoDto != null && StringUtils.equals(CommonConstants.STR_ONE, partnerInfoDto.getStatus())) {
                    // partnerId 存在且有效
                    RedisUtil.set(CommonConstants.PARTNER_ID + partnerId, CommonConstants.STR_ONE);
                    // 删除锁
                    RedisUtil.del(mutexKey);
                    return true;
                } else {
                    // partnerId 不存在或无效
                    RedisUtil.set(CommonConstants.PARTNER_ID + partnerId, CommonConstants.STR_ZERO);
                    return false;
                }
            } else {
                //休息50毫秒后重试
                try {
                    Thread.sleep(50);
                } catch (InterruptedException e) {
                    LOGGER.error("获取partnerId互斥锁异常" + e.getMessage(), e);
                }
                return checkPartnerId(partnerId);
            }
            // val 不为空
        } else {
            return StringUtils.equals(CommonConstants.STR_ONE, val);
        }
    }

}

HttpUtil工具类:

public class HttpUtil {

    private HttpUtil() {
    }

    private static final Logger LOGGER = LoggerFactory.getLogger(HttpUtil.class);

    /**
     * 获取request中的body信息 JSON格式
     *
     * @param request
     * @return
     */
    public static String getRequestBody(HttpServletRequest request) {
        BufferedReader br = null;
        StringBuilder bodyDataBuilder = new StringBuilder();
        try {
            br = request.getReader();
            String str;
            while ((str = br.readLine()) != null) {
                bodyDataBuilder.append(str);
            }
            br.close();
        } catch (IOException e) {
            LOGGER.error(e.getMessage(), e);
        } finally {
            if (null != br) {
                try {
                    br.close();
                } catch (IOException e) {
                    LOGGER.error(e.getMessage(), e);
                }
            }
        }
        String bodyString = bodyDataBuilder.toString();
        LOGGER.info("bodyString={}", bodyString);
        return bodyString;
    }

    /**
     * 获取request中的body信息,并组装好按“参数=参数值”的格式
     *
     * @param request
     * @return
     */
    public static String getAssembleRequestBody(HttpServletRequest request) {
        String bodyString = getRequestBody(request);
        Map<String, Object> originMap = JacksonJsonUtil.toObject(bodyString, Map.class);
        Map<String, Object> sortedParams = getSortedMap(originMap);
        String assembleBody = getSignContent(sortedParams);
        return assembleBody;
    }

    /**
     * 根据requestBody中的原始map获取解析后并组装的参数字符串,根据&符拼接
     *
     * @param originMap
     * @return
     */
    public static String getAssembleParam(Map<String, Object> originMap) {
        return getSignContent(getSortedMap(originMap));
    }


    /**
     * 将body转成按key首字母排好序
     *
     * @return
     */
    public static Map<String, Object> getSortedMap(Map<String, Object> originMap) {
        Map<String, Object> sortedParams = new TreeMap<String, Object>();
        if (originMap != null && originMap.size() > 0) {
            sortedParams.putAll(originMap);
        }
        return sortedParams;
    }

    /**
     * 将排序好的map的key和value拼接成字符串
     *
     * @param sortedParams
     * @return
     */
    public static String getSignContent(Map<String, Object> sortedParams) {
        StringBuffer content = new StringBuffer();
        List<String> keys = new ArrayList<String>(sortedParams.keySet());
        Collections.sort(keys);
        int index = 0;
        for (int i = 0; i < keys.size(); i++) {
            String key = keys.get(i);
            Object value = sortedParams.get(key);
            if (StringUtils.isNotEmpty(key) && value != null) {
                content.append((index == 0 ? "" : "&") + key + "=" + value);
                index++;
            }
        }
        return content.toString();
    }

    /**
     * Json转实体对象
     *
     * @param jsonStr
     * @param clazz 目标生成实体对象
     * @return
     */
    public static <T> T fromJsonToObject(String jsonStr, Class clazz) {
        T results = null;
        try {
            results = (T) JacksonJsonUtil.toObject(jsonStr, clazz);
        } catch (Exception e) {
        }
        return results;
    }

    /**
     * 对请求参数进行校验,目前只进行非空校验
     *
     * @param srcData body数据
     * @param tarClass 校验规则
     * @return 校验成功返回true
     */
    public static <T> boolean paramCheck(Map<String, Object> srcData, Class<T> tarClass){
        try {
            Field[] fields = tarClass.getDeclaredFields();
            for(Field field : fields){
                ParamVerify verify = field.getAnnotation(ParamVerify.class);
                if(verify != null){
                    //非空校验,后续若需增加校验类型,应抽离
                    if(verify.nullable() == CheckEnum.NOTNULL){
                        String fn = field.getName();
                        Object val = srcData.get(fn);
                        if(val == null || "".equals(val.toString())){
                            return false;
                        }
                    }
                }
            }
        }catch (Exception ex){
            LOGGER.info("Param verify error");
            return false;
        }
        return true;
    }

}

日志工具类:

 /**
     * 打印日志并发送错误邮件
     *
     * @param msg
     * @param t
     */
    public static void LogAndMail(String msg, Throwable t) {
        // 获取调用此工具类的该方法 的调用方信息
        // 查询当前线程的堆栈信息
        StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
        // 按照规则,此方法的上一级调用类为
        StackTraceElement ste = stackTrace[2];
        String className = ste.getClassName();
        String methodName = ste.getMethodName();
        LOGGER.error("{}#{},{}," + t.getMessage(), className, methodName, msg, t);
        // 异步发送邮件
        String ms = "[" + ThreadContext.currentThreadContext().getThreadId() + "]" + msg;
        executor.execute(() -> SendMailUtil.sendErrorMail(ms, t, 3));
    }


    /**
     * 只发送错误邮件不打印日志
     *
     * @param msg
     */
    public static void sendErrorLogMail(String msg, Throwable t) {
        // 异步发送邮件
        String ms = "[" + ThreadContext.currentThreadContext().getThreadId() + "]" + msg + assembleStackTrace(t);
        executor.execute(() -> SendMailUtil.sendErrorMail(ms, t, 3));
    }

    /**
     * 组装异常堆栈
     *
     * @param t
     * @return
     */
    public static String assembleStackTrace(Throwable t) {
        StringWriter sw = new StringWriter();
        PrintWriter ps = new PrintWriter(sw);
        t.printStackTrace(ps);
        return sw.toString();
    }

有关两个切面的执行顺序问题,请参考:https://www.cnblogs.com/yangyongjie/p/11800862.html

END

原文地址:https://www.cnblogs.com/yangyongjie/p/12535938.html