使用同一张图显示TensorFlow中的训练和验证准确性
machine-learning
python
tensorboard
tensorflow
7
0

我有一个TensorFlow模型,该模型的一部分评估了准确性。 accuracy只是张量流图中的另一个节点,该节点接受logitslabels

当我想绘制训练精度时,这很简单:我有类似以下内容:

tf.scalar_summary("Training Accuracy", accuracy)
tf.scalar_summary("SomethingElse", foo)
summary_op = tf.merge_all_summaries()
writer = tf.train.SummaryWriter('/me/mydir/', graph=sess.graph)

然后,在我的训练循环中,我会看到以下内容:

for n in xrange(1000):
  ...
  summary, ..., ... = sess.run([summary_op, ..., ...], feed_dict)
  writer.add_summary(summary, n)
  ...

同样在for循环中,每说100次迭代,我想评估验证的准确性。我为此有一个单独的feed_dict,我能够在python中很好地评估验证准确性。

但是,这是我的问题:我想通过使用accuracy节点来为验证准确性做另外一个总结 。我不清楚如何执行此操作。由于我具有accuracy节点,因此我应该能够重复使用它,但是我不确定如何确切地做到这一点,这样我也可以将验证准确性写成单独的scalar_summary ...

这怎么可能?

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

您可以重复使用准确性节点,但是需要使用两个不同的SummaryWriter,一个用于训练运行,另一个用于测试数据。另外,您还必须将标量摘要分配给变量以确保准确性。

accuracy_summary = tf.scalar_summary("Training Accuracy", accuracy)
tf.scalar_summary("SomethingElse", foo)
summary_op = tf.merge_all_summaries()
summaries_dir = '/me/mydir/'
train_writer = tf.train.SummaryWriter(summaries_dir + '/train', sess.graph)
test_writer = tf.train.SummaryWriter(summaries_dir + '/test')

然后,在训练循环中,您将接受常规训练,并使用train_writer记录总结。此外,您每进行第100次迭代就在测试集上运行图形,并仅使用test_writer记录准确性摘要。

# Record train set summaries, and train
summary, _ = sess.run([summary_op, train_step], feed_dict=...)
train_writer.add_summary(summary, n)
if n % 100 == 0:  # Record summaries and test-set accuracy
  summary, acc = sess.run([accuracy_summary, accuracy], feed_dict=...)
  test_writer.add_summary(summary, n)
  print('Accuracy at step %s: %s' % (n, acc))

然后,您可以将TensorBoard指向父目录(summaries_dir),它将加载两个数据集。

也可以在TensorFlow HowTo的https://www.tensorflow.org/versions/r0.11/how_tos/summaries_and_tensorboard/index.html中找到

收藏
评论

要运行相同的操作但使用不同的feed_dict数据获取摘要,只需将两个摘要操作附加到该操作即可。假设您要对验证和测试数据都运行准确性操作,并希望同时获取两者的摘要:

validation_acc_summary = tf.summary.scalar('validation_accuracy', accuracy)  # intended to run on validation set
test_acc_summary = tf.summary.scalar('test_accuracy', accuracy)  # intended to run on test set
with tf.Session() as sess:
    # do your thing
    # ...
    # accuracy op just needs labels y_ and input x to compute logits 
    validation_summary_str = sess.run(validation_acc_summary, feed_dict=feed_dict={x: mnist.validation.images,y_: mnist.validation.labels})
    test_summary_str = sess.run(test_acc_summary, feed_dict={x: mnist.test.images,y_: mnist.test.labels})

    # assuming you have a tf.summary.FileWriter setup
    file_writer.add_summary(validation_summary_str)
    file_writer.add_summary(test_summary_str)

还请记住,您始终可以像这样始终从原型bummary_str中提取原始(标量)数据,并进行自己的日志记录。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号