您可以使用Dataset.take()
和Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)
为了更笼统,我举了一个使用70/15/15 train / val / test split的示例,但是如果您不需要测试或val集,则只需忽略最后两行。
采取 :
从此数据集中创建一个最多包含count个元素的数据集。
跳过 :
创建一个数据集,该数据集从该数据集中跳过计数元素。
您可能还需要研究Dataset.shard()
:
创建一个仅包含此数据集的1 / num_shards的数据集。
免责声明我就这个问题回答绊倒后, 这一个 ,所以我想我会传播爱
0
有谁知道如何将Tensorflow中由数据集API(tf.data.Dataset)创建的数据集拆分为Test and Train?