对tf.Tensor.set_shape()的澄清
tensorflow
5
0

我的图像为478 x 717 x 3 = 1028178像素,等级为1。我通过调用tf.shape和tf.rank进行了验证。

当我调用image.set_shape([478,717,3])时,它将引发以下错误。

"Shapes %s and %s must have the same rank" % (self, other)) 
ValueError: Shapes (?,) and (478, 717, 3) must have the same rank

我通过首先强制转换为1028178再次进行了测试,但该错误仍然存在。

ValueError: Shapes (1028178,) and (478, 717, 3) must have the same rank

好吧,这确实是有道理的,因为其中一个的等级为1,另一个的等级为3。但是,为什么要抛出错误,因为像素总数仍然匹配。

我当然可以使用tf.reshape,它可以工作,但我认为这不是最佳选择。

如TensorFlow常见问题解答所述

x.set_shape()和x = tf.reshape(x)有什么区别?

tf.Tensor.set_shape()方法更新Tensor对象的静态形状,当无法直接推断时,通常用于提供其他形状信息。它不会改变张量的动态形状。

tf.reshape()操作创建一个具有不同动态形状的新张量。

创建一个新的张量涉及内存分配,而当涉及到更多的训练示例时,这可能会花费更高的成本。这是设计使然,还是我在这里错过了什么?

参考资料:
Stack Overflow
收藏
评论
共 1 个回答
高赞 时间 活跃

据我所知(并编写了该代码), Tensor.set_shape()没有错误。我认为误解源于该方法的名称混乱。

为了详细说明您引用FAQ条目Tensor.set_shape()是一个纯Python函数, 可以改善给定tf.Tensor对象的形状信息。 “改善”是指“使内容更加具体”。

因此,当您具有形状为(?,)Tensor对象t时,这是一个未知长度的一维张量。您可以调用t.set_shape((1028178,)) ,然后在调用t.get_shape()t将具有形状(1028178,) t.get_shape() 。这不会影响基础存储,也不会影响后端的任何内容;它仅表示使用t进行的后续形状推断可以依赖于断言它是长度为1028178的向量。

如果t具有形状(?,) ,则对t.set_shape((478, 717, 3)) (?,)的调用将失败,因为TensorFlow已经知道t是向量,因此它不能具有形状(478, 717, 3) 。如果要根据t的内容制作具有该形状的新Tensor,可以使用reshaped_t = tf.reshape(t, (478, 717, 3)) 。这将在Python中创建一个新的tf.Tensor对象。 tf.reshape()的实际实现是使用张量缓冲区的浅表副本执行此操作的,因此在实践中它不昂贵。

一个类比是Tensor.set_shape()就像是一种在运行时Tensor.set_shape()的语言,如Java。例如,如果您有一个指向Object的指针,但是知道它实际上是一个String ,则可以执行强制转换(String) obj以便将obj传递给需要String参数的方法。但是,如果你有一个String s ,并尝试将其转换为java.util.Vector ,编译器会给你一个错误,因为这两种类型无关。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号