How to freeze model parameters?

1,640 views
Skip to first unread message

Qisong Wang

unread,
Jun 7, 2018, 11:21:38 PM6/7/18
to TensorFlow.js Discussion
Hi,

I'm wondering if there is a way to fix variable values of a model when training models like GAN? 

I know it's possible to select variables via varList parameter of optimizer.minimize(), but I didn't manage to access things like that in the model or layer objects.


Thanks. 

Shanqing Cai

unread,
Jun 7, 2018, 11:34:10 PM6/7/18
to TensorFlow.js Discussion
Are you using the Layers API? If so, it is fairly easy to freeze weights of given layers by setting their `trainable` property. 

See the following example (available in CodePen: https://codepen.io/caisq/pen/NzbbXL?editors=1011)

const model = tf.sequential();
model.add(tf.layers.dense({units: 4, inputShape: [2], activation: 'relu'}));
model.add(tf.layers.dense({units: 1}));

const layer0 = model.getLayer(null, 0);
const layer1 = model.getLayer(null, 1);

// Freeze the first layer.
layer0.trainable = false;

// Before training, print the values of the first and second layers' weights
console.log('=== Weights before training: ===');
layer0.getWeights()[0].print();
layer0.getWeights()[1].print();
layer1.getWeights()[0].print();
layer1.getWeights()[1].print();

const xs = tf.tensor2d([[1, 2], [3, 4]]);
const ys = tf.tensor2d([[-5], [6]]);
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
model.fit(xs, ys, {epochs: 100}).then(history => {
  // After training, print the weight values again.
  console.log('=== Weights after training: ===');
  layer0.getWeights()[0].print();
  layer0.getWeights()[1].print();
  layer1.getWeights()[0].print();
  layer1.getWeights()[1].print();
  
  // Notice how the weights of `layer0` remains the same after training,
  // while the weights of `layer1` changes.
});

This API is consistent with Python Keras.

Qisong Wang

unread,
Jun 8, 2018, 12:34:27 AM6/8/18
to TensorFlow.js Discussion
Thanks for the prompt reply, it works like a charm! I did try to set model's trainable property to false but forgot to give layer's a try.

Qisong Wang

unread,
Jun 12, 2018, 6:40:12 AM6/12/18
to TensorFlow.js Discussion
This method does not seem to stop optimizer.minimize() function from updating frozen layer parameters. I am using optimizer.minimize() instead of model.compile() and fit() because it allows me to minimize composite losses. For example, in the following code, I want to optimize for encoder and decoder network parameters while keeping discriminator network unchanged. Any solution?

freeze
(discriminator); // freeze all layers in the discriminator model using loop.
optimizer
.minimize(() => {
   
return tf.tidy(() => {
       
const z = encoder.apply(tf.concat([x, eps], 1));
       
const loss1 = tf.metrics.binaryCrossentropy(x, decoder.apply(z).mean();
       
const loss2 = discriminator.apply(tf.concat([x, z], 1)).mean();
       
return loss1.add(loss2);
   
});
};




On Friday, 8 June 2018 04:34:10 UTC+1, Shanqing Cai wrote:

Shanqing Cai

unread,
Jun 12, 2018, 1:40:48 PM6/12/18
to TensorFlow.js Discussion
Qisong, your observation is correct. Optimizer.minize() does not respect the trainable attribute of tf.Model or tf.Layer currently. Only tf.Model.fit() respects that attribute.

The reason is as follows. Optimizer.minize() is a construct in tensorflow.js Core. trainable and fit() are constructs in tensorflow.js Layers. 
If you want to minimize() a subset of variables, you should use the third input argument of minimize() to specify the Array of variables. See this API doc for details:

We should document this better. 

Daniel Smilkov

unread,
Jun 13, 2018, 10:19:52 AM6/13/18
to Shanqing Cai, TensorFlow.js Discussion
Thanks for the feedback. We are tacking this an an issue. We'll make sure that the non-trainable bit from Layers will propagate to the optimizer.

Daniel


--
You received this message because you are subscribed to the Google Groups "TensorFlow.js Discussion" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tfjs+uns...@tensorflow.org.
Visit this group at https://groups.google.com/a/tensorflow.org/group/tfjs/.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/tfjs/81e21de3-1336-40e8-bcca-3658be21316f%40tensorflow.org.
Reply all
Reply to author
Forward
0 new messages