使用TensorFlow数据集进行批处理,重复和随机播放有什么作用?
tensorflow
14
0

我目前正在学习TensorFlow,但我在这段代码中遇到了困惑:

dataset = dataset.shuffle(buffer_size = 10 * batch_size) 
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()

我知道首先数据集将保存所有数据,但是shuffle(),repeat()和batch()对数据集做了什么?请给我一个例子的解释

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

想象一下,您有一个数据集: [1, 2, 3, 4, 5, 6] ,然后:

ds.shuffle()如何工作

dataset.shuffle(buffer_size=3)将分配一个大小为3的缓冲区以挑选随机条目。该缓冲区将连接到源数据集。我们可以这样成像:

Random buffer
   |
   |   Source dataset where all other elements live
   |         |
   ↓         ↓
[1,2,3] <= [4,5,6]

假设条目2来自随机缓冲区。可用空间由源缓冲区中的下一个元素填充,即4

2 <= [1,3,4] <= [5,6]

我们继续阅读,直到一无所有:

1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6]   <= []
6 <= [4]      <= []
4 <= []      <= []

ds.repeat()如何工作

从数据集中读取所有条目并尝试读取下一个元素后,数据集将引发错误。这就是ds.repeat()发挥作用的地方。它将重新初始化数据集,使其再次如下所示:

[1,2,3] <= [4,5,6]

ds.batch()将产生什么

ds.batch()将获取第一个batch_size条目并从中进行批处理。因此,示例数据集的批处理大小为3将产生两个批处理记录:

[2,1,5]
[3,6,4]

由于批处理之前有一个ds.repeat() ,因此将继续生成数据。但是由于ds.random() ,元素的顺序将有所不同。应该考虑的是,由于随机缓冲区的大小,第一批中将永远不会出现6

收藏
评论

tf.Dataset中的以下方法:

  1. repeat( count=0 )该方法重复数据集count的次数。
  2. shuffle( buffer_size, seed=None, reshuffle_each_iteration=None)该方法可对数据集中的样本进行shuffle( buffer_size, seed=None, reshuffle_each_iteration=None)buffer_size是随机化并作为tf.Dataset返回的样本数。
  3. batch(batch_size,drop_remainder=False)创建数据集的批次,其批次大小指定为batch_size ,这也是批次的长度。
收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号