Tensorflow:如何替换计算图中的节点?
python
tensorflow
5
0

如果您有两个不相交的图形,并且想要链接它们,请执行以下操作:

x = tf.placeholder('float')
y = f(x)

y = tf.placeholder('float')
z = f(y)

到这个:

x = tf.placeholder('float')
y = f(x)
z = g(y)

有没有办法做到这一点?在某些情况下,它似乎可以使施工更容易。

例如,如果您有一个将输入图像作为tf.placeholder ,并且想要优化输入图像(深梦风格),是否有一种方法可以仅用tf.variable节点替换占位符?还是在构建图形之前必须考虑一下?

参考资料:
Stack Overflow
收藏
评论
共 3 个回答
高赞 时间 活跃

如果要组合训练好的模型(例如在新模型中重用预先训练的模型的一部分),则可以使用Saver保存第一个模型的检查点,然后将该模型(全部或部分)还原到另一个模型中。

例如,假设您要在模型2中重用模型1的权重w ,并将x从占位符转换为变量:

with tf.Graph().as_default() as g1:
    x = tf.placeholder('float')
    w = tf.Variable(1., name="w")
    y = x * w
    saver = tf.train.Saver()

with tf.Session(graph=g1) as sess:
    w.initializer.run()
    # train...
    saver.save(sess, "my_model1.ckpt")

with tf.Graph().as_default() as g2:
    x = tf.Variable(2., name="v")
    w = tf.Variable(0., name="w")
    z = x + w
    restorer = tf.train.Saver([w]) # only restore w

with tf.Session(graph=g2) as sess:
    x.initializer.run()  # x now needs to be initialized
    restorer.restore(sess, "my_model1.ckpt") # restores w=1
    print(z.eval())  # prints 3.
收藏
评论

TL; DR:如果可以将这两个计算定义为Python函数,则应该这样做。如果不能,那么TensorFlow中有更多高级功能可用于序列化和导入图形,这使您可以从不同来源组成图形。

在TensorFlow中执行此操作的一种方法是将不相交的计算构建为单独的tf.Graph对象,然后使用Graph.as_graph_def()将它们转换为序列化的协议缓冲区:

with tf.Graph().as_default() as g_1:
  input = tf.placeholder(tf.float32, name="input")
  y = f(input)
  # NOTE: using identity to get a known name for the output tensor.
  output = tf.identity(y, name="output")

gdef_1 = g_1.as_graph_def()

with tf.Graph().as_default() as g_2:  # NOTE: g_2 not g_1       
  input = tf.placeholder(tf.float32, name="input")
  z = g(input)
  output = tf.identity(y, name="output")

gdef_2 = g_2.as_graph_def()

然后,您可以使用tf.import_graph_def()gdef_1gdef_2组成第三张图:

with tf.Graph().as_default() as g_combined:
  x = tf.placeholder(tf.float32, name="")

  # Import gdef_1, which performs f(x).
  # "input:0" and "output:0" are the names of tensors in gdef_1.
  y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
                           return_elements=["output:0"])

  # Import gdef_2, which performs g(y)
  z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
                           return_elements=["output:0"]
收藏
评论

事实证明, tf.train.import_meta_graph将所有其他参数传递给具有input_map参数的基础import_scoped_meta_graph ,并在对input_map进行自己的(内部)调用时import_graph_def

它没有记录,花了我很多时间才找到它,但是它有效!

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号