笔记 EINSUM IS ALL YOU NEED

原文 https://rockt.github.io/2018/04/30/einsum

就是说有一种运算,叫做einsum,可以做各种矩阵和向量的运算,而且特别简洁和优美

自己跑一下里面的例子,就知道是怎么回事了,

这里记录一下其中的tensor contraction,算是最general的形式了

先看 torch.einsum('ij,ij->', [a, b]) 是什么意思?

import torch

a = torch.arange(2*3).reshape(2, 3)
b = torch.arange(2*3).reshape(2, 3)
x = torch.einsum('ij,ij->', [a, b])
print(a)
print(b)
print(x)

res = 0
for i in range(2):
    for j in range(3):
            res += a[i,j] * b[i,j]

print(res)

结果:

(deeplearning) ➜  Catchfish python einsum_test.py
tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[0, 1, 2],
        [3, 4, 5]])
tensor(55)
tensor(55)

相当于把对应位置相乘再相加,这样二维空间收缩为1个值

三维矩阵的收缩同理,torch.einsum('ijk,ijk->', [a, b]) 是什么意思?

其实二维矩阵的乘法也是tensor contraction,只不过只是将其中一维收缩,torch.einsum('ik,kj->ij', [a, b])

能收缩的条件是:只要对应维的长度相同即可

前面的讲完了,重点是高维矩阵是如何收缩的?

例子:

内部是怎么运算的呢?相同维数的3和5进行了收缩,相当于2,7,11,13,17固定

验证一下:取出一个固定状态,将相同的那两维收缩,与之前整体收缩再取同一状态对比,发现两个值一样

import torch

a = torch.arange(2*3*5*7).reshape(2,3,5,7)
b = torch.arange(11*13*3*17*5).reshape(11,13,3,17,5)
x = torch.einsum('pqrs,tuqvr->pstuv', [a, b])
print(x.shape)

m1 = a[1, :, :, 5]
m2 = b[6, 7, :, 8, :]
res = torch.einsum("ij,ij->", [m1, m2])
print(res)
print(x[1, 5, 6, 7, 8])

结果:

(deeplearning) ➜  Catchfish python tensor_contraction.py 
torch.Size([2, 7, 11, 13, 17])
tensor(52027730)
tensor(52027730)

维度的计算:相同维数的收缩了,剩下的各个维数组成结果的维数

自己可以试一下,收缩三个及更高的维数也是一样的做法。

个性签名:时间会解决一切
原文地址:https://www.cnblogs.com/lfri/p/15473640.html