I am trying to build a encoder-decoder setup but I am struggeling with making the decoder accept the encoders last state for initialization
# bi-lstm encoder
f_init = in_fwdRNN.initial_state()
b_init = in_bwdRNN.initial_state()
# do the encoding
fwdRnnStates = f_init.add_inputs(embeddings)
bwdRnnStates = b_init.add_inputs(reversed(embeddings))
# collect cell states and encoded input
src_encodings = []
forward_cells = []
backward_cells = []
for forward_state, backward_state in zip(fwdRnnStates, bwdRnnStates):
fwd_cell, fwd_enc = forward_state.s()
bak_cell, bak_enc = backward_state.s()
src_encodings.append(dy.concatenate([fwd_enc, bak_enc]))
forward_cells.append(fwd_cell)
backward_cells.append(bak_cell)
# decoder shall use the last state of the fwd/bwd pass in combination
decoder_init = dy.concatenate([forward_cells[-1], backward_cells[0]])
# too naive :(
dec = decoderRNN.initial_state(decoder_init)
What would be the proper way to initialize the decoder with the encoder's last state?
Thanks for a hint, I am not really sure how to solve this.