TensorFlow:numpy.repeat()替代
tensorflow
10
0

我想以成对的方式比较来自我的神经网络的预测值yp ,所以我正在使用(返回我以前的numpy实现):

idx = np.repeat(np.arange(len(yp)), len(yp))
jdx = np.tile(np.arange(len(yp)), len(yp))
s = yp[[idx]] - yp[[jdx]]

基本上,这将创建一个索引网格,然后供我使用。 idx=[0,0,0,1,1,1,...]jdx=[0,1,2,0,1,2...]我不知道是否有更简单的方式来做...

无论如何,TensorFlow拥有一个tf.tile() ,但似乎缺少一个tf.repeat()

idx = np.repeat(np.arange(n), n)
v2 = v[idx]

我得到错误:

TypeError: Bad slice index [  0   0   0 ..., 215 215 215] of type <type 'numpy.ndarray'>

将TensorFlow常数用于索引也行不通:

idx = tf.constant(np.repeat(np.arange(n), n))
v2 = v[idx]

--

TypeError: Bad slice index Tensor("Const:0", shape=TensorShape([Dimension(46656)]), dtype=int64) of type <class 'tensorflow.python.framework.ops.Tensor'>

这个想法是将我的RankNet实现转换为TensorFlow。

参考资料:
Stack Overflow
收藏
评论
共 4 个回答
高赞 时间 活跃

您可以结合使用tf.tile()tf.reshape()来达到np.repeat()的效果:

idx = tf.range(len(yp))
idx = tf.reshape(idx, [-1, 1])    # Convert to a len(yp) x 1 matrix.
idx = tf.tile(idx, [1, len(yp)])  # Create multiple columns.
idx = tf.reshape(idx, [-1])       # Convert back to a vector.

您可以使用tf.tile()简单地计算jdx

jdx = tf.range(len(yp))
jdx = tf.tile(jdx, [len(yp)])

对于索引,您可以尝试使用tf.gather()yp张量中提取非连续切片:

s = tf.gather(yp, idx) - tf.gather(yp, jdx)
收藏
评论

根据tf api 文档tf.keras.backend.repeat_elements()np.repeat()进行相同的工作。例如,

x = tf.constant([1, 3, 3, 1], dtype=tf.float32)
rep_x = tf.keras.backend.repeat_elements(x, 5, axis=0)
# result: [1. 1. 1. 1. 1. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 1. 1. 1. 1. 1.]
收藏
评论

仅针对一维张量,我做了这个功能

def tf_repeat(y,repeat_num)   
        return tf.reshape(tf.tile(tf.expand_dims(y,axis=-1),[1,repeat_num]),[-1]) 
收藏
评论

看来您的问题如此受欢迎,以至于人们在TF跟踪器上引用了它 。遗憾的是,TF中仍未实现相同的功能。

您可以通过组合tf.tiletf.reshapetf.squeeze来实现它。这是从np.repeat转换示例的一种方法:

import numpy as np
import tensorflow as tf

x = [[1,2],[3,4]]
print np.repeat(3, 4)
print np.repeat(x, 2)
print np.repeat(x, 3, axis=1)

x = tf.constant([[1,2],[3,4]])
with tf.Session() as sess:
    print sess.run(tf.tile([3], [4]))
    print sess.run(tf.squeeze(tf.reshape(tf.tile(tf.reshape(x, (-1, 1)), (1, 2)), (1, -1))))
    print sess.run(tf.reshape(tf.tile(tf.reshape(x, (-1, 1)), (1, 3)), (2, -1)))

在最后一种情况下,每个元素的重复都不相同,您很可能需要循环

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号