Graph Neural Networks
Graph Convolutional Network
import torch
import torch.nn as nn
import torch.nn.functional as F
class GCNLayer(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, x, adj):
deg = adj.sum(dim=1, keepdim=True)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
adj_norm = deg_inv_sqrt * adj * deg_inv_sqrt.T
support = self.linear(x)
output = torch.mm(adj_norm, support)
return output
class GCN(nn.Module):
def __init__(self, n_features, n_hidden, n_classes):
super().__init__()
self.conv1 = GCNLayer(n_features, n_hidden)
self.conv2 = GCNLayer(n_hidden, n_classes)
def forward(self, x, adj):
x = F.relu(self.conv1(x, adj))
x = self.conv2(x, adj)
return F.log_softmax(x, dim=1)
gcn = GCN(n_features=1433, n_hidden=16, n_classes=7)
Graph Attention Network
class GATLayer(nn.Module):
def __init__(self, in_features, out_features, n_heads=8, dropout=0.2):
super().__init__()
self.n_heads = n_heads
self.d_k = out_features // n_heads
self.W_q = nn.Linear(in_features, out_features)
self.W_k = nn.Linear(in_features, out_features)
self.W_v = nn.Linear(in_features, out_features)
self.dropout = nn.Dropout(dropout)
self.out_proj = nn.Linear(out_features, out_features)
def forward(self, x, adj):
batch_size = x.shape[0]
Q = self.W_q(x).view(batch_size, self.n_heads, self.d_k)
K = self.W_k(x).view(batch_size, self.n_heads, self.d_k)
V = self.W_v(x).view(batch_size, self.n_heads, self.d_k)
scores = torch.bmm(Q, K.transpose(1, 2)) / (self.d_k ** 0.5)
scores = scores.masked_fill(adj.unsqueeze(1) == 0, -1e9)
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
context = torch.bmm(attn, V)
context = context.view(batch_size, -1)
return self.out_proj(context)
class GAT(nn.Module):
def __init__(self, n_features, n_hidden, n_classes, n_heads=8):
super().__init__()
self.conv1 = GATLayer(n_features, n_hidden, n_heads)
self.conv2 = GATLayer(n_hidden, n_classes, n_heads)
def forward(self, x, adj):
x = F.elu(self.conv1(x, adj))
x = self.conv2(x, adj)
return x
gat = GAT(n_features=1433, n_hidden=8, n_classes=7, n_heads=8)
GraphSAGE
class GraphSAGELayer(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features * 2, out_features)
def aggregate(self, x, adj):
support = torch.mm(adj, x)
return support
def forward(self, x, adj):
neighbor_sum = self.aggregate(x, adj)
combined = torch.cat([x, neighbor_sum], dim=1)
return F.relu(self.linear(combined))
class GraphSAGE(nn.Module):
def __init__(self, n_features, n_hidden, n_classes):
super().__init__()
self.conv1 = GraphSAGELayer(n_features, n_hidden)
self.conv2 = GraphSAGELayer(n_hidden, n_classes)
def forward(self, x, adj):
x = self.conv1(x, adj)
x = self.conv2(x, adj)
return x
graphsage = GraphSAGE(n_features=1433, n_hidden=64, n_classes=7)
Best Practices
- Use residual connections for deep GNNs
- Apply dropout to prevent over-smoothing
- Use batch normalization for training stability
- Consider graph pooling for graph-level tasks
- Use skip connections to preserve node features