并行化tf.data.Dataset.from_generator
tensorflow
tensorflow-datasets
6
0

我有一个非平凡的输入管道, from_generator非常适合...

dataset = tf.data.Dataset.from_generator(complex_img_label_generator,
                                        (tf.int32, tf.string))
dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()

其中complex_img_label_generator动态生成图像并返回代表(H, W, 3)图像和简单string标签的numpy数组。我无法将其表示为从文件和tf.image操作读取的处理。

我的问题是关于如何使生成器参数化?我如何让N个生成器在各自的线程中运行。

一种想法是使用带有num_parallel_calls dataset.map来处理线程。但是地图在张量上运行...另一个想法是创建多个生成器,每个生成器都有自己的prefetch并以某种方式加入它们,但是我看不到如何加入N个生成器流?

我可以遵循的任何经典示例吗?

参考资料:
Stack Overflow
收藏
评论
共 3 个回答
高赞 时间 活跃

事实证明,如果我使生成器超轻量级(仅生成元数据),然后将实际的Dataset.map照明移动到无状态函数中,则可以使用Dataset.map 。这样,我可以使用py_func将繁重的部分与.map py_func

作品;但是感觉有点笨拙...能够将num_parallel_calls添加到from_generator :)

def pure_numpy_and_pil_complex_calculation(metadata, label):
  # some complex pil and numpy work nothing to do with tf
  ...

dataset = tf.data.Dataset.from_generator(lightweight_generator,
                                         output_types=(tf.string,   # metadata
                                                       tf.string))  # label

def wrapped_complex_calulation(metadata, label):
  return tf.py_func(func = pure_numpy_and_pil_complex_calculation,
                    inp = (metadata, label),
                    Tout = (tf.uint8,    # (H,W,3) img
                            tf.string))  # label
dataset = dataset.map(wrapped_complex_calulation,
                      num_parallel_calls=8)

dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()
收藏
评论

generator完成的工作限制在最低限度并使用map并行化昂贵的处理是明智的。

或者,您可以使用parallel_interleave “加入”多个生成器,如下所示:

def generator(n):
  # returns n-th generator function

def dataset(n):
  return tf.data.Dataset.from_generator(generator(n))

ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))

# where N is the number of generators you use
收藏
评论

我正在为from_indexabletf.data.Dataset https://github.com/tensorflow/tensorflow/issues/14448

from_indexable的优点是可以并行化,而python生成器则不能并行化。

函数from_indexable生成tf.data.range ,将可索引的包装在广义的tf.py_func并调用map。

对于那些现在想要from_indexable ,这里是lib代码

import tensorflow as tf
import numpy as np

from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import nest

def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
    def decorator(func):
        def call(*args):
            nonlocal output_shapes

            flat_output_types = nest.flatten(output_types)
            flat_values = tf.py_func(
                func, 
                inp=args, 
                Tout=flat_output_types,
                stateful=stateful, name=name
            )
            if output_shapes is not None:
                # I am not sure if this is nessesary
                output_shapes = nest.map_structure_up_to(
                    output_types, tensor_shape.as_shape, output_shapes)
                flattened_shapes = nest.flatten_up_to(output_types, output_shapes)
                for ret_t, shape in zip(flat_values, flattened_shapes):
                    ret_t.set_shape(shape)
            return nest.pack_sequence_as(output_types, flat_values)
        return call
    return decorator

def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
    ds = tf.data.Dataset.range(len(iterator))
    @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
    def index_to_entry(index):
        return iterator[index]    
    return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)

这是一个示例(注意: from_indexable具有num_parallel_calls argument

class PyDataSet:
    def __len__(self):
        return 20

    def __getitem__(self, item):
        return np.random.normal(size=(item+1, 10))

ds = from_indexable(PyDataSet(), output_types=tf.float64, output_shapes=[None, 10])
it = ds.make_one_shot_iterator()
entry = it.get_next()
with tf.Session() as sess:
    print(sess.run(entry).shape)
    print(sess.run(entry).shape)

更新 2018年6月10日:由于https://github.com/tensorflow/tensorflow/pull/15121被合并, from_indexable的代码简化为:

import tensorflow as tf

def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
    def decorator(func):
        def call(*args, **kwargs):
            return tf.contrib.framework.py_func(
                func=func, 
                args=args, kwargs=kwargs, 
                output_types=output_types, output_shapes=output_shapes, 
                stateful=stateful, name=name
            )
        return call
    return decorator

def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
    ds = tf.data.Dataset.range(len(iterator))
    @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
    def index_to_entry(index):
        return iterator[index]    
    return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)
收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号