TensorFlow中的步骤和纪元之间是什么关系?
tensorflow
8
0

我正在阅读TensorFlow 入门教程 。在tf.contrib.learn示例中,这些是两行代码:

input_fn = tf.contrib.learn.io.numpy_input_fn({"x":x}, y, batch_size=4, num_epochs=1000)
estimator.fit(input_fn=input_fn, steps=1000)

我想知道是什么参数的区别steps中调用fit功能和num_epochsnumpy_input_fn通话。不应该只有一个论点吗?它们如何连接?

我发现该代码以某种方式将这两者中的min作为本教程玩具示例中的步骤数。

至少num_epochssteps这两个参数之一必须是冗余的。我们可以彼此计算。有没有办法知道我的算法实际执行了多少步骤(参数更新的次数)?

我很好奇哪个优先。并取决于其他一些参数吗?

参考资料:
Stack Overflow
收藏
评论
共 5 个回答
高赞 时间 活跃

时代:一次遍历整个数据。

批次大小:一批中没有示例。

如果有1000个示例,并且批处理大小为100,则每个时期将有10个步骤。

时期和批量大小完全定义了步骤数。

steps_cal =(不符合/批量大小)* no_of_epochs

estimator.fit(input_fn=input_fn)

如果您只是编写上述代码,则“ steps”的值将由上述公式中的“ steps_cal”给出。

estimator.fit(input_fn=input_fn, steps  = steps_less)

如果您提供的值(例如``steps_less'')小于``steps_cal'',则仅执行``steps_less''个步骤,在这种情况下,培训将不会涵盖所提到的所有时期。

estimator.fit(input_fn=input_fn, steps  = steps_more)

如果您提供的值(例如,steps_more)多于steps_cal,则也会执行“ steps_cal”,而不会执行任何步骤。

收藏
评论

该答案基于我对入门教程代码所做的实验。

Mad Wombat对num_epochsbatch_sizesteps术语进行了详细说明。这个答案是对他答案的扩展。

num_epochs-程序可以在一train()迭代整个数据集的最大次数。使用此参数,我们可以限制在执行一个train()方法期间可以处理的批处理数量。

batch_size - input_fn发出的单个批次中的示例数

步骤 LinearRegressor.train()方法可在一次执行中处理的批次数量

max_stepsLinearRegressor.train()方法的另一个参数。此参数定义了LinearRegressor()对象生命周期中可以处理的最大步数(批次)。

我们这意味着什么。以下实验更改了教程提供的两行代码。其余代码保持不变。

注意:对于所有示例,假定训练次数,即x_train的长度等于4。

例1:

input_fn = tf.estimator.inputs.numpy_input_fn( {"x": x_train}, y_train, batch_size=4, num_epochs=2, shuffle=True)

estimator.train(input_fn=input_fn, steps=10)

在此示例中,我们定义了batch_size = 4和num_epochs =2。因此,对于一次train()执行, input_fn只能发出2批输入数据。即使我们定义的步数 = 10, train()方法也会在2个步骤后停止。

现在,再次执行estimator.train(input_fn=input_fn, steps=10) 。我们可以看到还执行了2个步骤。我们可以一次又一次地继续执行train()方法。如果我们执行train() 50次,则总共执行了100个步骤。

例2:

input_fn = tf.estimator.inputs.numpy_input_fn( {"x": x_train}, y_train, batch_size=2, num_epochs=2, shuffle=True)

estimator.train(input_fn=input_fn, steps=10)

在此示例中, batch_size的值更改为2(在示例1中等于4)。现在,在每次执行train()方法时,将处理4个步骤。第四步之后,没有要继续运行的批处理。如果再次执行train()方法,则将处理另外4个步骤,使其总共达到8个步骤。

在这里, 步骤的值无关紧要,因为train()方法最多可以获取4个批次。如果steps的值小于( num_epochs x training_size )/ batch_size ,请参阅示例 3。

例3:

input_fn = tf.estimator.inputs.numpy_input_fn( {"x": x_train}, y_train, batch_size=2, num_epochs=8, shuffle=True)

estimator.train(input_fn=input_fn, steps=10)

现在,让batch_size = 2, num_epochs = 8, 步数 = input_fn可以在一次train()方法中总共发出16个批处理。但是,将step设置为10。这意味着,即使input_fn可以提供16个批处理来执行, train()必须在10个步骤之后停止。当然,可以重新执行train()方法以累积更多步骤。


从示例1、2和3中,我们可以清楚地看到stepnum_epochbatch_size的值如何影响一次可以由train()方法执行的步骤数。

的参数MAX_STEPS train()方法的限制,可以通过累积运行步骤的总数train()

例4:

如果batch_size = 4, num_epochs = 2,则input_fn可以发出2个批处理以执行一次train() 。但是,如果将max_steps设置为20,那么无论执行多少次train() ,优化中都将只运行20个步骤。这与示例1相反,在示例1中,如果将train()方法提取100次,则优化器可以运行200步。

希望这能对这些论点的含义有一个详细的了解。

收藏
评论

TL; DR :一个时期是您的模型一次遍历整个训练数据。步骤是当您的模型训练单个批次(如果您一个个地发送样本,则训练单个样本)。在1000个样本上训练5个时期,每批10个样本将需要500个步骤。

contrib.learn.io模块的文档不够好,但是numpy_input_fn()函数似乎需要一些numpy数组并将它们分批作为分类器的输入。因此,时期的数量可能意味着“停止之前要经过多少次输入数据”。在这种情况下,它们以4批元素的方式喂入两个长度为4的数组,因此这仅意味着输入函数最多会执行1000次此操作,然后再引发“数据不足”异常。 estimator fit()函数中的steps参数是estimator应该执行训练循环的次数。这个特定示例有些不正确,所以让我组成另一个示例,使事情更清楚(希望如此)。

假设您有两个要训练的numpy数组(样本和标签)。它们每个都是100个元素。您希望培训以每批10个样本的方式进行。因此,经过10批处理后,您将遍历所有训练数据。那是一个时代。如果将输入生成器设置为10个纪元,它将在停止之前经过10次训练集,即最多生成100个批次。

再次说明,io模块没有文档说明,但是考虑到tensorflow中其他与输入相关的API如何工作,应该有可能使其生成无限次数的数据,因此控制训练时间的唯一方法就是步骤。这为您提供了一些额外的灵活性,让您可以更好地进行培训。您可以一次执行多个纪元,或者一次执行多个步骤,或者两者都执行。

收藏
评论

num_epochs:最大纪元数(查看每个数据点)。

步骤:(参数的)更新次数。

当批次大小小于训练数据数量时,您可以在一个时期内多次更新。

收藏
评论

让我们从相反的顺序开始:

1) 步骤 -学习算法中的训练循环运行以更新模型中参数的次数。在每次循环迭代中,它将处理大块数据,这基本上是一个批处理。通常,此循环基于“ 梯度下降”算法。

2) 批量大小 -您在学习算法的每个循环中提供的数据块的大小。您可以提供整个数据集,在这种情况下,批处理大小等于数据集大小。您也可以一次提供一个示例。或者,您可以提供一些N个示例。

3) 时代 -您遍历数据集提取批处理以供学习算法使用的次数。

假设您有1000个示例。设置批处理大小= 100,历元= 1,步数= 200,可在整个数据集中进行一次处理(一次历时)。在每遍中,它将为该算法提供100个示例的批处理。该算法每批将运行200个步骤。总共可以看到10个批次。如果将纪元更改为25,则它将执行25次,总共将看到25x10批。

我们为什么需要这个?梯度下降(批量,随机,微型批量)以及用于优化学习参数的其他算法(例如L-BFGS)有很多变体。其中一些需要分批查看数据,而其他一些则一次只能查看一个基准。另外,其中一些包含随机因素/步骤,因此您可能需要对数据进行多次遍历才能获得良好的收敛性。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号