How to create a custom loss function based on cosine similarity?

76 views
Skip to first unread message

Steve Tan

unread,
Apr 14, 2020, 9:34:35 AM4/14/20
to theano-users
Hi,
I'm a total newb to theano and have been struggling with these problem for days,
how can I add a custom loss function based on cosine similarity as defined in this paper - github code here?

def inv_correlation(y_true, y_pred):
    """ Computes 1 minus the dot product between corresponding pairs of samples in two tensors. """
    return 1. - K.sum(y_true * y_pred, axis = -1)

I try doing by coping out the code from the theano nnet but am totally confused on how to do this properly.
For some reason my code is going to the second if else, the ndim doesn't match up, and I'm lost after that.

def cosine_similarity(coding_dist, true_dist):
    """
    H(p,q) = - \sum_x p(x) \log(q(x))
    """
    if true_dist.ndim == coding_dist.ndim: #it's not coming here
        return -T.sum(true_dist * T.log(coding_dist),
                           axis=coding_dist.ndim - 1)
    elif true_dist.ndim == coding_dist.ndim - 1: # try to modify this part but it seems to be a bit way over my head
        return cosine_similarity_1hot(coding_dist, true_dist)
    else:
        raise TypeError('rank mismatch between coding and true distributions')


class CosineSimilarity1HotGrad(gof.Op):

    __props__ = ()

    def make_node(self, g_y, coding_dist, true_one_of_n):
        return Apply(self, [g_y, coding_dist, true_one_of_n],
                     [coding_dist.type()])

    def perform(self, node, inp, out):
        g_y, coding_dist, true_one_of_n = inp
        g_coding_strg, = out
        g_coding = np.zeros_like(coding_dist)
        for i in xrange(len(g_y)):
            g_coding[i, true_one_of_n[i]] = (-g_y[i] /
                                             coding_dist[i, true_one_of_n[i]])
        g_coding_strg[0] = g_coding

    def infer_shape(self, node, in_shapes):
        return [in_shapes[1]]

cosine_similarity_1hot_grad = CosineSimilarity1HotGrad()


class CosineSimilarity1Hot(gof.Op):
    """
    Compute the cross entropy between a coding distribution and
    a true distribution of the form [0, 0, ... 0, 1, 0, ..., 0].
    .. math::
        y[i] = - \log(coding_dist[i, one_of_n[i])
    Notes
    -----
    In the case that the coding distribution is the output of a
    softmax, an application of this Op will probably be optimized
    away in favour of one with a C implementation.
    """
    __props__ = ()

    def make_node(self, coding_dist, true_one_of_n):
        """
        Parameters
        ----------
        coding_dist : dense matrix
        true_one_of_n : lvector
        Returns
        -------
        dvector
        """
        _coding_dist = tensor.as_tensor_variable(coding_dist)
        _true_one_of_n = tensor.as_tensor_variable(true_one_of_n)
        if _coding_dist.type.ndim != 2:
            raise TypeError('matrix required for argument: coding_dist')
        if _true_one_of_n.type not in (tensor.lvector, tensor.ivector):
            raise TypeError(
                'integer vector required for argument: true_one_of_n'
                '(got type: %s instead of: %s)' % (_true_one_of_n.type,
                                                   tensor.lvector))

        return Apply(self, [_coding_dist, _true_one_of_n],
                     [tensor.Tensor(dtype=_coding_dist.dtype,
                      broadcastable=[False])()])

    def perform(self, node, inp, out):
        coding, one_of_n = inp
        y_out, = out
        y = np.zeros_like(coding[:, 0])
        for i in xrange(len(y)):
            y[i] = -np.log(coding[i, one_of_n[i]])
            y[i] = 1. - T.sum(true_dist * coding_dist[0],  axis= -1)
        y_out[0] = y

    def infer_shape(self, node, in_shapes):
        return [(in_shapes[0][0],)]

    def grad(self, inp, grads):
        coding, one_of_n = inp
        g_y, = grads
        return [cosine_similarity_1hot_grad(g_y, coding, one_of_n),
                grad_not_implemented(self, 1, one_of_n)]

cosine_similarity_1hot = CosineSimilarity1Hot()


Reply all
Reply to author
Forward
0 new messages