function encoder_decoder:training_step(encoder_input, decoder_target)
-- forward propagate through the encoder
cutorch.setDevice(self.gpuID.encoder)
local encoder_output = self.encoder:forward(encoder_input)
-- move max-pool indices from encoder GPU to decoder GPU
cutorch.setDevice(self.gpuID.decoder)
for pooling_idx, unpool_mod in ipairs(self.unpooling_modules) do
unpool_mod.pooling = self.pooling_modules[pooling_idx]:clone()
end
-- move compressed input to decoder GPU and forward propagate through the decoder
local decoder_input = encoder_output:clone()
local decoder_output = self.decoder:forward(decoder_input)
-- compute cost and its gradient with respect to output
local decoder_cost = self.criterion:forward( decoder_output , decoder_target )
local decoder_dcost_dout = self.criterion:backward( decoder_output , decoder_target )
-- back propagate through decoder
local decoder_backward_output = self.decoder:backward( decoder_input, decoder_dcost_dout)
-- transfer back propagated output to encoder and continue back propagation
cutorch.setDevice(self.gpuID.encoder)
local encoder_dcost_dout = decoder_backward_output:clone()
self.encoder:backward(encoder_input, encoder_dcost_dout)
-- clean up max-pool data from decoder GPU (Required to avoid memory leak!)
cutorch.setDevice(self.gpuID.decoder)
for _, unpool_mod in ipairs(self.unpooling_modules) do
unpool_mod.pooling:empty()
end
cutorch.setDevice(self.gpuID.encoder)
return decoder_output
end