Tensorflow中one_hot() 函数用法

官网默认定义如下:
one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None)
该函数的功能主要是转换成one_hot类型的张量输出。


参数功能如下:
  1)indices中的元素指示on_value的位置,不指示的地方都为off_value。indices可以是向量、矩阵。
  2)depth表示输出张量的尺寸,indices中元素默认不超过(depth-1),如果超过,输出为[0,0,···,0]
  3)on_value默认为1
  4)off_value默认为0
  5)dtype默认为tf.float32


下面用几个例子说明一下:
1. indices是向量
 1 import tensorflow as tf
 2 
 3 indices = [0,2,3,5]
 4 depth1 = 6   # indices没有元素超过(depth-1)
 5 depth2 = 4   # indices有元素超过(depth-1)
 6 a = tf.one_hot(indices,depth1)
 7 b = tf.one_hot(indices,depth2)
 8 
 9 with tf.Session() as sess:
10     print('a = 
',sess.run(a))
11     print('b = 
',sess.run(b))

运行结果:

# 输入是一维的,则输出是一个二维的
a = [[1. 0. 0. 0. 0. 0.] [0. 0. 1. 0. 0. 0.] [0. 0. 0. 1. 0. 0.] [0. 0. 0. 0. 0. 1.]]      # shape=(4,6) b = [[1. 0. 0. 0.] [0. 0. 1. 0.] [0. 0. 0. 1.] [0. 0. 0. 0.]]          # shape=(4,4)

2. indices是矩阵

 1 import tensorflow as tf
 2 
 3 indices = [[2,3],[1,4]]
 4 depth1 = 9   # indices没有元素超过(depth-1)
 5 depth2 = 4   # indices有元素超过(depth-1)
 6 a = tf.one_hot(indices,depth1)
 7 b = tf.one_hot(indices,depth2)
 8 
 9 with tf.Session() as sess:
10     print('a = 
',sess.run(a))
11     print('b = 
',sess.run(b))

运行结果:

# 输入是二维的,则输出是三维的
a = [[[0. 0. 1. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 1. 0. 0. 0. 0. 0.]] [[0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0.]]]    # shape=(2,2,9) b = [[[0. 0. 1. 0.] [0. 0. 0. 1.]] [[0. 1. 0. 0.] [0. 0. 0. 0.]]]             # shape=(2,2,4)
 
原文地址:https://www.cnblogs.com/muzidaitou/p/11262820.html