import keras
import numpy as np
import keras.backend as K
from keras.models import Model
from keras.layers import Input, Dense
from functools import partial, update_wrapper
# debug
from ipdb import set_trace as bp
# callback to update the partial function at the end of each epoch
class LossWeightUpdate(keras.callbacks.Callback):
def __init__(self):
super(LossWeightUpdate, self).__init__()
def on_epoch_end(self, epoch, logs={}):
weight = self.model.loss.keywords['weight']
new_weight = weight - weight/4
## Update loss <partial> function using the new weight
bp()
# custom loss function with extra argument
def cce(y_true, y_pred, weight=0):
crossentropy = (-K.sum(y_true * K.log(y_pred)))
return weight*crossentropy
if __name__ == '__main__':
# dummy data
data = np.random.random((100,32))
labels = np.random.random((100,1)) > 0.5
# dummy model
a = Input(shape=(32,))
b = Dense(1, activation='softmax')(a)
model = Model(input=a, output=b)
# wrapping of loss function with partial
wrapped_cce = update_wrapper( partial(cce, weight=1.), cce)
# compiling model
model.compile(optimizer='adam', loss=wrapped_cce)
# fit
model.fit(data, labels, callbacks=[LossWeightUpdate()])