TensorFlow,为什么保存模型后会有3个文件?
tensorflow
5
0

阅读文档后 ,我将模型保存在TensorFlow ,这是我的演示代码:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

但是之后,我发现有3个文件

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

而且我无法通过还原model.ckpt文件来还原模型,因为没有这样的文件。这是我的代码

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

那么,为什么有3个文件?

参考资料:
Stack Overflow
收藏
评论
共 3 个回答
高赞 时间 活跃

尝试这个:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")

TensorFlow保存方法保存三种文件,因为它与变量值分开存储图结构.meta文件描述了保存的图形结构,因此您需要在还原检查点之前将其导入(否则,它不知道保存的检查点值对应于哪些变量)。

或者,您可以这样做:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")

即使没有名为model.ckpt文件,在还原文件时,仍会使用该名称来引用已保存的检查点。从saver.py源代码

用户只需要与用户指定的前缀进行交互...,而不需要与任何物理路径名进行交互。

收藏
评论

我正在从Word2Vec tensorflow教程中恢复经过训练的单词嵌入。

如果您创建了多个检查点:

例如,创建的文件如下所示

型号.ckpt-55695.data-00000-of-00001

型号.ckpt-55695.index

型号.ckpt-55695.meta

尝试这个

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')

调用restore_session()时:

def test_word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")
收藏
评论
  • meta文件 :描述保存的图形结构,包括GraphDef,SaverDef等;然后应用tf.train.import_meta_graph('/tmp/model.ckpt.meta') ,将还原SaverGraph

  • 索引文件 :它是一个不可变的字符串表(tensorflow :: table :: Table)。每个键是张量的名称,其值是序列化的BundleEntryProto。每个BundleEntryProto都描述张量的元数据:哪个“数据”文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等。

  • 数据文件 :它是TensorBundle集合,保存所有变量的值。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号