Permutation-invariant loss function

25 views
Skip to first unread message

Livio C.

unread,
Jan 20, 2025, 10:36:05 AMJan 20
to Keras-users
I have a multioutput multiclass task to predict, for each sample, 3 numbers from 1 to 10 without repetition (e.g., [2,4,1]). I am struggling to write a permutation-invariant loss function in Tensorflow, in that the order or the predicted labels is not relevant for my task. SparseCategoricalCrossentopy should be invoked for the predicted labels and all possible permutations of the gold labels, and the lowest score be selected. I have found no easily reusable code online: any suggestion? Thanks.

Livio C.

unread,
Jan 22, 2025, 6:17:49 PMJan 22
to Keras-users
I have a multioutput multiclass classification task to predict, for each sample, 3 labels, the value of each of which can be an integer from 1 to 10 without repetition  (e.g., [1,5,8]). Since the order of the predicted labels is not relevant for my task, I am struggling to write a permutation-invariant loss function in Tensorflow (i.e., the predicted numbers should be compared to all permutations of the gold labels, and the comparison returning the lowest score be chosen): I have not found any easily reusable code online: any suggestion?

Best,
Livio 

Samer Attrah

unread,
Jan 22, 2025, 8:10:08 PMJan 22
to Keras-users
It is not a difficult task, if you assign that to me I could help you with a function with this behaviour, but I don't know any existing function already implemented like the one you describe.

If you like the idea send me a direct email, please.
Reply all
Reply to author
Forward
0 new messages