Graph Neural Networks
Data isn't always tabular. Social networks, molecules, knowledge graphs, and citation networks are inherently graph-structured. Graph Neural Networks (GNNs) learn on this structure by aggregating information from neighbors Β the message passing paradigm.
Graph Message Passing
Graph Fundamentals
A graph G = (V, E) has nodes V and edges E. Each node has features, and the task might be node classification, link prediction, or graph classification.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, global_mean_pool
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import numpy as np
import warnings
warnings.filterwarnings('ignore')
Creating Graph Data
# Create a simple graph
num_nodes = 100
num_features = 16
num_classes = 3
# Node features
x = torch.randn(num_nodes, num_features)
# Edges (source, target) Β random graph
edge_index = torch.randint(0, num_nodes, (2, 300))
# Labels for node classification
y = torch.randint(0, num_classes, (num_nodes,))
# Train/test mask
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[:70] = True
test_mask[70:] = True
data = Data(x=x, edge_index=edge_index, y=y,
train_mask=train_mask, test_mask=test_mask)
print(f"Graph: {data.num_nodes} nodes, {data.num_edges} edges")
print(f"Node features: {data.num_node_features}")
print(f"Classes: {data.num_classes}")
Graph Convolutional Network (GCN)
class GCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.conv3 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.conv2(x, edge_index))
x = self.conv3(x, edge_index)
return x
model = GCN(num_features, 32, num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# Training
def train(model, data):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
def test(model, data):
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
accs = []
for mask in [data.train_mask, data.test_mask]:
correct = pred[mask] == data.y[mask]
accs.append(correct.float().mean().item())
return accs
for epoch in range(200):
loss = train(model, data)
train_acc, test_acc = test(model, data)
if (epoch + 1) % 50 == 0:
print(f"Epoch {epoch+1}: Loss={loss:.4f}, Train Acc={train_acc:.4f}, Test Acc={test_acc:.4f}")
Graph Attention Network (GAT)
class GAT(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, concat=True)
self.conv2 = GATConv(hidden_channels * heads, hidden_channels, heads=heads, concat=True)
self.conv3 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False)
def forward(self, x, edge_index):
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.5, training=self.training)
x = F.elu(self.conv2(x, edge_index))
x = self.conv3(x, edge_index)
return x
gat_model = GAT(num_features, 8, num_classes, heads=4)
print(f"GAT parameters: {sum(p.numel() for p in gat_model.parameters()):,}")
GraphSAGE
class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, hidden_channels)
self.conv3 = SAGEConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = self.conv3(x, edge_index)
return x
sage_model = GraphSAGE(num_features, 32, num_classes)
print("GraphSAGE: learns from sampled neighbors, scalable to large graphs")
Link Prediction
class LinkPredictor(nn.Module):
def __init__(self, in_channels, hidden_channels):
super().__init__()
self.encoder = GCNConv(in_channels, hidden_channels)
self.decoder = nn.Sequential(
nn.Linear(hidden_channels * 2, hidden_channels),
nn.ReLU(),
nn.Linear(hidden_channels, 1)
)
def encode(self, x, edge_index):
return F.relu(self.encoder(x, edge_index))
def decode(self, z, edge_index):
src, dst = edge_index
h = torch.cat([z[src], z[dst]], dim=1)
return torch.sigmoid(self.decoder(h))
# Create positive and negative edges
pos_edge_index = edge_index
neg_edge_index = torch.randint(0, num_nodes, pos_edge_index.shape)
# Split into train/test
n_pos = pos_edge_index.size(1)
perm = torch.randperm(n_pos)
train_pos = pos_edge_index[:, perm[:int(0.8 * n_pos)]]
test_pos = pos_edge_index[:, perm[int(0.8 * n_pos):]]
link_model = LinkPredictor(num_features, 32)
print(f"Link predictor ready: {link_model}")
Message Passing Framework
from torch_geometric.nn import MessagePassing
class CustomGNN(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # aggregation: add, mean, max
self.mlp = nn.Sequential(
nn.Linear(in_channels, out_channels),
nn.ReLU(),
nn.Linear(out_channels, out_channels)
)
def forward(self, x, edge_index):
return self.propagate(edge_index, x=x)
def message(self, x_j):
# x_j: features of neighbor nodes
return self.mlp(x_j)
def update(self, aggr_out):
return aggr_out
custom_gnn = CustomGNN(num_features, 32)
out = custom_gnn(data.x, data.edge_index)
print(f"Custom GNN output: {out.shape}")
Graph Classification
# Create multiple graphs
graphs = []
for _ in range(100):
n = np.random.randint(10, 50)
edge_index = torch.randint(0, n, (2, n * 2))
x = torch.randn(n, num_features)
y = torch.randint(0, 2, (1,))
graph = Data(x=x, edge_index=edge_index, y=y,
batch=torch.zeros(n, dtype=torch.long))
graphs.append(graph)
class GraphClassifier(nn.Module):
def __init__(self, in_channels, hidden_channels, num_classes):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.classifier = nn.Linear(hidden_channels, num_classes)
def forward(self, x, edge_index, batch):
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
# Global pooling
x = global_mean_pool(x, batch)
return self.classifier(x)
graph_model = GraphClassifier(num_features, 32, 2)
print("Graph classifier with global mean pooling")
Best Practices
- Normalize adjacency matrix Β symmetric normalization for GCN
- Use skip connections Β for deep GNNs (>3 layers)
- Watch for oversmoothing Β deep GCNs converge all node representations
- Mini-batch with neighbor sampling Β for large graphs (GraphSAGE)
- Evaluate node/link/graph tasks separately Β different metrics apply
- Handle heterogeneity Β different node/edge types need specialized architectures
Summary
GNNs learn on graph-structured data through message passing. GCN, GAT, and GraphSAGE are foundational architectures for node classification, link prediction, and graph classification. Master message passing and neighborhood aggregation to work with social networks, molecules, and knowledge graphs.