在Tensorflow中恢复检查点时如何获取global_step?
tensorflow
9
0

我正在保存会话状态,如下所示:

self._saver = tf.saver()
self._saver.save(self._session, '/network', global_step=self._time)

稍后还原时,我想获取要还原的检查点的global_step值。这是为了从中设置一些超级参数。

这样做的怪癖方法是遍历并解析检查点目录中的文件名。但是,真的必须有一种更好的内置方法来做到这一点吗?

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

这有点骇人听闻,但其他答案对我根本没有用

ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 

#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])

更新9/2017

我不确定这是否由于更新而开始起作用,但是以下方法似乎对于使global_step正确更新和加载有效:

创建两个操作。一个用于保存global_step,另一个用于对其进行递增:

    global_step = tf.Variable(0, trainable=False, name='global_step')
    increment_global_step = tf.assign_add(global_step,1,
                                            name = 'increment_global_step')

现在,在您的训练循环中,每次您运行训练操作时都要运行增量操作。

sess.run([train_op,increment_global_step],feed_dict=feed_dict)

如果您想随时以整数形式检索全局步长值,只需在加载模型后使用以下命令:

sess.run(global_step)

这对于创建文件名或计算当前时期是有用的,而无需第二个tensorflow变量来保存该值。例如,计算加载时的当前时间将类似于:

loaded_epoch = sess.run(global_step)//(batch_size*num_train_records)
收藏
评论

一般模式是使用global_step变量来跟踪步骤

global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)

然后,您可以保存

saver.save(sess, save_path, global_step=global_step)

还原时,也会还原global_step的值

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号