在TensorFlow培训期间打印损失
python
tensorflow
10
0

我正在看TensorFlow“ MNIST对于ML初学者 ”教程,我想在每个训练步骤之后打印出训练损失。

我的训练循环目前看起来像这样:

for i in range(100):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

现在, train_step定义为:

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

其中cross_entropy是我要打印的损失:

cross_entropy = -tf.reduce_sum(y_ * tf.log(y))

一种打印方法是在训练循环中显式计算cross_entropy

for i in range(100):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
    print 'loss = ' + str(cross_entropy)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

我现在有两个问题:

  1. 假设cross_entropy已在sess.run(train_step, ...)期间计算出, sess.run(train_step, ...)其计算两次效率低下,这需要所有训练数据的前向通过次数的两倍。是否有访问的价值的方式cross_entropy时期间计算sess.run(train_step, ...)

  2. 我如何打印tf.Variable ?使用str(cross_entropy)给我一个错误...

谢谢!

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

不仅要运行training_step,还要运行cross_entropy节点,以便将其值返回给您。请记住:

var_as_a_python_value = sess.run(tensorflow_variable)

会给你想要的东西,所以你可以这样做:

[_, cross_entropy_py] = sess.run([train_step, cross_entropy],
                                 feed_dict={x: batch_xs, y_: batch_ys})

既可以进行训练,又可以提取出交叉熵在迭代过程中计算出的值。请注意,我同时将sess.run的参数和返回值都转换为列表,以便两者都发生。

收藏
评论

您可以通过将cross_entropy的值添加到sess.run(...)的参数列表中来获取该值。例如,您的for -loop可以重写如下:

for i in range(100):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
    _, loss_val = sess.run([train_step, cross_entropy],
                           feed_dict={x: batch_xs, y_: batch_ys})
    print 'loss = ' + loss_val

可以使用相同的方法来打印变量的当前值。假设除了cross_entropy的值cross_entropy ,您还想打印tf.Variable的值。 tf.Variable W可以执行以下操作:

for i in range(100):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
    _, loss_val, W_val = sess.run([train_step, cross_entropy, W],
                                  feed_dict={x: batch_xs, y_: batch_ys})
    print 'loss = %s' % loss_val
    print 'W = %s' % W_val
收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题