TensorFlow教程batch_xs,batch_ys = mnist.train.next_batch(100)的next_batch来自何处?
numpy
python
tensorflow
4
0

我正在尝试TensorFlow教程,但不了解此行中的next_batch来自何处?

 batch_xs, batch_ys = mnist.train.next_batch(100)

我看着

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

而且那里也没有看到next_batch。

现在,当在我自己的代码中尝试next_batch时,

AttributeError: 'numpy.ndarray' object has no attribute 'next_batch'

所以我想了解next_batch来自哪里?

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

next_batchDataSet类的方法(有关该类内容的更多信息,请参见https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py ) 。

加载mnist数据并将其分配给变量mnist使用以下命令:

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

看一下mnist.train的类。您可以通过键入以下内容进行查看:

print mnist.train.__class__

您会看到以下内容:

<class 'tensorflow.contrib.learn.python.learn.datasets.mnist.Dataset'>

由于mnist.trainDataSet类的实例,因此可以使用该类的next_batch函数。有关类的更多信息,请查阅文档

收藏
评论

浏览过tensorflow存储库后,它似乎起源于这里:

https://github.com/tensorflow/tensorflow/blob/9230423668770036179a72414482d45ddde40a3b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py#L905

但是,如果您希望以自己的代码(针对自己的数据集)实现它,那么像我一样,将其自己编写到数据集对象中可能会更简单。据我了解,这是一种对整个数据集进行混洗,并从混洗后的数据集中返回$ mini_batch_size个样本的方法。

这是一些伪代码:

shuffle data.x and data.y while retaining relation return [data.x[:mb_n], data.y[:mb_n]]

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号