Flatten层在Keras中如何工作?
keras
keras-layer
neural-network
tensorflow
7
0

我正在使用TensorFlow后端。

我依次应用卷积,最大池化,展平和密集层。卷积需要3D输入(高度,宽度,color_channels_depth)。

卷积后,该值变为(高度,宽度,Number_of_filters)。

应用最大池化后,高度和宽度会发生变化。但是,在应用平坦层之后,究竟会发生什么?例如,如果扁平化之前的输入为(24,24,32),那么如何扁平化呢?

它是按顺序(24 * 24)顺序排列还是按其他方式按顺序排列每个过滤器编号的高度,重量?实际值将是一个示例。

参考资料:
Stack Overflow
收藏
评论
共 3 个回答
高赞 时间 活跃

它是连续的,类似于24 * 24 * 32,并按以下代码所示对其进行重塑。

def batch_flatten(x):
    """Turn a nD tensor into a 2D tensor with same 0th dimension.
    In other words, it flattens each data samples of a batch.
    # Arguments
        x: A tensor or variable.
    # Returns
        A tensor.
    """
    x = tf.reshape(x, tf.stack([-1, prod(shape(x)[1:])]))
    return x
收藏
评论

展平张量意味着除去除一个以外的所有尺寸。

Keras中的Flatten层将张量整形为具有等于张量中包含的元素数量的形状。

这与制作元素的一维数组相同。

例如,在VGG16模型中,您可能会很容易理解:

>>> model.summary()
Layer (type)                     Output Shape          Param #
================================================================
vgg16 (Model)                    (None, 4, 4, 512)     14714688
________________________________________________________________
flatten_1 (Flatten)              (None, 8192)          0
________________________________________________________________
dense_1 (Dense)                  (None, 256)           2097408
________________________________________________________________
dense_2 (Dense)                  (None, 1)             257
===============================================================

请注意flatten_1图层的形状如何(无,8192),其中8192实际上是4 * 4 * 512。


PS,“无”表示任何尺寸(或动态尺寸),但通常可以将其读取为1。您可以在此处找到更多详细信息。

收藏
评论

Flatten()运算符将展开从最后一个维度开始的值(至少对于Theano,这是“通道优先”,而不是像TF的“通道最后”。我无法在我的环境中运行TensorFlow)。这等效于numpy.reshape '排序:

“ C”表示使用类似C的索引顺序读取/写入元素,最后一个轴索引更改最快,回到第一个轴索引更改最慢。

这是一个使用Keras Functional API的Flatten运算符的独立示例。您应该能够轻松适应您的环境。

import numpy as np
from keras.layers import Input, Flatten
from keras.models import Model
inputs = Input(shape=(3,2,4))

# Define a model consisting only of the Flatten operation
prediction = Flatten()(inputs)
model = Model(inputs=inputs, outputs=prediction)

X = np.arange(0,24).reshape(1,3,2,4)
print(X)
#[[[[ 0  1  2  3]
#   [ 4  5  6  7]]
#
#  [[ 8  9 10 11]
#   [12 13 14 15]]
#
#  [[16 17 18 19]
#   [20 21 22 23]]]]
model.predict(X)
#array([[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
#         11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,
#         22.,  23.]], dtype=float32)
收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号