手撸orm

orm的作用就是类和数据库的表的映射关系。

一个类代表的就是一张表,一个类实例化出来的对象就是一条记录。

from orm_demo import  mysql_control

class Field:
    def __init__(self,name , column_type, primary_key , default):
        self.name=name
        self.column_type=column_type
        self.primary_key= primary_key
        self.default = default


class IntegerField(Field):
    def __init__(self, name , column_type="int" , primary_key=False , default = 0):
        super().__init__( name , column_type , primary_key, default)

class StringField(Field):
    def __init__(self, name, column_type='varchar(64)', primary_key=False, default=None):
        super().__init__(name, column_type, primary_key, default)

class OrmMetaClass(type):
    def __new__(cls, class_name, class_base, class_attr):
        if class_name == 'Models':
            return type.__new__(cls, class_name, class_base , class_attr)

        table_name = class_attr.get('table_name', class_name)

        mappings = {}

        primary_key = None

        for k,v in class_attr.items():
             if isinstance(v,Field):

                 mappings[k] = v

                 if v.primary_key:

                     if primary_key:
                         raise TypeError('只能有一个主键')

                     primary_key = v.name
        for k in mappings.keys():
            class_attr.pop(k)

        if not primary_key:
            raise TypeError('必须要有一个主键')

        class_attr['table_name'] = table_name
        class_attr['primary_key'] = primary_key
        class_attr['mappings'] = mappings

        return type.__new__(cls, class_name, class_base , class_attr)

class Models(dict,metaclass= OrmMetaClass):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    # __getattr__: 在对象.属性时,属性没有时触发。
    def __getattr__(self, item):
        # 字典本来的取值方式  字典[key] ---> 字典.key
        return self.get(item)

    # __setattr__: 在对象.属性赋值的时候触发。
    def __setattr__(self, key, value):
        # 字典本身的赋值方式
        self[key] = value

    @classmethod
    def select(cls , **kwargs):
        mysql_obj = mysql_control.Mysql()
        if not kwargs:

            sql= 'select * from %s' % cls.table_name
            res = mysql_obj.select(sql)

        else:
            key = list(kwargs.keys())[0]
            value = kwargs.get(key)
            sql = 'select * from %s where %s=?' % (cls.table_name, key)

            sql = sql.replace('?' , '%s')
            res = mysql_obj.select(sql, value)

        return [cls(**r) for r in res]

    def save(self):
        mysql = mysql_control.Mysql()

        fields = []
        values = []
        replace = []

        for k,v in self.mappings.items():
            fields.append(k)
            values.append(
                getattr(self, v.name, v.default)
            )
            replace.append('?')

        sql = 'insert into %s(%s) values(%s)' % (self.table_name,','.join(fields),','.join(replace))
        sql = sql.replace('?', '%s')
        mysql.execute(sql, values)

    def sql_update(self):
        mysql = mysql_control.Mysql()

        fields = []
        values = []
        primary_key = None

        for k, v in self.mappings.items():

            if v.primary_key:
                primary_key = getattr(self, v.name)

            else :
                fields.append(v.name + '=?')

                values.append(
                    getattr(self, v.name)
                )
        sql = 'update %s set %s where %s=%s' % (self.table_name, ','.join(fields),self.primary_key,primary_key)
        sql = sql.replace('?', '%s')
        mysql.execute(sql, values)
原文地址:https://www.cnblogs.com/chanyuli/p/11638402.html