想象一下,您有一个数据集: [1, 2, 3, 4, 5, 6]
,然后:
ds.shuffle()如何工作
dataset.shuffle(buffer_size=3)
将分配一个大小为3的缓冲区以挑选随机条目。该缓冲区将连接到源数据集。我们可以这样成像:
Random buffer
|
| Source dataset where all other elements live
| |
↓ ↓
[1,2,3] <= [4,5,6]
假设条目2
来自随机缓冲区。可用空间由源缓冲区中的下一个元素填充,即4
:
2 <= [1,3,4] <= [5,6]
我们继续阅读,直到一无所有:
1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6] <= []
6 <= [4] <= []
4 <= [] <= []
ds.repeat()如何工作
从数据集中读取所有条目并尝试读取下一个元素后,数据集将引发错误。这就是ds.repeat()
发挥作用的地方。它将重新初始化数据集,使其再次如下所示:
[1,2,3] <= [4,5,6]
ds.batch()将产生什么
ds.batch()
将获取第一个batch_size
条目并从中进行批处理。因此,示例数据集的批处理大小为3将产生两个批处理记录:
[2,1,5]
[3,6,4]
由于批处理之前有一个ds.repeat()
,因此将继续生成数据。但是由于ds.random()
,元素的顺序将有所不同。应该考虑的是,由于随机缓冲区的大小,第一批中将永远不会出现6
。
0
我目前正在学习TensorFlow,但我在这段代码中遇到了困惑:
我知道首先数据集将保存所有数据,但是shuffle(),repeat()和batch()对数据集做了什么?请给我一个例子的解释