Transfer Learning in NLP: Using BERT for Classification

ai machine-learning nlp transformers

BERT made transfer learning practical for NLP. Pre-train once, fine-tune for any task. Here’s how to use it for text classification.

The Transfer Learning Paradigm

Before BERT, each NLP task required training from scratch:

With BERT:

  1. Use pre-trained BERT (trained on massive unlabeled text)
  2. Add a classification head
  3. Fine-tune on your small labeled dataset
  4. Done in hours, not days

Setting Up

pip install transformers torch
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import Trainer, TrainingArguments
import torch

Loading Pre-trained BERT

# Tokenizer converts text to token IDs
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Model with classification head
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels=2  # Binary classification
)

Preparing Data

from datasets import load_dataset

# Load a dataset
dataset = load_dataset('imdb')

# Tokenize
def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        padding='max_length',
        truncation=True,
        max_length=512
    )

tokenized_datasets = dataset.map(tokenize_function, batched=True)

Fine-Tuning

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    evaluation_strategy='epoch',
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],
)

trainer.train()

Making Predictions

# Single prediction
text = "This movie was absolutely fantastic!"
inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)

with torch.no_grad():
    outputs = model(**inputs)
    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
    
print(f"Positive: {predictions[0][1]:.2%}, Negative: {predictions[0][0]:.2%}")

Using Pipelines (Easier)

from transformers import pipeline

classifier = pipeline(
    'sentiment-analysis',
    model=model,
    tokenizer=tokenizer
)

result = classifier("I love this product!")
print(result)  # [{'label': 'POSITIVE', 'score': 0.9998}]

Multi-Class Classification

from transformers import BertForSequenceClassification

# For multiple classes
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels=5,  # 5 categories
    id2label={0: 'sports', 1: 'tech', 2: 'politics', 3: 'entertainment', 4: 'business'},
    label2id={'sports': 0, 'tech': 1, 'politics': 2, 'entertainment': 3, 'business': 4}
)

Multi-Label Classification

from transformers import BertForSequenceClassification
import torch.nn as nn

# Multiple labels per input
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels=10,
    problem_type='multi_label_classification'
)

# Use BCEWithLogitsLoss in training
# Predictions use sigmoid, not softmax
predictions = torch.sigmoid(outputs.logits)

Handling Long Text

BERT can only handle 512 tokens. For longer texts:

Truncation

# Just take first 512 tokens
tokenizer(text, max_length=512, truncation=True)

Sliding Window

def classify_long_text(text, model, tokenizer, window_size=510, stride=256):
    encoding = tokenizer(text, return_tensors='pt')
    input_ids = encoding['input_ids'][0]
    
    scores = []
    for i in range(0, len(input_ids), stride):
        window = input_ids[i:i + window_size]
        if len(window) < 50:  # Skip tiny windows
            continue
        
        with torch.no_grad():
            outputs = model(window.unsqueeze(0))
            scores.append(outputs.logits)
    
    # Average scores across windows
    return torch.stack(scores).mean(dim=0)

Use Longformer

from transformers import LongformerForSequenceClassification

model = LongformerForSequenceClassification.from_pretrained(
    'allenai/longformer-base-4096',
    num_labels=2
)
# Handles up to 4096 tokens

Model Variants

ModelParametersUse Case
bert-base-uncased110MGeneral purpose
bert-large-uncased340MHigher accuracy, more resources
distilbert-base-uncased66MFaster, smaller, 97% of accuracy
roberta-base125MOften better than BERT
albert-base-v212MVery lightweight

Practical Tips

Learning Rate

# BERT-specific learning rates
# Lower than typical neural networks
training_args = TrainingArguments(
    learning_rate=2e-5,  # 5e-5, 3e-5, 2e-5 are common
    ...
)

Epochs

Fine-tuning usually needs 2-4 epochs. More can cause overfitting.

Batch Size

Smaller batches (8-16) with gradient accumulation:

training_args = TrainingArguments(
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,  # Effective batch size 32
    ...
)

Freezing Layers

For small datasets, freeze early layers:

for param in model.bert.embeddings.parameters():
    param.requires_grad = False
for layer in model.bert.encoder.layer[:6]:
    for param in layer.parameters():
        param.requires_grad = False

Saving and Loading

# Save
model.save_pretrained('./my_model')
tokenizer.save_pretrained('./my_model')

# Load
model = BertForSequenceClassification.from_pretrained('./my_model')
tokenizer = BertTokenizer.from_pretrained('./my_model')

Final Thoughts

BERT democratized NLP. With pre-trained models, you can build state-of-the-art classifiers with hundreds of labeled examples, not millions.

Start with distilbert-base-uncased for speed. Move to larger models if accuracy matters more than inference time.

The Hugging Face ecosystem makes this accessible. Experiment and iterate quickly.


Pre-training is the new ImageNet moment for NLP.

All posts