How to pick best model after model.fit()

3,482 views
Skip to first unread message

DSA

unread,
May 18, 2016, 2:54:01 PM5/18/16
to Keras-users
Hi all,

After running model.fit() to train the network, how can I pick the best model state before running model.predict(), since by default it uses the last state which may not be the best (i.e. I'd like to pick the model with lowest val_loss, rather than the last one whatever it happens to be).

I understand there is ModelCheckpoint callback that can save the best model to a file, but I don't need saving it to a file. I just want to run fit(), pick the best model and run predict()

Thanks!

Akshay Chaturvedi

unread,
May 19, 2016, 7:06:06 AM5/19/16
to Keras-users

Hi,

One way to do it is to import earlystopping (i.e. from keras.callbacks import EarlyStopping) and pass val_loss as an argument with some patience value of your choice.. Kindly look at EarlyStopping from here: http://keras.io/callbacks/ for better understanding.

Hope it helped..

Cheers

DSA

unread,
May 19, 2016, 1:54:16 PM5/19/16
to Keras-users
From what I see in the docs and the code it may help but only partially. I.e. the model will stop training once your metric no longer improves after the number of steps specified by patience parameter. So it sounds like you'll either get into local maximum situation if patience=0, or you'll get a model that is near its best, but not the best (since it degraded since the best value and ran out of patience).

DSA

unread,
May 19, 2016, 3:15:56 PM5/19/16
to Keras-users
OK, I've figured how to do it. One note that Keras docs say I need to do model.compile() after loading, but it seems to work fine without it (I've compared results with and without it), perhaps because I do model.compile() after the model definition.

model = Sequential()
# rest of model definition goes here...
model.compile(loss='mse', optimizer=opt)
checkpointer = ModelCheckpoint(filepath="weights.hdf5", verbose=1, save_best_only=True)
hist = model.fit(X_train_mat, Y_train_mat, nb_epoch=e, batch_size=b, validation_split=0.1, callbacks=[checkpointer])
model.load_weights('weights.hdf5')
predicted = model.predict(X_test_mat)

Akshay Chaturvedi

unread,
May 20, 2016, 3:59:19 AM5/20/16
to Keras-users
Yeah, you are right. It works fine without doing model.compile() after loading. I didn't know about ModelCheckpoint. Thanks a lot for letting me know.


Reply all
Reply to author
Forward
0 new messages