Vision Transformers (ViT) — Transformers for Images
ViT (Vision Transformer, 2020) proved Transformer architecture works for images by splitting an image into 16×16 patches and treating them like tokens. With enough data/pre-training, ViT matches or beats CNNs while having no inductive biases about locality or translation invariance.
Vision Transformer — Patch Embedding
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
"""
ViT key insight: split image into 16x16 patches, embed each as a token.
An image of 224×224 → 14×14 = 196 patches → 196 tokens (like 196 words).
"""
def __init__(self, img_size: int = 224, patch_size: int = 16,
in_channels: int = 3, embed_dim: int = 768):
super().__init__()
self.n_patches = (img_size // patch_size) ** 2 # 196 for 224/16
# Conv2d with stride=patch_size elegantly performs patch extraction + projection
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, C, H, W] → [B, embed_dim, H/P, W/P] → [B, n_patches, embed_dim]
x = self.proj(x) # [B, 768, 14, 14]
x = x.flatten(2) # [B, 768, 196]
x = x.transpose(1, 2) # [B, 196, 768] — patches as sequence
return x
class VisionTransformer(nn.Module):
"""Simplified ViT-Base architecture."""
def __init__(self, img_size=224, patch_size=16, embed_dim=768,
depth=12, n_heads=12, mlp_ratio=4, n_classes=1000):
super().__init__()
n_patches = (img_size // patch_size) ** 2
self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # learned [CLS] token
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim)) # learnable PE
self.pos_drop = nn.Dropout(p=0.1)
# Stack of Transformer encoder blocks
self.blocks = nn.ModuleList([
TransformerEncoderBlock(embed_dim, n_heads, mlp_ratio)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, n_classes)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B = x.shape[0]
x = self.patch_embed(x) # [B, 196, 768]
cls = self.cls_token.expand(B, -1, -1) # [B, 1, 768]
x = torch.cat([cls, x], dim=1) # [B, 197, 768] — prepend CLS
x = self.pos_drop(x + self.pos_embed) # add positional embedding
for block in self.blocks:
x = block(x)
x = self.norm(x)
return self.head(x[:, 0]) # CLS token → classification
class TransformerEncoderBlock(nn.Module):
"""Standard Transformer encoder block: LayerNorm → Attention → LayerNorm → MLP."""
def __init__(self, dim: int, n_heads: int, mlp_ratio: float = 4.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, n_heads, dropout=0.1, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
mlp_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim), nn.GELU(), nn.Dropout(0.1), nn.Linear(mlp_dim, dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
normed = self.norm1(x)
attn_out, _ = self.attn(normed, normed, normed) # self-attention
x = x + attn_out # residual
x = x + self.mlp(self.norm2(x)) # FFN + residual
return x
# Usage
vit = VisionTransformer(n_classes=1000)
img = torch.randn(2, 3, 224, 224)
out = vit(img)
print(f"ViT output: {out.shape}") # [2, 1000]
print(f"Parameters: {sum(p.numel() for p in vit.parameters()):,}") # ~86MTip
Tip
Practice Vision Transformers ViT Transformers for Images in small, isolated examples before integrating into larger projects. Breaking concepts into small experiments builds genuine understanding faster than reading alone.
QKV attention. LoRA fine-tuning.
Practice Task
Note
Practice Task — (1) Write a working example of Vision Transformers ViT Transformers for Images from scratch without looking at notes. (2) Modify it to handle an edge case (empty input, null value, or error state). (3) Share your solution in the Priygop community for feedback.
Quick Quiz
Common Mistake
Warning
A common mistake with Vision Transformers ViT Transformers for Images is skipping edge case testing — empty inputs, null values, and unexpected data types. Always validate boundary conditions to write robust, production-ready ai code.