flat_inputs = tf.layers.flatten(inputs)

将张量流平展

共 3 个回答
高赞
时间
活跃
0

0

您可以在运行时使用动态重塑来通过tf.batch
获取批次尺寸的值,并将整个新尺寸的集合计算为tf.reshape
。这是在不知道列表长度的情况下将平面列表重塑为方矩阵的示例。
tf.reset_default_graph()
sess = tf.InteractiveSession("")
a = tf.placeholder(dtype=tf.int32)
# get [9]
ashape = tf.shape(a)
# slice the list from 0th to 1st position
ashape0 = tf.slice(ashape, [0], [1])
# reshape list to scalar, ie from [9] to 9
ashape0_flat = tf.reshape(ashape0, ())
# tf.sqrt doesn't support int, so cast to float
ashape0_flat_float = tf.to_float(ashape0_flat)
newshape0 = tf.sqrt(ashape0_flat_float)
# convert [3, 3] Python list into [3, 3] Tensor
newshape = tf.pack([newshape0, newshape0])
# tf.reshape doesn't accept float, so convert back to int
newshape_int = tf.to_int32(newshape)
a_reshaped = tf.reshape(a, newshape_int)
sess.run(a_reshaped, feed_dict={a: np.ones((9))})
你应该看到
array([[1, 1, 1],
[1, 1, 1],
[1, 1, 1]], dtype=int32)
0

您可以使用tf.reshape()轻松完成此操作,而无需知道批处理大小。
x = tf.placeholder(tf.float32, shape=[None, 9,2])
shape = x.get_shape().as_list() # a list: [None, 9, 2]
dim = numpy.prod(shape[1:]) # dim = prod(9,2) = 18
x2 = tf.reshape(x, [-1, dim]) # -1 means "all"
无论运行时批处理大小如何,最后一行中的-1
表示整个列。您可以在tf.reshape()中看到它。
更新:shape = [无,3,无]
谢谢@kbrose。对于未定义1个以上维度的情况,可以将tf.shape()与tf.reduce_prod()交替使用。
x = tf.placeholder(tf.float32, shape=[None, 3, None])
dim = tf.reduce_prod(tf.shape(x)[1:])
x2 = tf.reshape(x, [-1, dim])
tf.shape()返回可在运行时评估的形状张量。 tf.get_shape()和tf.shape()之间的区别可以在doc中看到。
我还在另一个尝试过tf.contrib.layers.flatten()。对于第一种情况,这是最简单的,但对于第二种情况,它是无法处理的。
新手导航
- 社区规范
- 提出问题
- 进行投票
- 个人资料
- 优化问题
- 回答问题
0
我有一个形状为
[None, 9, 2]
9,2[None, 9, 2]
(其中None
是批处理)。要对其执行其他操作(例如matmul),我需要将其转换为
[None, 18]
形状。怎么做?