Introduction to Graph Neural Networks (GNNs)

ai machine-learning

Some data is naturally structured as graphs: social networks, molecules, knowledge bases, recommendation systems. Graph Neural Networks (GNNs) learn from this structure directly. Here’s an introduction.

Why Graphs?

Data That Doesn’t Fit in Tables

Data TypeStructureExample
TabularRows × ColumnsCustomer records
SequentialTime seriesStock prices
Grid2D/3D arraysImages
GraphNodes + EdgesSocial networks

Graph Examples

Social Network:
    Alice ─── Bob
      │        │
      └── Carol ┘

Molecules:
    C ─ C ═ C ─ O
    │           │
    H           H

Knowledge Graph:
    Paris ─[capital_of]─ France
      │                     │
   [located_in]         [in_continent]
      │                     │
    Europe ══════════════════

Graph Basics

Structure

# Nodes (vertices)
nodes = {'A', 'B', 'C', 'D'}

# Edges (connections)
edges = [('A', 'B'), ('B', 'C'), ('A', 'C'), ('C', 'D')]

# Node features (optional)
features = {
    'A': [0.1, 0.5, 0.3],  # Feature vector
    'B': [0.2, 0.4, 0.2],
    # ...
}

Adjacency Matrix

    A  B  C  D
A [ 0  1  1  0 ]
B [ 1  0  1  0 ]
C [ 1  1  0  1 ]
D [ 0  0  1  0 ]

The Key Insight: Message Passing

GNNs learn by passing information between connected nodes:

Step 1: Each node has features
         A[0.1, 0.5]  B[0.2, 0.4]

Step 2: Nodes aggregate neighbor information
         A' = f(A, neighbors(A)) = f(A, {B, C})

Step 3: Update node representations
         A_new = update(A, aggregated_messages)

Step 4: Repeat for multiple layers
         Deeper layers = wider neighborhood

GNN Architectures

Graph Convolutional Network (GCN)

# Simplified GCN layer
h_new = σ(A_norm @ h @ W)

# Where:
# A_norm: Normalized adjacency matrix
# h: Node features
# W: Learnable weights
# σ: Activation function (ReLU)

Each layer aggregates neighbor features.

Graph Attention Network (GAT)

Attention weights for neighbor importance:

# Attention between nodes i and j
α_ij = softmax(attention_score(h_i, h_j))

# Aggregation with attention
h_new = σ(Σ α_ij * W * h_j)

Some neighbors matter more than others.

GraphSAGE

Sample and aggregate:

# Sample neighbors (for scalability)
sampled = sample(neighbors(node), k=10)

# Aggregate
neighbor_embedding = aggregate(sampled)

# Combine with self
h_new = combine(h, neighbor_embedding)

Works on large graphs by sampling.

Implementation with PyTorch Geometric

Installation

pip install torch-geometric

Basic GNN

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid

# Load dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

class GCN(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_classes):
        super().__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)
    
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GCN(dataset.num_features, 16, dataset.num_classes)

Training

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

for epoch in range(200):
    loss = train()

Tasks on Graphs

Node Classification

Predict labels for nodes (e.g., paper topics in citation network):

# Output: probability distribution per node
predictions = model(data.x, data.edge_index)
node_labels = predictions.argmax(dim=1)

Graph Classification

Predict labels for entire graphs (e.g., molecule properties):

from torch_geometric.nn import global_mean_pool

class GraphClassifier(torch.nn.Module):
    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        # Pool nodes to graph-level representation
        x = global_mean_pool(x, batch)
        x = self.classifier(x)
        return x

Predict missing edges (e.g., friend recommendations):

# Score potential edges
def link_score(z_i, z_j):
    return (z_i * z_j).sum(dim=-1)

Node Embedding

Learn useful representations:

# Use GNN output as embeddings
embeddings = model.get_embeddings(data.x, data.edge_index)
# Use for downstream tasks, visualization, similarity search

Applications

Social Networks

Molecules and Drug Discovery

Recommendations

Knowledge Graphs

Traffic and Maps

Challenges

Scalability

Large graphs don’t fit in memory:

Over-smoothing

Deep GNNs make all nodes similar:

Heterogeneous Graphs

Different node and edge types:

Final Thoughts

GNNs unlock machine learning for relational data. If your data has connections—social, molecular, knowledge—GNNs should be in your toolkit.

Start with PyTorch Geometric or DGL. The abstractions make experimentation accessible.


Everything is connected. Now we can learn from it.

All posts