有两种方法可以将单个新图像提供给cifar10模型。第一种方法是一种更简洁的方法,但是需要在主文件中进行修改,因此将需要重新培训。当用户不想修改模型文件而想使用现有的检查点/元图文件时,第二种方法适用。
第一种方法的代码如下:
import tensorflow as tf
import numpy as np
import cv2
sess = tf.Session('', tf.Graph())
with sess.graph.as_default():
# Read meta graph and checkpoint to restore tf session
saver = tf.train.import_meta_graph("/tmp/cifar10_train/model.ckpt-200.meta")
saver.restore(sess, "/tmp/cifar10_train/model.ckpt-200")
# Read a single image from a file.
img = cv2.imread('tmp.png')
img = np.expand_dims(img, axis=0)
# Start the queue runners. If they are not started the program will hang
# see e.g. https://www.tensorflow.org/programmers_guide/reading_data
coord = tf.train.Coordinator()
threads = []
for qr in sess.graph.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
start=True))
# In the graph created above, feed "is_training" and "imgs" placeholders.
# Feeding them will disconnect the path from queue runners to the graph
# and enable a path from the placeholder instead. The "img" placeholder will be
# fed with the image that was read above.
logits = sess.run('softmax_linear/softmax_linear:0',
feed_dict={'is_training:0': False, 'imgs:0': img})
#Print classifiction results.
print(logits)
该脚本要求用户创建两个占位符和一个条件执行语句才能使其正常工作。
占位符和条件执行语句添加到cifar10_train.py中,如下所示:
def train():
"""Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default():
global_step = tf.contrib.framework.get_or_create_global_step()
with tf.device('/cpu:0'):
images, labels = cifar10.distorted_inputs()
is_training = tf.placeholder(dtype=bool,shape=(),name='is_training')
imgs = tf.placeholder(tf.float32, (1, 32, 32, 3), name='imgs')
images = tf.cond(is_training, lambda:images, lambda:imgs)
logits = cifar10.inference(images)
cifar10模型中的输入连接到队列运行器对象,该对象是一个多阶段队列,可以并行地从文件中预取数据。 在这里查看队列赛跑者的漂亮动画
尽管队列运行器可以有效地预取大型数据集进行训练,但对于仅需要对单个文件进行分类的推理/测试来说,它们是一个过大的杀伤力,而且它们在修改/维护方面也涉及更多。因此,我添加了一个占位符“ is_training”,在训练时将其设置为False,如下所示:
import numpy as np
tmp_img = np.ndarray(shape=(1,32,32,3), dtype=float)
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op, feed_dict={is_training: True, imgs: tmp_img})
另一个占位符“ imgs”持有在推断期间将要馈送的图像的张量形状(1、32、32、3)-第一维是批大小,在这种情况下为一。我修改了cifar模型以接受32x32图像而不是24x24,因为原始的cifar10图像是32x32。
最后,条件语句将占位符或队列运行器的输出提供给图形。在推理过程中,将“ is_training”占位符设置为False,为“ img”占位符提供一个numpy数组-将numpy数组从3维向量重构为4维向量,以符合模型中的输入张量和推理函数。
这就是全部。可以使用单个/用户定义的测试数据来推断任何模型,如上面的脚本中所示。本质上是读取图,将数据馈送到图节点并运行图以获取最终输出。
现在第二种方法。另一种方法是对cifar10.py和cifar10_eval.py进行修改,以将批处理大小更改为1,并将来自队列运行器的数据替换为从文件读取的数据。
将批次大小设置为1:
tf.app.flags.DEFINE_integer('batch_size', 1,
"""Number of images to process in a batch.""")
读取图像文件的调用推断。
def evaluate(): with tf.Graph().as_default() as g:
# Get images and labels for CIFAR-10.
eval_data = FLAGS.eval_data == 'test'
images, labels = cifar10.inputs(eval_data=eval_data)
import cv2
img = cv2.imread('tmp.png')
img = np.expand_dims(img, axis=0)
img = tf.cast(img, tf.float32)
logits = cifar10.inference(img)
然后将logits传递给eval_once并修改一次eval以评估logits:
def eval_once(saver, summary_writer, top_k_op, logits, summary_op):
...
while step < num_iter and not coord.should_stop():
predictions = sess.run([top_k_op])
print(sess.run(logits))
没有单独的脚本可以运行此推理方法,只需运行cifar10_eval.py即可从用户定义的位置读取文件,批处理大小为1。
0
我认为,如果针对CIFAR-10教程中的convnet创建的模型测试单个新图像这一关键任务有一个文档完善的解决方案,那么它对Tensorflow社区将有极大帮助。
我可能是错的,但是似乎缺少使经过训练的模型在实践中可用的关键步骤。该教程中有一个“缺失链接” —一个脚本,该脚本将直接加载单个图像(作为数组或二进制图像),将其与经过训练的模型进行比较,然后返回分类。
先前的答案给出了部分解决方案,这些解决方案说明了整体方法,但是我没有一个能够成功实现。可以在这里和那里找到其他零碎的部分,但不幸的是,这些零碎的部分尚未添加到有效的解决方案中。在将其标记为重复项或已回答之前,请考虑一下我所做的研究。
Tensorflow:如何保存/恢复模型?
恢复TensorFlow模型
无法在Tensorflow V0.8中还原模型
https://gist.github.com/nikitakit/6ef3b72be67b86cb7868
最受欢迎的答案是第一个,其中@RyanSepassi和@YaroslavBulatov描述了问题和解决方法:一个需要“手动构造具有相同节点名称的图,并使用Saver将权重加载到其中”。尽管这两个答案都有帮助,但将其插入CIFAR-10项目的方式尚不明确。
一个功能齐全的解决方案将是非常可取的,因此我们可以将其移植到其他单个图像分类问题。在这方面有几个关于SO的问题需要解决,但仍然没有完整的答案(例如,使用tensorflow DNN加载Checkpoint并评估单个图像 )。
我希望我们可以集中在每个人都可以使用的工作脚本上。
以下脚本尚不起作用,并且很高兴听到您的意见,说明如何使用CIFAR-10 TF教程训练的模型来改进此方法以提供单图像分类的解决方案。
假定原始教程未涉及所有变量,文件名等。
新文件: cifar10_eval_single.py