通过元类实现ORM

通过元类实现ORM

首先ORM全称叫对象关系映射,能够让不会数据库操作的程序员通过面向对象的方法简单快捷的操作数据库,ORM有三层映射关系

  • 类映射数据库的表
  • 对象映射成数据库的表中的一条条记录
  • 对象获取属性映射成数据库的表中的某条记录某个字段对应的值

具体做法就是在类创建过程中通过元类拦截它的创建,在类创建出来之前给类赋上表该有的属性表名,主键字段,其他普通字段。

一、pymysql封装

import pymysql

"""
@author RansySun
@create 2020-3-27-11:38
"""


class MySql:
    """
    对数据库、增删改查封装
    """
    __instance = None

    def __new__(cls, *args, **kwargs):
        """
        单例模式
        :param args:
        :param kwargs:
        :return:
        """
        if not cls.__instance:
            cls.__instance = object.__new__(cls)
        return cls.__instance

    def __init__(self):
        """
        创建数据连接
        """
        self.mysql = pymysql.connect(
            user='root',
            passwd='010101',
            charset='utf8',
            db='db_orm',
            autocommit=True
        )

        self.cursor = self.mysql.cursor(pymysql.cursors.DictCursor)

    def select(self, sql, args=None):
        """
        查询
        :param sql: 查询语句
        :param args: 条件
        :return: 返回查询结果集--->字典
        """
        print(f'33[31m sql: {sql}33[0m')
        print(f'33[34m sql-args: {args}33[0m')

        self.cursor.execute(sql, args)  # 提交sql语句
        res_data = self.cursor.fetchall()  # 查询结果
        return res_data

    def execute(self, sql, args=None):
        """
        增、删、改
        :param sql: 增、删、改sql
        :param args: 参数,防止sql注入
        """
        try:
            print(f'33[31m sql: {sql}33[0m')
            print(f'33[34m sql-args: {args}33[0m')
            # [None, 'test1', '123'],为什么会有None,因为有一个id字段,所以为None


            # insert into 表名(字段名) values(值)
            self.cursor.execute(sql, args)  # 提交sql语句

        except Exception as e:
            print(f'33[31m sql错误: {e}33[0m')

    def close(self):
        """
        关闭数据的连接
        """
        self.cursor.close()
        self.mysql.close()

二、元类实现ORM

from orm_db.orm_contromysql import MySql

"""
@author RansySun
@create 2019-3-27-12:10
"""


class Filed:
    """
    生成表中对应的字段约束
    """

    def __init__(self, name, column_type, primary_key, default):
        self.name = name
        self.column_type = column_type
        self.primary_key = primary_key
        # print(default)
        self.default = default


class IntegetFiled(Filed):
    """
    与表中对应的整型类型字段
    """

    def __init__(self, name, column_type='int', primary_key=False, default=0):
        super().__init__(name, column_type, primary_key, default)


class StringFiled(Filed):
    """
    与表中对应的字符串类型字段
    """

    def __init__(self, name, column_type='varchar(250)', primary_key=False, default=None):
        super().__init__(name, column_type, primary_key, default)


class OrmMetaClass(type):
    def __new__(cls, class_name, class_base, class_dic):
        """
        控制数据库中表的约束和字段的封装
        :param class_name:
        :param class_base:
        :param class_dic:
        :return:
        """
        if class_name == 'Models':
            return type.__new__(cls, class_name, class_base, class_dic)
        # print(class_dic)

        # 数据库中的表名
        table_name = class_dic.get('table_name', class_name)

        # 数据库中主键字段的名称
        primary_key_name_id = None

        # 数据库中的所有字段
        mappings = {}

        # 与数据库中对应的字段封装起来,
        for field_name, filed_obj in class_dic.items():
            if isinstance(filed_obj, Filed):
                mappings[field_name] = filed_obj

                # 判断是否只有一个主键
                if filed_obj.primary_key:

                    # 判断只有一个主键
                    if primary_key_name_id:
                        raise TypeError("只能有一个主键")

                    primary_key_name_id = filed_obj.name

        # 判断必须有一个主键
        if not primary_key_name_id:
            raise TypeError('必须有一个主键')

        # 过滤重复的字段名
        for field_name in mappings.keys():
            class_dic.pop(field_name)

        # 将字段添加到名称空间中
        class_dic['table_name'] = table_name
        class_dic['primary_key_name_id'] = primary_key_name_id
        class_dic['mappings'] = mappings
        # print(class_dic)
        return type.__new__(cls, class_name, class_base, class_dic)


class Models(dict, metaclass=OrmMetaClass):
    # def __init__(self, **kwargs):
    #     super().__init__(**kwargs)

    def __setattr__(self, key, value):
        """
        重写,让字段的字段 = 赋值
        :param key:
        :param value:
        """
        self[key] = value

    def __getattr__(self, item):
        """
        通过.可以获取值
        :param item:
        :return:
        """
        return self.get(item)

    @classmethod
    def sql_select(cls, **kwargs):
        """
        sql查询语句,有条件和没有条件查询
        :param kwargs: 条件,支持一个条件
        :return:  返回结果对象
        """
        mysql_obj = MySql()

        # 判断查询是否有条件
        if not kwargs:
            # select * from 表名
            sql = 'select * from %s' % cls.table_name
            sql_data = mysql_obj.select(sql)

        else:

            # select * from 表名 where 字段名 = 值(添加) kwargs.keys()[0]==>不能直接取值,因为他是key对象

            filed_name = list(kwargs.keys())[0]
            filed_value = kwargs.get(filed_name)

            sql = 'select * from %s where %s= ?' % (cls.table_name, filed_name)
            sql = sql.replace('?', '%s')

            # 获取数据---->dict
            sql_data = mysql_obj.select(sql, filed_value)

        # 关闭连接
        mysql_obj.close()
        return [cls(**s) for s in sql_data]

    def sql_save(self):
        """
        数据库插入保存数据
        """
        mysql_obj = MySql()

        # 字段名
        filed_name_list = []

        # 字段值
        filed_value_list = []

        # ?占位符,防止sql注入
        replace_list = []

        # 获取字段名,字段值
        for filed_name, filed_obj in self.mappings.items():
            filed_name_list.append(filed_name)

            filed_value_list.append(
                # filed_obj.name:如果不存在他触发的是__getattr__,获取一个返回值None
                getattr(self, filed_obj.name, filed_obj.default)  # 通过反射获取字段名
            )

            replace_list.append('?')

        # 拼接sql insert into 表名(字段值) values(值)
        sql = 'insert into %s(%s) values (%s)' % (self.table_name, ','.join(filed_name_list), ','.join(replace_list))
        sql = sql.replace('?', '%s')

        # print(filed_name_list)
        # print(filed_value_list)
        # print(sql)

        mysql_obj.execute(sql, filed_value_list)
        mysql_obj.close()

    def sql_update(self):
        """
        修改内容
        """
        mysql_obj = MySql()

        # 字段名
        filed_name_list = []

        # 字段值
        filed_value_list = []

        # 获取条件名key的值
        primary_key_value = None

        # 获取字段名,字段值,条件
        for filed_name, filed_obj in self.mappings.items():
            filed_name_list.append(f'{filed_name}=?')
            filed_value_list.append(
                getattr(self, filed_obj.name, filed_obj.default)
            )

            # 获取主键值
            if filed_obj.primary_key:
                primary_key_value = getattr(self, filed_obj.name, filed_obj.default)

        # print(filed_name_list)
        # print(filed_value_list)
        # print(primary_key_value)

        # 拼接sql语句 update 表名 set 字段名=值 where id = 值
        # UserInfo set user_id=%s,user_name=%s,user_pwd=%s where user_id=None

        sql = 'update %s set %s where %s=%s' % (
            self.table_name,
            ','.join(filed_name_list),
            self.primary_key_name_id,
            primary_key_value
        )

        sql = sql.replace('?', '%s')
        # print(sql)

        # 修改数据
        mysql_obj.execute(sql, filed_value_list)

三、测试

class UserInfo(Models):
    table_name = 'UserInfo'
    user_id = IntegetFiled(name='user_id', primary_key=True)
    user_name = StringFiled('user_name')
    user_pwd = StringFiled('user_pwd')


if __name__ == '__main__':
    user = UserInfo(
        user_name='RandySun',
        user_pwd='0101')
    user.sql_save()
    user = UserInfo.sql_select(user_name='RandySun')
    print(user)

img

原文地址:https://www.cnblogs.com/randysun/p/12596814.html