TensorFlow:dataset.train.next_batch如何定义?
neural-network
python-3.x
tensorflow
5
0

我正在尝试学习TensorFlow并在以下位置研究示例: https : //github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb

然后,我在下面的代码中有一些疑问:

for epoch in range(training_epochs):
    # Loop over all batches
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        # Run optimization op (backprop) and cost op (to get loss value)
        _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
    # Display logs per epoch step
    if epoch % display_step == 0:
        print("Epoch:", '%04d' % (epoch+1),
              "cost=", "{:.9f}".format(c))

由于mnist只是一个数据集, mnist.train.next_batch到底是mnist.train.next_batch意思?如何定义dataset.train.next_batch

谢谢!

参考资料:
Stack Overflow
收藏
评论
共 1 个回答
高赞 时间 活跃

mnist对象是从tf.contrib.learn模块中定义的read_data_sets()函数返回的。 mnist.train.next_batch(batch_size)方法在此处实现,它返回两个数组的元组,其中第一个表示一批batch_size MNIST图像,第二个表示一批与这些图像相对应的batch-size标签。

图像以大小为[batch_size, 784]的2-D NumPy数组形式返回(因为MNIST图像中有784个像素),并且标签以大小为[batch_size]的一维NumPy数组形式返回(如果使用one_hot=False调用read_data_sets()或使用[batch_size, 10]大小的二维NumPy数组(如果使用one_hot=True调用read_data_sets() )。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号