15、编写ORM

概述

在一个Web App中,所有数据,包括用户信息、发布的日志、评论等,都存储在数据库中。

Web App里面有很多地方都要访问数据库。访问数据库需要创建数据库连接、游标对象,然后执行SQL语句,最后处理异常,清理资源。这些访问数据库的代码如果分散到各个函数中,势必无法维护,也不利于代码复用。

所以,我们要首先把常用的SELECT、INSERT、UPDATE和DELETE操作用函数封装起来。

由于Web框架使用了基于asyncio的aiohttp,这是基于协程的异步模型。在协程中,不能调用普通的同步IO操作,因为所有用户都是由一个线程服务的,协程的执行速度必须非常快,才能处理大量用户的请求。而耗时的IO操作不能在协程中以同步的方式调用,否则,等待一个IO操作时,系统无法响应任何其他用户。

这就是异步编程的一个原则:一旦决定使用异步,则系统每一层都必须是异步,“开弓没有回头箭”。

幸运的是aiomysql为MySQL数据库提供了异步IO的驱动。

 简单的orm实现技术原理可参考先前写的博文:12、元类(metaclass)实现精简ORM框架

一、创建连接池

我们需要创建一个全局的连接池,每个HTTP请求都可以从连接池中直接获取数据库连接。使用连接池的好处是不必频繁地打开和关闭数据库连接,而是能复用就尽量复用。

连接池由全局变量__pool存储,缺省情况下将编码设置为utf8,自动提交事务:

@asyncio.coroutine
def create_pool(loop, **kwargs):
    logging.info('create database connection pool...')
    global __pool 
    __pool = yield from aiomysql.create_pool(
        host=kwargs.get('host', 'localhost'),
        port=kwargs.get('port', 3306),
        user=kwargs['user'],
        password=kwargs['password'],
        db=kwargs['db'],
        charset=kwargs.get('charset', 'utf8'),  
        autocommit=kwargs.get('autocommit', True), 
        maxsize=kwargs.get('maxsize', 10),
        minsize=kwargs.get('minsize', 1),
        loop=loop
    )

 关于aiomysql.create_pool的详细讲述,请参考博文:16、【翻译】aiomysql-Pool

 create_pool方法中的kwargs是关键字参数,保存着连接数据库所必须的host、port、user、password等信息,这些关键字参数在函数内部自动组装为一个dict。

二、封装select语句

# 该协程封装的是查询事务,第一个参数为sql语句,第二个为sql语句中占位符的参数列表,第三个参数是要查询数据的数量
@asyncio.coroutine
def select(sql, args, size=None):
    log(sql, args)  #显示sql语句和参数
    global __pool   #引用全局变量
    with (yield from __pool) as conn:   # 以上下文方式打开conn连接,无需再调用conn.close()  或写成 with await __pool as conn:
        cur = yield from conn.cursor(aiomysql.DictCursor)   # 创建一个DictCursor类指针,返回dict形式的结果集
        yield from cur.execute(sql.replace('?', '%s'), args or ())  # 替换占位符,SQL语句占位符为?,MySQL为%s。
        if size:
            rs = yield from cur.fetchmany(size) #接收size条返回结果行.
        else:
            rs = yield from cur.fetchall()  #接收全部的返回结果行.
        yield from cur.close()  #关闭游标
        logging.info('rows returned: %s' % len(rs)) #打印返回结果行数
        return rs   #返回结果

SQL语句的占位符是?,而MySQL的占位符是%sselect()函数在内部自动替换。注意要始终坚持使用带参数的SQL,而不是自己拼接SQL字符串,这样可以防止SQL注入攻击。

注意到yield from将调用一个子协程(也就是在一个协程中调用另一个协程)并直接获得子协程的返回结果。

如果传入size参数,就通过fetchmany()获取最多指定数量的记录,否则,通过fetchall()获取所有记录。

三、封装INSERT、UPDATE、DELETE语句

#执行update,insert,delete语句,可以统一用一个execute函数执行,
# 因为它们所需参数都一样,而且都只返回一个整数表示影响的行数。
@asyncio.coroutine
def execute(sql, args, autocommit=True):
    log(sql)
    with (yield from __pool) as conn:
        if not autocommit:
            yield from conn.begin()
        try:
             cur = yield from conn.cursor()
             yield from cur.execute(sql.replace('?', '%s'), args)
             affected = cur.rowcount
             yield from cur.close()
             if not autocommit:
                 yield from conn.commit()
        except BaseException as e:  #如果事务处理出现错误,则回退
            if not autocommit:
                yield from conn.rollback()
            raise
        return affected

execute()函数和select()函数所不同的是,cursor对象不返回结果集,而是通过rowcount返回结果数。

四、ORM

设计ORM需要从上层调用者角度来设计。

我们先考虑如何定义一个User对象,然后把数据库表users和它关联起来。

from orm import Model, StringField, IntegerField

class User(Model):
    __table__ = 'users'

    id = IntegerField(primary_key=True)
    name = StringField()

注意到定义在User类中的__table__idname是类的属性,不是实例的属性,类的所有示例都可以访问!!!所以,在类级别上定义的属性用来描述User对象和表的映射关系,而实例属性用来描述数据库表中的一行数据,必须通过__init__()方法去初始化,所以两者互不干扰:

# 创建实例:
user = User(id=123, name='Michael')
# 存入数据库:
user.insert()
# 查询所有User对象:
users = User.findAll()

五、Field以及各种Field子类

用来描述数据库中表字段的属性(字段名、类型、是否主键等等)。

首先定义基类Field:

class Field(object):

    def __init__(self, name, column_type, primary_key, default):
        self.name = name
        self.column_type = column_type
        self.primary_key = primary_key
        self.default = default

    def __str__(self):
        return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)

__str__()是Python中有特殊用途的函数,用来定制类。当我们print(Field或Field子类对象)时,会打印该对象(字段)的类名,字段类别以及字段名称。

 然后在Field的基础上,进一步定义各种类型的Field:

# 字符串类型字段,继承自父类Field
class StringField(Field):
    #如果一个函数的参数中含有默认参数,则这个默认参数后的所有参数都必须是默认参数 ,
    # 否则会抛出:SyntaxError: non-default argument follows default argument的异常。
    def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'):
        super(StringField, self).__init__(name, ddl, primary_key, default)

# 布尔值类型字段,继承自父类Field
class BooleanField(Field):

    def __init__(self, name=None, default=False):
        super(BooleanField, self).__init__(name, 'boolean', False, default)

# 整数类型字段,继承自父类Field
class IntegerField(Field):

    def __init__(self, name=None, primary_key=False, default=0):
        super(IntegerField, self).__init__(name, 'bigint', primary_key, default)

# 浮点数类型字段,继承自父类Field
class FloatField(Field):

    def __init__(self, name=None, primary_key=False, default=0.0):
        super(FloatField, self).__init__(name, 'real', primary_key, default)

# 文本类型字段,继承自父类Field
class TextField(Field):

    def __init__(self, name=None, default=None):
        super(TextField, self).__init__(name, 'text', False, default)

上述子类生成对象时,均会调用父类的Init方法初始化。

可见,数据库表字段共4个属性:字段名、字段类型、是否主键、默认值。

六、编写元类—ModelMetaclass

 1 class ModelMetaclass(type):
 2 
 3     def __new__(cls, name, bases, attrs):
 4         # 排除Model类本身:
 5         if name=='Model':
 6             return type.__new__(cls, name, bases, attrs)
 7         # 获取table名称:
 8         tableName = attrs.get('__table__', None) or name
 9         logging.info('found model: %s (table: %s)' % (name, tableName))
10         # 获取所有的Field和主键名:
11         mappings = dict()
12         fields = []
13         primaryKey = None
14         for k, v in attrs.items():
15             if isinstance(v, Field):
16                 logging.info('  found mapping: %s ==> %s' % (k, v))
17                 mappings[k] = v
18                 if v.primary_key:
19                     # 找到主键:
20                     if primaryKey:
21                         raise RuntimeError('Duplicate primary key for field: %s' % k)
22                     primaryKey = k
23                 else:
24                     fields.append(k)
25         if not primaryKey:
26             raise RuntimeError('Primary')
27         for k in mappings.keys():
28             attrs.pop(k)
29         escaped_fields = list(map(lambda f: '`%s`' % f, fields))
30         attrs['__mappings__'] = mappings    # 保存属性和列的映射关系
31         attrs['__table__'] = tableName
32         attrs['__primary_key__'] = primaryKey   # 主键属性名
33         attrs['__fields__'] = fields    # 除主键外的属性名
34         # 构造默认的SELECT, INSERT, UPDATE和DELETE语句:
35         ##以下四种方法保存了默认了增删改查操作,其中添加的反引号``,是为了避免与sql关键字冲突的,否则sql语句会执行出错
36         attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName)
37         attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))
38         attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey)
39         attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey)
40         return type.__new__(cls, name, bases, attrs)

1、首先进行判断,如果将要创建的类是Model,无需做个性化定制,直接通过type创建,排除对Model类的修改;

2、获取table名称,即类名;

3、mappings保存类属性和表字段的映射关系primaryKey保存映射表主键的类属性;fields保存映射其余表字段的类属性;

4、注意匿名函数【 lambda f: '`%s`' % f, fields 】 的用法,实际上第29行代码就是这个意思: 

fields = ['one', 'two', 'three']

def fun(f):
    return '`%s`' % f

escaped_fields = list(map(fun, fields))

 使用匿名函数lambda,针对fields中每个元素,如 name,加上反引号后:`name`后返回;

为何要加上反引号?它是为了区分MYSQL的保留字与普通字符而引入的符号。

 5、下面是一系列为定制类动态添加的属性:

(1) attrs['__mappings__'] = mappings    -》 保存类属性和表字段的映射关系;

(2)attrs['__table__'] = tableName    -》 保存该类对应的表名;

(3)attrs['__primary_key__'] = primaryKey    -》 保存映射表中主键字段的类属性;

(4)attrs['__fields__'] = fields    -》  保存映射非主键字段的类属性;

接着是SQL语句模板,届时调用时只需要将参数传递给Mysql占位符?即可:

(5)attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName)

  例:

(6)attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))

  例:

(7)attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey)

  例:

(8)attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey)

   例:

6、在模块加载时使用type动态地定制类。

7、注意:

(1)以上属性都是类的属性,属类所有,所有实例对象共享一个类属性。实例属性属于各个实例所有,互不干扰;在编写程序的时候,千万不要对实例属性和类属性使用相同的名字,因为相同名称的实例属性将屏蔽掉类属性,但是当你删除实例属性后,再使用相同的名称,访问到的将是类属性。

(2)表的字段名使用类属性名,即名字相同!

七、编写基类——Model

 1 class Model(dict, metaclass=ModelMetaclass):
 2 
 3     def __init__(self, **kwargs):
 4         super(Model, self).__init__(**kwargs)
 5 
 6     def __getattr__(self, key):
 7         try:
 8             return self[key]
 9         except KeyError:
10             raise AttributeError(r"'Model' object has no attribute '%s'" % key)
11 
12     def __setattr__(self, key, value):
13         self[key] = value
14 
15     def getValue(self, key):
16         return getattr(self, key, None)
17 
18     def getValueOrDefault(self, key):
19         value = getattr(self, key, None)
20         if value is None:
21             field = self.__mappings__[key]
22             if field.default is not None:
23                 value = field.default() if callable(field.default) else field.default
24                 logging.debug('using default value for %s: %s' % (key, str(value)))
25                 setattr(self, key, value)
26         return value
27 
28     @classmethod
29     @asyncio.coroutine
30     def findAll(cls, where=None, args=None, **kwargs):
31         'find objects by where clause'
32         sql = [cls.__select__]  #sql是list类型,元素是定制类的类属性——select查询语句模板
33         if where:
34             sql.append('where')
35             sql.append(where)
36         if args is None:
37             args = []
38         orderBy = kwargs.get('orderBy', None)
39         if orderBy:
40             sql.append('order by')
41             sql.append(orderBy)
42         limit = kwargs.get('limit', None)
43         if limit is not None:
44             sql.append('limit')
45             if isinstance(limit, int):
46                 sql.append('?')
47                 args.append(limit)
48             elif isinstance(limit, tuple) and len(limit) == 2:
49                 sql.append('?, ?')
50                 args.extend(limit)
51             else:
52                 raise ValueError('Invalid limit value: %s' % str(limit))
53         rs = yield from select(' '.join(sql), args)  #传入sql语句及参数,调用select语句获取查询结果
54         return [cls(**r) for r in rs]
55 
56     @classmethod
57     @asyncio.coroutine
58     def findNumber(cls, selectField, where=None, args=None):
59         'find number by select and where.'
60         sql = ['select %s _num_ from `%s`' % (selectField, cls.__table__)]
61         if where:
62             sql.append('where')
63             sql.append(where)
64             rs = yield from select(' '.join(sql), args, 1)
65             if len(rs) == 0:
66                 return None
67             return rs[0]['_num_']
68 
69     @classmethod
70     @asyncio.coroutine
71     def find(cls, pk):
72         'find object by primary key.'
73         rs = yield from select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1)
74         if len(rs) == 0:
75             return None
76         return cls(**rs[0])
77 
78     @asyncio.coroutine
79     def save(self):
80         args = list(map(self.getValueOrDefault, self.__fields__))
81         args.append(self.getValueOrDefault(self.__primary_key__))
82         rows = yield from execute(self.__insert__, args)
83         if rows != 1:
84             logging.warn('failed to insert record: affected rows: %s' % rows)
85 
86     @asyncio.coroutine
87     def update(self):
88         args = list(map(self.getValue, self.__fields__))
89         args.append(self.getValue(self.__primary_key__))
90         rows = yield from execute(self.__update__, args)
91         if rows != 1:
92             logging.warn('failed to update by primary key: affected rows: %s' % rows)
93 
94     @asyncio.coroutine
95     def remove(self):
96         args = [self.getValue(self.__primary_key__)]
97         rows = yield from execute(self.__delete__, args)
98         if rows != 1:
99             logging.warn('failed to remove by primary key: affected rows: %s' % rows)

1、__getattr__为内置方法,当使用点号获取实例属性,例如 stu.score 时,如果属性score不存在就自动调用__getattr__方法。注意:已有的属性,比如name,不会在__getattr__中查找;

2、__setattr__当设置实例属性时自动调用,如 stu.score=5时,就会调用__setattr__方法  self.[score]=5;

3、getValueOrDefault()  ->   获取属性值,如果为空,则取默认值;

4、@classmethod装饰的方法是类方法,直接使用类名调用,所有子类都可以调用类方法。不需要实例化,不需要 self 参数,第一个参数是表示自身类的 cls 参数。

5、分析findAll() 方法:

(1)第53行语句: rs = yield from select(' '.join(sql), args),调试可见返回结果是list类型,元素是dict类型的每行表数据:

(2)第54行语句: return [cls(**r) for r in rs],不太能理解,故编写语句 result = User.findAll() 来将返回值保存在result参数中,调试可得:

由此可得出结论,[cls(**r) for r in rs] 是将查询数据库表得到的每行结果,生成cls类的对象。

6、分析save() 方法

我们编写下列语句调用save方法:

@asyncio.coroutine
    def save(self):
        args = list(map(self.getValueOrDefault, self.__fields__))
        args.append(self.getValueOrDefault(self.__primary_key__))
        rows = yield from execute(self.__insert__, args)
        if rows != 1:
            logging.warn('failed to insert record: affected rows: %s' % rows)

经分析可得,获取调用save函数的类属性__friends__及__primary_key__的值,有默认值就传入默认值,作为sql语句参数执行。

 八、定义映射数据库表的类

def next_id():
    return '%015d%s000' % (int(time.time() * 1000), uuid.uuid4().hex)

class User(Model):
    __table__ = 'users'
    id = StringField(primary_key=True, default=next_id, ddl='varchar(50)')
    email = StringField(ddl='varchar(50)')
    passwd = StringField(ddl='varchar(50)')
    admin = BooleanField()
    name = StringField(ddl='varchar(50)')
    image = StringField(ddl='varchar(500)')
    created_at = FloatField(default=time.time)

九、编写测试代码

import orm
from models import User, Blog, Comment
import asyncio

loop = asyncio.get_event_loop()

async def test():
    # 创建连接池,里面的host,port,user,password需要替换为自己数据库的信息
    await orm.create_pool(loop=loop, host='127.0.0.1', port=3306, user='root', password='root', db='awesome')
    # 没有设置默认值的一个都不能少
    u = User(name='Test', email='547280745@qq.com', passwd='1234567890', image='about:blank', id="123")
    await u.save()
    result = await User.findAll()

loop.run_until_complete(test())

在Mysql数据中查询结果可知导入数据成功:

原文地址:https://www.cnblogs.com/zwb8848happy/p/8799044.html