Keras中RNN、LSTM和GRU的参数计算

1. RNN

 

RNN结构图

计算公式:

 

代码:

1 model = Sequential()
2 model.add(SimpleRNN(7, batch_input_shape=(None, 4, 2)))
3 model.summary()

运行结果:

 可见,共70个参数

记输入维度(x的维度,本例中为2)为dx, 输出维度(h的维度, 与隐藏单元数目一致,本例中为7)为dh

则公式中U的shape应该是dh*dx, W的shape因该是dh*dh, b的shape应该是dh*1

这样计算的h(t)维度才能是dh

计算公式:

nums = dh * ( dh + dx ) + dh

括号中可以理解为x和h(t-1)合并

70 = 7 *( 7 + 2 ) + 7

2. LSTM

https://zhuanlan.zhihu.com/p/147496732

参考这篇吧,讲的不错

LSTM单元结构图

代码:

1 model = Sequential()
2 model.add(LSTM(7, batch_input_shape=(None, 4, 2)))
3 model.summary()

运行结果:

 计算公式:

nums = 4 * [ dh * (dh + dx) + dh ]

 280 = 4 * [ 7 * (7 + 2) + 7 ]

3. GRU

 

GRU单元结构图

代码:

1 model = Sequential()
2 model.add(GRU(7, batch_input_shape=(None, 4, 2)))
3 model.summary()

运行结果:

 计算方式:

nums = 3 * [ dh * (dh + dx) + dh ]

 210 = 3 * [ 7 * (7 + 2) + 7 ]

原文地址:https://www.cnblogs.com/eastblue/p/13582223.html