freeze_graph()
我以前使用freeze_graph()
答案freeze_graph()
,它仅在您将其作为脚本调用时才freeze_graph()
,还有一个非常不错的函数,它将为您完成所有繁重的工作,适合从您的常规模型训练代码中调用。
convert_variables_to_constants()
做两件事:
- 通过用常量替换变量来冻结权重
- 它删除与前馈预测无关的节点
假设sess
是您的tf.Session()
, "output"
是您的预测节点的名称,下面的代码会将您的最小图形序列化为文本和二进制protobuf。
from tensorflow.python.framework.graph_util import convert_variables_to_constants
minimal_graph = convert_variables_to_constants(sess, sess.graph_def, ["output"])
tf.train.write_graph(minimal_graph, '.', 'minimal_graph.proto', as_text=False)
tf.train.write_graph(minimal_graph, '.', 'minimal_graph.txt', as_text=True)
0
我正在看Google如何在Android上部署和使用预训练的Tensorflow图(模型) 的示例 。本示例在以下位置使用
.pb
文件:这是自动下载文件的链接 。
该示例显示了如何将
.pb
文件加载到Tensorflow会话中并使用它执行分类,但是似乎没有提到在训练图之后(例如,在Python中)如何生成这样的.pb
文件。有什么例子可以做到吗?