[The Annotated Transformer] Iterators

Iterators

对torchtext的batch实现的修改算法原理

Batching matters a ton for speed. We want to have very evenly divided batches, with absolutely minimal padding. To do this we have to hack a bit around the default torchtext batching. This code patches their default batching to make sure we search over enough sentences to find tight batches.

这里是对torchtext中默认的batching操作进行的优化修改。

参考:https://towardsdatascience.com/how-to-use-torchtext-for-neural-machine-translation-plus-hack-to-make-it-5x-faster-77f3884d95

Torchtext本身已经很好了,并且sort_key使得dataset中的数据排序,这样batching后序列长度相近的会被放在同一个batch中,可以很大程度上降低padding的个数。

但是下面代码又进行了优化:根据每个batch中序列的最大长度,动态更改batch_size,使得可以更好的利用计算资源。

举个例子:

假设你的RAM每个iteration可以处理1500个tokens, batch_size = 20, 那么只有当batch中的序列长度为sequence length = 1500 / 20 = 75时,才可以将计算资源利用完全。

现实中,每个batch的sequence length的显然是在变化的,那么如果希望尽量多的利用计算资源,就需要可以动态调整当前的batch_size.

Transformer中的MyIterator重载了data.Iterator中的create_batches函数:

 1 class MyIterator(data.Iterator):
 2     def create_batches(self):
 3         if self.train:
 4             def pool(d, random_shuffler):
 5                 for p in data.batch(d, self.batch_size * 100):
 6                     p_batch = data.batch(
 7                         sorted(p, key=self.sort_key),
 8                         self.batch_size, self.batch_size_fn)
 9                     for b in random_shuffler(list(p_batch)):
10                         yield b
11             self.batches = pool(self.data(), self.random_shuffler)
12             
13         else:
14             self.batches = []
15             for b in data.batch(self.data(), self.batch_size,
16                                           self.batch_size_fn):
17                 self.batches.append(sorted(b, key=self.sort_key))
18 
19 def rebatch(pad_idx, batch):
20     "Fix order in torchtext to match ours"
21     src, trg = batch.src.transpose(0, 1), batch.trg.transpose(0, 1)
22     return Batch(src, trg, pad_idx)

pool函数

其中pool函数的功能与https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py中定义的class BucketIterator(Iterator)的pool函数功能类似。

1. 将原始的data分成大小为 100 * batch_size的一些chunks => (以上迭代 p 即为 每个chunk)

2. 在每个chunk中根据 sort_key 对examples进行排序,并对每个chunk按照batch_size分成100个batch =>

p_batch = data.batch( sorted(p, key=self.sort_key), self.batch_size, self.batch_size_fn) )

3. 将这些chunks进行shuffle  => (random_shuffler(list(p_batch)))

4. 在每个chunk中再把examples分成 大小为 batch_size 的 100 个 batch => (以上 b 即为每个 batch)

5. 生成器每次 yield一个batch  => (yield b)

原文地址:https://www.cnblogs.com/shiyublog/p/10919988.html