Model Saving, Loading & Checkpointing
Saving and loading models correctly is essential — a wrong checkpoint can cost hours of training. Always save the full training state (model + optimizer + epoch + loss) for resumable training, and just the model weights for inference deployment.
Checkpointing Patterns
import torch
import torch.nn as nn
model = nn.Linear(10, 2)
optimizer = torch.optim.Adam(model.parameters())
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# PATTERN 1: Save ONLY WEIGHTS (.pt / .pth)
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Use for: deployment, sharing models, HuggingFace Hub
torch.save(model.state_dict(), "model_weights.pt")
# Loading:
new_model = nn.Linear(10, 2)
new_model.load_state_dict(torch.load("model_weights.pt", map_location="cpu"))
new_model.eval() # ALWAYS call eval() before inference
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# PATTERN 2: Save FULL TRAINING STATE (for resumable training)
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
checkpoint = {
"epoch": 42,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"val_loss": 0.234,
"config": {"lr": 1e-3, "batch_size": 64, "architecture": "resnet50"},
}
torch.save(checkpoint, "checkpoint_epoch42.pt")
# Resuming training:
ckpt = torch.load("checkpoint_epoch42.pt", map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
start_epoch = ckpt["epoch"] + 1 # resume from next epoch
print(f"Resuming from epoch {start_epoch}, val_loss={ckpt['val_loss']}")
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# PATTERN 3: Export to ONNX (cross-platform deployment)
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
dummy_input = torch.randn(1, 10)
torch.onnx.export(
model, dummy_input, "model.onnx",
input_names=["input"], output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
opset_version=17,
)
# ONNX inference (CPU or GPU, framework-agnostic):
import onnxruntime as ort
import numpy as np
sess = ort.InferenceSession("model.onnx", providers=["CUDAExecutionProvider"])
result = sess.run(["output"], {"input": np.random.randn(4, 10).astype(np.float32)})
print(f"ONNX output shape: {result[0].shape}") # [4, 2]
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# HuggingFace push_to_hub — share with the world
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# from transformers import AutoModelForSequenceClassification
# model_hf = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
# model_hf.push_to_hub("your-username/my-bert-classifier")
# model_hf = AutoModelForSequenceClassification.from_pretrained("your-username/my-bert-classifier")Tip
Tip
Practice Model Saving Loading Checkpointing in small, isolated examples before integrating into larger projects. Breaking concepts into small experiments builds genuine understanding faster than reading alone.
Technical diagram.
Practice Task
Note
Practice Task — (1) Write a working example of Model Saving Loading Checkpointing from scratch without looking at notes. (2) Modify it to handle an edge case (empty input, null value, or error state). (3) Share your solution in the Priygop community for feedback.
Quick Quiz
Common Mistake
Warning
A common mistake with Model Saving Loading Checkpointing is skipping edge case testing — empty inputs, null values, and unexpected data types. Always validate boundary conditions to write robust, production-ready ai code.