LSTM training

557 views
Skip to first unread message

Arjun Sharma

unread,
Jun 23, 2015, 7:25:45 AM6/23/15
to tor...@googlegroups.com

I am trying to train a LSTM (single layer) of input vector size=26 for classification into 18 categories using ‘nnx’ (code by nicholas-leonard). I use the following code to build the structure of lstm:

 

inputSize = 26

hiddenSize = 256

outputSize = 18

lr = 0.01

updateInterval = 10

 

l=nn.LSTM(inputSize, hiddenSize)

 

lstm = nn.Sequential()

 

lstm:add(l)

lstm:add(nn.Linear(hiddenSize, outputSize))

lstm:add(nn.LogSoftMax())

lstm:cuda()


criterion = nn.ClassNLLCriterion()

criterion:cuda()

 

The code for training each example is

 

  for step=1,nSteps do

      input = inputs[step]

      target = targets[step]

      

      local output = lstm:forward(input:cuda())

      local err = criterion:forward(output:cuda(), target:cuda())

      local gradOutput = criterion:backward(output:cuda(), target:cuda())

      

      lstm:backward(input:cuda(), gradOutput:cuda())

      if step % updateInterval == 0 or step == nSteps then

                 l:updateParameters(lr)

      end

  end

 

The output from the forward pass becomes 'nan' after the first time the parameters of the lstm are updated using updateParameters. This code is similar in spirit to the RNN code at https://github.com/Element-Research/rnn

Can someone point out what it is that I am doing wrong? The code does not give 'nan' when using simple rnn instead of lstm.

Thanks

abc123

unread,
Jun 23, 2015, 9:42:38 AM6/23/15
to tor...@googlegroups.com
Might be the gradient exploding problem.
Check the norms of your gradients for your first update.

Sergey Zagoruyko

unread,
Jun 24, 2015, 3:28:26 PM6/24/15
to tor...@googlegroups.com
I don't see gradParameters:zero() in the beginning of your for loop. Looks like your gradients are garbage and the network diverges.

Arjun Sharma

unread,
Jul 11, 2015, 9:34:17 AM7/11/15
to tor...@googlegroups.com
Thanks for your replies guys. The exploding gradient problem was indeed the cause. Initializing the weights differently and setting lower learning rates did help (Still get NaNs once in a while though.)

Alex Graves suggests clipping gradients to [-1,1] range or [-10,10] range during BPTT. Any ideas on how this can be implemented in the current 'rnn' module code?

Also, I am struggling to save the model and resume training by reloading it. (Would help a lot since the learning rate needs to decreased often after checking the performance.)
I am not talking about only forward propagations after loading. I want to resume training.

Other threads suggest saving the optim state. But as far as I know, we are not using the optim package, right? Simple saving the model and reloading it later does not seem to work for me (validation and training loss is much higher after reloading then before). Saving parameters and gradients and reloading them does not work either. Has anyone successfully been able to resume training from a saved model/weights on the 'rnn' package from Nicholas Leonard?

Thanks & Regards 

andrew morgan

unread,
Oct 5, 2015, 5:05:48 PM10/5/15
to torch7
May I ask how you initialised your weights? to solve that divergence problem? I think I've got the same issue.
many thanks,
Andrew

Brendan

unread,
Oct 6, 2015, 6:36:59 AM10/6/15
to torch7
You can do element-wise clipping like this, where params is your parameter vector (e.g.: local params, gradParams = lstm:getParameters() to get it all in one tensor).

params:clamp(-5,5) -- now each element that's >5 or <-5 is clipped.

Alternatively, to clip the norm of the vector which is what some other people do:

local norm = params:norm()
if norm > 5 then
  params:mul(5/norm)
end
-- now params has norm at most 5

You should zero gradients right after you update, as well... that's a standard way to train RNNs using truncated BPTT. As is, your gradients will keep accumulating into the same buffer, which is likely a source of your problem. You can zero parameters using gradParams:zero(), or lstm:zeroGradParameters().

I suggest switching to the optim package at some point, as then you can easily change to non-SGD optimizers like rmsprop or adam.

If you'd like a code example see the learning-to-execute code by Wojciech Zaremba, Oxford's course's practical 6, or Andrej Karpathy's char-rnn demo. All these implement truncated BPTT (albeit manually rather than using the rnn package), and use optim.

Can't answer the serialization question, I don't use the rnn package.


On Saturday, 11 July 2015 14:34:17 UTC+1, Arjun Sharma wrote:
Reply all
Reply to author
Forward
0 new messages