Tensorflow Python:访问张量中的各个元素
python
python-2.7
tensorflow
5
0

这个问题是关于访问张量中的各个元素的,例如[[1,2,3]]。我需要访问内部元素[1,2,3](可以使用.eval()或sess.run()执行,但是在张量很大时会花费更长的时间)

有什么方法可以更快地完成同样的工作吗?

提前致谢。

参考资料:
Stack Overflow
收藏
评论
共 1 个回答
高赞 时间 活跃

访问张量中的元素子集的主要方法有两种,每种方法均适用于您的示例。

  1. 使用索引运算符(基于tf.slice() )从张量中提取连续切片。

     input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) output = input[0, :] print sess.run(output) # ==> [1 2 3] 

    索引运算符支持许多与NumPy相同的切片规范。

  2. 使用tf.gather()操作从张量中选择不连续的切片。

     input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) output = tf.gather(input, 0) print sess.run(output) # ==> [1 2 3] output = tf.gather(input, [0, 2]) print sess.run(output) # ==> [[1 2 3] [7 8 9]] 

    请注意, tf.gather()仅允许您选择第0维的整个切片(在矩阵示例中为整行),因此您可能需要输入tf.reshape()tf.transpose()才能获得适当的元素。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号