被tf.cond的行为所迷惑
tensorflow
6
0

我的图形中需要一个条件控制流。如果predTrue ,则图应调用一个操作,该操作将更新变量,然后将其返回,否则它将返回不变的变量。简化的版本是:

pred = tf.constant(True)
x = tf.Variable([1])
assign_x_2 = tf.assign(x, [2])
def update_x_2():
  with tf.control_dependencies([assign_x_2]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())

但是,我发现pred=Truepred=False导致相同的结果y=[2] ,这意味着当update_x_2未选择tf.cond时,也会调用assign op。怎么解释呢?以及如何解决这个问题?

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃
pred = tf.constant(False)
x = tf.Variable([1])

def update_x_2():
    assign_x_2 = tf.assign(x, [2])
    with tf.control_dependencies([assign_x_2]):
        return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())

这将得到[1]的结果。

该答案与上述答案完全相同。但是我想分享的是,您可以将您想使用的所有操作都放入其分支功能中。因为给定您的示例代码,张量x是可以由update_x_2函数直接使用。

收藏
评论

TL; DR:如果要让tf.cond()在其中一个分支中执行副作用(如赋值),则必须创建在传递给tf.cond()的函数执行副作用的op。 。

tf.cond()的行为有点不直观。由于TensorFlow图中的执行向前遍历图,因此在评估条件之前,必须先执行任一分支中引用的所有操作。这意味着true和false分支都接收对tf.assign() op的控制依赖项,因此y始终设置为2 ,即使pred is False`。

解决方案是在定义真实分支的函数内创建tf.assign() op。例如,您可以按以下方式组织代码:

pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
def update_x_2():
  with tf.control_dependencies([tf.assign(x, [2])]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval(feed_dict={pred: False}))  # ==> [1]
  print(y.eval(feed_dict={pred: True}))   # ==> [2]
收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号