class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.dense = tf.keras.layers.Dense(10)
def call(self, x):
out1 = self.dense(x)
# This call will not contribute to the gradients
out2 = self.dense(x)
out2 = tf.stop_gradient(out2)
return out1 + out2
model = MyModel()
x = tf.ones((10, 10))
with tf.GradientTape() as tape:
y = model(x)
loss = 2 * y
grads_and_vars = tape.gradient(loss, model.trainable_variables)