提前停止使用tf.estimator,怎么办?
keras
neural-network
python
tensorflow
6
0

我在TensorFlow 1.4中使用tf.estimator ,并且tf.estimator.train_and_evaluate很好,但是我需要尽早停止。首选的添加方式是什么?

我假设tf.train.SessionRunHook有一些tf.train.SessionRunHook 。我看到有一个带有ValidationMonitor的旧的contrib程序包,它似乎早已停止运行,但在1.4中似乎不再存在。还是将来首选的方法是依靠tf.keras (使用它确实很容易尽早停止),而不是tf.estimator/tf.layers/tf.data

参考资料:
Stack Overflow
收藏
评论
共 3 个回答
高赞 时间 活跃

好消息! tf.estimator现在已在master上提供了早期停止支持,并且看起来将在1.10中。

estimator = tf.estimator.Estimator(model_fn, model_dir)

os.makedirs(estimator.eval_dir())  # TODO This should not be expected IMO.

early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
    estimator,
    metric_name='loss',
    max_steps_without_decrease=1000,
    min_steps=100)

tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
    eval_spec=tf.estimator.EvalSpec(eval_input_fn))
收藏
评论

是的,有tf.train.StopAtStepHook

在执行了多个步骤或到达最后一个步骤之后,该挂钩请求将停止。只能指定两个选项之一。

您还可以扩展它,并根据步骤结果实施自己的停止策略。

class MyHook(session_run_hook.SessionRunHook):
  ...
  def after_run(self, run_context, run_values):
    if condition:
      run_context.request_stop()
收藏
评论

首先,您必须命名损失以使其可用于提前停止的呼叫。如果您的损失变量在估算器中被命名为“损失”,则该行

copyloss = tf.identity(loss, name="loss")

就在它下面会起作用。

然后,使用此代码创建一个钩子。

class EarlyStopping(tf.train.SessionRunHook):
    def __init__(self,smoothing=.997,tolerance=.03):
        self.lowestloss=float("inf")
        self.currentsmoothedloss=-1
        self.tolerance=tolerance
        self.smoothing=smoothing
    def before_run(self, run_context):
        graph = ops.get_default_graph()
        #print(graph)
        self.lossop=graph.get_operation_by_name("loss")
        #print(self.lossop)
        #print(self.lossop.outputs)
        self.element = self.lossop.outputs[0]
        #print(self.element)
        return tf.train.SessionRunArgs([self.element])
    def after_run(self, run_context, run_values):
        loss=run_values.results[0]
        #print("loss "+str(loss))
        #print("running average "+str(self.currentsmoothedloss))
        #print("")
        if(self.currentsmoothedloss<0):
            self.currentsmoothedloss=loss*1.5
        self.currentsmoothedloss=self.currentsmoothedloss*self.smoothing+loss*(1-self.smoothing)
        if(self.currentsmoothedloss<self.lowestloss):
            self.lowestloss=self.currentsmoothedloss
        if(self.currentsmoothedloss>self.lowestloss+self.tolerance):
            run_context.request_stop()
            print("REQUESTED_STOP")
            raise ValueError('Model Stopping because loss is increasing from EarlyStopping hook')

这会将指数平滑损失验证与最小值进行比较,如果容忍度更高,则会停止训练。如果停止太早,提高公差和平滑度将使其稍后停止。保持平滑度低于1,否则它将永不停止。

如果要基于其他条件停止运行,则可以将after_run中的逻辑替换为其他内容。

现在,将此钩子添加到评估规范中。您的代码应如下所示:

eval_spec=tf.estimator.EvalSpec(input_fn=lambda:eval_input_fn(batchsize),steps=100,hooks=[EarlyStopping()])#

重要说明:run_context.request_stop()函数在train_and_evaluate调用中已损坏,并且不会停止训练。因此,我提出了一个价值错误以停止训练。因此,您必须将train_and_evaluate调用包装在try catch块中,如下所示:

try:
    tf.estimator.train_and_evaluate(classifier,train_spec,eval_spec)
except ValueError as e:
    print("training stopped")

如果您不这样做,则训练停止时代码将崩溃并显示错误。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号