Frankly, I still fail to understand what is going on with the xent regularization. Apparently the function that computes the chain loss and both chain and xent derivatives only has one input, nnet_output. I assume it to be the output of the chain head of the network. https://github.com/kaldi-asr/kaldi/blob/f8b678a61e932f4858115dbe2d11caed48a7dbac/src/chain/chain-training.h#L118
However, if xent regularization is used, xent gradients are pushed through another head of the network, if I get it correctly: https://github.com/kaldi-asr/kaldi/blob/f8b678a61e932f4858115dbe2d11caed48a7dbac/src/nnet3/nnet-chain-training.cc#L343
I don't get why it could work. Given that the supervision head is different from the xent head, I don't understand how we can compute the gradient of xent loss wrt supervision output. Probably my reasoning is wrong at some earlier point?
Best,
Ilya
The xent output seems to be fetched at this line: https://github.com/kaldi-asr/kaldi/blob/f8b678a61e932f4858115dbe2d11caed48a7dbac/src/nnet3/nnet-chain-training.cc#L316 . As it doesn't seem to influence xent_deriv, it looks like the xent head gets updated with a great that might be not correct.
Consider an extreme case where the computation performed by the xent head is the negation of the computation performed by the supervision head. Then, if we estimate gradients for the xent head using supervision chain output, we are actually increasing the loss.
Realistically in most cases I would expect xent and supervision heads to be uncorrelated, turning this update into random gradient noise.
I believe I am missing something trivial early on that would explain the rationale?
My current understanding is that in swbd nnet3-chain config we see a network with common stem, which splits into two branches after certain layer, close to the top. These both branches are parameterized and don't share parameters. One of them ends with "supervision" or "chain" output (linear), another ends with xent output (last transformation on xent branch is LogSoftmax).
When we call ComputeChainObjfAndDeriv we pass the output of the first head as an argument, nnet_output matrix. We get back two matrices with gradients, nnet_output_deriv and xent_deriv. Later xent_deriv is not modified (except for scaling with weights).
xent_output is used to compute the xent_objf, but doesn't affect the xent_deriv: https://github.com/kaldi-asr/kaldi/blob/f8b678a61e932f4858115dbe2d11caed48a7dbac/src/nnet3/nnet-chain-training.cc#L320
I assume that at this point xent_deriv is not really the derivative, but soft alignment (per-frame posteriors wrt the output of supervision head?). The comment seems to be also saying these are posteriors, not gradients.
It looks plausible that backprop of logsoftmax handles these posteriors correctly as its input. But: I still find it strange that we update the xent head with the gradient whose computation didn't use the parameters of the xent head (e.g. affine transforms in it) in the forward pass.
Could you please correct my understanding?
Thanks for your patience!
Ilya