数据不平衡和加权交叉熵
deep-learning
machine-learning
python
tensorflow
6
0

我正在尝试训练数据不平衡的网络。我有A(198个样本),B(436个样本),C(710个样本),D(272个样本),并且我已经阅读了有关“ weighted_cross_entropy_with_logits”的信息,但是我发现的所有示例都是针对二进制分类的,因此我不太了解对如何设置这些权重充满信心。

样本总数:1616

A_weight:198/1616 = 0.12?

如果我理解的话,其背后的想法是惩罚市长阶层的错误,更积极地重视少数族裔的打击,对吧?

我的一段代码:

weights = tf.constant([0.12, 0.26, 0.43, 0.17])
cost = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(logits=pred, targets=y, pos_weight=weights))

我已经阅读了这个和其他带有二进制分类的示例,但仍然不太清楚。

提前致谢。

参考资料:
Stack Overflow
收藏
评论
共 1 个回答
高赞 时间 活跃

请注意, weighted_cross_entropy_with_logitssigmoid_cross_entropy_with_logits的加权变体。 S形交叉熵通常用于二进制分类。是的,它可以处理多个标签,但是S型交叉熵基本上是对每个标签做出(二进制)决策-例如,对于人脸识别网,这些(但不互斥)标签可能是“ 对象戴眼镜吗? ”,“ 对象是女性吗? ”等。

在二进制分类中,每个输出通道都对应一个二进制(软)判决。因此,需要在损失的计算中进行加权。这是weighted_cross_entropy_with_logits功能,方法是将交叉熵的一项权重于另一项。

在互斥的多softmax_cross_entropy_with_logits分类中,我们使用softmax_cross_entropy_with_logits ,其行为不同:每个输出通道对应于一个类别候选者的分数。该决定是 ,通过比较每个信道的相应输出。

因此,在做出最终决定之前进行加权很简单,通常是通过与权重相乘,在比较分数之前对其进行修改。例如,对于三元分类任务,

# your class weights
class_weights = tf.constant([[1.0, 2.0, 3.0]])
# deduce weights for batch samples based on their true label
weights = tf.reduce_sum(class_weights * onehot_labels, axis=1)
# compute your (unweighted) softmax cross entropy loss
unweighted_losses = tf.nn.softmax_cross_entropy_with_logits(onehot_labels, logits)
# apply the weights, relying on broadcasting of the multiplication
weighted_losses = unweighted_losses * weights
# reduce the result to get your final loss
loss = tf.reduce_mean(weighted_losses)

您还可以依靠tf.losses.softmax_cross_entropy处理最后三个步骤。

在您的情况下,如果您需要解决数据不平衡问题,那么班级权数的确可能与火车数据中它们的频率成反比。对它们进行规范化,以使它们加起来等于一个或多个类,这也是有意义的。

请注意,在上文中,我们根据样本的真实标签对损失进行了处罚。我们还可以通过简单地定义基于估计标签的损失

weights = class_weights

由于广播魔术,其余代码无需更改。

在一般情况下,您希望权重取决于您所犯错误的类型。换句话说,对于每对标签XY ,您可以选择当真实标签为Y时如何惩罚选择标签X您最终得到一个完整的先验权重矩阵,这将导致以上的weights为完整的(num_samples, num_classes)张量。这超出了您想要的范围,但是了解上面代码中仅需要更改权重张量的定义可能会很有用。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号