S11 day 96 RestFramework 之认证权限

一、设计一个简易的登录

 

1. 建立一个模型  

class UserInfo(models.Model):
    username =models.CharField(max_length=16)
    password =models.CharField(max_length=16)

    type =models.SmallIntegerField(
        choices=((1,"普通用户"),(2,"vip用户")),
        default=1
    )


class Token(models.Model):
    token =models.CharField(max_length=128)
    user=models.OneToOneField(to="UserInfo")

makemigration ,migrate ,添加数据 

2.设计url

    url(r'login/$', views.LoginView.as_view()),

3.制作一个视图类进行简单登录 

lass LoginView(APIView):
    def post(self,request):
        print(request.data)
        res={"code":0}
        username =request.data.get("username")
        password =request.data.get("password")
        #去数据库查询
        user_obj = models.UserInfo.objects.filter(
            username =username ,
            password=password,
        ).first()
        if user_obj:
            #登录成功
            #发token
            res["data"] ="登录成功"

        else:
            #登录失败
            res["code"] = 1
            res["error"] = "用户名或密码错误"
        return Response(res)

输出结果:

 二 、认证

4. 为视图增加一个生成token的功能

#生成Token的函数
def get_token_code(username):
    """
    根据用户名和时间戳生成用户登录成功的随机字符串
    :param username: 字符串格式的用户名
    :return: 返回字符串格式的token
    """
    import  time,hashlib
    timestamp =str(time.time())
    m =hashlib.md5(bytes(username,encoding="utf-8"))
    m.update(bytes(timestamp,encoding="utf-8"))
    return m.hexdigest()


#登录视图
class LoginView(APIView):
    def post(self,request):
        print(request.data)
        res={"code":0}
        username =request.data.get("username")
        password =request.data.get("password")
        #去数据库查询
        user_obj = models.UserInfo.objects.filter(
            username =username ,
            password=password,
        ).first()
        if user_obj:
            #登录成功
            #生成token
            token= get_token_code(username)
            #将token保存
            #用户user =user_obj 这个条件去token表里查询,如果有的话就更新default里的参数,没有记录就创建
            models.Token.objects.update_or_create(defaults={"token":token},user =user_obj)
            res["data"] ="登录成功"

        else:
            #登录失败
            res["code"] = 1
            res["error"] = "用户名或密码错误"
        return Response(res)

5. 测试token能否生成。

 6. get请求里的token和数据库里的token进行对比.

 

 7.认证的代码

红色部分为局部认证

from app01.utils1.auth import MyAuth
from rest_framework.viewsets import  ModelViewSet
class CommentView(ModelViewSet):
        queryset = models.Comment.objects.all()
        serializer_class = app01_serializers.CommentSerializer
        authentication_classes = [MyAuth,]  #视图认证,属于局部认证.

全局认证写在settings里

REST_FRAMEWORK = {
#     #关于认证的全局配置.
    "DEFAULT_AUTHENTICATION_CLASSES":["app01.utils.MyAuth",]
}
#登录视图
class LoginView(APIView):
    def post(self,request):
        print(request.data)
        res={"code":0}
        username =request.data.get("username")
        password =request.data.get("password")
        #去数据库查询
        user_obj = models.UserInfo.objects.filter(
            username =username ,
            password=password,
        ).first()
        if user_obj:
            #登录成功
            #生成token
            token= get_token_code(username)
            #将token保存
            #用户user =user_obj 这个条件去token表里查询,如果有的话就更新default里的参数,没有记录就创建
            models.Token.objects.update_or_create(defaults={"token":token},user =user_obj)
            res["data"] ="登录成功"

        else:
            #登录失败
            res["code"] = 1
            res["error"] = "用户名或密码错误"
        return Response(res)
"""
自定义的认证都放在这里
"""
from  rest_framework.authentication import BaseAuthentication
from app01 import models
from rest_framework.exceptions import AuthenticationFailed


class MyAuth(BaseAuthentication):
    def authenticate(self, request):
        if request.method in ["POST","PUT","DELETE"]:
            token =request.data.get("token")
            #去数据库中查找有没有这个这个token
            token_obj =models.Token.objects.filter(token=token).first()
            if token_obj:
                return token_obj.user,token
            else :
                raise AuthenticationFailed("无效的token")

        else:return  None,None

 三、权限

"""
自定义的权限类

# """

from rest_framework.permissions import BasePermission
class Mypermission(BasePermission):
    message={}
    def has_permission(self, request, view):
        """
        判断该用户有没有权限
        """
        #判断用户是不是vip用户
        #如果是vip用户就返回trun
        #普通用户就返回false
        print("requet--->",request.user.username)
        print("requet--->",request.user.type)
        if request.user.type ==2:#是vip用户
            return True
        else:
            return False


    def has_object_permission(self, request, view, obj):
        """
        判断当前评论的作者是不是当前的用户
        只有作者自己才能评论自己的文章

        """
        print("这是在自定义权限类里的")
        print(obj)
        if request.method in ["PUT","DELETE"]:
            if obj.user == request.user:
                #当前要删除评论的作者就是当前登录的账号
                return True
            else:
                return False
        else:
            return True

 四、限制  频率 

方法1 . 

'''
自定义的访问限制类 第一种方法
'''
from rest_framework.throttling import BaseThrottle
import time
VISIT = {

}#{ "127.0.0.1":[12213323,2342424,2424324332]}

class MyThrottle(BaseThrottle):
    def allow_request(self, request, view):
        """
        返回true 就表示放行,返回false表示被限制....
        :param request:
        :param view:
        :return:
        """
        # 1. 获取当前的访问ip
        ip=request.META.get("REMOTE_ADDR")
        print("这是自定义限制类中的allow_request")
        print(ip)
        #2. 获取当前的时间
        now =time.time
        #判断当前ip是否有访问记录
        if ip not in VISIT:
            VISIT[ip]=[] #初始化一个空的访问列表
            #把这一次的访问时间交到访问历史列表的第一位
        history =VISIT[ip]
        while history and now - history[-1]>60:
            history.pop()
        #判断最近一分钟的访问次数是否超过了阈值(3次)
            if len(history)>=3:
                return False
            else:
                VISIT[ip].insert(0, now)
                return True

方法2. 

'''
自定义的访问限制类 第二种方法 ,采用源码的代码 
'''
# VISIT2={
#     "XXX":{
#         ip:[]
#     }
# }
from rest_framework.throttling import SimpleRateThrottle
class VisitThrottle(SimpleRateThrottle):
    scope ="xxx"
    def get_cache_key(self, request, view):
        return  self.get_ident(request) #求当前访问的ip

原文地址:https://www.cnblogs.com/mengbin0546/p/9412218.html