Spring AOP + Redis 实现针对用户的接口访问频率限制

根据请求参数或请求头中的字段进行频率限制

限制类型:

package com.seliote.fr.config.api;

/**
 * API 调用频率限制注解类型
 *
 * @author seliote
 */
public enum ApiFrequencyType {
    // 请求参数,如果为该类型那么方法有且只能有一个参数
    ARG,
    // 请求头
    HEADER
}

限制实际使用的注解:

package com.seliote.fr.annotation;

import com.seliote.fr.config.api.ApiFrequencyType;
import com.seliote.fr.config.auth.TokenFilter;

import java.lang.annotation.*;
import java.time.temporal.ChronoUnit;

/**
 * API 调用频率限制注解,配合切面使用
 * 要求注解的方法有且只能有一个参数
 * 默认值为取 Token 请求头,每分钟五次
 *
 * @author seliote
 */
@Documented
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface ApiFrequency {

    // 类型,表示判断频率的值需要从哪个字段取,有先后顺序
    ApiFrequencyType type() default ApiFrequencyType.HEADER;

    // 判断频率使用的值,多个参数使用 && 连接
    String key() default TokenFilter.TOKEN_HEADER;

    // API 最大频率
    int frequency() default 5;

    // 时间
    long time() default 1;

    // 时间单位
    ChronoUnit unit() default ChronoUnit.MINUTES;
}

切面代码:

package com.seliote.fr.config.api;

import com.seliote.fr.annotation.stereotype.ApiComponent;
import com.seliote.fr.exception.FrequencyException;
import com.seliote.fr.service.RedisService;
import com.seliote.fr.util.CommonUtils;
import com.seliote.fr.util.TextUtils;
import lombok.extern.log4j.Log4j2;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.annotation.Order;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import java.beans.IntrospectionException;
import java.beans.PropertyDescriptor;
import java.lang.reflect.InvocationTargetException;
import java.time.Instant;
import java.util.Optional;

import static com.seliote.fr.util.ReflectUtils.getClassName;

/**
 * API 调用频率限制 AOP
 *
 * @author seliote
 */
@Log4j2
@Order(1)
@ApiComponent
@Aspect
public class ApiFrequency {

    private static final String redisNameSpace = "frequency";

    private final RedisService redisService;

    @Autowired
    public ApiFrequency(RedisService redisService) {
        this.redisService = redisService;
        log.debug("Construct {}", getClassName(this));
    }

    /**
     * API 调用频率限制
     *
     * @param joinPoint AOP JoinPoint 对象
     */
    @Before("com.seliote.fr.config.api.ApiAop.api() && @annotation(com.seliote.fr.annotation.ApiFrequency)")
    public void apiFrequency(JoinPoint joinPoint) {
        Optional<String> uri = CommonUtils.getUri();
        if (uri.isEmpty()) {
            log.error("Frequency check error, uri is null");
            throw new FrequencyException("URI is empty");
        }
        var method = ((MethodSignature) joinPoint.getSignature()).getMethod();
        var annotation = method.getAnnotation(com.seliote.fr.annotation.ApiFrequency.class);
        var identifier = (annotation.type() == ApiFrequencyType.ARG ?
                getArg(uri.get(), joinPoint, annotation) : getHeader(uri.get(), annotation));
        if (identifier.isEmpty()) {
            log.error("Frequency check error, identifier is empty for: {}", uri.get());
            // 过滤器在切面前执行,如果获取到为空说明代码有问题
            throw new FrequencyException("Identifier is empty");
        }
        // Token 或者参数可能会很长,所以 SHA-1 一下
        var sha1 = TextUtils.sha1(identifier.get());
        // 单位时长
        var unitSeconds = CommonUtils.time2Seconds(annotation.time(), annotation.unit());
        var redisKey = getRedisKey(uri.get(), sha1, unitSeconds);
        var current = frequency(redisKey, unitSeconds);
        var frequency = annotation.frequency();
        if (current <= frequency) {
            log.debug("Frequency pass for: {}, current: {}, identifier: {}, sha1: {}",
                    uri.get(), current, identifier.get(), sha1);
        } else {
            log.warn("Frequency too high: {}, current: {}, identifier: {}, sha1: {}",
                    uri.get(), current, identifier.get(), sha1);
            throw new FrequencyException("Frequency too high");
        }
    }

    /**
     * 获取请求参数中的标识符
     *
     * @param joinPoint  JoinPoint 对象注解
     * @param annotation @ApiFrequency 对象
     * @return 请求参数中的标识符
     */
    private Optional<String> getArg(String uri, JoinPoint joinPoint, com.seliote.fr.annotation.ApiFrequency annotation) {
        var args = joinPoint.getArgs();
        if (args == null || args.length != 1) {
            log.error("Args length incorrect: {}, {}", uri, args);
            return Optional.empty();
        }
        var arg = args[0];
        var keys = annotation.key().split(getKeySeparator());
        var identifiers = new String[keys.length];
        for (var i = 0; i < keys.length; ++i) {
            try {
                final var pd = new PropertyDescriptor(keys[i], arg.getClass());
                var result = pd.getReadMethod().invoke(arg);
                if (result == null) {
                    log.error("Argument getter return null: {}, argument: {}, getter: {}", uri, arg, keys[i]);
                    throw new FrequencyException("Argument getter return null");
                }
                identifiers[i] = result.toString();
            } catch (IntrospectionException | IllegalAccessException | InvocationTargetException exception) {
                log.error("Get identifier args error: {}, {}, exception: {}, message: {}, exception at: {}",
                        uri, arg, getClassName(exception), exception.getMessage(), keys[i]);
                return Optional.empty();
            }
        }
        return Optional.of(String.join(getIdentifierSeparator(), identifiers));
    }

    /**
     * 获取请求头中的标识符
     *
     * @param annotation @ApiFrequency 对象
     * @return 请求头中的标识符
     */
    private Optional<String> getHeader(String uri, com.seliote.fr.annotation.ApiFrequency annotation) {
        var keys = annotation.key().split(getKeySeparator());
        var identifiers = new String[keys.length];
        var servletAttr = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        if (servletAttr != null) {
            var httpAttr = servletAttr.getRequest();
            for (var i = 0; i < keys.length; ++i) {
                var header = httpAttr.getHeader(keys[i]);
                if (header == null || header.length() <= 0) {
                    log.error("Header is null for: {}, header: {}", uri, keys[i]);
                    throw new FrequencyException("Get header return null");
                }
                identifiers[i] = header;
            }
        }
        return Optional.of(String.join(getIdentifierSeparator(), identifiers));
    }

    /**
     * 获取 Redis 中存储的 Key
     *
     * @param uri        请求 URI
     * @param identifier 本次请求的标识符
     * @param seconds    单位时间秒数
     * @return Redis 中存储的 Key
     */
    private String getRedisKey(String uri, String identifier, long seconds) {
        var now = Instant.now().getEpochSecond();
        // 计算单位起始时间,单位时间内访问次数限制
        var start = (now - (now % seconds)) + "";
        return redisService.formatKey(redisNameSpace, uri, identifier, start);
    }

    /**
     * 增加本次的频率并获取单位时间内的访问次数
     *
     * @param key     Redis Key
     * @param seconds 单位时间秒数
     * @return 访问次数
     */
    private long frequency(String key, long seconds) {
        if (!redisService.exists(key)) {
            redisService.setex(key, (int) seconds, "0");
        }
        return redisService.incr(key);
    }

    /**
     * 获取 Redis Key 的分隔符
     *
     * @return 分隔符
     */
    private String getKeySeparator() {
        return "&&";
    }

    /**
     * 获取标识符间的分隔符
     *
     * @return 分隔符
     */
    private String getIdentifierSeparator() {
        return ".";
    }
}

实际使用的例子:

package com.seliote.fr.controller;

import com.seliote.fr.annotation.ApiFrequency;
import com.seliote.fr.annotation.stereotype.ApiController;
import com.seliote.fr.config.api.ApiFrequencyType;
import com.seliote.fr.exception.ApiException;
import com.seliote.fr.model.ci.user.LoginCi;
import com.seliote.fr.model.co.Co;
import com.seliote.fr.model.co.user.LoginCo;
import com.seliote.fr.model.si.user.LoginSi;
import com.seliote.fr.service.UserService;
import com.seliote.fr.util.ReflectUtils;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.ResponseBody;

import javax.validation.Valid;

import static com.seliote.fr.util.ReflectUtils.getClassName;

/**
 * 用户帐户 Controller
 *
 * @author seliote
 */
@Log4j2
@ApiController
@RequestMapping(value = "user", method = {RequestMethod.POST})
public class UserController {

    private final UserService userService;

    @Autowired
    public UserController(UserService userService) {
        this.userService = userService;
        log.debug("Construct {}", getClassName(this));
    }

    /**
     * 登录用户帐户,未注册的账户将会自动注册
     *
     * @param ci CI
     * @return CO
     */
    @ApiFrequency(type = ApiFrequencyType.ARG, key = "countryCode&&telNo")
    @RequestMapping("login")
    @ResponseBody
    public Co<LoginCo> login(@Valid @RequestBody LoginCi ci) {
        var so = userService.login(ReflectUtils.copy(ci, LoginSi.class));
        if (so.getLoginResult() == 0 || so.getLoginResult() == 1) {
            return Co.cco(ReflectUtils.copy(so, LoginCo.class));
        } else {
            log.error("login for: {}, service return: {}", ci, so);
            throw new ApiException("service return value error");
        }
    }

    @ApiFrequency()
    @RequestMapping("info")
    @ResponseBody
    public Co<Void> info() {
        return Co.cco(null);
    }
}
原文地址:https://www.cnblogs.com/seliote/p/14458006.html