tensor 3维分块乘法

a = torch.range(1,4)
a = a.reshape(2,1,2)
b= torch.range(1,12)
b = b.reshape(2,2,3)
c = torch.bmm(a,b)
print('c')
print(c)
print(c.shape)
d = torch.zeros(2,1,3)
for i in range(2):
    a_ = a[i,:,:]
    b_ = b[i,:,:]
    c_ = torch.mm(a_,b_)
    d[i,:,:] =c_
print('d')
print(d)
print(d.shape)

torch.bmm 只能3维 https://blog.csdn.net/qq_40178291/article/details/100302375

torch.mm https://blog.csdn.net/da_kao_la/article/details/87484403

原文地址:https://www.cnblogs.com/tingtin/p/14529003.html