如何为TensorFlow变量分配值?
deep-learning
neural-network
python
tensorflow
5
0

我正在尝试为python中的tensorflow变量分配一个新值。

import tensorflow as tf
import numpy as np

x = tf.Variable(0)
init = tf.initialize_all_variables()
sess = tf.InteractiveSession()
sess.run(init)

print(x.eval())

x.assign(1)
print(x.eval())

但是我得到的输出是

0
0

因此该值没有改变。我想念什么?

参考资料:
Stack Overflow
收藏
评论
共 4 个回答
高赞 时间 活跃

在TF1中,语句x.assign(1)实际上并未将值1分配给x ,而是创建了tf.Operation ,您必须显式运行以更新变量。*调用Operation.run()Session.run()可用于运行操作:

assign_op = x.assign(1)
sess.run(assign_op)  # or `assign_op.op.run()`
print(x.eval())
# ==> 1

(*实际上,它返回一个tf.Tensor ,与变量的更新值相对应,以便更轻松地链接分配。)

但是,在TF2中, x.assign(1)现在会急切地分配值:

x.assign(1)
print(x.numpy())
# ==> 1
收藏
评论

首先,您可以通过使用占位符的相同方式将值输入变量/常量来为它们分配值。因此,这样做完全合法:

import tensorflow as tf
x = tf.Variable(0)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(x, feed_dict={x: 3})

关于您与tf.assign()运算符的混淆。在TF中,在会话内部运行它之前不会执行任何操作。因此,您始终必须执行以下操作: op_name = tf.some_function_that_create_op(params) ,然后在会话内部运行sess.run(op_name) 。以assign为例,您将执行以下操作:

import tensorflow as tf
x = tf.Variable(0)
y = tf.assign(x, 1)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(x)
    print sess.run(y)
    print sess.run(x)
收藏
评论

另外,还必须注意,如果您使用的是your_tensor.assign() ,则无需显式调用tf.global_variables_initializer因为assign操作会在后台为您执行此操作。

例:

In [212]: w = tf.Variable(12)
In [213]: w_new = w.assign(34)

In [214]: with tf.Session() as sess:
     ...:     sess.run(w_new)
     ...:     print(w_new.eval())

# output
34 

但是,这不会初始化所有变量,而只会初始化执行assign的变量。

收藏
评论

您也可以为tf.Variable分配新值,而无需在图形上添加操作: tf.Variable.load(value, session) 。从图形外部分配值时,此功能还可以节省添加占位符的功能,这在图形完成时很有用。

import tensorflow as tf
x = tf.Variable(0)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(x))  # Prints 0.
x.load(1, sess)
print(sess.run(x))  # Prints 1.

更新:在TF2中对此进行了描述,因为默认情况下急切执行,并且在面向用户的API中不再显示图形。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号