pytorch中的model.train()有什么作用?
python
pytorch
35
0

它会在nn.Module调用forward()吗?我以为,当我们调用模型时, forward使用forward方法。为什么我们需要指定train()?

参考资料:
Stack Overflow
收藏
评论
共 3 个回答
高赞 时间 活跃

model.train()告诉模型您正在训练模型。因此,在培训和测试过程中表现不同的有效层(例如辍学,batchnorm等)可以知道发生了什么,因此可以相应地表现。

更多详细信息:它设置训练模式(请参阅源代码 )。您可以调用model.eval()model.train(mode=False)来告诉您正在测试。期望train功能能够训练模型有点直观,但是并不能做到这一点。它只是设置模式。

收藏
评论

这是module.train()的代码:

def train(self, mode=True):
        r"""Sets the module in training mode."""      
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

这是module.eval

def eval(self):
        r"""Sets the module in evaluation mode."""
        return self.train(False)

模式traineval是我们可以在其中设置模块的仅有的两种模式,它们是完全相反的。

那只是一个self.training标志,目前只有 辍学bachnorm关心那个标志。

默认情况下,此标志设置为True

收藏
评论

有两种方法可以让模型知道您的意图,即您要训练模型还是要使用模型进行评估。在使用model.train()的情况下,模型知道必须学习层,并且当我们使用model.eval()时,它指示模型没有新知识要学习,并且该模型用于测试。 model.eval()也是必需的,因为在pytorch中,如果我们使用batchnorm,而在测试过程中,如果我们只想传递单个图像,则如果未指定model.eval(),则pytorch会引发错误。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号