Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sparse] A hetero-relational GCN example #6157

Merged
merged 13 commits into from
Aug 23, 2023
284 changes: 284 additions & 0 deletions examples/sparse/hetero-rgcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
"""
Modeling Relational Data with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1703.06103
Reference Code: https://github.com/tkipf/relational-gcn

This script trains and tests a Hetero Relational Graph Convolutional Networks
(Hetero-RGCN) model based on the information of a full graph.

This flowchart describes the main functional sequence of the provided example.
main
├───> Load and preprocess full dataset
├───> Instantiate Hetero-RGCN model
├───> train
│ │
│ └───> Training loop
│ │
│ └───> Hetero-RGCN.forward
└───> test
└───> Evaluate the model
"""
import argparse
import time

import dgl
import dgl.sparse as dglsp

import numpy as np

import torch as th
import torch.nn as nn
import torch.nn.functional as F

from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset


class RelGraphEmbed(nn.Module):
r"""Embedding layer for featureless heterograph."""

def __init__(
self,
ntype_num,
embed_size,
):
super(RelGraphEmbed, self).__init__()
self.embed_size = embed_size
self.dropout = nn.Dropout(0.0)

# Create weight embeddings for each node for each relation.
self.embeds = nn.ParameterDict()
for ntype, num_nodes in ntype_num.items():
embed = nn.Parameter(th.Tensor(num_nodes, self.embed_size))
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain("relu"))
self.embeds[ntype] = embed

def forward(self):
return self.embeds


class HeteroRelationalGraphConv(nn.Module):
r"""HeteroRelational graph convolution layer.

Parameters
----------
in_size : int
Input feature size.
out_size : int
Output feature size.
relation_names : list[str]
Relation names.
"""

def __init__(
self,
in_size,
out_size,
relation_names,
activation=None,
):
super(HeteroRelationalGraphConv, self).__init__()
self.in_size = in_size
self.out_size = out_size
self.relation_names = relation_names
self.activation = activation

########################################################################
# (HIGHLIGHT) HeteroGraphConv is a graph convolution operator over
# heterogeneous graphs. A dictionary is passed where the key is the
# relation name and the value is the insatnce of conv layer.
########################################################################
self.W = nn.ModuleDict(
{str(rel): nn.Linear(in_size, out_size) for rel in relation_names}
)

self.dropout = nn.Dropout(0.0)

def forward(self, A, inputs):
"""Forward computation

Parameters
----------
A : Hetero Sparse Matrix
Input graph.
inputs : dict[str, torch.Tensor]
Node feature for each node type.

Returns
-------
dict[str, torch.Tensor]
New node features for each node type.
"""
hs = {}
for rel in A:
src_type, edge_type, dst_type = rel
if dst_type not in hs:
hs[dst_type] = th.zeros(
inputs[dst_type].shape[0], self.out_size
)
####################################################################
# (HIGHLIGHT) Sparse library use hetero sparse matrix to present
# heterogeneous graphs. A dictionary is passed where the key is
# the tuple of (source node type, edge type, destination node type)
# and the value is the sparse matrix contructed from the key on
# global graph. The convolution operation is the multiplication of
# sparse matrix and convolutional layer.
####################################################################
hs[dst_type] = hs[dst_type] + (
A[rel].T @ self.W[str(edge_type)](inputs[src_type])
)
if self.activation:
hs[dst_type] = self.activation(hs[dst_type])
hs[dst_type] = self.dropout(hs[dst_type])

return hs


class EntityClassify(nn.Module):
def __init__(
self,
in_size,
out_size,
relation_names,
embed_layer,
):
super(EntityClassify, self).__init__()
self.in_size = in_size
self.out_size = out_size
self.relation_names = relation_names
self.relation_names.sort()
self.embed_layer = embed_layer

self.layers = nn.ModuleList()
# Input to hidden.
self.layers.append(
HeteroRelationalGraphConv(
self.in_size,
self.in_size,
self.relation_names,
activation=F.relu,
)
)
# Hidden to output.
self.layers.append(
HeteroRelationalGraphConv(
self.in_size,
self.out_size,
self.relation_names,
)
)

def forward(self, A):
h = self.embed_layer()
for layer in self.layers:
h = layer(A, h)
return h


def main(args):
# Load graph data.
if args.dataset == "aifb":
dataset = AIFBDataset()
elif args.dataset == "mutag":
dataset = MUTAGDataset()
elif args.dataset == "bgs":
dataset = BGSDataset()
elif args.dataset == "am":
dataset = AMDataset()
else:
raise ValueError()

g = dataset[0]
category = dataset.predict_category
num_classes = dataset.num_classes
train_mask = g.nodes[category].data.pop("train_mask")
test_mask = g.nodes[category].data.pop("test_mask")
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop("labels")

# Split dataset into train, validate, test.
val_idx = train_idx[: len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5 :]

embed_layer = RelGraphEmbed(
{ntype: g.num_nodes(ntype) for ntype in g.ntypes}, 16
)

# Create model.
model = EntityClassify(
16,
num_classes,
list(set(g.etypes)),
embed_layer,
)

# Optimizer.
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0)

# Construct hetero sparse matrix.
A = {}
for stype, etype, dtype in g.canonical_etypes:
eg = g[stype, etype, dtype]
indices = th.stack(eg.edges("uv"))
A[(stype, etype, dtype)] = dglsp.spmatrix(
indices, shape=(g.num_nodes(stype), g.num_nodes(dtype))
)
###########################################################
# (HIGHLIGHT) Compute the normalized adjacency matrix with
# Sparse Matrix API
###########################################################
D1_hat = dglsp.diag(A[(stype, etype, dtype)].sum(1)) ** -0.5
D2_hat = dglsp.diag(A[(stype, etype, dtype)].sum(0)) ** -0.5
A[(stype, etype, dtype)] = D1_hat @ A[(stype, etype, dtype)] @ D2_hat

# Training loop.
print("start training...")
model.train()
for epoch in range(20):
optimizer.zero_grad()
logits = model(A)[category]
loss = F.cross_entropy(logits[train_idx], labels[train_idx])
loss.backward()
optimizer.step()

train_acc = th.sum(
logits[train_idx].argmax(dim=1) == labels[train_idx]
).item() / len(train_idx)
val_loss = F.cross_entropy(logits[val_idx], labels[val_idx])
val_acc = th.sum(
logits[val_idx].argmax(dim=1) == labels[val_idx]
).item() / len(val_idx)
print(
f"Epoch {epoch:05d} | Train Acc: {train_acc:.4f} | "
f"Train Loss: {loss.item():.4f} | Valid Acc: {val_acc:.4f} | "
f"Valid loss: {val_loss.item():.4f} "
)
print()

model.eval()
logits = model.forward(A)[category]
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
test_acc = th.sum(
logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print(
"Test Acc: {:.4f} | Test loss: {:.4f}".format(
test_acc, test_loss.item()
)
)
print()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RGCN")
parser.add_argument(
"-d", "--dataset", type=str, required=True, help="dataset to use"
)

args = parser.parse_args()
print(args)
main(args)