backpropagating gradient in a self-created nngraph module

45 views
Skip to first unread message

Gaurav Pandey

unread,
May 5, 2016, 8:47:01 AM5/5/16
to torch7
Hi,

I created a small nn module for multiplying a scalar with a Tensor. However, when I use the module as an nngraph module, and try to backpropagate the gradients through it, I find that the inputs to the backward function are nil. This is the module.


local MulTable, parent = torch.class('nn.MulTable', 'nn.Module')

function MulTable:__init(scalar)
   parent.__init(self)
   self.scalar = scalar
   self.gradInput = {}
end

function MulTable:updateOutput(input)
   self.output:resizeAs(input):copy(input)
   self.output:mul(self.scalar)
 
   return self.output
end

function MulTable:updateGradInput(input, gradOutput)
   self.gradInput = self.gradInput or input.new()
   self.gradInput:resizeAs(input):copy(gradOutput)
   self.gradInput:mul(self.scalar)
   return self.gradInput

****************************************************************
This is what I try to execute:

require 'nn'
require 'MulTable'
require 'nngraph'
input1 = nn.Identity()()
output = nn.MulTable(.1)(input1)
model = nn.gModule({input1},  {output})
model:forward(torch.zeros(10)+3)--orch.ones(10))
model:backward(torch.zeros(10)+3, torch.ones(10))
***************************************************************

It gives me an error on the line shown in bold

Hugh Perkins

unread,
May 5, 2016, 10:23:47 AM5/5/16
to torch7
in your __init, you assign {} t oself.gradInput.  Therefore in your updateGradinput, hte line self.gradInput = self.gradInput or input.new()  will do nothing, leave self.gradInput still being a table, hence next line will fail.

By the way, a useful debug method is sprinkling 'print' statements liberally in your code.  I noticed the issue by visual inspection, but if you'd put, just before the faliing line, 'print(torch.type(self.gradInput))', then the problem would have become obvious :-)
Reply all
Reply to author
Forward
0 new messages