Update <partial> loss function during training

451 views
Skip to first unread message

SC

unread,
Mar 22, 2017, 9:51:13 AM3/22/17
to Keras-users
Hello everybody,

I am trying to find a way to edit/update/change the loss function of a model during training. The setup that I am using right now uses a custom loss function wrapped inside a python partial function. I am using this structure with the partial function because I need to pass extra arguments in the loss function apart from the y_true, y_pred. What I am trying to do is to change this extra arguments during training using a callback similar to the way that the LearningRateScheduler does for the learning rate. Here is a toy example:

import keras
import numpy as np
import keras.backend as K
from keras.models import Model
from keras.layers import Input, Dense
from functools import partial, update_wrapper
# debug
from ipdb import set_trace as bp

# callback to update the partial function at the end of each epoch
class LossWeightUpdate(keras.callbacks.Callback):
   
def __init__(self):
       
super(LossWeightUpdate, self).__init__()
   
def on_epoch_end(self, epoch, logs={}):
        weight
= self.model.loss.keywords['weight']
        new_weight
= weight - weight/4
       
## Update loss <partial> function using the new weight
        bp
()

# custom loss function with extra argument
def cce(y_true, y_pred, weight=0):
    crossentropy
= (-K.sum(y_true * K.log(y_pred)))
   
return weight*crossentropy

if __name__ == '__main__':
   
# dummy data
    data
= np.random.random((100,32))
    labels
= np.random.random((100,1)) > 0.5

   
# dummy model
    a
= Input(shape=(32,))
    b
= Dense(1, activation='softmax')(a)
    model
= Model(input=a, output=b)

   
# wrapping of loss function with partial
    wrapped_cce
= update_wrapper( partial(cce, weight=1.), cce)

   
# compiling model
    model
.compile(optimizer='adam', loss=wrapped_cce)

   
# fit
    model
.fit(data, labels, callbacks=[LossWeightUpdate()])

I am using the theano backend and during the model compile the loss function is also compiled of course. I wonder if there is a way that I can somehow update this argument using a callback.

Thank you!
Reply all
Reply to author
Forward
0 new messages