rest framework 节流

一、简单节流示例

所谓节流就是控制用户访问频率,这里分为匿名用户(非登录用户)和登录用户的限制。

  • 匿名用户:根据其 IP 限制其频率
  • 登录用户:IP、用户名都 OK

获取用户请求 IP:request.META.get('REMOTE_ADDR')

需求:一分钟内匿名用户最多只能访问三次

1、app/utils/throttle.py

import time

# 用户保存用户访问时间记录,以用户 IP 为键,值为当前时间组成的列表,最多只能有三个
VISIT_RECORD = {}

# {'127.0.0.1': [1561284634.3823733, 1561284633.0358386, 1561284631.2291381]}


class VisitThrottle(object):
    """节流,限制访问频率"""
    def __init__(self):
        self.history = None

    def allow_request(self, request, view):
        """允许过去的请求"""
        # 获取 ip
        remote_addr = request.META.get('REMOTE_ADDR')
        ctime = time.time()
        # 第一次访问
        if remote_addr not in VISIT_RECORD:
            VISIT_RECORD[remote_addr] = [ctime, ]
            return True
        
        # 获取记录
        history = VISIT_RECORD[remote_addr]
        self.history = history

        # 若列表最后一个和当前时间的差大于 1 分钟,就删除最后一个
        while history and history[-1] < ctime - 60:
            history.pop()   # 弹出最后一个

        if len(history) < 3:
            history.insert(0, ctime)		# 插入到列表第一个位置
            return True

    def wait(self):
        """剩余计时"""
        ctime = time.time()
        return 60 - (ctime - self.history[-1])

2、app/views.py

from .utils.throttle import VisitThrottle


class UserInfo(APIView):
    """用户个人信息"""
    authentication_classes = []
    permission_classes = []
    throttle_classes = [VisitThrottle, ]

    def get(self, request, *args, **kwargs):
        # user_name = request.user.username
        # return HttpResponse(user_name)

        # self.dispatch

        return HttpResponse('UserInfo')

3、访问:http://127.0.0.1:8000/api/v1/info/,这里我以匿名用户身份访问,三次之后:

上面是以用户 IP 作为键来限制用户访问频率,若是想要以用户名限制,只需将 IP 更换为用户名即可。

二、全局配置

# 设置全局认证
REST_FRAMEWORK = {
    "DEFAULT_THROTTLE_CLASSES": ['app.utils.throttle.VisitThrottle', ],
}

局部设置:

throttle_classes = [VisitThrottle, ]

三、内置节流

1、与认证、权限一样,rest framework 也有内置的节流相关实现,一般地自定义节流类需继承 BaseThrottle

import time
from rest_framework.throttling import BaseThrottle


VISIT_RECORD = {}


class VisitThrottle(BaseThrottle):
    """节流,限制访问频率"""
    def __init__(self):
        self.history = None

    def allow_request(self, request, view):
        """允许过去的请求"""
        # 获取 ip
        ...

2、类 BaseThrottle 源码:

BaseThrottle 中实现了 allow_request()、get_indent() 以及 wait() 方法,但是除了 get_ident() 方法中获取用户 IP 外,另外两个函数都没有实现其功能。

class BaseThrottle(object):
    """
    Rate throttling of requests(对请求进行速率限制).
    """

    def allow_request(self, request, view):
        """
        Return `True` if the request should be allowed, `False` otherwise.
        """
        raise NotImplementedError('.allow_request() must be overridden')

    def get_ident(self, request):
        """
        Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
        if present and number of proxies is > 0. If not use all of
        HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
        """
        xff = request.META.get('HTTP_X_FORWARDED_FOR')
        remote_addr = request.META.get('REMOTE_ADDR')		# 获取 IP
        num_proxies = api_settings.NUM_PROXIES

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr

    def wait(self):
        """
        Optionally, return a recommended number of seconds to wait before
        the next request.
        """
        return None

3、我们来看看 BaseThrottle 下面的类 SimpleRateThrottle

class SimpleRateThrottle(BaseThrottle):
    """
    A simple cache implementation, that only requires `.get_cache_key()`
    to be overridden. 一个简单的缓存实现,只需要重写 get_cache_key()方法

    The rate (requests / seconds) is set by a `rate` attribute on the View
    class.  The attribute is a string of the form 'number_of_requests/period'.
    速率格式为:request/分钟

    Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')

    Previous request information used for throttling is stored in the cache.
    用于限制的先前请求信息存储在缓存中
    """
    cache = default_cache		# Django 内置的缓存
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES	# 从配置文件中加载限制速率

    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

    def get_cache_key(self, request, view):
        """
        Should return a unique cache-key which can be used for throttling.
        Must be overridden.

        May return `None` if the request should not be throttled.
        """
        raise NotImplementedError('.get_cache_key() must be overridden')

    def get_rate(self):
        """
        Determine the string representation of the allowed request rate.
        """
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

    def parse_rate(self, rate):
        """
        Given the request rate string, return a two tuple of:
        <allowed number of requests>, <period of time in seconds>
        """
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

    def allow_request(self, request, view):
        """
        Implement the check to see if the request should be throttled.

        On success calls `throttle_success`.
        On failure calls `throttle_failure`.
        """
        if self.rate is None:
            return True

        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True

        self.history = self.cache.get(self.key, [])
        self.now = self.timer()

        # Drop any requests from the history which have now passed the
        # throttle duration
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

    def throttle_success(self):
        """
        Inserts the current request's timestamp along with the key
        into the cache.
        """
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
        """
        Called when a request to the API has failed due to throttling.
        """
        return False

    def wait(self):
        """
        Returns the recommended next request time in seconds.
        """
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)

4、它继承 BaseThrottle ,并重写了 allow_request()wait() 方法,另外地还设置控制访问速率属性设置。因此我们自定义节流类时,只需继承 SimpleRateThrottle ,并重写 get_cache_key() 方法即可 app/utils/throttle.py

from rest_framework.throttling import BaseThrottle, SimpleRateThrottle

class VisitThrottle(SimpleRateThrottle):
    scope = 'Hubery'

    def get_cache_key(self, request, view):
        return self.get_ident(request)

5、settings.py

# 设置全局认证
REST_FRAMEWORK = {
    "DEFAULT_THROTTLE_CLASSES": ['app.utils.throttle.VisitThrottle', ],
    "DEFAULT_THROTTLE_RATES": {
        "Hubery": '4/m',		# 匿名用户设置每分钟只能访问 4 次
    }
}


四、匿名用户和登录用户同时限制

1、app/utils/throttle.py

class VisitThrottle(SimpleRateThrottle):
    """匿名用户根据 IP 限制每分钟访问 4 次"""
    scope = 'Hubery'

    def get_cache_key(self, request, view):
        return self.get_ident(request)


class UserThrottle(SimpleRateThrottle):
    """登录用户限制每分钟可以访问 10 次"""
    scope = 'Jun'		# 随便定义

    def get_cache_key(self, request, view):
        return request.user.username

2、settings.py

# 设置全局认证
REST_FRAMEWORK = {
    "DEFAULT_THROTTLE_CLASSES": ['app.utils.throttle.UserThrottle', ],
    "DEFAULT_THROTTLE_RATES": {
        "Hubery": '4/m',		# 匿名用户设置每分钟只能访问 4 次
        "Jun": '10/m',			# 登录用户设置每分钟只能访问 10 次
    }
}

全局设置 登录用户限制(UserThrottle),局部设置 匿名用户限制即可

3、app/views.py

from .utils.throttle import VisitThrottle

class OrderView(APIView):
    """订单管理"""
    # authentication_classes = [MyAuthentication, ]  # 添加认证

    # authentication_classes = []
    # permission_classes = [SVIPPermission, ]

    def get(self, request, *args, **kwargs):
        ret = {'code': 1000, 'msg': None, 'data': None, }
        ret['data'] = ORDER_DICT
        print(request.user)
        return JsonResponse(ret)


    
class UserInfo(APIView):
    """用户个人信息"""
    authentication_classes = []
    permission_classes = []
    throttle_classes = [VisitThrottle, ]		# 局部设置匿名用户访问频率

    def get(self, request, *args, **kwargs):

        return HttpResponse('UserInfo')


源码流程图

总结

  • 自定义节流,须继承 SimpleRateThrottle,并重写 get_cache_key() 方法
  • 可以全局和局部限制,不限制可设置 throttle_classes = []
原文地址:https://www.cnblogs.com/midworld/p/11075983.html