redis限流器的设计

1.定义注解

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Repeatable;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * redis缓存的注解
 *
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Repeatable(RateLimits.class)
public @interface RedisRateLimitAttribute {
    /**
     * {@link #key()}的别名
     *
     * @return key()的别名
     */
    String value() default "";

    /**
     * key, 支持SpEL表达式解析
     *
     * @return 限流的key值
     */
    String key() default "";

    /**
     * 限流的优先级
     *
     * @return 限流器的优先级
     */
    int order() default 0;

    /**
     * 执行计数的条件表达式,支持SpEL表达式,如果结果为真,则执行计数
     *
     * @return 执行计数的条件表达式
     */
    String incrCondition() default "true";

    /**
     * 限流的最大值,支持配置引用
     *
     * @return 限流的最大值
     */
    String limit() default "1";

    /**
     * 限流的时间范围值,支持配置引用
     *
     * @return 限流的时间范围值
     */
    String intervalInMilliseconds() default "1000";

    /**
     * 降级的方法名,降级方法的参数与原方法一致或多了一个原方法的ReturnValue的类型
     *
     * @return 降级的方法名
     */
    String fallbackMethod() default "";
}
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * 多重限流注解的存储器
 */
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimits {

    /**
     *
     * @return 注解列表
     */
    RedisRateLimitAttribute[] value() default {};
}

2. 切面方法

import com.google.common.base.Strings;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.EnableAspectJAutoProxy;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.Order;
import org.springframework.core.env.Environment;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import redis.clients.jedis.JedisCluster;

//开启AspectJ 自动代理模式,如果不填proxyTargetClass=true,默认为false,
@EnableAspectJAutoProxy(proxyTargetClass = true)
@Component
@Order(-1)
@Aspect
public class RedisRateLimitAspect {
    /**
     * 日志
     */
    private static Logger logger = LoggerFactory.getLogger(RedisRateLimitAspect.class);

    /**
     * SPEL表达式解析器
     */
    private static final ExpressionParser EXPRESSION_PARSER = new SpelExpressionParser();

    /**
     * 获取方法参数名称发现器
     */
    private static final ParameterNameDiscoverer PARAMETER_NAME_DISCOVERER = new DefaultParameterNameDiscoverer();

    /**
     * Redis集群
     */
    @Autowired
    private JedisCluster jedisCluster;

    /**
     * springboot自动加载配置信息
     */
    @Autowired
    private Environment environment;

    /**
     * 切面切入点
     */
    @Pointcut("@annotation(com.g2.order.server.annotation.RedisRateLimitAttribute)")
    public void rateLimit() {

    }

    /**
     * 环绕切面
     */
    @Around("rateLimit()")
    public Object handleControllerMethod(ProceedingJoinPoint proceedingJoinPoint) throws Throwable {
        //获取切入点对应的方法.
        MethodSignature methodSignature = (MethodSignature) proceedingJoinPoint.getSignature();
        Method method = methodSignature.getMethod();

        //获取注解列表
        List<RedisRateLimitAttribute> redisRateLimitAttributes =
                AnnotatedElementUtils.findMergedRepeatableAnnotations(method, RedisRateLimitAttribute.class)
                        .stream()
                        .sorted(Comparator.comparing(RedisRateLimitAttribute::order))
                        .collect(Collectors.collectingAndThen(Collectors.toList(), Collections::unmodifiableList));

        if (CollectionUtils.isEmpty(redisRateLimitAttributes)) {
            return proceedingJoinPoint.proceed();
        }

        // 切入点所在的实例,调用fallback方法时需要
        Object target = proceedingJoinPoint.getTarget();
        // 方法入参集合,调用fallback方法时需要
        Object[] args = proceedingJoinPoint.getArgs();
        if (args == null) {
            args = new Object[0];
        }

        // 前置检查
        for (RedisRateLimitAttribute rateLimit : redisRateLimitAttributes) {
            // 获取限流设置的key(可能有配置占位符和spel表达式)
            String key = computeExpress(formatKey(rateLimit.key()), proceedingJoinPoint, String.class);
            // 获取限流配置的阀值
            long limitV = Long.parseLong(formatKey(rateLimit.limit()));
            // 获取当前key已记录的值
            String currentValue = jedisCluster.get(key);
            long currentV = Strings.isNullOrEmpty(currentValue) ? 0 : Long.parseLong(jedisCluster.get(key));
            // 当前值如果小于等于阀值,则合法;否则不合法
            boolean validated = currentV <= limitV;
            // 如果不合法则进入fallback流程
            if (!validated) {
                // 获取当前限流配置的fallback
                Method fallbackMethod = getFallbackMethod(proceedingJoinPoint, rateLimit.fallbackMethod());
                // 如果fallback参数数量与切入点参数数量不一样,则压入空的返回值
                if (fallbackMethod.getParameterCount() != method.getParameterCount()) {
                    Object[] args2 = Arrays.copyOf(args, args.length + 1);
                    args2[args2.length - 1] = null;
                    return invokeFallbackMethod(fallbackMethod, target, args2);
                }

                return invokeFallbackMethod(fallbackMethod, target, args);
            }
        }

        // 前置检查通过后,执行方法体
        Object result = proceedingJoinPoint.proceed();

        // 后置检查
        for (RedisRateLimitAttribute rateLimit : redisRateLimitAttributes) {
            // 获取限流设置的key(可能有配置占位符和spel表达式)
            String key = computeExpress(formatKey(rateLimit.key()), proceedingJoinPoint, String.class, result);
            // 获取限流配置的阀值
            long limitV = Long.parseLong(formatKey(rateLimit.limit()));
            // 获取限流配置的限流区间
            long interval = Long.parseLong(formatKey(rateLimit.intervalInMilliseconds()));
            boolean validated = true;
            // 计算当前一次执行后是否满足限流条件
            boolean incrMatch = match(proceedingJoinPoint, rateLimit, result);
            if (incrMatch) {
                // 如果不存在key,则设置该key,并且超时时间为限流区间值
                // 获取当前key已记录的值
                String currentValue = jedisCluster.get(key);
                // TODO 这里最好修改成 lua脚本来实现原子性
                long currentV = Strings.isNullOrEmpty(currentValue) ? 0 : Long.parseLong(jedisCluster.get(key));
                if (currentV == 0) {
                    jedisCluster.set(key, "1", "nx", "ex", interval);
                } else {
                    jedisCluster.incrBy(key, 1);
                }
                validated = currentV +1 <= limitV;
            }

            if (!validated) {
                // 获取fallback方法
                // TODO 这里可以修改为已获取的话Map里,下次不需要再调用getFallbackMethod方法了
                Method fallbackMethod = getFallbackMethod(proceedingJoinPoint, rateLimit.fallbackMethod());
                Object[] args2 = Arrays.copyOf(args, args.length + 1);
                args2[args2.length - 1] = result;
                return invokeFallbackMethod(fallbackMethod, target, args2);
            }
        }

        return result;
    }

    /**
     * 计算spel表达式
     *
     * @param expression 表达式
     * @param context    上下文
     * @return String的缓存key
     */
    private <T> T computeExpress(String expression, JoinPoint context, Class<T> tClass) {
        // 计算表达式(根据参数上下文)
        return computeExpress(expression, context, tClass, null);
    }

    /**
     * 计算spel表达式
     *
     * @param expression 表达式
     * @param context    上下文
     * @return String的缓存key
     */
    private <T> T computeExpress(String expression, JoinPoint context, Class<T> tClass, Object returnValue) {
        // 将参数名与参数值放入参数上下文
        EvaluationContext evaluationContext = buildEvaluationContext(returnValue, context);

        // 计算表达式(根据参数上下文)
        return EXPRESSION_PARSER.parseExpression(expression).getValue(evaluationContext, tClass);
    }

    /**
     * 计算是否匹配限流策略
     * @param context
     * @param rateLimit
     * @param returnValue
     * @return
     */
    private boolean match(JoinPoint context, RedisRateLimitAttribute rateLimit, Object returnValue) {
        return computeExpress(rateLimit.incrCondition(), context, Boolean.class, returnValue);
    }

    /**
     * 格式化key
     * @param v
     * @return
     */
    private String formatKey(String v) {
        String result = v;
        if (Strings.isNullOrEmpty(result)) {
            throw new IllegalStateException("key配置不能为空");
        }
        return environment.resolvePlaceholders(result);
    }

    /**
     * 放入参数值到StandardEvaluationContext
     */
    private static void addParameterVariable(StandardEvaluationContext evaluationContext, JoinPoint context) {
        MethodSignature methodSignature = (MethodSignature) context.getSignature();
        Method method = methodSignature.getMethod();
        String[] parameterNames = PARAMETER_NAME_DISCOVERER.getParameterNames(method);
        if (parameterNames != null && parameterNames.length > 0) {
            Object[] args = context.getArgs();
            for (int i = 0; i < parameterNames.length; i++) {
                evaluationContext.setVariable(parameterNames[i], args[i]);
            }
        }
    }

    /**
     * 放入返回值到StandardEvaluationContext
     */
    private static void addReturnValue(StandardEvaluationContext evaluationContext, Object returnValue) {
        evaluationContext.setVariable("returnValue", returnValue);
        evaluationContext.setVariable("response", returnValue);
    }

    /**
     * 构建StandardEvaluationContext
     */
    private static EvaluationContext buildEvaluationContext(Object returnValue, JoinPoint context) {
        StandardEvaluationContext evaluationContext = new StandardEvaluationContext();
        addParameterVariable(evaluationContext, context);
        addReturnValue(evaluationContext, returnValue);

        return evaluationContext;
    }

    /**
     * 获取降级方法
     *
     * @param context        过滤器上下文
     * @param fallbackMethod 失败要执行的函数
     * @return 降级方法
     */
    private static Method getFallbackMethod(JoinPoint context, String fallbackMethod) {
        MethodSignature methodSignature = (MethodSignature) context.getSignature();
        Class[] parameterTypes = Optional.ofNullable(methodSignature.getParameterTypes()).orElse(new Class[0]);
        try {
            Method method = context.getTarget().getClass().getDeclaredMethod(fallbackMethod, parameterTypes);
            method.setAccessible(true);
            return method;
        } catch (NoSuchMethodException e) {

        }

        try {
            Class[] parameterTypes2 = Arrays.copyOf(parameterTypes, parameterTypes.length + 1);
            parameterTypes2[parameterTypes2.length - 1] = methodSignature.getReturnType();

            Method method = context.getTarget().getClass().getDeclaredMethod(fallbackMethod, parameterTypes2);
            method.setAccessible(true);
            return method;
        } catch (NoSuchMethodException e) {

        }

        String message = String.format("获取fallbackMethod失败, context: %s, fallbackMethod: %s",
                context, fallbackMethod);
        throw new RuntimeException(message);
    }

    /**
     * 执行降级fallback方法
     * @param fallbackMethod
     * @param fallbackTarget
     * @param fallbackArgs
     * @return
     * @throws Throwable
     */
    private static Object invokeFallbackMethod(Method fallbackMethod, Object fallbackTarget, Object[] fallbackArgs)
            throws Throwable {
        try {
            return fallbackMethod.invoke(fallbackTarget, fallbackArgs);
        } catch (InvocationTargetException e) {
            if (e.getCause() != null) {
                throw e.getCause();
            }
            throw e;
        }
    }
}

3.调用事例

@Slf4j
@Api(value = "HomeController", description = "用户登录登出接口")
@RestController
@RequestMapping("/home")
public class HomeController {
    private static Logger logger = LoggerFactory.getLogger(HomeController.class);


    @ApiOperation(value = "用户登录", notes = "用户登录接口")
    @RequestMapping(value = "/login",
            method = RequestMethod.POST,
            consumes = MediaType.APPLICATION_JSON_VALUE,
            produces = MediaType.APPLICATION_JSON_VALUE)
    @ResponseBody
  
    @RedisRateLimitAttribute(key = "'login'+#req.userId"
            , limit = "${login.maxFailedTimes:3}"
            , incrCondition = "#response.success == true"
            , intervalInMilliseconds = "${login.limit.millseconds:3600}"
            , fallbackMethod = "loginFallback"
    )
    public UserLoginResp login(@RequestBody UserLoginReq req) {
        logger.info("进入登陆业务");
        
        UserModel userModel = new UserModel();
        userModel.setRoleId(123);
        userModel.setUserId(req.getUserId());
        userModel.setMustValidateCode(false);

        return new UserLoginResp(userModel);
    }

    private UserLoginResp loginFallback(UserLoginReq req, UserLoginResp resp) {
        if (resp == null) {
            return new UserLoginResp();
           
        }
        resp.getPayload().setMustValidateCode(true);
        return resp;
    }
}
@Data
public class UserModel {
    /***
     * 用户id
     */
    private String userId;

    /**
     * 角色
     */
    private String roleName;

    /**
     * 角色编号
     */
    private Integer roleId;

    /**
     * 登陆是否需要验证码
     * 当错误次数达到阀值时,需要验证码来增加提交难度
     */
    private Boolean mustValidateCode;
}
import lombok.Data;

@Data
public class Response<T> {
    private Boolean success;
    private String errorMessage;
    private T payload;

    public Response() {
        this(true);
    }

    public Response(boolean succ) {
        this(succ, "");
    }

    public Response(boolean succ, String msg) {
        this(succ, msg, null);
    }

    public Response(T data) {
        this(true, "", data);
    }

    public Response(boolean succ, String msg, T data) {
        success = succ;
        errorMessage = msg;
        this.payload = data;
    }
}
public class UserLoginResp extends Response<UserModel> {
    public UserLoginResp(){
    }
    public UserLoginResp(UserModel userModel){
        super(userModel);
    }

    @Override
    public String toString() {
        return super.toString();
    }
}
原文地址:https://www.cnblogs.com/zhshlimi/p/11835401.html