是否有一个示例如何生成保存训练有素的TensorFlow图的protobuf文件
tensorflow
10
0

我正在看Google如何在Android上部署和使用预训练的Tensorflow图(模型) 的示例 。本示例在以下位置使用.pb文件:

https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip

这是自动下载文件的链接

该示例显示了如何将.pb文件加载到Tensorflow会话中并使用它执行分类,但是似乎没有提到在训练图之后(例如,在Python中)如何生成这样的.pb文件。

有什么例子可以做到吗?

参考资料:
Stack Overflow
收藏
评论
共 4 个回答
高赞 时间 活跃

freeze_graph()我以前使用freeze_graph()答案freeze_graph() ,它仅在您将其作为脚本调用时才freeze_graph() ,还有一个非常不错的函数,它将为您完成所有繁重的工作,适合从您的常规模型训练代码中调用。

convert_variables_to_constants()做两件事:

  • 通过用常量替换变量来冻结权重
  • 它删除与前馈预测无关的节点

假设sess是您的tf.Session()"output"是您的预测节点的名称,下面的代码会将您的最小图形序列化为文本和二进制protobuf。


from tensorflow.python.framework.graph_util import convert_variables_to_constants

minimal_graph = convert_variables_to_constants(sess, sess.graph_def, ["output"])

tf.train.write_graph(minimal_graph, '.', 'minimal_graph.proto', as_text=False)
tf.train.write_graph(minimal_graph, '.', 'minimal_graph.txt', as_text=True)
收藏
评论

这是@Mostafa答案的另一种说法。运行tf.assign ops的一种更tf.assign是将它们存储在tf.group 。这是我的Python代码:

  ops = []
  for v in tf.trainable_variables():
    vc = tf.constant(v.eval())
    ops.append(tf.assign(v, vc));
  tf.group(*ops, name="assign_trained_variables")

在C ++中:

  std::vector<tensorflow::Tensor> tmp;
  status = session.Run({}, {}, { "assign_trained_variables" }, &tmp);
  if (!status.ok()) {
    // Handle error
  }

这样,您只有一个命名的op可以在C ++端运行,因此您不必在节点上进行迭代。

收藏
评论

我不知道如何实现mrry描述的方法。但是这是我如何解决的。我不确定这是否是解决问题的最佳方法,但至少可以解决问题。

由于write_graph也可以存储常量的值,因此在使用write_graph函数编写图形之前,我向python添加了以下代码:

for v in tf.trainable_variables():
    vc = tf.constant(v.eval())
    tf.assign(v, vc, name="assign_variables")

这将创建在训练后存储变量值的常量,然后创建张量“ assign_variables ”以将其分配给变量。现在,当您调用write_graph时,它将以常量形式将变量的值存储在文件中。

剩下的唯一部分是在c代码中将这些张量称为“ assign_variables ”,以确保为变量分配了存储在文件中的常量值。这是一种实现方法:

      Status status = NewSession(SessionOptions(), &session);
      std::vector<tensorflow::Tensor> outputs;
      char name[100];
      for(int i = 0;status.ok(); i++) {
        if (i==0)
            sprintf(name, "assign_variables");
        else
            sprintf(name, "assign_variables_%d", i);

        status = session->Run({}, {name}, {}, &outputs);
      }
收藏
评论

编辑: freeze_graph.py脚本是TensorFlow信息库的一部分,现在用作从现有TensorFlow GraphDef和保存的检查点生成代表“冻结”训练模型的协议缓冲区的工具。它使用与以下所述相同的步骤,但使用起来容易得多。


目前,该过程的文档尚不完善(有待完善),但大致步骤如下:

  1. 建立和训练你的模型作为tf.Graph称为g_1
  2. 获取每个变量的最终值,并将它们存储为numpy数组(使用Session.run() )。
  3. 名为g_2 的新tf.Graph ,使用在步骤2中获取的相应numpy数组的值为每个变量创建tf.constant()张量。
  4. 使用tf.import_graph_def()从复制节点g_1g_2 ,并使用input_map参数替换每个变量g_1与相应tf.constant()在第三步建立张量你也可能需要使用input_map指定新输入张量(例如,用tf.placeholder()替换输入管道 )。使用return_elements参数指定预测的输出张量的名称。

  5. 调用g_2.as_graph_def()以获取图形的协议缓冲区表示形式。

注意:生成的图将在图中包含额外的节点以进行训练。尽管它不是公共API的一部分,但您可能希望使用内部graph_util.extract_sub_graph()函数从图中剥离这些节点。)

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题