Fine-tuning with tfhub + Keras API

653 views
Skip to first unread message

Colin Morris

unread,
Jul 30, 2018, 4:33:54 PM7/30/18
to TensorFlow Hub
I've been playing around with tfhub using the Keras API, and had a pretty confusing experience trying to fine-tune weights. It took me a while to realize this, but even though I was setting trainable=True when loading the module, the embedding weights weren't changing at all. 

Here's a small repro. I build a model that takes a string and outputs the component-wise average of its embedding. Then I train it to try to push certain words to have very large embeddings. But the embeddings don't change.

Eventually, I figured out that a Keras model will only train variables that belong to the 'trainable_weights' property of one of its layers (it's not enough that a variable is part of the model's calculations and belongs to the 'trainable_variables' collection). So to accomplish any fine-tuning, I needed to add some code like this when creating the embedding layer:

embed_layer = keras.layers.Lambda(embed_lambda)
all_trainables = tf.get_default_graph().get_collection('trainable_variables')
for vars in embedder.variable_map.values():
    if not isinstance(vars, list):
        vars = [vars]
    for var in vars:
        if var in all_trainables:
            embed_layer.trainable_weights.append(var)
            
I found this pretty surprising. In the "build a simple text classifier" tutorial which uses the estimator framework, finetuning just works when trainable is set to True, with no further configuration required. 

I wonder whether this should be documented somewhere? Or better yet, maybe the hub API could include a text_embedding_layer and image_embedding_layer (in the same way that they export the text_embedding_column and image_embedding_column helpers for the estimator framework).

André Susano Pinto

unread,
Aug 6, 2018, 10:37:47 AM8/6/18
to colin...@google.com, TensorFlow Hub
Hi Colin,

Thanks for sharing your experience! Native support for Keras is on our list of feature requests but we still had no time to solve it.

As you found out keras.layers.Lambda assumes the function being wrapped is stateless and so is not suitable to wrap modules to be fine-tuned.
You can try to patch in trainable variables like you suggested but then one has to wonder what other properties have to be patched? E.g. if the variables are not trainable they should still be reflected in layer.non_trainable_weights, additionally update ops, regularizer losses, etc...


--
You received this message because you are subscribed to the Google Groups "TensorFlow Hub" group.
To unsubscribe from this group and stop receiving emails from it, send an email to hub+uns...@tensorflow.org.
Visit this group at https://groups.google.com/a/tensorflow.org/group/hub/.

Shubhanshu Mishra

unread,
Mar 25, 2019, 5:52:59 PM3/25/19
to TensorFlow Hub
Hi Colin,

I had a similar issue. It was able to add the trainable weights of elmo to the keras model by subclassing Layer and the updating
_trainable_weights

My code can be found at: https://gist.github.com/napsternxg/324d219471398015fb40030b496a5c39

I hope this helps. 

Shubhanshu Mishra

unread,
Mar 25, 2019, 5:54:50 PM3/25/19
to TensorFlow Hub
Also, I was using elmo model instead of nnlm. 
Reply all
Reply to author
Forward
0 new messages