图神经网络:图的分类

文章说明:
1)参考资料:PYG的文档。文档超链。
2)博主水平不高,如有错误,还望批评指正。
3)我在百度网盘上传这篇文章的jupyter notebook以及有关文献。提取码8848。

MUTAG数据集说明

MUTAG数据集是一个分子图形数据集。每个分子包含一个二元标签表示该分子是否为一种类固醇化合物。我们的任务是进行一个二分类的任务,判断某个分子是否类固醇化合物。
导入依赖

from torch_geometric.datasets import TUDataset
import torch

导入数据

dataset=TUDataset(root='C:/Users/19216/Desktop/project/Project1/Graph_Classification/TUDataset',name='MUTAG')

打乱顺序,训测拆分

dataset=dataset.shuffle()
train_dataset=dataset[:150]
test_dataset=dataset[150:]

我们下面观察数据

图分类的常见手段

在这里插入图片描述
为充分利用GPU资源,我们使用如上方式。1)创建一个包含多个孤立图的超巨型图,2)特征矩阵简单连接。如上。优点如下:1)不同图间不会进行信息传递,2)稀疏矩阵保存不会占用内存。
导入依赖

from torch_geometric.loader import DataLoader

观察数据

train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=64,shuffle=False)
for step,data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()
#输出如下:
#Step 1:
#=======
#Number of graphs in the current batch: 64
#DataBatch(edge_index=[2, 2626], x=[1187, 7], edge_attr=[2626, 4], y=[64], batch=[1187], ptr=[65])

#Step 2:
#=======
#Number of graphs in the current batch: 64
#DataBatch(edge_index=[2, 2448], x=[1107, 7], edge_attr=[2448, 4], y=[64], batch=[1107], ptr=[65])

#Step 3:
#=======
#Number of graphs in the current batch: 22
#DataBatch(edge_index=[2, 978], x=[441, 7], edge_attr=[978, 4], y=[22], batch=[441], ptr=[23])

图分类的基本流程

1.多轮信息传递嵌入每个节点
2.聚合节点嵌入变为嵌入的图
3.训练一个嵌入图分类器
PS:第二步使用如下的公式: X G = 1 ∣ V ∣ ∑ v ∈ V x v L \mathcal{X_{\mathcal{G}}}=\frac{1}{|\mathcal{V}|}\sum_{\mathcal{v} \in \mathcal{V}}\mathcal{x}_{\mathcal{v}}^{L} XG=V1vVxvL。将每个节点的特征向量简单地加起来取个平均。
搭建模型

from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import GraphConv
import torch.nn.functional as F
from torch.nn import Linear

class GCN(torch.nn.Module):
    
    def __init__(self,hidden_channels):
        super(GCN, self).__init__()
        self.conv1=GraphConv(dataset.num_node_features,hidden_channels)
        self.conv2=GraphConv(hidden_channels,hidden_channels)
        self.conv3=GraphConv(hidden_channels,hidden_channels)
        self.lin=Linear(hidden_channels,dataset.num_classes)

    def forward(self,x,edge_index,batch):
        x=self.conv1(x,edge_index)
        x=x.relu()
        x=self.conv2(x,edge_index)
        x=x.relu()
        x=self.conv3(x,edge_index)
        x=x.relu()
        x=global_mean_pool(x,batch)
        x=F.dropout(x,p=0.5,training=self.training)
        x=self.lin(x)
        return x

model=GCN(hidden_channels=64)
print(model)
#输出如下:
#GCN(
#  (conv1): GCNConv(7, 64)
#  (conv2): GCNConv(64, 64)
#  (conv3): GCNConv(64, 64)
#  (lin): Linear(in_features=64, out_features=2, bias=True)
#)

得出结果

model=GCN(hidden_channels=64)
optimizer=torch.optim.Adam(model.parameters(),lr=0.01)
criterion=torch.nn.CrossEntropyLoss()

optimizer=torch.optim.Adam(model.parameters(),lr=0.01)
criterion=torch.nn.CrossEntropyLoss()

def train():
    model.train()
    for data in train_loader: 
         out=model(data.x,data.edge_index,data.batch)
         loss=criterion(out,data.y)
         loss.backward()
         optimizer.step()
         optimizer.zero_grad()

def test(loader):
     model.eval()
     correct=0
     for data in loader:
         out=model(data.x,data.edge_index,data.batch)  
         pred=out.argmax(dim=1)
         correct+=int((pred==data.y).sum())
     return correct/len(loader.dataset)

for epoch in range(1,171):
    train()
    train_acc=test(train_loader)
    test_acc=test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
#输出如下(这里只有最后一个):
#Epoch: 170, Train Acc: 0.9800, Test Acc: 0.8158

改进算法

第一篇文章:HOW POWERFUL ARE GRAPH NEURAL NETWORKS?
主要工作:1)他们证明了区分图结构方面,GNNs的表达能力小于等于Weisfeiler-Lehman test。2)他们具体指出什么情况两个算法是效果相同的 3)他们具体指出GNNs及变体能够识别哪些图的结构以及不能识别哪些图的结构。4)他们开发一种简单的GIN结构,效果等同Weisfeiler-Lehman test算法。
第二篇文章:Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks
主要工作:1)他们证明了区分图结构方面,GNNs的表达能力小于等于Weisfeiler-Lehman test。2)提出一种1-k-GNNs的算法。3)高阶的图属性对于分类回归十分重要。

WL-1伪代码:博主感觉这张图的比较好懂
在这里插入图片描述