用法示例:
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")
# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')
# List contents of v0 tensor.
# Example output: tensor_name: v0 [[[[ 9.27958265e-02 7.40226209e-02 4.52989563e-02 3.15700471e-02
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')
# List contents of v1 tensor.
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')
更新:从Tensorflow 0.12.0-rc0开始, 已将 all_tensors
参数添加到print_tensors_in_checkpoint_file
因此如果需要,您可能需要添加all_tensors=False
或all_tensors=True
。
替代方法:
from tensorflow.python import pywrap_tensorflow
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
print(reader.get_tensor(key)) # Remove this is you want to print only variable names
希望能帮助到你。
0
我想查看保存在TensorFlow检查点中的变量及其值。如何找到保存在TensorFlow检查点中的变量名称?
我使用了
tf.train.NewCheckpointReader
, 在此进行了说明。但是,它在TensorFlow的文档中没有给出。还有其他办法吗?