如何查找保存在检查点中的变量名称和值?
tensorflow
5
0

我想查看保存在TensorFlow检查点中的变量及其值。如何找到保存在TensorFlow检查点中的变量名称?

我使用了tf.train.NewCheckpointReader此进行了说明。但是,它在TensorFlow的文档中没有给出。还有其他办法吗?

参考资料:
Stack Overflow
收藏
评论
共 3 个回答
高赞 时间 活跃

您可以使用inspect_checkpoint.py工具。

因此,例如,如果将检查点存储在当前目录中,则可以按如下所示打印变量及其值

import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file


latest_ckp = tf.train.latest_checkpoint('./')
print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')
收藏
评论

一些细节。

例如,如果您的模型是使用V2格式保存的,则目录/my/dir/有以下文件

model-10000.data-00000-of-00001
model-10000.index
model-10000.meta

那么file_name参数应该仅是前缀,即

print_tensors_in_checkpoint_file(file_name='/my/dir/model_10000', tensor_name='', all_tensors=True)

参见https://github.com/tensorflow/tensorflow/issues/7696进行讨论。

收藏
评论

用法示例:

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")

# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')

# List contents of v0 tensor.
# Example output: tensor_name:  v0 [[[[  9.27958265e-02   7.40226209e-02   4.52989563e-02   3.15700471e-02
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')

# List contents of v1 tensor.
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')

更新:Tensorflow 0.12.0-rc0开始, 已将 all_tensors参数添加到print_tensors_in_checkpoint_file因此如果需要,您可能需要添加all_tensors=Falseall_tensors=True

替代方法:

from tensorflow.python import pywrap_tensorflow
import os

checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()

for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key)) # Remove this is you want to print only variable names

希望能帮助到你。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号