将Tensorflow数据集API创建的数据集拆分为训练和测试?
tensorflow
8
0

有谁知道如何将Tensorflow中由数据集API(tf.data.Dataset)创建的数据集拆分为Test and Train?

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

您可以使用Dataset.take()Dataset.skip()

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)

为了更笼统,我举了一个使用70/15/15 train / val / test split的示例,但是如果您不需要测试或val集,则只需忽略最后两行。

采取

从此数据集中创建一个最多包含count个元素的数据集。

跳过

创建一个数据集,该数据集从该数据集中跳过计数元素。

您可能还需要研究Dataset.shard()

创建一个仅包含此数据集的1 / num_shards的数据集。


免责声明我就这个问题回答绊倒后, 这一个 ,所以我想我会传播爱

收藏
评论

假设您具有tf.data.Dataset类型的all_dataset变量:

test_dataset = all_dataset.take(1000) 
train_dataset = all_dataset.skip(1000)

测试数据集现在具有前1000个元素,其余的用于训练。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号