在Tensorflow中,如何将tf.gather()用于最后一个维度?
deep-learning
python
tensorflow
7
0

我正在尝试根据最后一个维度收集张量的切片,以实现层之间的部分连接。因为输出张量的形状为[batch_size, h, w, depth] ,所以我想根据最后一个维度选择切片,例如

# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]

但是, tf.gather(L, [0, 2,3,8])似乎仅适用于第一维(对吗?)有人可以告诉我该怎么做吗?

参考资料:
Stack Overflow
收藏
评论
共 6 个回答
高赞 时间 活跃

@Andrei答案的正确版本为

cat_idx = tf.stack([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=1)
result = tf.gather_nd(matrix, cat_idx)
收藏
评论

您可以尝试这种方式,例如(在大多数情况下,至少在NLP中),

参数的形状为[batch_size, depth] ,索引为[i,j,k,n,m],其长度为batch_size。然后gather_nd可能会有所帮助。

parameters = tf.constant([
                          [11, 12, 13], 
                          [21, 22, 23], 
                          [31, 32, 33], 
                          [41, 42, 43]])    
targets = tf.constant([2, 1, 0, 1])    
batch_nums = tf.range(0, limit=parameters.get_shape().as_list()[0])     
indices = tf.stack((batch_nums, targets), axis=1) # the axis is the dimension number   
items = tf.gather_nd(parameters, indices)  
# which is what we want: [13, 22, 31, 42]

该代码段首先通过batch_num找到拳头尺寸,然后通过目标编号沿该尺寸获取物品。

收藏
评论

这里有一个跟踪错误来支持此用例: https : //github.com/tensorflow/tensorflow/issues/206

现在,您可以:

  1. 转置矩阵,以便首先收集维数(转置非常昂贵)

  2. 将张量重整为1d(重整很便宜),然后在线性索引处将您的收集列索引转换为单个元素索引的列表,然后重新整形

  3. 使用gather_nd 。仍然需要将您的列索引转换为单个元素索引的列表。
收藏
评论

从TensorFlow 1.3开始, tf.gather具有axis参数,因此这里不再需要各种解决方法。

https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gather https://github.com/tensorflow/tensorflow/issues/11223

收藏
评论

使用tf.unstack(...),tf.gather(...)和tf.stack(..)的另一种解决方案

码:

import tensorflow as tf
import numpy as np

shape = [2, 2, 2, 10] 
L = np.arange(np.prod(shape))
L = np.reshape(L, shape)

indices = [0, 2, 3, 8]
axis = -1 # last dimension

def gather_axis(params, indices, axis=0):
    return tf.stack(tf.unstack(tf.gather(tf.unstack(params, axis=axis), indices)), axis=axis)

print(L)
with tf.Session() as sess:
    partL = sess.run(gather_axis(L, indices, axis))
    print(partL)

结果:

L = 
[[[[ 0  1  2  3  4  5  6  7  8  9]
   [10 11 12 13 14 15 16 17 18 19]]

  [[20 21 22 23 24 25 26 27 28 29]
   [30 31 32 33 34 35 36 37 38 39]]]


 [[[40 41 42 43 44 45 46 47 48 49]
   [50 51 52 53 54 55 56 57 58 59]]

  [[60 61 62 63 64 65 66 67 68 69]
   [70 71 72 73 74 75 76 77 78 79]]]]

partL = 
[[[[ 0  2  3  8]
   [10 12 13 18]]

  [[20 22 23 28]
   [30 32 33 38]]]


 [[[40 42 43 48]
   [50 52 53 58]]

  [[60 62 63 68]
   [70 72 73 78]]]]
收藏
评论

现在,您可以使用collect_nd执行以下操作:

cat_idx = tf.concat([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=0)
result = tf.gather_nd(matrix, cat_idx)

同样,正如用户Nova在@Yaroslav Bulatov所引用的线程中所报告的那样:

x = tf.constant([[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
y = tf.gather(tf.reshape(x, [-1]),  # flatten input
              idx_flattened)  # use flattened indices

with tf.Session(''):
  print y.eval()  # [2 4 9]

要点是将张量展平并使用带有tf.gather(...)的跨步一维寻址。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号