nngraph(全称:neural network graph),是torch里面的一个package
code example:
require 'torch'
require 'nn'
require 'nngraph'
function CreateModule(input_size)
local input = nn.Identity()() -- network input
local nn_module_1 = nn.Linear(input_size, 100)(input)
local nn_module_2 = nn.Linear(100, input_size)(nn_module_1)
local output = nn.CMulTable()({input, nn_module_2})
-- pack a graph into a convenient module with standard API (:forward(), :backward())
return nn.gModule({input}, {output})
end
input = torch.rand(30)
my_module = CreateModule(input:size(1))
output = my_module:forward(input)
criterion_err = torch.rand(output:size())
gradInput = my_module:backward(input, criterion_err)
print(gradInput)