Introduction to Graph Neural Networks (GNNs)
Some data is naturally structured as graphs: social networks, molecules, knowledge bases, recommendation systems. Graph Neural Networks (GNNs) learn from this structure directly. Here’s an introduction.
Why Graphs?
Data That Doesn’t Fit in Tables
| Data Type | Structure | Example |
|---|---|---|
| Tabular | Rows × Columns | Customer records |
| Sequential | Time series | Stock prices |
| Grid | 2D/3D arrays | Images |
| Graph | Nodes + Edges | Social networks |
Graph Examples
Social Network:
Alice ─── Bob
│ │
└── Carol ┘
Molecules:
C ─ C ═ C ─ O
│ │
H H
Knowledge Graph:
Paris ─[capital_of]─ France
│ │
[located_in] [in_continent]
│ │
Europe ══════════════════
Graph Basics
Structure
# Nodes (vertices)
nodes = {'A', 'B', 'C', 'D'}
# Edges (connections)
edges = [('A', 'B'), ('B', 'C'), ('A', 'C'), ('C', 'D')]
# Node features (optional)
features = {
'A': [0.1, 0.5, 0.3], # Feature vector
'B': [0.2, 0.4, 0.2],
# ...
}
Adjacency Matrix
A B C D
A [ 0 1 1 0 ]
B [ 1 0 1 0 ]
C [ 1 1 0 1 ]
D [ 0 0 1 0 ]
The Key Insight: Message Passing
GNNs learn by passing information between connected nodes:
Step 1: Each node has features
A[0.1, 0.5] B[0.2, 0.4]
Step 2: Nodes aggregate neighbor information
A' = f(A, neighbors(A)) = f(A, {B, C})
Step 3: Update node representations
A_new = update(A, aggregated_messages)
Step 4: Repeat for multiple layers
Deeper layers = wider neighborhood
GNN Architectures
Graph Convolutional Network (GCN)
# Simplified GCN layer
h_new = σ(A_norm @ h @ W)
# Where:
# A_norm: Normalized adjacency matrix
# h: Node features
# W: Learnable weights
# σ: Activation function (ReLU)
Each layer aggregates neighbor features.
Graph Attention Network (GAT)
Attention weights for neighbor importance:
# Attention between nodes i and j
α_ij = softmax(attention_score(h_i, h_j))
# Aggregation with attention
h_new = σ(Σ α_ij * W * h_j)
Some neighbors matter more than others.
GraphSAGE
Sample and aggregate:
# Sample neighbors (for scalability)
sampled = sample(neighbors(node), k=10)
# Aggregate
neighbor_embedding = aggregate(sampled)
# Combine with self
h_new = combine(h, neighbor_embedding)
Works on large graphs by sampling.
Implementation with PyTorch Geometric
Installation
pip install torch-geometric
Basic GNN
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
# Load dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
class GCN(torch.nn.Module):
def __init__(self, num_features, hidden_channels, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, num_classes)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return x
model = GCN(dataset.num_features, 16, dataset.num_classes)
Training
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
for epoch in range(200):
loss = train()
Tasks on Graphs
Node Classification
Predict labels for nodes (e.g., paper topics in citation network):
# Output: probability distribution per node
predictions = model(data.x, data.edge_index)
node_labels = predictions.argmax(dim=1)
Graph Classification
Predict labels for entire graphs (e.g., molecule properties):
from torch_geometric.nn import global_mean_pool
class GraphClassifier(torch.nn.Module):
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = self.conv2(x, edge_index)
# Pool nodes to graph-level representation
x = global_mean_pool(x, batch)
x = self.classifier(x)
return x
Link Prediction
Predict missing edges (e.g., friend recommendations):
# Score potential edges
def link_score(z_i, z_j):
return (z_i * z_j).sum(dim=-1)
Node Embedding
Learn useful representations:
# Use GNN output as embeddings
embeddings = model.get_embeddings(data.x, data.edge_index)
# Use for downstream tasks, visualization, similarity search
Applications
Social Networks
- Friend recommendation
- Community detection
- Influence prediction
Molecules and Drug Discovery
- Property prediction
- Drug-target interaction
- Molecule generation
Recommendations
- User-item graphs
- Session-based recommendations
- Knowledge-graph enhanced
Knowledge Graphs
- Entity typing
- Relation prediction
- Question answering
Traffic and Maps
- Traffic prediction
- Route optimization
- ETA estimation
Challenges
Scalability
Large graphs don’t fit in memory:
- Sampling (GraphSAGE)
- Mini-batching neighborhoods
- Distributed training
Over-smoothing
Deep GNNs make all nodes similar:
- Skip connections
- Layer normalization
- Jumping knowledge
Heterogeneous Graphs
Different node and edge types:
- Heterogeneous GNNs
- Relation-specific transforms
Final Thoughts
GNNs unlock machine learning for relational data. If your data has connections—social, molecular, knowledge—GNNs should be in your toolkit.
Start with PyTorch Geometric or DGL. The abstractions make experimentation accessible.
Everything is connected. Now we can learn from it.