如何使用布尔张量制作if语句
tensorflow
4
0

如何使用布尔张量制作if语句?更准确地说,我正在尝试将大小为1的张量与常量进行比较,以检查张量中的值是否小于常量。我发现我必须使常数成为自己的大小1张量,并使用方法检查第一个张量是否小于第二个张量,但是我不确定如何使生成的布尔张量正确放入if语句中。只需将其作为对if语句的查询,就可以使if语句始终返回true。

编辑:这或多或少是看起来的代码。但是,无论是否有参数,我都会得到错误的'bool' object has no attribute 'name'的信息,这使我认为问题在于它没有返回TensorFlow对象。

pred = tf.placeholder(tf.bool)

def if_true(x, y, z):
  #act on x, y, and z
  return True

def if_false():
  return False

# Will be `tf.cond()` in the next release.
from tensorflow.python.ops import control_flow_ops
from functools import partial
x = ...
y = ...
z = ...

result = control_flow_ops.cond(pred, partial(if_true, x, y, z), if_false)
参考资料:
Stack Overflow
收藏
评论
共 1 个回答
高赞 时间 活跃

TL; DR:您需要使用Session.run()来获取Python布尔值,但是还有其他方法可以达到相同的结果,而这种方法可能更有效。

看来您已经想出了如何从值中获取布尔张量,但是为了其他读者的利益,它看起来像这样:

computed_val = ...
constant_val = tf.constant(37.0)
pred = tf.less(computed_val, constant_val)  # N.B. Types of the two args must match

下一部分是如何使用它作为条件的。最简单的方法是使用Python if语句,但必须使用Session.run() 评估张量pred

sess = tf.Session()

if sess.run(pred):
  # Do something.
else:
  # Do something else.

关于使用Python if语句的一个警告是,您必须对整个表达式进行评估,直到pred为止,这使得重用已经计算出的中间值变得棘手。我想提请您注意另外两种使用TensorFlow计算条件表达式的方式,这些方式不需要您评估谓词并返回Python值。

第一种方法使用tf.select() op有条件地传递来自作为参数传递的两个张量的值:

pred = tf.placeholder(tf.bool)  # Can be any computed boolean expression.
val_if_true = tf.constant(28.0)
val_if_false = tf.constant(12.0)
result = tf.select(pred, val_if_true, val_if_false)

sess = tf.Session()
sess.run(result, feed_dict={pred: True})   # ==> 28.0
sess.run(result, feed_dict={pred: False})  # ==> 12.0

tf.select() op在其所有参数上都按元素进行工作,这使您可以组合来自两个输入张量的值。有关更多详细信息,请参见其文档tf.select()的缺点在于,它在计算结果之前先评估val_if_trueval_if_false ,如果它们是复杂的表达式,则可能会很昂贵。

第二种方法使用tf.cond() op,它有条件地计算两个表达式之一。如果表达式很昂贵,这尤其有用,如果它们具有副作用 ,则必不可少。基本模式是指定两个Python函数(或lambda表达式)来构建将在true或false分支上执行的子图:

# Define some large matrices
a = ...
b = ...
c = ...

pred = tf.placeholder(tf.bool)

def if_true():
  return tf.matmul(a, b)

def if_false():
  return tf.matmul(b, c)

# Will be `tf.cond()` in the next release.
from tensorflow.python.ops import control_flow_ops

result = tf.cond(pred, if_true, if_false)

sess = tf.Session()
sess.run(result, feed_dict={pred: True})   # ==> executes only (a x b)
sess.run(result, feed_dict={pred: False})  # ==> executes only (b x c)
收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号