详解pytorch中的max方法

实际上pytorch官方文档中的应该是torch.max(input)方法,而本文要讲的可能严格意义上不是torch中的,而是针对torch中的张量方法,即input.max(axis)[index]
其中input表示要求取最大值的张量,axis可以为0(表示求取每列的最大值),也可以为1(每行的最大值)。index为0表示只返回最大值本身,为1表示只返回最大值对应的索引。如下,其中axis可以省去:

a = torch.Tensor([[0,3,2],[4,0,0]])
print(a.max(axis=0)[0]) # tensor([4., 3., 2.]),即第一列为[0 4]最大值为4,第二列为[3 0],依此类推
print(a.max(axis=0)[1]) # tensor([1, 0, 0]),索引也是列的索引
print(a.max(axis=1)[0]) # tensor([3., 4.]),取各行的最大值
print(a.max(axis=1)[1]) # tensor([1, 0]),对应的索引

应用

在求解强化学习中需要qmaxq_{max}qmax对应的action时,通常是输入一个张量即神经网络算出的q值,然后输出q值对应的索引,输出的是int型,如下:

import torch
q = torch.Tensor([[0,3,2,1]])
action=q.max(1)[1].item() # .item()将只有一个元素的张量变为对应的元素
action=q.max(1)[1].view(1,1).item() # 如果不放心可在前面加view方法shape成只有一个元素的张量
原文地址:https://www.cnblogs.com/hzcya1995/p/13281640.html