pytorch 的LSTM batch_first=True 和 False的性能对比

pytorch 的LSTM batch_first=True 和 False的性能略有区别,不过区别不大。

下面这篇文章试验结论是batch_first= True要比batch_first = False更快。但是我自己跑结论却是相反,batch_first = False更快。

运行多次的结果:

2.3414649963378906    2.0364670753479004

2.188401699066162     2.2298429012298584

2.25323224067688       2.202291488647461

2.2564923763275146   2.1362855434417725

2.3355021476745605   2.1648573875427246

2.367983818054199     2.4390225410461426

2.3107049465179443   2.3457281589508057

2.261659622192383     2.1843318939208984

2.2949719429016113   2.1492083072662354

看到大部分情况后者更快(batch_first = False更快)。

下面是知乎上一篇文章的结果:

https://zhuanlan.zhihu.com/p/50484629?from_voters_page=true

经过实测,发现batch_first= True要比batch_first = False更快(不知道为啥pytorch要默认是batchfirst= False,同时网上很多地方都在说batch_first= False性能更好)

x_1 = torch.randn(100,200,512)
x_2 = x_1.transpose(0,1)

model_1 = torch.nn.LSTM(batch_first=True,hidden_size=1024,input_size=512)
model_2 = torch.nn.LSTM(batch_first=False,hidden_size=1024,input_size=512)

start_time_1 = time.time()


result_1 = model_1(x_1)
end_time_1 = time.time()

result_2 = model_2(x_2)
end_time_2 = time.time()

print(end_time_1 - start_time_1,end_time_2 - end_time_1)

原文地址:https://www.cnblogs.com/jiangkejie/p/13376281.html