tf.nn.softmax_cross_entropy_with_logits()函数解析(最清晰的解释)

欢迎关注WX公众号:【程序员管小亮】

最近学习中碰到了以前学过的tf.nn.softmax_cross_entropy_with_logits()函数,特此记录。

tf.nn.softmax_cross_entropy_with_logits()函数用于计算交叉熵。

tf.nn.softmax_cross_entropy_with_logits(
	_sentinel=None, 
	labels=None, 
	logits=None, 
	name=None
)

参数:

  • _sentinel:本质上是不用的参数,不用填
  • labels:和logits具有相同的type(float)和shape的张量(tensor)
  • logits:一个数据类型(type)是float32或float64
  • name:操作的名字,可填可不填

返回值:

长度为batch_size的一维Tensor

注意:如果labels的每一行是one-hot表示,也就是只有一个地方为1,其他地方为0,可以使用tf.sparse_softmax_cross_entropy_with_logits()

Tensorflow交叉熵计算函数输入中的logits都不是softmax或sigmoid的输出,而是softmax或sigmoid函数的输入,因为它在函数内部进行sigmoid或softmax操作。

参数labels,logits必须有相同的形状 [batch_size, num_classes] 和相同的类型(float16, float32, float64)中的一种。

例子1:

import tensorflow as tf

labels = [[0.2,0.3,0.5],
          [0.1,0.6,0.3]]
logits = [[2,0.5,1],
          [0.1,1,3]]
logits_scaled = tf.nn.softmax(logits)

result1 = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
result2 = -tf.reduce_sum(labels*tf.log(logits_scaled),1)
result3 = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits_scaled)

with tf.Session() as sess:
    print(sess.run(result1))
    print(sess.run(result2))
    print(sess.run(result3))
> [1.4143689 1.6642545]
  [1.4143689 1.6642545]
  [1.1718578 1.1757141]

上述例子中,labels的每一行是一个概率分布,而logits未经缩放(每行加起来不为1),我们用定义法计算得到交叉熵result2,和套用tf.nn.softmax_cross_entropy_with_logits()得到相同的结果, 但是将缩放后的logits_scaled输tf.nn.softmax_cross_entropy_with_logits(), 却得到错误的结果,所以一定要注意,这个操作的输入logits是未经缩放的。

python课程推荐。
在这里插入图片描述

参考文章:

tf.nn.softmax_cross_entropy_with_logits()笔记及交叉熵

原文地址:https://www.cnblogs.com/hzcya1995/p/13302841.html