“ Flatten”在Keras中的作用是什么?
deep-learning
keras
machine-learning
neural-network
9
0

我正在尝试了解Keras中Flatten函数的作用。下面是我的代码,它是一个简单的两层网络。它接收形状为(3,2)的二维数据,并输出形状为(1,4)的一维数据:

model = Sequential()
model.add(Dense(16, input_shape=(3, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

x = np.array([[[1, 2], [3, 4], [5, 6]]])

y = model.predict(x)

print y.shape

打印出y形状为(1、4)。但是,如果我删除展Flatten线,则它会打印出y形状为(1、3、4)。

我不明白根据我对神经网络的理解, model.add(Dense(16, input_shape=(3, 2)))函数正在创建一个具有16个节点的隐藏的完全连接层。这些节点中的每个都连接到3x2输入元素中的每个。因此,该第一层的输出处的16个节点已经“平坦”。因此,第一层的输出形状应为(1、16)。然后,第二层将此作为输入,并输出形状为(1、4)的数据。

因此,如果第一层的输出已经“平坦”并且形状为(1,16),为什么还要进一步使其平坦?

参考资料:
Stack Overflow
收藏
评论
共 3 个回答
高赞 时间 活跃

在此处输入图片说明 Flatten是将Matrix转换为单个数组的工作方式。

收藏
评论

简短阅读:

展平张量意味着除去除一个以外的所有尺寸。这正是Flatten层所做的。

长读:

如果我们考虑创建的原始模型(具有Flatten层),则可以得到以下模型摘要:

Layer (type)                 Output Shape              Param #   
=================================================================
D16 (Dense)                  (None, 3, 16)             48        
_________________________________________________________________
A (Activation)               (None, 3, 16)             0         
_________________________________________________________________
F (Flatten)                  (None, 48)                0         
_________________________________________________________________
D4 (Dense)                   (None, 4)                 196       
=================================================================
Total params: 244
Trainable params: 244
Non-trainable params: 0

对于此摘要,下一张图像有望对每一层的输入和输出大小提供更多的了解。

可以读取的Flatten层的输出形状为(None, 48) 。这里是提示。您应该阅读(1, 48)(2, 48)或...或(16, 48) ...或(32, 48) ,...

实际上,该位置的“ None ”表示任何批量。对于召回的输入,第一维表示批处理大小,第二维表示输入要素的数量。

在Keras中Flatten层的作用非常简单:

对张量进行展平操作可将张量整形,使其形状等于不包含批次尺寸的张量中包含的元素数量。

在此处输入图片说明


注意:我使用了model.summary()方法来提供输出形状和参数详细信息。

收藏
评论

如果阅读Dense的Keras文档条目,您将看到以下调用:

Dense(16, input_shape=(5,3))

会形成一个具有3个输入和16个输出的Dense网络,这将独立地应用于5个步骤中的每个步骤。因此,如果D(x)将3维矢量转换为16维矢量,则从图层输出的输出将是矢量序列: [D(x[0,:]), D(x[1,:]),..., D(x[4,:])]的形状为(5, 16) [D(x[0,:]), D(x[1,:]),..., D(x[4,:])] (5, 16) 。为了具有您指定的行为,您可以先将输入展Flatten为15 d向量,然后应用Dense

model = Sequential()
model.add(Flatten(input_shape=(3, 2)))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

编辑:由于有些人难以理解-在这里,您有一个解释性的图像:

在此处输入图片说明

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号