图代码简单无向图

简单无向图的定义:

方法一:

import torch
from torch_geometric.data import Data

#边,shape = [2,num_edge]
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
#点,shape = [num_nodes, num_node_features]
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)
>>> Data(edge_index=[2, 4], x=[3, 1])

 注意:edge_index中边的存储方式,有两个list。第 1 个list是边的起始点,第 2 个list是边的目标节点。注意与下面的存储方式的区别。

            由于是无向图,因此有 4 条边:(0 -> 1), (1 -> 0), (1 -> 2), (2 -> 1)。每个节点都有自己的特征

 方法二:

import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.t().contiguous())

这种情况edge_index需要先转置然后使用contiguous()方法。

Data中最基本的 4 个属性是xedge_indexposy,我们一般都需要这 4 个参数。
有了Data,我们可以创建自己的Dataset,读取并返回Data了。

原文地址:https://www.cnblogs.com/Catherinezhilin/p/15678534.html