MXNet中bucket机制注记

Preface

之前看API以为bucket是一个根植于底层操作的接口(MXNet doc功不可没 -_-|| )。从LSTM看过来,接触到了一些相关的程序,后面再把bucketing_module.py那部分查看了下,发现bucket只是一个应用层机制,主要的实现存在于module/bucketing_module.py里面。原理清晰,实现简洁,在这做个记号。

Code & Comments

先放些相关的链接,做个预备。

  1. MXNet 官方的文档( ucao 出个文档真不容易,还带时效性...)
  2. 大神的blog阐述,鞭辟入里
  3. 之前关于LSTM的blog
    鉴于大神已经在这篇[blog]里面说得生动透彻了,这里就能省就省,然后说些大神没功夫顾及的细节。
    另外考虑到MXNet的链接经常表现出不靠谱的症状(kuxia),归结一下1中有些用的结论:要使用bucket机制,初始化Module时传入的symbol应该是一个函数,这个函数在被调用时将被传入迭代器中的bucket_key参数

从调用路径的顺序来走一遍把。
fit里面经过bind,init等操作,后面会调用prepare对预取出的数据(如果有)进行准备:

# module/bucketing_module.py
    def prepare(self, data_batch):
        """Prepares a data batch for forward.

        Parameters
        ----------
        data_batch : DataBatch
        """
        # perform bind if haven't done so
        assert self.binded and self.params_initialized
        bucket_key = data_batch.bucket_key
        original_bucket_key = self._curr_bucket_key
        data_shapes = data_batch.provide_data
        label_shapes = data_batch.provide_label
        self.switch_bucket(bucket_key, data_shapes, label_shapes)
        # switch back
        self.switch_bucket(original_bucket_key, None, None)

显然,switch_bucket就是负责进行重新绑定的:

# module/bucketing_module.py
    def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
         assert self.binded, 'call bind before switching bucket'
        if not bucket_key in self._buckets:    # check if there is already...
            symbol, data_names, label_names = self._sym_gen(bucket_key)
            module = Module(symbol, data_names, label_names,
                            logger=self.logger, context=self._context,
                            work_load_list=self._work_load_list,
                            fixed_param_names=self._fixed_param_names,
                            state_names=self._state_names)
            module.bind(data_shapes, label_shapes, self._curr_module.for_training,
                        self._curr_module.inputs_need_grad,
                        force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
            self._buckets[bucket_key] = module

        self._curr_module = self._buckets[bucket_key]
        self._curr_bucket_key = bucket_key

逻辑很明白,_curr_module里面放了众多的module,这些module的参数全都指向同一组。如果出入的bucket_key没有出现过,就bind一个并放入_curr_module列表里面去;如果已经有了(包括刚刚bind出来的),就切换到那个module上。

Misc

其他有一些相关的材料顺带放在这。

  1. 上一篇blog里面推测bucket机制可能会对补齐的那部分进行处理,这一点与io.py里面的DataBatchpad变量有些联系。在module/base_module.py中,查找pad的引用,发现和io.py里面的注释一致,只在prediction的时候进行了使用,训练的时候被忽视。
  2. exmple/rnn/bucketing里面有更高层接口的使用示例。
原文地址:https://www.cnblogs.com/chenyliang/p/8060014.html