spring boot gateway自定义限流

参考:https://blog.csdn.net/ErickPang/article/details/84680132

采用自带默认网关请参照微服务架构spring cloud - gateway网关限流,参数与其唯一的区别是header中多了参数userLevel,值为A或者B

此处实现按传入参数取到不同配置

userLvl.A.replenishRate: 10
userLvl.A.burstCapacity: 100
userLvl.B.replenishRate: 20
userLvl.B.burstCapacity: 1000

自定义限流器
package com.gatewayaop.filter;

import com.iot.crm.gatewayaop.common.config.UserLevelRateLimiterConf;
import org.springframework.beans.BeansException;
import org.springframework.cloud.gateway.filter.ratelimit.AbstractRateLimiter;
import org.springframework.cloud.gateway.filter.ratelimit.RateLimiter;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.data.redis.core.ReactiveRedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.util.ObjectUtils;
import org.springframework.validation.Validator;
import org.springframework.validation.annotation.Validated;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import javax.validation.constraints.Min;
import java.time.Instant;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;


public class UserLevelRedisRateLimiter extends AbstractRateLimiter<UserLevelRedisRateLimiter.Config> implements ApplicationContextAware {
    //这些变量全部从RedisRateLimiter复制的,都会用到。
    public static final String REPLENISH_RATE_KEY = "replenishRate";

    public static final String BURST_CAPACITY_KEY = "burstCapacity";

    public static final String CONFIGURATION_PROPERTY_NAME = "sys-redis-rate-limiter";
    public static final String REDIS_SCRIPT_NAME = "redisRequestRateLimiterScript";
    public static final String REMAINING_HEADER = "X-RateLimit-Remaining";
    public static final String REPLENISH_RATE_HEADER = "X-RateLimit-Replenish-Rate";
    public static final String BURST_CAPACITY_HEADER = "X-RateLimit-Burst-Capacity";

    //处理速度
    private static final String DEFAULT_REPLENISHRATE="default.replenishRate";
    //容量
    private static final String DEFAULT_BURSTCAPACITY="default.burstCapacity";

    private ReactiveRedisTemplate<String, String> redisTemplate;
    private RedisScript<List<Long>> script;
    private AtomicBoolean initialized = new AtomicBoolean(false);

    private String remainingHeader = REMAINING_HEADER;

    /** The name of the header that returns the replenish rate configuration. */
    private String replenishRateHeader = REPLENISH_RATE_HEADER;

    /** The name of the header that returns the burst capacity configuration. */
    private String burstCapacityHeader = BURST_CAPACITY_HEADER;

    private Config defaultConfig;

    public UserLevelRedisRateLimiter(ReactiveRedisTemplate<String, String> redisTemplate,
                                  RedisScript<List<Long>> script, Validator validator) {
        super(Config.class , CONFIGURATION_PROPERTY_NAME , validator);
        this.redisTemplate = redisTemplate;
        this.script = script;
        initialized.compareAndSet(false,true);
    }

    public UserLevelRedisRateLimiter(int defaultReplenishRate, int defaultBurstCapacity){
        super(Config.class , CONFIGURATION_PROPERTY_NAME , null);
        defaultConfig = new Config()
                .setReplenishRate(defaultReplenishRate)
                .setBurstCapacity(defaultBurstCapacity);

    }
    //具体限流实现,此处调用的是lua脚本
    @Override
    public Mono<Response> isAllowed(String routeId, String id) {
        if (!this.initialized.get()) {
            throw new IllegalStateException("RedisRateLimiter is not initialized");
        }
        if (ObjectUtils.isEmpty(rateLimiterConf) ){
            throw new IllegalArgumentException("No Configuration found for route " + routeId);
        }
        //获取的是自定义的map
        Map<String , Integer> rateLimitMap = rateLimiterConf.getRateLimitMap();
        //缓存的key,此处routeId为userSev,Id为header参数userLevel的值(A或者B)
        String replenishRateKey = routeId + "." + id + "." + REPLENISH_RATE_KEY;
        //若map中不存在则采用默认值,存在则取值。
        int replenishRate = ObjectUtils.isEmpty(rateLimitMap.get(replenishRateKey)) ? rateLimitMap.get(DEFAULT_REPLENISHRATE) : rateLimitMap.get(replenishRateKey);
        //容量key
        String burstCapacityKey = routeId + "." + id + "." + BURST_CAPACITY_KEY;
        //若map中不存在则采用默认值,存在则取值。
        int burstCapacity = ObjectUtils.isEmpty(rateLimitMap.get(burstCapacityKey)) ? rateLimitMap.get(DEFAULT_BURSTCAPACITY) : rateLimitMap.get(burstCapacityKey);

        try {
            List<String> keys = getKeys(id);

            List<String> scriptArgs = Arrays.asList(replenishRate + "", burstCapacity + "",
                    Instant.now().getEpochSecond() + "", "1");
            Flux<List<Long>> flux = this.redisTemplate.execute(this.script, keys, scriptArgs);

            return flux.onErrorResume(throwable -> Flux.just(Arrays.asList(1L, -1L)))
                    .reduce(new ArrayList<Long>(), (longs, l) -> {
                        longs.addAll(l);
                        return longs;
                    }) .map(results -> {
                        boolean allowed = results.get(0) == 1L;
                        Long tokensLeft = results.get(1);

                        RateLimiter.Response response = new RateLimiter.Response(allowed, getHeaders(replenishRate , burstCapacity , tokensLeft));

                        return response;
                    });
        } catch (Exception e) {
            e.printStackTrace();
        }

        return Mono.just(new RateLimiter.Response(true, getHeaders(replenishRate , burstCapacity , -1L)));
    }

    private UserLevelRateLimiterConf rateLimiterConf;

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        this.rateLimiterConf = applicationContext.getBean(UserLevelRateLimiterConf.class);
    }

    public HashMap<String, String> getHeaders(Integer replenishRate, Integer burstCapacity , Long tokensLeft) {
        HashMap<String, String> headers = new HashMap<>();
        headers.put(this.remainingHeader, tokensLeft.toString());
        headers.put(this.replenishRateHeader, String.valueOf(replenishRate));
        headers.put(this.burstCapacityHeader, String.valueOf(burstCapacity));
        return headers;
    }

    static List<String> getKeys(String id) {
        // use `{}` around keys to use Redis Key hash tags
        // this allows for using redis cluster

        // Make a unique key per user.
        //此处可以自定义redis前缀信息
        String prefix = "request_sys_rate_limiter.{" + id;

        // You need two Redis keys for Token Bucket.
        String tokenKey = prefix + "}.tokens";
        String timestampKey = prefix + "}.timestamp";
        return Arrays.asList(tokenKey, timestampKey);
    }


    @Validated
    public static class Config{
        @Min(1)
        private int replenishRate;
        @Min(1)
        private int burstCapacity = 1;

        public int getReplenishRate() {
            return replenishRate;
        }

        public Config setReplenishRate(int replenishRate) {
            this.replenishRate = replenishRate;
            return this;
        }

        public int getBurstCapacity() {
            return burstCapacity;
        }

        public Config setBurstCapacity(int burstCapacity) {
            this.burstCapacity = burstCapacity;
            return this;
        }

        @Override
        public String toString() {
            return "Config{" +
                    "replenishRate=" + replenishRate +
                    ", burstCapacity=" + burstCapacity +
                    '}';
        }
    }
}

读取自定义配置类

package com.gatewayaop.common.config;

import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Configuration;
import org.springframework.stereotype.Component;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;


//使用配置文件的方式进行初始化

@Component
@ConfigurationProperties(prefix = "comsumer.ratelimiter-conf")
//@EnableConfigurationProperties(UserLevelRateLimiterConf.class)
public class UserLevelRateLimiterConf {
    //处理速度
    private static final String DEFAULT_REPLENISHRATE="default.replenishRate";
    //容量
    private static final String DEFAULT_BURSTCAPACITY="default.burstCapacity";

    //默认配置
    private Map<String , Integer> rateLimitMap = new ConcurrentHashMap<String , Integer>(){
        {
            put(DEFAULT_REPLENISHRATE , 10);
            put(DEFAULT_BURSTCAPACITY , 100);
        }
    };

    public Map<String, Integer> getRateLimitMap() {
        return rateLimitMap;
    }

    public void setRateLimitMap(Map<String, Integer> rateLimitMap) {
        this.rateLimitMap = rateLimitMap;
    }
}

定义限流器种类

package com.gatewayaop.common.config;

import com.iot.crm.gatewayaop.filter.UserLevelRedisRateLimiter;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.cloud.gateway.filter.ratelimit.KeyResolver;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.data.redis.core.ReactiveRedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.validation.Validator;
import reactor.core.publisher.Mono;

import java.util.List;


@Configuration
public class RequestRateLimiterConfig {
    @Bean
    @Primary
    KeyResolver apiKeyResolver() {
            //按URL限流
            return exchange -> Mono.just(exchange.getRequest().getPath().toString());
            }

    @Bean
    KeyResolver userKeyResolver() {
        //按用户限流
        return exchange -> Mono.just(exchange.getRequest().getQueryParams().getFirst("user"));
    }

    @Bean
    KeyResolver ipKeyResolver() {
        //按IP来限流
        return exchange -> Mono.just(exchange.getRequest().getRemoteAddress().getHostName());
    }

    @Bean
    KeyResolver userLevelKeyResolver() {
        //按IP来限流
        return exchange -> Mono.just(exchange.getRequest().getHeaders().getFirst("userLevel"));
    }

    @Bean
    @Primary
        //使用自己定义的限流类
    UserLevelRedisRateLimiter userLevelRedisRateLimiter(
            ReactiveRedisTemplate<String, String> redisTemplate,
            @Qualifier(UserLevelRedisRateLimiter.REDIS_SCRIPT_NAME) RedisScript<List<Long>> script,
            @Qualifier("defaultValidator") Validator validator){
        return new UserLevelRedisRateLimiter(redisTemplate , script , validator);
    }

}

yml配置

server:
  port: 9701


spring:
  application:
    name: gateway-aop-dev
  profiles:
    active: dev
  index: 62
  cloud:
    gateway:
      discovery:
        locator:
          enabled: true
          # 服务名小写
          lower-case-service-id: true
      routes:
        #与customer.中key相同即是java代码中的routeID
        - id: userSev
          # lb代表从注册中心获取服务,且已负载均衡方式转发
          uri: lb://hello-dev
          predicates:
            - Path=/hello-dev/**
          # 加上StripPrefix=1,否则转发到后端服务时会带上consumer前缀
          filters:
            - StripPrefix=1
            # 限流过滤器,使用gateway内置令牌算法
            - name: RequestRateLimiter
              args:
#                # 令牌桶每秒填充平均速率,即行等价于允许用户每秒处理多少个请求平均数
#                redis-rate-limiter.replenishRate: 10
#                # 令牌桶的容量,允许在一秒钟内完成的最大请求数
#                redis-rate-limiter.burstCapacity: 20
                # 用于限流的键的解析器的 Bean 对象的名字。它使用 SpEL 表达式根据#{@beanName}从 Spring 容器中获取 Bean 对象。
                key-resolver: "#{@userLevelKeyResolver}"
                rate-limiter: "#{@userLevelRedisRateLimiter}"
comsumer:
  ratelimiter-conf:
    #配置限流参数与RateLimiterConf类映射
    rateLimitMap:
      #格式为:routeid(gateway配置routes时指定的).系统名称.replenishRate(流速)/burstCapacity令牌桶大小
      userSev.A.replenishRate: 10
      userSev.A.burstCapacity: 100
      userSev.B.replenishRate: 20
      userSev.B.burstCapacity: 1000
 
原文地址:https://www.cnblogs.com/pu20065226/p/11449260.html