tf.argmax()函数作用

tf.argmax()函数原型:

def argmax(input,
           axis=None,
           name=None,
           dimension=None,
           output_type=dtypes.int64)

作用是返回每列/行的最大值的索引。

input是一个张量,

axis是0或1,0返回各列最大值索引,1返回各行最大值索引。

其他3个参数不常用,常用写法是 a = tf.argmax(tensor, 1)。

import tensorflow as tf
sess = tf.InteractiveSession()

a = tf.constant([[12, 3, 9],
                 [3, 6, 13]]) 

b_1 = tf.argmax(a, 0)   # 返回ndarry,元素是每列的最大值索引
b_2 = tf.argmax(a, 1)

print(b_1)   # >>array([0, 1, 1], dtype=int64)
print(b_2)   # >>array([0, 2], dtype=int64)
原文地址:https://www.cnblogs.com/panda-blog/p/12354055.html