Hi Rowan,
I read the code, and I have some questions. As my understanding, the follwing 2 lines refers to the Grounding part of R2C moel.
q_rep, q_obj_reps = self.embed_span(question, question_tags, question_mask, obj_reps['obj_reps'])
a_rep, a_obj_reps = self.embed_span(answers, answer_tags, answer_mask, obj_reps['obj_reps'])
in contextualization part, αi,j=softmaxj(riWqj), qi=∑jαi,jqj.(2) is used to find attended query, the corresponding code as follows, but I'm a little confused, is qa_similarity refers to ri*qi in the described equation? Because qa_similarity and question_mask are used to calculate weight, and then the weight is input into torch.einsum to get attend_q. I cannot match the equation with code.
qa_similarity = self.span_attention(
q_rep.view(q_rep.shape[0] * q_rep.shape[1], q_rep.shape[2], q_rep.shape[3]),
a_rep.view(a_rep.shape[0] * a_rep.shape[1], a_rep.shape[2], a_rep.shape[3]),
).view(a_rep.shape[0], a_rep.shape[1], q_rep.shape[2], a_rep.shape[2])
qa_attention_weights = masked_softmax(qa_similarity, question_mask[..., None], dim=2)
attended_q = torch.einsum('bnqa,bnqd->bnad', (qa_attention_weights, q_rep))
# Have a second attention over the objects, do A by Objs
# [batch_size, 4, answer_length, num_objs]
atoo_similarity = self.obj_attention(a_rep.view(a_rep.shape[0], a_rep.shape[1] * a_rep.shape[2], -1),
obj_reps['obj_reps']).view(a_rep.shape[0], a_rep.shape[1],
a_rep.shape[2], obj_reps['obj_reps'].shape[1])
atoo_attention_weights = masked_softmax(atoo_similarity, box_mask[:,None,None])
attended_o = torch.einsum('bnao,bod->bnad', (atoo_attention_weights, obj_reps['obj_reps']))
In reasoning part, the input is torch.cat result of a_rep,attended_o, and attended_q. But why use self.reasoning_use_answer and self.reason_use_obj....and I saw all of these bool parameter is true.
reasoning_inp = torch.cat([x for x, to_pool in [(a_rep, self.reasoning_use_answer),
(attended_o, self.reasoning_use_obj),
(attended_q, self.reasoning_use_question)]
if to_pool], -1)
Sorry for the black things, I don't know how to change the color.....
best regards,
Xuejiao