首先尝试:
import tensorflow as tf
graph = tf.get_default_graph()
然后,当您需要使用预测时:
with graph.as_default():
y = model.predict(X)
0
首先尝试:
import tensorflow as tf
graph = tf.get_default_graph()
然后,当您需要使用预测时:
with graph.as_default():
y = model.predict(X)
0
创建Model
,会话尚未还原。在Model.__init__
中定义的所有占位符,变量和操作都放置在新图中 ,这使其自身成为with
块的默认图。这是关键行:
with tf.Graph().as_default():
...
这意味着tf.Graph()
此实例等于with
block内部的tf.get_default_graph()
实例, 但 tf.get_default_graph()
之前或之后 。从这一刻起,存在两个不同的图。
稍后创建会话并将图形还原到其中时,您将无法在该会话中访问tf.Graph()
的先前实例。这是一个简短的示例:
with tf.Graph().as_default() as graph:
var = tf.get_variable("var", shape=[3], initializer=tf.zeros_initializer)
# This works
with tf.Session(graph=graph) as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(var)) # ok because `sess.graph == graph`
# This fails
saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess, "/tmp/model.ckpt")
print(sess.run(var)) # var is from `graph`, not `sess.graph`!
解决此问题的最佳方法是为所有节点命名,例如'input'
, 'target'
等,保存模型,然后按名称在恢复的图中查找节点,如下所示:
saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess, "/tmp/model.ckpt")
input_data = sess.graph.get_tensor_by_name('input')
target = sess.graph.get_tensor_by_name('target')
此方法保证所有节点都将来自会话中的图。
0
我收到这个错误
如果不
with tf.Graph(). as_default():
,代码可以完美运行with tf.Graph(). as_default():
但是我需要多次调用M.sample(...)
,并且每次session.close()
之后内存都不可用。可能存在内存泄漏,但不确定在哪里。我想还原一个预先训练的神经网络,将其设置为默认图,并在默认图上对其进行多次测试(例如10000),而不必每次都使其变大。
代码是:
模型是:
输出为: