Are gradient operations in TF themselves differentiable?

2,041 views
Skip to first unread message

heiz...@gmail.com

unread,
Mar 30, 2016, 4:04:19 AM3/30/16
to Discuss

Hello,

- > How do I find out in TF if an operation supports gradients or not (except at runtime where I will get an error if it doesn't) ?
- > Does the tf.gradients() or compute_gradients() operations support gradients? Are they differentiable? 


Tl;DR:   Can I put a gradients() operation in a cost function of some optimization? 


Thanks in advance for any answers.

Yaroslav Bulatov

unread,
Mar 30, 2016, 8:04:35 AM3/30/16
to heiz...@gmail.com, Discuss
"gradients" produces TensorFlow graphs, any part of that graph can be further differentiated.

Some gotchas:
- "gradients" takes derivatives with respect to single y, so need to call it several times for multi-dimensional y.
   (technically gradients function can take a list of ys, but it sums them up before differentiating)
- "gradients" produces Python list instead of Tensor, so need to use "tf.pack" to convert to Tensor
- "gradients" can produce None in cases when gradient is 0, but that's an illegal input to "gradients" so you need to replace None's with 0's

Here's an example of getting a Hessian matrix of loss by calling "gradients" twice

def replace_none_with_zero(l):
  return [0 if i==None else i for i in l] 

tf.reset_default_graph()

x = tf.Variable(1.)
y = tf.Variable(1.)
loss = tf.square(x) + tf.square(y)
sess = create_session()
grads = tf.gradients([loss], [x, y])
hess0 = replace_none_with_zero(tf.gradients([grads[0]], [x, y]))
hess1 = replace_none_with_zero(tf.gradients([grads[1]], [x, y]))
hessian = tf.pack([tf.pack(hess0), tf.pack(hess1)])
print hessian.eval()


--
You received this message because you are subscribed to the Google Groups "Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to discuss+u...@tensorflow.org.
To post to this group, send email to dis...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/discuss/d83404b7-793a-4201-ad73-fcecbb83eaac%40tensorflow.org.

heiz...@gmail.com

unread,
Mar 30, 2016, 10:16:09 AM3/30/16
to Discuss, heiz...@gmail.com
Thanks alot for the reply and the gotchas. I will try it out tonight.

George Dahl

unread,
Mar 30, 2016, 2:27:18 PM3/30/16
to Discuss, heiz...@gmail.com
Yaroslav already covered many of the issues, but depending on what you are doing you will might get bad results. If you sum up a gradient or average it into a scalar, you can probably differentiate again without too many issues. But if you aren't careful you can get a very large and expensive graph, or hit one of the issues Yaroslav mentioned.

heiz...@gmail.com

unread,
Mar 31, 2016, 6:08:01 AM3/31/16
to Discuss, heiz...@gmail.com
Unfortunately I have not had any success so far. I'm getting an error of  'LookupError: No gradient defined for operation 'gradients/MaxPool_grad/MaxPoolGrad'
which I guess means that there is no gradient for the maxpool_gradient operation.

What I basically want to do is input an image X and forward propagate it up to a layer L , then set the gradient equal to the predictionvector and backpropagating it back into pixelspace.
Now I want to calculate gradients of the costfunction of the backpropagated heatmap and a binary mask and update the weights in layer l. 

py_x = model(X,weights) # Prediction at layer l

# Evaluate model
predict_op = tf.argmax(py_x, 1)
scores_pred = tf.round(py_x)

# Project gradients back to pixelspace
heatMap_grads_op = tf.reduce_max(tf.abs(tf.gradients(scores_pred, X, scores_pred)[0]), reduction_indices=[3],keep_dims=True)

# Numerically stable cost
cost = tf.reduce_mean(tf.square(Y - heatMap_grads_op)) # Cost function of the backpropagated gradient heat-map and an annotated binary mask Y

# Optimization
optimizer = tf.train.AdamOptimizer(1e-4)
train_op = optimizer.minimize(cost)

As you can see this is similar to if I just forwarded the activations to layer L and then constructed a 'deconv' network with transposed weights and passed it through to reconstruct the image. 
The problem though I don't know how to (or if it's possible) in TF to copy pooling activation maps because I upon pooling alot of spacial information is lost and I would like to use the downpooled activation
maps for the uppooling.

Yaroslav Bulatov

unread,
Mar 31, 2016, 7:06:15 AM3/31/16
to Shagas Heizenberg, Discuss
I suspect that grad of maxpool_gradient is maxpool_gradient so maybe register that as the gradient function and see what happens? (maxpool backprop sends first-order sensitivities back to the largest element of the pool, second-order sensitivities should propagate in the same way)

But, I won't be surprised if there are more specialized operations which don't have gradient implemented

--
You received this message because you are subscribed to the Google Groups "Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to discuss+u...@tensorflow.org.
To post to this group, send email to dis...@tensorflow.org.
Reply all
Reply to author
Forward
0 new messages