nn.Module — Building Custom Models
nn.Module is the base class for all PyTorch neural network components. Every layer (nn.Linear, nn.Conv2d, nn.TransformerEncoder) inherits from it, and so should every model you write. It gives you automatic parameter tracking, GPU movement with .to(device), and train/eval mode switching.
nn.Module — Everything You Need to Know
import torch
import torch.nn as nn
from typing import Optional
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Basic nn.Module structure
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
class ResidualBlock(nn.Module):
"""
Residual connection: output = F(x) + x
This is the core insight of ResNet (2015) that enabled 100+ layer networks.
Without skip connections: deep network gradients vanish.
With skip connections: gradient highway from output to input.
"""
def __init__(self, dim: int, dropout: float = 0.1):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.ff = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim),
)
self.norm2 = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Pre-norm residual (modern Transformers use this over post-norm)
x = x + self.dropout(self.ff(self.norm1(x))) # sublayer 1
return x # second sublayer added by Transformer, simplified here
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Composing modules into a full model
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
class DeepResNet(nn.Module):
def __init__(self, input_dim: int, model_dim: int, n_layers: int, output_dim: int):
super().__init__()
self.input_proj = nn.Linear(input_dim, model_dim)
self.layers = nn.ModuleList( # ← use ModuleList, NOT a plain Python list
[ResidualBlock(model_dim) for _ in range(n_layers)]
) # nn.ModuleList registers all sub-modules → .parameters() includes them
self.output_proj = nn.Linear(model_dim, output_dim)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.input_proj(x)
for layer in self.layers:
x = layer(x)
return self.output_proj(x)
def count_parameters(self) -> dict:
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {"total": total, "trainable": trainable}
model = DeepResNet(input_dim=768, model_dim=512, n_layers=6, output_dim=10)
print(model.count_parameters())
print(f"\nModel architecture:")
print(model)
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Freezing and unfreezing parameters (for fine-tuning)
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def freeze_backbone(model: nn.Module, unfreeze_last_n: Optional[int] = None):
"""Freeze all layers, optionally unfreeze last N."""
for param in model.parameters():
param.requires_grad = False # freeze everything
if unfreeze_last_n and hasattr(model, 'layers'):
for layer in list(model.layers)[-unfreeze_last_n:]:
for param in layer.parameters():
param.requires_grad = True # unfreeze last N layers
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({trainable/total:.1%})")
freeze_backbone(model, unfreeze_last_n=2)Tip
Tip
Practice nnModule Building Custom Models in small, isolated examples before integrating into larger projects. Breaking concepts into small experiments builds genuine understanding faster than reading alone.
Deep Learning ⊂ Machine Learning ⊂ Artificial Intelligence
Practice Task
Note
Practice Task — (1) Write a working example of nnModule Building Custom Models from scratch without looking at notes. (2) Modify it to handle an edge case (empty input, null value, or error state). (3) Share your solution in the Priygop community for feedback.
Quick Quiz
Common Mistake
Warning
A common mistake with nnModule Building Custom Models is skipping edge case testing — empty inputs, null values, and unexpected data types. Always validate boundary conditions to write robust, production-ready ai code.