我如何在张量流中复制变量
tensorflow
5
0

在numpy中,我可以使用numpy.copy创建变量的副本。是否可以使用类似的方法在TensorFlow中创建Tensor的副本?

参考资料:
Stack Overflow
收藏
评论
共 3 个回答
高赞 时间 活跃

您可以通过以下两种方式进行操作。

  • 这将为您创建一个副本: v2 = tf.Variable(v1)
  • 您还可以使用Identity opv2 = tf.identity(v1) (我认为这是正确的做法。

这是一个代码示例:

import tensorflow as tf

v1 = tf.Variable([[1, 2], [3, 4]])
v_copy1 = tf.Variable(v1)
v_copy2 = tf.identity(v1)

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
a, b = sess.run([v_copy1, v_copy2])
sess.close()

print a
print b

它们都将打印相同的张量。

收藏
评论

您询问如何在标题中复制变量,但如何在问题中复制张量。让我们看一下不同的可能答案。

(1)您想要创建一个张量 ,该张量具有当前存储在我们称为var变量中的

tensor = tf.identity(var)

但是请记住,“张量”是一个图节点,在求值时将具有该值,并且在您每次对其求值时,它将获取var当前值。您可以使用诸如with_dependencies()类的控制流程操作来查看变量更新的顺序和标识的时间安排。

(2)您要创建另一个变量并将其值设置为当前存储在变量中的值:

import tensorflow as tf
var = tf.Variable(0.9)
var2 = tf.Variable(0.0)
copy_first_variable = var2.assign(var)
init = tf.initialize_all_variables()
sess = tf.Session()

sess.run(init)

print sess.run(var2)
sess.run(copy_first_variable)
print sess.run(var2)

(3)您想定义一个变量并将其起始值设置为与您已初始化变量相同的东西(这是上面的答案nivwu ..):

var2 = tf.Variable(var.initialized_value())

当您调用tf.initialize_all_variables时, var2将被初始化。在初始化图形并开始运行后,不能使用它来复制var。

收藏
评论

执行深层复制

copied_variable = tf.Variable(source_variable.initialized_value())

它还可以正确地处理初始化,即

tf.intialize_all_variables()

会先正确初始化source_variable,然后将该值复制到copyed_variable

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号