Attention Is All You Need: Breaking Down Transformers

ai machine-learning nlp transformers

“Attention Is All You Need” might be the most important machine learning paper of the decade. Published by Google researchers in 2017, it introduced the Transformer architecture that now powers BERT, GPT, and most modern NLP systems.

Let’s break it down.

The Problem with RNNs

Recurrent networks process sequences step-by-step. This creates two problems:

  1. Sequential bottleneck: Can’t parallelize across time steps
  2. Long-range dependencies: Information must flow through many steps

The Transformer solves both with a radical idea: process all positions simultaneously using attention.

The Core Innovation: Self-Attention

Self-attention lets each position attend to all other positions in the sequence. No recurrence, no convolutions—just attention.

Intuition

Given the sentence “The cat sat on the mat because it was tired”:

The Math

Self-attention computes three vectors for each position:

Attention(Q, K, V) = softmax(QK^T / √d_k) V

Where:

Code Example

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    attention_weights = F.softmax(scores, dim=-1)
    return torch.matmul(attention_weights, V)

Multi-Head Attention

Single attention has limited expressiveness. Multi-head attention runs multiple attention operations in parallel:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # Linear projections and reshape for multi-head
        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Attention
        attn_output = scaled_dot_product_attention(Q, K, V, mask)
        
        # Concatenate heads and project
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, d_model)
        return self.W_o(attn_output)

Different heads can learn different types of relationships:

The Complete Architecture

Encoder

Input Embedding + Positional Encoding

[Multi-Head Self-Attention]

Add & Normalize (residual connection)

[Feed-Forward Network]

Add & Normalize

(Repeat N times)

Decoder

Output Embedding + Positional Encoding

[Masked Multi-Head Self-Attention]  ← Can only attend to previous positions

Add & Normalize

[Multi-Head Cross-Attention]  ← Attends to encoder output

Add & Normalize

[Feed-Forward Network]

Add & Normalize

(Repeat N times)

Linear + Softmax → Output probabilities

Positional Encoding

Self-attention is position-independent—it treats “cat sat” the same as “sat cat”. Positional encodings add position information.

The paper uses sinusoidal functions:

def positional_encoding(max_len, d_model):
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
    
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    
    return pe

Learned positional embeddings work equally well (used in BERT).

Why It Works

Parallelization

All positions computed simultaneously. Training is dramatically faster than RNNs.

Constant Path Length

Any two positions are connected with O(1) operations. In RNNs, distant positions require O(n) steps.

Interpretable Attention

Attention weights show what the model focuses on—useful for debugging and understanding.

The Impact

The Transformer enabled:

In hindsight, the title was prophetic. Attention really is all you need.

Practical Advice

If you’re implementing Transformers:

  1. Use established libraries: Hugging Face Transformers, PyTorch’s nn.Transformer
  2. Start with pre-trained models: Training from scratch is expensive
  3. Pay attention to attention patterns: They reveal what the model learns
  4. Watch memory usage: Attention is O(n²) in sequence length

Final Thoughts

The Transformer is elegant in its simplicity. The core mechanism—scaled dot-product attention—is just a few lines of code. Yet it scales to billions of parameters and achieves remarkable performance.

Understanding Transformers is now essential for anyone working in AI. This paper is where modern NLP began.


Simple ideas, scaled massively.

All posts