tensorflow equivalent of torch.gather

1,087 views
Skip to first unread message

Ravi

unread,
Sep 2, 2018, 12:42:56 AM9/2/18
to Discuss
Hi,

I have a tensor of shape (16, 4096, 3). I have another tensor of indices of shape (16, 32768, 3). I am trying to collect the values along dim=1. This was initially done in pytorch using gather function as shown below-

# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b
= a.gather(1, idx)
# b.shape (16L, 32768L, 3L)

Please note that the size of the output b is the same as that of idx. However, when I apply gather function of tensorflow, I get a completely different output. The output dimension was found mismatching as shown below-

b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)

I also tried using tf.gather_nd but got in vain. See below-

b = tf.gather_nd(a, idx)
# b.shape (16, 32768)


Why am I getting different shapes of tensors? I want to get the tensor of the same shape as calculated by pytorch.


In other words, I want to know the tensorflow equivalent of torch.gather.


-

Thanks

Ravi

Ravi

unread,
Sep 3, 2018, 10:57:16 AM9/3/18
to Discuss
any suggestions, please?

Dirk Toewe

unread,
Sep 3, 2018, 4:41:29 PM9/3/18
to dis...@tensorflow.org

Hi Ravi,

gather_nd inteprets the last dimension of the index Tensor as mutl-index, so [[0,0,1],[0,1,0]] for example would collect the entries [0,0,1] and [0,1,0] from a Tensor. So if You would like to collect entries along some axis, You are going to have to stack it to gather with indices for all the other dimensions. Here a "little" example:

import tensorflow as tf
from tensorflow import newaxis as tf_new


ten = tf.constant([
  [[ 1, 2, 3, 4],
   [ 5, 6, 7, 8]],
  [[ 9,10,11,12],
   [13,14,15,16]],
  [[17,18,19,20],
   [21,22,23,24]]
])

index0 = tf.tile( tf.constant([0,1,2])[:,tf_new,tf_new], [1,2,4] )
index1 = tf.constant([[[ 0, 0, 0, 0 ],
                       [ 0, 0, 0, 1 ]],
                      [[ 0, 0, 1, 0 ],
                       [ 0, 0, 1, 1 ]],
                      [[ 0, 1, 0, 0 ],
                       [ 0, 1, 0, 1 ]]])
index2 = tf.tile( tf.constant([0,1,2,3])[tf_new,tf_new,:], [3,2,1] )

nd_index = tf.stack( (index0, index1, index2), axis=-1 )

gathered = tf.gather_nd(ten, nd_index)

with tf.Session() as sess:
  print('Multi-Index:')
  print( sess.run(nd_index) )
  print('Result:')
  print( sess.run(gathered) )

Yes, this is somewhat tedious! But it is tf.gather_nd pretty much covers every possible indexing.


Hope this helps,
Dirk

--
You received this message because you are subscribed to the Google Groups "Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to discuss+u...@tensorflow.org.
To post to this group, send email to dis...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/discuss/93bf026e-ec30-4158-b068-9751f1a76459%40tensorflow.org.

Ravi

unread,
Sep 4, 2018, 5:38:46 AM9/4/18
to Discuss, dirk...@gmail.com
Hi Dirk,

Thank you very much for providing the explanation for gather_id. It seems that gather_nd is much powerful and can be used in various situations. However, I am still trying to understand it based on the example you provided-

At first, I noticed that the input tensor is having the shape of (3, 2, 4). The indices are having the shape of (3, 2, 4, 3). The resultant tensor is having the shape of (3, 2, 4). It makes me confuse since I noticed that the output of torch equivalent of this function which is torch.gather() function is different. In case of torch.gather() function, the output has the same shape as indices.

Can you please, tell me how do I compose new indices so that I can get the similar tensor as computed by torch? 

Thanks again

-
Ravi

Alexandre Passos

unread,
Sep 4, 2018, 12:04:44 PM9/4/18
to ravi20...@gmail.com, dis...@tensorflow.org, dirk...@gmail.com

Ravi

unread,
Sep 5, 2018, 12:06:09 AM9/5/18
to Discuss, apa...@google.com
Thanks, Alex.

I tried tf.batch_gather on my input 'a' and indices 'idx'. Unfortunately, it throws following error-

Dimensions must be equal, but are 32768 and 4096 for 'add' (op: 'Add') with input shapes: [16,32768,3], [1,4096,1].


Below is the environment information-

pip install tf-nightly
Python version = 2.7.15
tf
.__version__ = 1.11.0-dev20180904

Any workaround, please?

-
Thanks
Ravi
Reply all
Reply to author
Forward
0 new messages