Kaldi nnet3的前向计算

  • 根据任务,构建ComputationRequst
  • 编译ComputationRequst,获取NnetComputation

    std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);

    • 创建计算——CreateComputation

      compiler.CreateComputation(opts, computation);

      • 从输出节点开始逐步向前计算依赖关系

        ComputationGraphBuilder builder(nnet_, &graph_);

        builder.Compute(*(requests_[segment]));

        每次向前深入一层,并计算所有Cindexes的依赖关系

        BuildGraphOneIter();

        对其中的每个Cindex,若需要计算其依赖:

        AddDependencies(cindex_id);

        • 若为kDescriptordesc.GetDependencies(index, &input_cindexes);
        • 若为kComponentcomponent->GetInputIndexes(request_->misc_info, index, &input_indexes);
        • 若为kDimRangeinput_cindexes[0] = Cindex(node.u.node_index, index);
        • 若为kInput,不需要依赖
      • 检查是否所有的输出都是可计算的

        if (!builder.AllOutputsAreComputable())

      • 将数据与运算组织为计算步

        对每个chunkCindexes根据不同网络层切分为phases,并以chunk为单位进行处理

        steps_computer.ComputeForSegment(*(requests_[segment]),phases_per_segment[segment]);

        phases以节点为单位切分为sub-phases,并以sub-phases为单位进行处理

        ProcessSubPhase(request, sub_phases[j]);

        sub-phases对于节点类型为:

        component-nodeProcessComponentStep(sub_phase);

        kSimpleComponent:除索引数-1外,将step复制为input_step

        else:从graph_->dependencies[c]获取依赖并插入到input_step

        input-nodeProcessInputOrOutputStep(request, false, sub_phase);

        output-nodeProcessInputOrOutputStep(request, true, sub_phase);

        dim-range-nodeProcessDimRangeSubPhase(sub_phase);

    • 优化计算——Optimize

      Optimize(opt_config_, nnet_,

      MaxOutputTimeInRequest(request),

      computation);

  • 根据NnetComputation构建NnetComputer

    NnetComputer computer(opts_.compute_config, *computation,

    nnet_, nnet_to_update);

  • 运行NnetComputer

    computer.Run();

    NnetComputation中所有Command迭代地运行

    ExecuteCommand();

    kPropagatevoid *memo = component->Propagate(indexes, input, &output);

    kBackpropcomponent->Backprop(debug_str.str(), indexes,

    in_value, out_value, out_deriv,

    memo, upd_component,

    c.arg6 == 0 ? NULL : &in_deriv);

    ...

  • NnetComputer获取输出

    computer.GetOutputDestructive("output", &cu_output);

原文地址:https://www.cnblogs.com/JarvanWang/p/10183427.html