这有点骇人听闻,但其他答案对我根本没有用
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
更新9/2017
我不确定这是否由于更新而开始起作用,但是以下方法似乎对于使global_step正确更新和加载有效:
创建两个操作。一个用于保存global_step,另一个用于对其进行递增:
global_step = tf.Variable(0, trainable=False, name='global_step')
increment_global_step = tf.assign_add(global_step,1,
name = 'increment_global_step')
现在,在您的训练循环中,每次您运行训练操作时都要运行增量操作。
sess.run([train_op,increment_global_step],feed_dict=feed_dict)
如果您想随时以整数形式检索全局步长值,只需在加载模型后使用以下命令:
sess.run(global_step)
这对于创建文件名或计算当前时期是有用的,而无需第二个tensorflow变量来保存该值。例如,计算加载时的当前时间将类似于:
loaded_epoch = sess.run(global_step)//(batch_size*num_train_records)
0
我正在保存会话状态,如下所示:
稍后还原时,我想获取要还原的检查点的global_step值。这是为了从中设置一些超级参数。
这样做的怪癖方法是遍历并解析检查点目录中的文件名。但是,真的必须有一种更好的内置方法来做到这一点吗?