Transfer Learning — Reuse Pre-trained Models
Training from scratch on ImageNet-scale data takes weeks on hundreds of GPUs. Transfer learning lets you start from a model already trained on millions of examples and fine-tune it on YOUR task in hours with a fraction of the data. This is how 90% of real-world AI models are built.
Transfer Learning with torchvision Pre-trained Models
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# STRATEGY 1: Feature Extraction (backbone frozen, new head only)
# When to use: small dataset (<1000 images), similar domain to ImageNet
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Load ResNet-50 pre-trained on ImageNet (top-1 accuracy: 76.1%)
backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
# Count pre-trained params
total = sum(p.numel() for p in backbone.parameters())
print(f"ResNet-50 total params: {total:,}") # 25.6M
# Freeze ALL backbone weights
for param in backbone.parameters():
param.requires_grad = False
# Replace the final classification head
# ResNet-50 last layer: Linear(2048 → 1000) for 1000 ImageNet classes
n_classes = 5 # your task (e.g., 5 flower species)
backbone.fc = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(2048, 512),
nn.ReLU(),
nn.Linear(512, n_classes),
)
# Only new head has requires_grad=True
trainable = sum(p.numel() for p in backbone.parameters() if p.requires_grad)
print(f"Trainable params: {trainable:,}") # only ~530K vs 25.6M total
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# STRATEGY 2: Full Fine-tuning (all layers, very low lr)
# When to use: large dataset (>10K images), domain shift from ImageNet
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
backbone2 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
backbone2.fc = nn.Linear(2048, n_classes)
# Differential learning rates: backbone needs MUCH lower lr to avoid catastrophic forgetting
optimizer = torch.optim.AdamW([
{"params": backbone2.layer1.parameters(), "lr": 1e-5}, # earliest features
{"params": backbone2.layer2.parameters(), "lr": 1e-5},
{"params": backbone2.layer3.parameters(), "lr": 1e-4},
{"params": backbone2.layer4.parameters(), "lr": 1e-4},
{"params": backbone2.fc.parameters(), "lr": 1e-3}, # new head: highest lr
], weight_decay=1e-4)
# STRATEGY 3: Modern backbones (2024)
# EfficientNet-B4: better accuracy/parameter trade-off than ResNet
effnet = models.efficientnet_b4(weights=models.EfficientNet_B4_Weights.IMAGENET1K_V1)
effnet.classifier = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(1792, n_classes),
)
# Vision Transformer (ViT) — Transformer architecture for images
vit = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
vit.heads = nn.Linear(768, n_classes)
# Quick comparison:
models_comparison = {
"ResNet-50": {"params": "25.6M", "top1": "76.1%", "speed": "fast"},
"EfficientNet-B4": {"params": "19.3M", "top1": "83.4%", "speed": "medium"},
"ViT-B/16": {"params": "86.6M", "top1": "81.1%", "speed": "fast on GPU"},
"ConvNeXt-Base": {"params": "88.6M", "top1": "84.1%", "speed": "medium"},
}
print("\nBackbone Comparison:")
for name, stats in models_comparison.items():
print(f" {name:20s} params={stats['params']} top1={stats['top1']}")Tip
Tip
Practice Transfer Learning Reuse Pretrained Models in small, isolated examples before integrating into larger projects. Breaking concepts into small experiments builds genuine understanding faster than reading alone.
Fine-tune pre-trained. LoRA for efficiency.
Practice Task
Note
Practice Task — (1) Write a working example of Transfer Learning Reuse Pretrained Models from scratch without looking at notes. (2) Modify it to handle an edge case (empty input, null value, or error state). (3) Share your solution in the Priygop community for feedback.
Quick Quiz
Common Mistake
Warning
A common mistake with Transfer Learning Reuse Pretrained Models is skipping edge case testing — empty inputs, null values, and unexpected data types. Always validate boundary conditions to write robust, production-ready ai code.