SparseCategoricalCrossEntropy loss on 2-d tensor within a batch

21 views
Skip to first unread message

Phuong LE-HONG

unread,
Jan 5, 2023, 1:22:08 AM1/5/23
to BigDL User Group
Hi all, 

I have a problem with the SparseCategoricalCrossEntropy loss function and need a good idea to overcome this. 

For an input example, the output of my model is a matrix, say of shape 2x3, each row is a probability vector obtained by the softmax layer

z = 
z11 z12 z13
z21 z22 z23

The target is a vector of length 2 contain two integer indices representing the collect class label (in the range 0,1,2)

y= 
y1
y2

Since targets are indices, I use the SparseCategoricalCrossEntropy loss. It computes correctly if being used independently:

loss.forward(z, y)

However, when we have the mini-batch dimension, say a batch size of 16, then each input to the loss is a 3-d tensor Z of shape 16x2x3, and each target Y is of shape 16x2. Then the loss function fails to compute:

loss.forward(Z, Y). It raises the error: 
===
Caused by: java.lang.IllegalArgumentException: ClassNLLCriterion:
  The input to the layer needs to be a vector(or a mini-batch of vectors);
===

I understand that this is because the current ClassNLLCriterion does not support 3-d input (including the batch dimension). 

But I have not found any possible way to overcome this. A Reshape layer to change z to one dimension cannot work because we need independent probability distributions. We cannot change the batch dimension.

A SplitTensor to split z into two vectors will make the output a Table which is not compatible with the target y, so the loss fails to compute. 

Can anyone suggest a way to solve this problem?

Thank you very much,

Phuong





Phuong LE-HONG

unread,
Jan 5, 2023, 6:31:44 AM1/5/23
to BigDL User Group
I found a solution. 

For those who have the same problem: we need to use the TimeDistributedCriterion and ClassNLLCriterion in the nn package (not in Keras package).

This trick works for training, but for the validation after each epoch to work, we cannot use the provided metrics such as Top1Accuracy or Loss.  

In my case, I needed to develop a kind of DistributedTop1Accuracy and plug it into the setValidation method. Now the model works well. :-)

Cheers,

Phuong



guoqiongsong

unread,
Jan 5, 2023, 10:04:51 PM1/5/23
to User Group for BigDL
the output of the model and the loss function should align. 
can you please show the code of how to define the model? as well as a couple of records of data?
Reply all
Reply to author
Forward
0 new messages