如何正确使用tf.metrics.accuracy?
tensorflow
6
0

使用tf.metricsaccuracy函数tf.metrics以logit为输入的多分类问题时,我遇到了一些麻烦。

我的模型输出如下:

logits = [[0.1, 0.5, 0.4],
          [0.8, 0.1, 0.1],
          [0.6, 0.3, 0.2]]

我的标签是一种热编码向量:

labels = [[0, 1, 0],
          [1, 0, 0],
          [0, 0, 1]]

当我尝试执行类似tf.metrics.accuracy(labels, logits)它永远不会给出正确的结果。我显然做错了,但我不知道是什么。

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

应用于cnn上,您可以编写:

x_len=24*24
y_len=2

x = tf.placeholder(tf.float32, shape=[None, x_len], name='input')

fc1 = ... # cnn's fully connected layer
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
layer_fc_dropout = tf.nn.dropout(fc1, keep_prob, name='dropout')

y_pred = tf.nn.softmax(fc1, name='output')
logits = tf.argmax(y_pred, axis=1)

y_true = tf.placeholder(tf.float32, shape=[None, y_len], name='y_true')
acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(y_true, axis=1), predictions=tf.argmax(y_pred, 1))


sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

def print_accuracy(x_data, y_data, dropout=1.0):
    accuracy = sess.run(acc_op, feed_dict = {y_true: y_data, x: x_data, keep_prob: dropout})
    print('Accuracy: ', accuracy)
收藏
评论

TL; DR

精度函数tf.metrics.accuracy基于它创建的两个局部变量来计算预测与标签匹配的频率: totalcount ,用于计算对logitslabels匹配的频率。

acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1), 
                                  predictions=tf.argmax(logits,1))

print(sess.run([acc, acc_op]))
print(sess.run([acc]))
# Output
#[0.0, 0.66666669]
#[0.66666669]
  • acc(准确性):仅使用totalcount返回指标,而不更新指标。
  • acc_op(更新):更新指标。

要了解acc为什么返回0.0 ,请仔细阅读以下详细信息。


详细信息使用一个简单的示例:

logits = tf.placeholder(tf.int64, [2,3])
labels = tf.Variable([[0, 1, 0], [1, 0, 1]])

acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1),   
                                  predictions=tf.argmax(logits,1))

初始化变量:

由于metrics.accuracy创建了两个局部变量totalcount ,因此我们需要调用local_variables_initializer()进行初始化。

sess = tf.Session()

sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())

stream_vars = [i for i in tf.local_variables()]
print(stream_vars)

#[<tf.Variable 'accuracy/total:0' shape=() dtype=float32_ref>,
# <tf.Variable 'accuracy/count:0' shape=() dtype=float32_ref>]

了解更新操作和准确性计算:

print('acc:',sess.run(acc, {logits:[[0,1,0],[1,0,1]]}))
#acc: 0.0

print('[total, count]:',sess.run(stream_vars)) 
#[total, count]: [0.0, 0.0]

尽管给出了匹配的输入,但是由于totalcount均为零,因此上面的方法返回的精度为0.0。

print('ops:', sess.run(acc_op, {logits:[[0,1,0],[1,0,1]]})) 
#ops: 1.0

print('[total, count]:',sess.run(stream_vars)) 
#[total, count]: [2.0, 2.0]

使用新输入时,将在调用更新op时计算精度。注意:由于所有logits和label都匹配,因此我们的准确度为1.0,局部变量totalcount实际上给出了total correctly predictedtotal comparisons made并进行了total comparisons made

现在,我们将新输入(而不是更新操作)称为accuracy

print('acc:', sess.run(acc,{logits:[[1,0,0],[0,1,0]]}))
#acc: 1.0

准确性调用不会使用新的输入来更新指标,它只是使用两个局部变量返回值。注意:在这种情况下,logit和标签不匹配。现在再次调用更新操作:

print('op:',sess.run(acc_op,{logits:[[0,1,0],[0,1,0]]}))
#op: 0.75 
print('[total, count]:',sess.run(stream_vars)) 
#[total, count]: [3.0, 4.0]

指标已更新为新输入


有关培训期间如何使用指标以及在验证期间如何重置指标的更多信息,请参见此处

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号