Exclude layer from gradient (Keras)

210 views
Skip to first unread message

Jorge Fernández de Cossío Díaz

unread,
Jan 14, 2020, 12:12:10 PM1/14/20
to Discuss
I am using the Keras sub-classing API. I have two layers (that I will call A and B) that share weights in a model. But I want that whenever I call A on some input, it does not contribute to the gradient of these weights. But when I call B on some input, it should contribute to the gradient of these weights.

Is it possible to arrange for this in some easy way, while still training the model by just calling model.fit?

Or the only way is to write everything explicitly using GradientTape?

Jorge Fernández de Cossío Díaz

unread,
Jan 14, 2020, 12:27:39 PM1/14/20
to Discuss
By the way I am just trying to do this: https://www.tensorflow.org/api_docs/python/tf/stop_gradient. If I understand that docs correctly, I should call `tf.stop_gradient` on the output of the layer?

Jorge Fernández de Cossío Díaz

unread,
Jan 14, 2020, 1:05:30 PM1/14/20
to Discuss
Is there a difference between tf.stop_gradient and tf.keras.backend.stop_gradient?


On Tuesday, January 14, 2020 at 6:12:10 PM UTC+1, Jorge Fernández de Cossío Díaz wrote:

Tom O'Malley

unread,
Jan 14, 2020, 2:49:02 PM1/14/20
to Discuss
This should work:

class MyModel(tf.keras.Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.dense = tf.keras.layers.Dense(10)

  def call(selfx):
    out1 = self.dense(x)
    
# This call will not contribute to the gradients
    out2 = self.dense(x)
    out2 = tf.stop_gradient(out2)

    return out1 + out2

model = MyModel()
x = tf.ones((1010))
with tf.GradientTape() as tape:
  y = model(x)
  loss = 2 * y
  grads_and_vars = tape.gradient(loss, model.trainable_variables)

Jorge Fernández de Cossío Díaz

unread,
Jan 14, 2020, 2:59:26 PM1/14/20
to Discuss
Thanks. So it's fine to use either tf.stop_gradient or tf.keras.backend.stop_gradient?

Tom O'Malley

unread,
Jan 14, 2020, 3:07:02 PM1/14/20
to Discuss
I'd recommend using tf.stop_gradient. tf.keras.backend utilities exist primarily for compatibility with Keras before it was brought into TensorFlow, and while we're not getting rid of them, I'd encourage new projects to avoid them. Equivalent functionality should exist elsewhere in TensorFlow
Reply all
Reply to author
Forward
0 new messages