mxnet包含NDArray的列表更新

Jul 26, 2017

之前写的用来人工设定batch_sizeacc_xxx发现出现了问题。最终发现是列表更新的问题。
想想之前的NDArray处理,也是奇葩了。比如,你能告诉我下面这段中注释的与非注释的,产生差别的原理么?...what?!居然会有差别?You kidding?

    def acc_update(self,normsize=1):
        assert self.binded and self.params_initialized and self.optimizer_initialized
#        self._curr_module._exec_group.grad_arrays=
#                      [[grad.copyto(grad.context)/normsize if grad is not None else None for grad in grads] for grads in self.grad]
        for acc_grads, mod_grads in zip(self.grad,self._curr_module._exec_group.grad_arrays):
            for acc_grad, mod_grad in zip(acc_grads, mod_grads):
                if acc_grad is not None:
                    mod_grad = acc_grad.copyto(mod_grad.context)/normsize
                else:
                    mod_grad = None
        ...

Oct 22, 2017

最近发现接口又改了((⊙﹏⊙)b),新版的(V0.11.1)里面这样做也不合适,用分片的方法可能是对的(从一些结果上来看,还不能肯定没问题)。

Oct 23, 2017

对比了累计更新和一次更新作为一个batch的输出,初步验证程序的正确性。


两处的目的都很明显:想用self.grad的内容更新self._curr_module._exec_group.grad_arrays
然而调试的结果是,没被注释掉的能够完成这项预期,另外一个不能(可能是暂时的)归纳出其规律,表现某种单一增长的特征。
感觉上应该是列表之间的替换,但却没有这样运行
后面再看问题出在哪?

Sep 13, 2017

没有发现可能的问题,打算先放一放了。开始的时候打算从_exec_group.grad_arays的接口入手,发现是从_exec_group._execs[].grad_array中传过来的,但在update的时候,用的是前者,猜测可能在更新前有过同步,在没有找到的情况下,直接将后者del或者赋值为None,但都没有效果,和昨天的情况保持了相同;此外,将上段程序中的normsize设置为0发现,更新后确实也没有显著变化(细微的变化应该是由weight decay引起的——仅在小数点变化),也就是说,acc_update中的self.grad确实起到了作用。所以被凌乱了...mess
贴上两次的结果对比吧,以伺观code者得焉

>>> mod.forward(d)
>>> #mod.get_outputs()[0].asnumpy().max()
... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
5873.4209
>>> #mod.backward()
... #mod.update()
... mod.acc_backward()
>>> mod.acc_update()
>>> mod.forward(d)
>>> #mod.get_outputs()[0].asnumpy().max()
... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
269556.56
>>> #mod.backward()
... #mod.update()
... mod.acc_backward()
>>> mod.acc_update()
>>> mod.forward(d)
>>> #mod.get_outputs()[0].asnumpy().max()
... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
1444888.8
>>> #mod.backward()
... #mod.update()
... mod.acc_backward()
>>> mod.acc_update()
>>> mod.forward(d)
>>> #mod.get_outputs()[0].asnumpy().max()
... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
4637960.0
>>> #mod.backward()
... #mod.update()
... mod.acc_backward()
>>> mod.acc_update()
>>> mod.forward(d)
>>> #mod.get_outputs()[0].asnumpy().max()
... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
11257292.0
>>> #mod.backward()
... #mod.update()
... mod.acc_backward()
>>> mod.acc_update()
>>> 
>>> mod.forward(d)
>>> #mod.get_outputs()[0].asnumpy().max()
... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
22884572.0
>>> #mod.backward()
... #mod.update()
... mod.acc_backward()
>>> mod.acc_update()
>>> mod.forward(d)
>>> #mod.get_outputs()[0].asnumpy().max()
... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
41182624.0
>>> #mod.backward()
... #mod.update()
... mod.acc_backward()
>>> mod.acc_update()

上面这段应该是非正常结果的,下面这段是归为正常结果。

>>> mod.forward(d)
>>> #mod.get_outputs()[0].asnumpy().max()
... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
5873.4209
>>> #mod.backward()
... #mod.update()
... mod.acc_backward()
>>> mod.acc_update()
>>> 
>>> mod.forward(d)
>>> #mod.get_outputs()[0].asnumpy().max()
... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
269556.56
>>> #mod.backward()
... #mod.update()
... mod.acc_backward()
>>> mod.acc_update()
>>> 
>>> mod.forward(d)
>>> #mod.get_outputs()[0].asnumpy().max()
... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
741699.12
>>> #mod.backward()
... #mod.update()
... mod.acc_backward()
>>> mod.acc_update()
>>> mod.forward(d)
>>> #mod.get_outputs()[0].asnumpy().max()
... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
33154.039
>>> #mod.backward()
... #mod.update()
... mod.acc_backward()
>>> mod.acc_update()
>>> mod.forward(d)
>>> #mod.get_outputs()[0].asnumpy().max()
... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
30.383278
>>> #mod.backward()
... #mod.update()
... mod.acc_backward()
>>> mod.acc_update()
>>> mod.forward(d)
>>> #mod.get_outputs()[0].asnumpy().max()
... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
30.155006
>>> #mod.backward()
... #mod.update()
... mod.acc_backward()
>>> mod.acc_update()
原文地址:https://www.cnblogs.com/chenyliang/p/7512347.html