欢迎关注WX公众号:【程序员管小亮】
最近学习中碰到了以前学过的tf.nn.softmax_cross_entropy_with_logits()函数,特此记录。
tf.nn.softmax_cross_entropy_with_logits()函数用于计算交叉熵。
tf.nn.softmax_cross_entropy_with_logits(
_sentinel=None,
labels=None,
logits=None,
name=None
)
参数:
- _sentinel:本质上是不用的参数,不用填
- labels:和logits具有相同的type(float)和shape的张量(tensor)
- logits:一个数据类型(type)是float32或float64
- name:操作的名字,可填可不填
返回值:
长度为batch_size的一维Tensor
注意:如果labels的每一行是one-hot表示,也就是只有一个地方为1,其他地方为0,可以使用tf.sparse_softmax_cross_entropy_with_logits()
Tensorflow交叉熵计算函数输入中的logits都不是softmax或sigmoid的输出,而是softmax或sigmoid函数的输入,因为它在函数内部进行sigmoid或softmax操作。
参数labels,logits必须有相同的形状 [batch_size, num_classes] 和相同的类型(float16, float32, float64)中的一种。
例子1:
import tensorflow as tf
labels = [[0.2,0.3,0.5],
[0.1,0.6,0.3]]
logits = [[2,0.5,1],
[0.1,1,3]]
logits_scaled = tf.nn.softmax(logits)
result1 = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
result2 = -tf.reduce_sum(labels*tf.log(logits_scaled),1)
result3 = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits_scaled)
with tf.Session() as sess:
print(sess.run(result1))
print(sess.run(result2))
print(sess.run(result3))
> [1.4143689 1.6642545]
[1.4143689 1.6642545]
[1.1718578 1.1757141]
上述例子中,labels的每一行是一个概率分布,而logits未经缩放(每行加起来不为1),我们用定义法计算得到交叉熵result2,和套用tf.nn.softmax_cross_entropy_with_logits()得到相同的结果, 但是将缩放后的logits_scaled输tf.nn.softmax_cross_entropy_with_logits(), 却得到错误的结果,所以一定要注意,这个操作的输入logits是未经缩放的。
python课程推荐。