python unittest 源码学习

主要学习的模块:

├── suite.py
├── case.py
├── main.py
├── runner.py
├── warnings.py
├── signals.py
├── result.py
└── loader.py
  • suite.py TestSuite是TestCase的集合

  • case.py 就是我们平时继承的 unittest.TestCase

  • main.py TestProgram所在文件执行 parseArgs,runTests

  • runner.py 实际跑单测的时候直接加载的 TextTestResultTextTestRunner 所在地

  • warnings.py处理相关警告信息

  • signals.py 处理相关信号

  • result.py 保存结果的基类

  • loader.py 加载测试用例


学习开始代码:

class BaiduTestClass1(unittest.TestCase):
    def setUp(self):
        pass

    def test_baidu1_func1(self):
        print('test_baidu1_func1')
        self.assertEqual(u"He", u"He")

    def test_baidu1_func2(self):
        print('test_baidu1_func2')
        self.assertEqual(u"He", u"He")

    def tearDown(self):
        pass

class BaiduTestClass2(unittest.TestCase):
    def setUp(self):
        pass

    def test_baidu2_func1(self):
        print('test_baidu2_func1')
        self.assertEqual(u"He2", u"He2")

    def test_baidu2_func2(self):
        print('test_baidu2_func2')
        self.assertEqual(u"He2", u"He2")

    def tearDown(self):
        pass


class GetAttrTest(unittest.TestCase):
    def getattrtest(self):
        pass

class GetAttrTest(object):
    def getattrtest(self):
        pass

if __name__ == "__main__":
    unittest.main()

  1. 程序开始后进入到main.py
'''
main.py
'''
main = TestProgram
...
class TestProgram(object):
...

    def __init__(self, module='__main__', defaultTest=None, argv=None,...
        self.testRunner = testRunner
        self.testLoader = testLoader  
        self.progName = os.path.basename(argv[0])
        self.parseArgs(argv)
        self.runTests()

上面main.pyself.testLoader = testLoader会触发loader.py的代码

'''
main.py
'''
testLoader=loader.defaultTestLoader # 调用loader.py中的

'''
loader.py
'''
defaultTestLoader = TestLoader()
testMethodPrefix = 'test'  #确定case的前缀

  1. 上面1步骤的main.pyself.parseArgs(argv)会设置self.testNames = None并且执行createTests()
'''
main.py
'''
elif self.defaultTest is None:
     # createTests will load tests from self.module
     self.testNames = None
...
self.createTests()

createTests()会触发 self.testLoader调用,self.test开始被赋值

'''
main.py
'''
elif self.testNames is None:
	self.test = self.testLoader.loadTestsFromModule(self.module)

上面代码会触发loader中的如下调用,返回tests(suiteClass类)

'''
loader.py
'''
...
tests = []
	for name in dir(module):
	obj = getattr(module, name)
            #tests = [] 增加case 必须 attrname.startswith(self.testMethodPrefix):
            if isinstance(obj, type) and issubclass(obj, case.TestCase):
                tests.append(self.loadTestsFromTestCase(obj))
        load_tests = getattr(module, 'load_tests', None)      
         # test 列表转换成suite
        tests = self.suiteClass(tests)
		... 
        return tests
  1. 上面1步骤runTests()会构建testRunner(testRunner类)并且执行testRunner.run(self.test) ,test是上面返回的tests (suiteClass类)
'''
main.py
'''
try:
      try:
        testRunner = self.testRunner(verbosity=self.verbosity,
                                    failfast=self.failfast,
                                    buffer=self.buffer,
                                    warnings=self.warnings,
                                    tb_locals=self.tb_locals)
...
self.result = testRunner.run(self.test)
  1. testRunner.run(self.test)触发如下代码中的 try:test(result)
'''
runner.py
'''
    def run(self, test):
        "Run the given test case or test suite."
        result = self._makeResult()
        registerResult(result)
        result.failfast = self.failfast
        result.buffer = self.buffer
        result.tb_locals = self.tb_locals
        with warnings.catch_warnings():
            if self.warnings:
                # if self.warnings is set, use it to filter all the warnings
                warnings.simplefilter(self.warnings)
                # if the filter is 'default' or 'always', special-case the
                # warnings from the deprecated unittest methods to show them
                # no more than once per module, because they can be fairly
                # noisy.  The -Wd and -Wa flags can be used to bypass this
                # only when self.warnings is None.
                if self.warnings in ['default', 'always']:
                    warnings.filterwarnings('module',
                            category=DeprecationWarning,
                            message=r'Please use assertw+ instead.')
            startTime = time.perf_counter()
            startTestRun = getattr(result, 'startTestRun', None)
            if startTestRun is not None:
                startTestRun()
            try:
                test(result)
            finally:
                stopTestRun = getattr(result, 'stopTestRun', None)
                if stopTestRun is not None:
                    stopTestRun()
            stopTime = time.perf_counter()
        timeTaken = stopTime - startTime
        result.printErrors()
        if hasattr(result, 'separator2'):
            self.stream.writeln(result.separator2)
        run = result.testsRun
        self.stream.writeln("Ran %d test%s in %.3fs" %
                            (run, run != 1 and "s" or "", timeTaken))
        self.stream.writeln()

        expectedFails = unexpectedSuccesses = skipped = 0
        try:
            results = map(len, (result.expectedFailures,
                                result.unexpectedSuccesses,
                                result.skipped))
        except AttributeError:
            pass
        else:
            expectedFails, unexpectedSuccesses, skipped = results

        infos = []
        if not result.wasSuccessful():
            self.stream.write("FAILED")
            failed, errored = len(result.failures), len(result.errors)
            if failed:
                infos.append("failures=%d" % failed)
            if errored:
                infos.append("errors=%d" % errored)
        else:
            self.stream.write("OK")
        if skipped:
            infos.append("skipped=%d" % skipped)
        if expectedFails:
            infos.append("expected failures=%d" % expectedFails)
        if unexpectedSuccesses:
            infos.append("unexpected successes=%d" % unexpectedSuccesses)
        if infos:
            self.stream.writeln(" (%s)" % (", ".join(infos),))
        else:
            self.stream.write("
")
        return result

从而触发class TestSuite(BaseTestSuite)中的BaseTestSuite的__call__方法

def __call__(self, *args, **kwds):        
    return self.run(*args, **kwds)
  1. 上面的调用会执行suite(TestSuite类)的run()调用,这里如果test是TestSuite类就会递归执行test(result)
'''
suite.py
'''
...
# BaseTestSuite实现了__iter__方法,这里就可以递归执行,test(result) 因为调用了enumerate(self)
for index, test in enumerate(self):             
            if result.shouldStop:
                break
            if _isnotsuite(test):
                self._tearDownPreviousClass(test, result)
                self._handleModuleFixture(test, result)
                self._handleClassSetUp(test, result)
                result._previousTestClass = test.__class__
                if (getattr(test.__class__, '_classSetupFailed', False) or
                    getattr(result, '_moduleSetUpFailed', False)):
                    continue

            if not debug:              
                test(result)

如果上面的是test是TestCase类就会调用TestCase 的__call__方法触发run方法

'''
case.py
'''
def __call__(self, *args, **kwds):       
        return self.run(*args, **kwds)
       
def run(self, result=None):
		...
        # 获取测试方法
        testMethod = getattr(self, self._testMethodName)
        
        # 是否跳过
        if (getattr(self.__class__, "__unittest_skip__", False) or
            getattr(testMethod, "__unittest_skip__", False)):
            # If the class or method was skipped.
            try:
                skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
                            or getattr(testMethod, '__unittest_skip_why__', ''))
                self._addSkip(result, self, skip_why)
            finally:
                result.stopTest(self)
            return
      	....
        try:
            self._outcome = outcome

            #执行setup
            with outcome.testPartExecutor(self):
                self._callSetUp()
            if outcome.success:
			    outcome.expecting_failure = expecting_failure
                with outcome.testPartExecutor(self, isTest=True):             
                    # amize 这里真正执行测试用例
                    self._callTestMethod(testMethod)
 

一段便于理解的unittest的代码


import importlib
import logging


class TestCase(object):
    def __init__(self, name):
        self.name = name

    def setup(self):
        pass

    def teardown(self):
        pass


class Loader(object):
    def __init__(self):
        self.cases = {}

    def load(self, path):
        module = importlib.import_module(path)
        for test_class_name in dir(module):
            test_class = getattr(module, test_class_name)
            if (
                    isinstance(test_class, type) and
                    issubclass(test_class, TestCase)
            ):
                self.cases.update({
                    test_class: self.find_test_method(test_class) or []
                })

    def find_test_method(self, test_class):
        test_methods = []

        for method in dir(test_class):
            if method.startswith("test_"):
                test_methods.append(
                    getattr(test_class, method)
                )

        return test_methods

    def __iter__(self):
        for test_class, test_cases in self.cases.items():
            yield test_class, test_cases


class Runner(object):
    def __init__(self, path):
        self.path = path

    def run(self):
        loader = Loader()
        loader.load(self.path)

        for test_class, test_cases in loader:
            test_instance = test_class(test_class.__name__)
            test_instance.setup()

            try:
                for test_case in test_cases:
                    test_case(test_instance)
            except:
                logging.exception("error occured, skip this method")

            test_instance.teardown()
from myunittest import TestCase


class DemoTestCase(TestCase):
    def setup(self):
        print("setup")

    def teardown(self):
        print("teardown")

    def test_normal(self):
        print("test normal function")

    def test_exception(self):
        raise Exception("haha, exception here!")

from myunittest import Runner


if __name__ == "__main__":
    runner = Runner("test_demo")
    runner.run()

补充:

  1. case的排序问题(list自带的sort加上cmp_to_key函数)

  2. failfast装饰器

def failfast(method):
    @wraps(method)
    def inner(self, *args, **kw):
        if getattr(self, 'failfast', False):
            self.stop()
        return method(self, *args, **kw)
    return inner
  1. HTMLTestRunner里面执行
    def run(self, test):
        "Run the given test case or test suite."
        result = _TestResult(self.verbosity)
        test(result)
        self.stopTime = datetime.datetime.now()
        self.generateReport(test, result)
        return result

原文地址:https://www.cnblogs.com/amize/p/13247060.html