在Pytorch中使用Dropout:nn.Dropout与F.dropout
deep-learning
neural-network
pytorch
17
0

通过使用pyTorch,有两种方法可以torch.nn.Dropouttorch.nn.functional.Dropout

我很难看到它们之间的区别:

  • 什么时候使用什么?
  • 这有什么不同吗?

切换它们时,我看不到任何性能差异。

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

技术差异已在其他答案中显示。但是主要区别在于nn.Dropout是一个手电筒模块,它具有一些便利:

一个简短的例子来说明一些差异:

import torch
import torch.nn as nn

class Model1(nn.Module):
    # Model 1 using functional dropout
    def __init__(self, p=0.0):
        super().__init__()
        self.p = p

    def forward(self, inputs):
        return nn.functional.dropout(inputs, p=self.p, training=True)

class Model2(nn.Module):
    # Model 2 using dropout module
    def __init__(self, p=0.0):
        super().__init__()
        self.drop_layer = nn.Dropout(p=p)

    def forward(self, inputs):
        return self.drop_layer(inputs)
model1 = Model1(p=0.5) # functional dropout 
model2 = Model2(p=0.5) # dropout module

# creating inputs
inputs = torch.rand(10)
# forwarding inputs in train mode
print('Normal (train) model:')
print('Model 1', model1(inputs))
print('Model 2', model2(inputs))
print()

# switching to eval mode
model1.eval()
model2.eval()

# forwarding inputs in evaluation mode
print('Evaluation mode:')
print('Model 1', model1(inputs))
print('Model 2', model2(inputs))
# show model summary
print('Print summary:')
print(model1)
print(model2)

输出:

Normal (train) model:
Model 1 tensor([ 1.5040,  0.0000,  0.0000,  0.8563,  0.0000,  0.0000,  1.5951,
         0.0000,  0.0000,  0.0946])
Model 2 tensor([ 0.0000,  0.3713,  1.9303,  0.0000,  0.0000,  0.3574,  0.0000,
         1.1273,  1.5818,  0.0946])

Evaluation mode:
Model 1 tensor([ 0.0000,  0.3713,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000])
Model 2 tensor([ 0.7520,  0.1857,  0.9651,  0.4281,  0.7883,  0.1787,  0.7975,
         0.5636,  0.7909,  0.0473])
Print summary:
Model1()
Model2(
  (drop_layer): Dropout(p=0.5)
)

那我应该使用哪个呢?

两者在应用nn.Dropout方面是完全等效的,即使用法的差异不是很大,但出于某些原因,我们还是倾向于nn.Dropout不是nn.functional.dropout

辍学被设计为仅在训练期间应用,因此在进行模型的预测或评估时,您希望关闭辍学。

退出模块nn.Dropout方便地处理此问题,并在模型进入评估模式后立即关闭退出,而功能辍学并不在乎评估/预测模式。

即使您可以将功能辍学设置为training=False来关闭它,它仍然不是像nn.Dropout这样方便的解决方案。

下降率也存储在模块中,因此您不必将其保存在额外的变量中。在较大的网络中,您可能需要创建具有不同丢弃率的不同丢弃层-在此处nn.Dropout可以提高可读性,并且在多次使用这些层时也可以带来一些便利。

最后,分配给您模型的所有模块都将在您的模型中注册。因此,您的模型类会跟踪它们,这就是为什么您可以通过调用eval()来关闭dropout模块的原因。使用功能性辍学时,您的模型不知道它,因此它不会出现在任何摘要中。

收藏
评论

如果查看nn.DropoutFunctional.Dropout的源代码, 可以看到Functional是一个接口,并且nn模块针对此接口实现功能。
查看nn类中的实现:

from .. import functional as F
class Dropout(_DropoutNd):
    def forward(self, input):
        return F.dropout(input, self.p, self.training, self.inplace)

class Dropout2d(_DropoutNd):
    def forward(self, input):
        return F.dropout2d(input, self.p, self.training, self.inplace)

等等。

Functional类的实现:

def dropout(input, p=0.5, training=False, inplace=False):
    return _functions.dropout.Dropout.apply(input, p, training, inplace)

def dropout2d(input, p=0.5, training=False, inplace=False):
    return _functions.dropout.FeatureDropout.apply(input, p, training, inplace)

看下面的例子了解:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

有一个F.dropoutforward()函数和nn.Dropout__init__()函数。现在说明一下:

在PyTorch中,您将模型定义为torch.nn.Module的子类。

init函数中,您应该初始化要使用的层。与keras不同,Pytorch的层次更低,您必须指定网络的大小,以便所有内容都匹配。

在forward方法中,指定图层的连接。这意味着您将使用已经初始化的层,以便对您进行的每个数据前传重复使用同一层。

torch.nn.Functional包含一些有用的函数,例如激活函数和可以使用的卷积运算。但是,这些不是完整的图层,因此,如果要指定任何种类的图层,都应使用torch.nn.Module。

您可以使用torch.nn.Functional conv操作来定义自定义层,例如通过卷积操作,而不是定义标准卷积层。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号