Tensorflow图像读取和显示
python
tensorflow
9
0

我有一堆图像,格式类似于Cifar10(二进制文件,每个图像size = 96*96*3字节),一个接一个( STL-10数据集 )。我打开的文件有138MB。

我试图阅读并检查包含图像的张量的内容,以确保阅读正确,但是我有两个问题-

  1. FixedLengthRecordReader是否加载整个文件,但是一次只提供一个输入?由于读取第一个size字节应该相对较快。但是,该代码大约需要两分钟才能运行。
  2. 如何获得可显示格式的实际图像内容,或在内部显示它们以验证图像是否被良好阅读?我做了sess.run(uint8image) ,但是结果为空。

代码如下:

import tensorflow as tf
def read_stl10(filename_queue):
  class STL10Record(object):
    pass
  result = STL10Record()

  result.height = 96
  result.width = 96
  result.depth = 3
  image_bytes = result.height * result.width * result.depth
  record_bytes = image_bytes

  reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
  result.key, value = reader.read(filename_queue)
  print value
  record_bytes = tf.decode_raw(value, tf.uint8)

  depth_major = tf.reshape(tf.slice(record_bytes, [0], [image_bytes]),
                       [result.depth, result.height, result.width])
  result.uint8image = tf.transpose(depth_major, [1, 2, 0])
  return result
# probably a hack since I should've provided a string tensor

filename_queue = tf.train.string_input_producer(['./data/train_X'])
image = read_stl10(filename_queue)

print image.uint8image
with tf.Session() as sess:
  result = sess.run(image.uint8image)
  print result, type(result)

输出:

Tensor("ReaderRead:1", shape=TensorShape([]), dtype=string)
Tensor("transpose:0", shape=TensorShape([Dimension(96), Dimension(96), Dimension(3)]), dtype=uint8)
I tensorflow/core/common_runtime/local_device.cc:25] Local device intra op parallelism threads: 4
I tensorflow/core/common_runtime/local_session.cc:45] Local session inter op parallelism threads: 4
[empty line for last print]
Process finished with exit code 137

我正在我的CPU上运行它,如果添加了任何内容。

编辑:感谢Rosa,我找到了纯TensorFlow解决方案。显然,在使用string_input_producer ,为了查看结果,您需要初始化队列运行器。添加到上面的代码中唯一需要做的事情是下面的第二行:

...
with tf.Session() as sess:
    tf.train.start_queue_runners(sess=sess)
...

之后,可以使用matplotlib.pyplot.imshow(result)显示result的图像。我希望这可以帮助别人。如果您还有其他问题,请随时询问我或查看Rosa答案中的链接。

参考资料:
Stack Overflow
收藏
评论
共 4 个回答
高赞 时间 活跃

使用tf.train.match_filenames_once加载名称获得要通过tf.size打开会话进行迭代的文件数,并享受;-)

import tensorflow as tf
import numpy as np
import matplotlib;
from PIL import Image

matplotlib.use('Agg')
import matplotlib.pyplot as plt


filenames = tf.train.match_filenames_once('./images/*.jpg')
count_num_files = tf.size(filenames)
filename_queue = tf.train.string_input_producer(filenames)

reader=tf.WholeFileReader()
key,value=reader.read(filename_queue)
img = tf.image.decode_jpeg(value)

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    num_files = sess.run(count_num_files)
    for i in range(num_files):
        image=img.eval()
        print(image.shape)
        Image.fromarray(np.asarray(image)).save('te.jpeg')
收藏
评论

只是给出一个完整的答案:

filename_queue = tf.train.string_input_producer(['/Users/HANEL/Desktop/tf.png']) #  list of files to read

reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)

my_img = tf.image.decode_png(value) # use png or jpg decoder based on your files.

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init_op)

  # Start populating the filename queue.

  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)

  for i in range(1): #length of your filename list
    image = my_img.eval() #here is your image Tensor :) 

  print(image.shape)
  Image.fromarray(np.asarray(image)).show()

  coord.request_stop()
  coord.join(threads)

或者,如果您有图像目录,则可以通过此Github源文件将其全部添加

@mttk和@ salvador-dali:我希望这是您需要的

收藏
评论

在评论中与您交谈后,我相信您可以使用numpy / scipy进行此操作。想法是读取numpy 3d数组中的图像并将其输入到变量中。

from scipy import misc
import tensorflow as tf

img = misc.imread('01.png')
print img.shape    # (32, 32, 3)

img_tf = tf.Variable(img)
print img_tf.get_shape().as_list()  # [32, 32, 3]

然后,您可以运行图形:

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
im = sess.run(img_tf)

并验证是否相同:

import matplotlib.pyplot as plt
fig = plt.figure()
fig.add_subplot(1,2,1)
plt.imshow(im)
fig.add_subplot(1,2,2)
plt.imshow(img)
plt.show()

在此处输入图片说明

您提到的PSSince it's supposed to parallelize reading, it seems useful to know. 。我可以说,在数据分析中很少会读取数据的瓶颈。您大部分时间将花在训练模型上。

收藏
评论

根据文档,您可以解码JPEG / PNG图像。

应该是这样的:

import tensorflow as tf

filenames = ['/image_dir/img.jpg']
filename_queue = tf.train.string_input_producer(filenames)

reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)

images = tf.image.decode_jpeg(value, channels=3)

您可以在这里找到更多信息

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题