Mini Project: Train an MLP on MNIST
MNIST is the 'Hello World' of deep learning. 70,000 handwritten digit images. Train an MLP from scratch achieving >98% validation accuracy using everything from this module: MLP architecture, CrossEntropyLoss, AdamW, BatchNorm, Dropout, cosine LR schedule, and early stopping.
Complete MNIST MLP Project
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 1. DATA LOADING
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # MNIST mean and std
])
train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST("./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
print(f"Training samples: {len(train_dataset):,}") # 60,000
print(f"Test samples: {len(test_dataset):,}") # 10,000
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 2. MODEL ARCHITECTURE
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
class DigitClassifier(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Flatten(), # [B, 1, 28, 28] → [B, 784]
nn.Linear(784, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Linear(128, 10), # 10 digit classes
)
# He init for ReLU networks
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DigitClassifier().to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}") # ~500K
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 3. TRAINING
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
loss_fn = nn.CrossEntropyLoss()
best_acc = 0.0
for epoch in range(20):
# Train
model.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
loss = loss_fn(model(imgs), labels)
loss.backward()
optimizer.step()
scheduler.step()
# Evaluate
model.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in test_loader:
imgs, labels = imgs.to(device), labels.to(device)
preds = model(imgs).argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
acc = correct / total
best_acc = max(best_acc, acc)
print(f"Epoch {epoch+1:2d} | Test Accuracy: {acc:.2%}")
print(f"\nBest accuracy: {best_acc:.2%}") # should reach 98-99% on MNISTTip
Tip
Practice Mini Project Train an MLP on MNIST in small, isolated examples before integrating into larger projects. Breaking concepts into small experiments builds genuine understanding faster than reading alone.
Better prompts = better AI output. Structure, examples, and constraints matter.
Practice Task
Note
Practice Task — (1) Write a working example of Mini Project Train an MLP on MNIST from scratch without looking at notes. (2) Modify it to handle an edge case (empty input, null value, or error state). (3) Share your solution in the Priygop community for feedback.
Quick Quiz
Common Mistake
Warning
A common mistake with Mini Project Train an MLP on MNIST is skipping edge case testing — empty inputs, null values, and unexpected data types. Always validate boundary conditions to write robust, production-ready ai code.