Torch常用的函数总结

矩阵运算相关:

torch.mul(a,b)  是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵

torch.mm(a,b)  是矩阵a和b的 矩阵相乘。比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵

torch.transpose(Phi, 0, 1)  是交换一个tensor的两个维度,返回的类型也是tensor。即“torch.transpose(input, dim0, dim1) → Tensor”,需要注意的是transpose中的两个维度参数的顺序是可以交换位置的,即transpose(x, 0, 1) 和transpose(x, 1, 0)效果是相同的。

view(-1,...,) 在torch里面,view函数相当于numpy的reshape

Note:经常会看见参数为-1,这里-1表示一个不确定的数,让电脑帮我们计算。例如,一个长度的16向量xx.view(-1, 2)等价于x.view(8,2)

模型相关

1. 统计模型参数

model = Net(layer_num)
print('Total number of parameters net:',
      sum(p.numel() for p in model.parameters()))
原文地址:https://www.cnblogs.com/HuangYJ/p/13849298.html