如何在TensorFlow中更改变量的形状?
tensorflow
8
0

TensorFlow教程说,在创建时,我们需要指定张量的形状。该形状自动变为张量的形状。它还说TensorFlow提供了重塑变量的高级机制。我怎样才能做到这一点?任何代码示例?

参考资料:
Stack Overflow
收藏
评论
共 4 个回答
高赞 时间 活跃

建议使用tf.Variable类创建变量,但是一旦创建变量,它将限制您更改变量形状的能力。

如果需要更改变量的形状,则可以执行以下操作(例如,对于32位浮点张量):

var = tf.Variable(tf.placeholder(tf.float32))
# ...
new_value = ...  # Tensor or numpy array.
change_shape_op = tf.assign(var, new_value, validate_shape=False)
# ...
sess.run(change_shape_op)  # Changes the shape of `var` to new_value's shape.

请注意,此功能不在记录的公共API中,因此可能会发生更改。如果您确实需要使用此功能,请告诉我们,我们可以研究一种支持其发展的方法。

收藏
评论

查看TensorFlow文档中的形状和形状 。它描述了可用的不同形状转换。

最常见的功能可能是tf.reshape ,类似于它的numpy等效项。只要元素数量保持不变,它就可以指定所需的任何形状。文档中有一些示例。

收藏
评论

文档显示了重塑方法。他们是:

  • 重塑
  • 压缩(从张量的形状中删除尺寸为1的尺寸)
  • expand_dims(添加尺寸为1的尺寸)

以及一系列获取张量的shapesizerank的方法。可能最常用的是reshape ,这是一个带有几个边缘情况(-1)的代码示例:

import tensorflow as tf

v1 = tf.Variable([
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 10, 11, 12]
])
v2 = tf.reshape(v1, [2, 6])
v3 = tf.reshape(v1, [2, 2, -1])
v4 = tf.reshape(v1, [-1])
# v5 = tf.reshape(v1, [2, 4, -1]) will fail, because you can not find such an integer for -1
v6 = tf.reshape(v1, [1, 4, 1, 3, 1])
v6_shape = tf.shape(v6)
v6_squeezed = tf.squeeze(v6)
v6_squeezed_shape = tf.shape(v6_squeezed)

init = tf.initialize_all_variables()

sess = tf.Session()
sess.run(init)
a, b, c, d, e, f, g = sess.run([v2, v3, v4, v6, v6_shape, v6_squeezed, v6_squeezed_shape])
# print all variables to see what is there
print e # shape of v6
print g # shape of v6_squeezed
收藏
评论
tf.Variable(tf.placeholder(tf.float32))

在tensorflow 1.2.1上无效

在python shell中:

import tensorflow as tf
tf.Variable(tf.placeholder(tf.float32))

你会得到:

ValueError: initial_value must have a shape specified: Tensor("Placeholder:0", dtype=float32)

更新:如果添加validate_shape=False ,将没有错误。

tf.Variable(tf.placeholder(tf.float32), validate_shape=False)

如果tf.py_func符合您的要求:

def init():
    return numpy.random.rand(2,3)
a = tf.pyfun(init, [], tf.float32)

您可以通过传递自己的init函数来创建具有任何形状的变量。

另一种方式:

var = tf.get_varible('my-name', initializer=init, shape=(1,1))

您可以传递tf.constant或任何返回numpy数组的init函数。提供的形状不会被验证。输出形状是您的实际数据形状。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号