Training Loop — Complete Implementation
A production training loop is more than just forward → backward → step. It includes proper train/eval mode switching, gradient clipping, learning rate scheduling, metric tracking, checkpointing, and early stopping. This template is the foundation for all future projects.
Production Training Loop Template
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from typing import Optional
def train_model(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
n_epochs: int = 50,
lr: float = 1e-3,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
checkpoint_path: Optional[str] = "best_model.pt",
) -> dict:
"""
Complete training loop with:
- Train/eval mode switching
- Gradient clipping
- LR scheduling
- Validation loop
- Best model checkpointing
- Early stopping
"""
model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
loss_fn = nn.CrossEntropyLoss()
history = {"train_loss": [], "val_loss": [], "val_acc": []}
best_val_loss = float("inf")
patience_counter = 0
PATIENCE = 5
for epoch in range(n_epochs):
# ─── TRAIN PHASE ───────────────────────────────────────
model.train() # enables Dropout + BatchNorm train mode
train_loss = 0.0
for X_batch, y_batch in train_loader:
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
optimizer.zero_grad() # 1. clear gradients
logits = model(X_batch) # 2. forward pass
loss = loss_fn(logits, y_batch) # 3. compute loss
loss.backward() # 4. backward pass
torch.nn.utils.clip_grad_norm_( # 5. clip gradients
model.parameters(), max_norm=1.0)
optimizer.step() # 6. update weights
train_loss += loss.item()
scheduler.step() # update learning rate after each epoch
avg_train_loss = train_loss / len(train_loader)
# ─── VALIDATION PHASE ──────────────────────────────────
model.eval() # disables Dropout, uses running stats in BatchNorm
val_loss, correct, total = 0.0, 0, 0
with torch.no_grad(): # no gradient computation — saves memory
for X_batch, y_batch in val_loader:
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
logits = model(X_batch)
val_loss += loss_fn(logits, y_batch).item()
correct += (logits.argmax(dim=1) == y_batch).sum().item()
total += y_batch.size(0)
avg_val_loss = val_loss / len(val_loader)
val_accuracy = correct / total
# ─── LOGGING ───────────────────────────────────────────
history["train_loss"].append(avg_train_loss)
history["val_loss"].append(avg_val_loss)
history["val_acc"].append(val_accuracy)
current_lr = scheduler.get_last_lr()[0]
print(f"Epoch [{epoch+1:3d}/{n_epochs}] "
f"train_loss={avg_train_loss:.4f} "
f"val_loss={avg_val_loss:.4f} "
f"val_acc={val_accuracy:.2%} "
f"lr={current_lr:.2e}")
# ─── CHECKPOINTING ─────────────────────────────────────
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience_counter = 0
if checkpoint_path:
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"val_loss": best_val_loss,
}, checkpoint_path)
print(f" ✅ New best model saved (val_loss={best_val_loss:.4f})")
else:
patience_counter += 1
if patience_counter >= PATIENCE:
print(f"Early stopping at epoch {epoch+1} (no improvement for {PATIENCE} epochs)")
break
return history
# Usage:
# X = torch.randn(1000, 784)
# y = torch.randint(0, 10, (1000,))
# dataset = TensorDataset(X, y)
# train_ds, val_ds = torch.utils.data.random_split(dataset, [800, 200])
# train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
# val_loader = DataLoader(val_ds, batch_size=64)
# model = MLP(784, [256, 128], 10)
# history = train_model(model, train_loader, val_loader)Tip
Tip
Practice Training Loop Complete Implementation in small, isolated examples before integrating into larger projects. Breaking concepts into small experiments builds genuine understanding faster than reading alone.
Better prompts = better AI output. Structure, examples, and constraints matter.
Practice Task
Note
Practice Task — (1) Write a working example of Training Loop Complete Implementation from scratch without looking at notes. (2) Modify it to handle an edge case (empty input, null value, or error state). (3) Share your solution in the Priygop community for feedback.
Quick Quiz
Common Mistake
Warning
A common mistake with Training Loop Complete Implementation is skipping edge case testing — empty inputs, null values, and unexpected data types. Always validate boundary conditions to write robust, production-ready ai code.