nngraph

nngraph(全称:neural network graph),是torch里面的一个package

code example:

require 'torch'
require 'nn'
require 'nngraph'

function CreateModule(input_size)
    local input = nn.Identity()()   -- network input

    local nn_module_1 = nn.Linear(input_size, 100)(input)
    local nn_module_2 = nn.Linear(100, input_size)(nn_module_1)

    local output = nn.CMulTable()({input, nn_module_2})

    -- pack a graph into a convenient module with standard API (:forward(), :backward())
    return nn.gModule({input}, {output})
end


input = torch.rand(30)

my_module = CreateModule(input:size(1))

output = my_module:forward(input)
criterion_err = torch.rand(output:size())

gradInput = my_module:backward(input, criterion_err)
print(gradInput)
原文地址:https://www.cnblogs.com/lightsun/p/7127228.html