Kaldi attention解析

xconfig示例

num_targets=3766

learning_rate_factor=20

dir=`mktemp -d`

mkdir -p $dir/configs

cat <<EOF > $dir/configs/network.xconfig

input dim=71 name=input

attention-relu-renorm-layer name=attention1 num-heads=5 value-dim=40 key-dim=20 num-left-inputs=5 num-right-inputs=2 time-stride=3

output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5

EOF

(cd ~/kaldi/egs/wsj/s5;steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/)

config示例

component name=attention1.attention type=RestrictedAttentionComponent value-dim=40 key-dim=20 num-left-inputs=5 num-right-inputs=2 num-left-inputs-required=-1 num-right-inputs-required=-1 output-context=True time-stride=3 num-heads=5 key-scale=0.158113883008

component-node name=attention1.attention component=attention1.attention input=attention1.affine

raw.txt示例

<ComponentName> attention1.attention <RestrictedAttentionComponent> <NumHeads> 5 <KeyDim> 20 <ValueDim> 40 <NumLeftInputs> 5 <NumRightInputs> 2 <TimeStride> 3 <NumLeftInputsRequired> 5 <NumRightInputsRequired> 2 <OutputContext> T <KeyScale> 0.1581139 <StatsCount> 0 <EntropyStats> [ ]

<PosteriorStats> [ ]

</RestrictedAttentionComponent>

拓扑结构

根据拓扑结构可知,kaldi nnet3 RestrictedAttentionComponent相当于一个非线性层

gdb示例

$ gdb -d ~/kaldi/src/nnet3 --args nnet3-compute ref.raw ark,t:/tmp/feat ark:/dev/null

(gdb) rb kaldi::nnet3::.*::Propagate

(gdb) run

Breakpoint 3, kaldi::nnet3::AffineComponent::Propagate (this=0x11a6ec80, indexes=0x0, in=...,

out=0x7fffffffb790) at nnet-simple-component.cc:1236

输入为71x71的矩阵

(gdb) printf "%d, %d, %d, %d ", in.NumRows(), in.NumCols(), out->NumRows(), out->NumCols()

71, 71, 71, 440

(gdb) c

Breakpoint 43, kaldi::nnet3::RestrictedAttentionComponent::Propagate (this=0x12f07e40, indexes_in=0x12f09a00,

in=..., out=0x7fffffffb790) at nnet-attention-component.cc:134

输入为71x440的矩阵,分为5heads

Head 1

Head 2

Head 3

Head 4

Head 5

71x88

71x88

71x88

71x88

71x88

(gdb) printf "%d, %d, %d, %d ", in.NumRows(), in.NumCols(), out->NumRows(), out->NumCols()

71, 440, 50, 240

此处对每个head分别进行attention,即PropagateOneHead

(gdb) c

Breakpoint 44, kaldi::nnet3::RestrictedAttentionComponent::PropagateOneHead (this=this@entry=0x12f07e40,

io=..., in=..., c=c@entry=0x7fffffffb630, out=out@entry=0x7fffffffb650) at nnet-attention-component.cc:164

164 CuMatrixBase<BaseFloat> *out) const {

(gdb) printf "%d, %d, %d, %d ", in.NumRows(), in.NumCols(), out->NumRows(), out->NumCols()

71, 88, 50, 48

71帧中包含了

  1. num-left-inputs*time-stride=5*3=15帧左上文,不输出
  2. 中间50帧,输出
  3. num-right-inputs*time-stride=2*3=6帧右上文,不输出

PropagateOneHead的计算示例为:

   

整个RestrictedAttentionComponent的计算逻辑图为:

   

原文地址:https://www.cnblogs.com/JarvanWang/p/11084359.html