drf权限认证, drf频率认证

drf三大认证之权限认证

权限认证流程源码

'''
# ...Libsite-packages
est_frameworkviews.py
from rest_framework.settings import api_settings
...
class APIView(View):
	...
    permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES  # []
	...

    def get_permissions(self):
        return [permission() for permission in self.permission_classes]  # 返回包含权限认证类的对象的列表

    ...
    def check_permissions(self, request):
        for permission in self.get_permissions():
            if not permission.has_permission(request, self):
                self.permission_denied(
                    request, message=getattr(permission, 'message', None)
                )
        
    ...
    def initial(self, request, *args, **kwargs):
        ...
        self.check_permissions(request)  # 进入权限认证
        ...

    def dispatch(self, request, *args, **kwargs):
        ...
        request = self.initialize_request(request, *args, **kwargs)
        ...
        try:
            self.initial(request, *args, **kwargs)  # 进入三大认证
            ...
'''

drf内置的权限认证类

'''
# ...Libsite-packages
est_frameworkpermissions.py
...
class IsAuthenticated(BasePermission):
    def has_permission(self, request, view):
        return bool(request.user and request.user.is_authenticated)
        
        
class IsAuthenticatedOrReadOnly(BasePermission):
    def has_permission(self, request, view):
        return bool(
            request.method in SAFE_METHODS or
            request.user and
            request.user.is_authenticated
        )        
'''

局部配置自定义的权限认证类

'''
# ...d_projapiurls.py
...
urlpatterns = [
    url(r'^login/$', views.JwtLoginAPIView.as_view())
]
router.register('user', views.SingleUserInfoViewSet, basename='myuser')
urlpatterns.extend(router.urls)


# ...d_projapiviews.py
from rest_framework.viewsets import ViewSet
from rest_framework.response import Response
from . import my_permissions, my_serializers


class SingleUserInfoViewSet(ViewSet):
    permission_classes = [my_permissions.IsVip, ]

    def retrieve(self, request, *args, **kwargs):
        user_ser = my_serializers.UserModelSerializer(instance=request.user)
        return Response(data={'code': 0, 'res': user_ser.data}, status=200)
        
        
# ...d_projapimy_permissions.py
from rest_framework.permissions import BasePermission


class IsVip(BasePermission):
    def has_permission(self, request, view):
        for group in request.user.groups.all():  # 通过django admin后台在auth_group表中创建vip分组, 并将tank用户添加到该分组中
            if group.name == 'vip':
                return True
        return False
'''

drf三大认证之频率认证

频率认证流程源码

'''
# ...Libsite-packages
est_frameworkviews.py
from rest_framework.settings import api_settings
...
class APIView(View):
	...
    throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
	...

    def get_throttles(self):
        return [throttle() for throttle in self.throttle_classes]  # 返回包含频率认证类的对象的列表

    ...
    def check_throttles(self, request):
        throttle_durations = []  # 定义列表存放等待时间
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                throttle_durations.append(throttle.wait())  # 如果访问者因频率限制而无法访问, 将wait方法的结果添加到时间列表中

        if throttle_durations:
            ...
        
    def initial(self, request, *args, **kwargs):
        ...
        self.check_throttles(request)  # 进入频率认证
        ...

    def dispatch(self, request, *args, **kwargs):
        ...
        request = self.initialize_request(request, *args, **kwargs)
        ...
        try:
            self.initial(request, *args, **kwargs)  # 进入三大认证
            ...
'''

drf内置的频率认证类

'''
# ...Libsite-packages
est_framework	hrottling.py
...
from django.core.cache import cache as default_cache
from rest_framework.settings import api_settings


class SimpleRateThrottle(BaseThrottle):
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES  # {'user': None, 'anon': None}

    def __init__(self):
        if not getattr(self, 'rate', None):  # self为频率认证类的对象
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

    def get_cache_key(self, request, view):
        raise NotImplementedError('.get_cache_key() must be overridden')

    def get_rate(self):
        if not getattr(self, 'scope', None):  # self为频率认证类的对象
            ...
            raise ...

        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            ...
            raise ...

    def parse_rate(self, rate):
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)  # num_requests为单位时间内允许请求的次数, duration为时间单位

    def allow_request(self, request, view):
        if self.rate is None:
            return True  # 返回True表示通过频率认证

        self.key = self.get_cache_key(request, view)
        if self.key is None:	
            return True

        self.history = self.cache.get(self.key, [])
        self.now = self.timer()

        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

    def throttle_success(self):
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
        return False  # 返回False表示未通过频率认证

    def wait(self):
        ...
        return remaining_duration / float(available_requests)  
        

class AnonRateThrottle(SimpleRateThrottle):
    scope = 'anon'

    def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            return None  # Only throttle unauthenticated requests.

        return self.cache_format % {  # 有名占位: 'throttle_%(scope)s_%(ident)s' % {'scope': xxx, 'ident': ***}
            'scope': self.scope,
            'ident': self.get_ident(request)  # ident值的设置规则不唯一
        }
'''

局部配置自定义的频率认证类

'''
# ...d_projapiurls.py
...
urlpatterns = [
    url(r'^login/$', views.JwtLoginAPIView.as_view())
]
router.register('user', views.UserModelViewSet, basename='user')
urlpatterns.extend(router.urls)


# ...d_projapiviews.py
...
class UserModelViewSet(ModelViewSet):
    throttle_classes = [my_throttling.EmailRateThrottle]

    queryset = models.User.objects.filter(is_active=True).all()
    serializer_class = my_serializers.UserModelSerializer
  
  
# ...d_projapimy_throttling.py
from rest_framework.throttling import SimpleRateThrottle


class EmailRateThrottle(SimpleRateThrottle):
    scope = 'email'

    def get_cache_key(self, request, view):
        if request.user.is_anonymous or not request.user.email:
            return None

        return self.cache_format % {
            'scope': self.scope,
            'ident': request.user.email
        }
        
        
# ...d_projd_projsettings.py
...
REST_FRAMEWORK = {
    ...,
    'DEFAULT_THROTTLE_RATES': {
        'mobile': '1/min',
        'email': '3/min'
    },
}
'''
原文地址:https://www.cnblogs.com/-406454833/p/12708926.html