其他答案也不错,但是需要注意的是,直接对大型numpy数组使用from_tensor_slices
可以快速填满您的内存,因为IIRC,这些值将作为tf.constants
复制到图形中。以我的经验,这将导致无声的失败,最终培训将开始,但损失等方面的改善将不会发生。
更好的方法是使用占位符。例如,这是我的代码,用于为图像及其单一目标创建发生器:
def create_generator_tf_dataset(self, images, onehots, batch_size):
# Get shapes
img_size = images.shape
img_size = (None, img_size[1], img_size[2], img_size[3])
onehot_size = onehots.shape
onehot_size = (None, onehot_size[1])
# Placeholders
images_tensor = tf.placeholder(tf.float32, shape=img_size)
onehots_tensor = tf.placeholder(tf.float32, shape=onehot_size)
# Dataset
dataset = tf.data.Dataset.from_tensor_slices((images_tensor, onehots_tensor))
# Map function (e.g. augmentation)
if map_fn is not None:
dataset = dataset.map(lambda x, y: (map_fn(x), y), num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Combined shuffle and infinite repeat
dataset = dataset.apply(
tf.data.experimental.shuffle_and_repeat(len(images), None))
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(1)
# Make the iterator
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
next_val = iterator.get_next()
with K.get_session().as_default() as sess:
sess.run(init_op, feed_dict={images_tensor: images, onehots_tensor: onehots})
while True:
inputs, labels = sess.run(next_val)
yield inputs, labels
0
fit_generator()
模型方法需要一个生成器,该生成器生成形状(输入,目标)的元组,其中两个元素都是NumPy数组。 该文档似乎暗示着,如果我只是将Dataset
迭代器包装在生成器中,并确保将Tensors转换为NumPy数组,那我应该很好。这段代码给我一个错误:这是我得到的错误:
奇怪的是,在我初始化
datagen
之后直接添加包含next(datagen)
的行会导致代码运行正常,没有错误。为什么我的原始代码不起作用?将行添加到代码中后,为什么它开始起作用?是否有一种更有效的方式将TensorFlow的Dataset API与Keras结合使用,而无需将Tensors转换为NumPy数组然后再次返回?