Self-Attention — QKV Computation from Scratch
Self-attention allows every position in a sequence to directly attend to every other position in ONE step — no sequential computation. Queries (what I'm looking for), Keys (what I offer), and Values (what I return) are all learned linear projections of the same input. This is the breakthrough that makes Transformers parallelizable and capable of capturing long-range dependencies.
Self-Attention — Complete Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# SELF-ATTENTION — step by step
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
class SelfAttention(nn.Module):
"""Single-head self-attention with optional causal mask."""
def __init__(self, embed_dim: int, dropout: float = 0.1):
super().__init__()
self.d_k = embed_dim
# Three learned linear projections — the only LEARNED parts of attention
self.W_q = nn.Linear(embed_dim, embed_dim, bias=False) # Query projection
self.W_k = nn.Linear(embed_dim, embed_dim, bias=False) # Key projection
self.W_v = nn.Linear(embed_dim, embed_dim, bias=False) # Value projection
self.W_o = nn.Linear(embed_dim, embed_dim) # Output projection
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, causal: bool = False) -> tuple:
# x: [B, T, D] — batch, sequence length, embedding dim
B, T, D = x.shape
Q = self.W_q(x) # [B, T, D] — Queries: "what am I looking for?"
K = self.W_k(x) # [B, T, D] — Keys: "what do I have to offer?"
V = self.W_v(x) # [B, T, D] — Values: "what I return if matched"
# Attention scores: QK^T / sqrt(d_k)
# Why sqrt(d_k)? For large d_k, dot products have large variance
# → softmax gets very sharp → vanishing gradients. Scaling prevents this.
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# scores: [B, T, T] — score[i,j] = how much token i attends to token j
if causal:
# CAUSAL MASK: token i cannot attend to future tokens j > i
# Creates a lower-triangular 1s mask
mask = torch.tril(torch.ones(T, T, device=x.device)) # [T, T]
scores = scores.masked_fill(mask == 0, float("-inf")) # future = -inf → 0 after softmax
weights = F.softmax(scores, dim=-1) # [B, T, T] — rows sum to 1
weights = self.dropout(weights) # attention dropout (regularization)
out = torch.matmul(weights, V) # [B, T, D] — weighted sum of Values
return self.W_o(out), weights # output + attention weights for visualization
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Intuition with a concrete example
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
attn = SelfAttention(embed_dim=64)
x = torch.randn(2, 10, 64) # batch=2, seq_len=10, dim=64
out, weights = attn(x)
print(f"Input: {x.shape}") # [2, 10, 64]
print(f"Output: {out.shape}") # [2, 10, 64]
print(f"Weights: {weights.shape}") # [2, 10, 10] — 10×10 attention matrix per sequence
# What the attention matrix means:
# weights[0, 5, 3] = 0.4 means:
# "In sample 0, token position 5 devotes 40% of its attention to token position 3"
# Row 5 sums to 1.0 across all columns
# CAUSAL SELF-ATTENTION (for GPT-style language models)
out_causal, weights_causal = attn(x, causal=True)
# weights_causal[0, 5, 6] = 0.0 — position 5 CANNOT see position 6 (future)
# This enables autoregressive generation: predict next token from past onlyTip
Tip
Practice SelfAttention QKV Computation from Scratch in small, isolated examples before integrating into larger projects. Breaking concepts into small experiments builds genuine understanding faster than reading alone.
Technical diagram.
Practice Task
Note
Practice Task — (1) Write a working example of SelfAttention QKV Computation from Scratch from scratch without looking at notes. (2) Modify it to handle an edge case (empty input, null value, or error state). (3) Share your solution in the Priygop community for feedback.
Quick Quiz
Common Mistake
Warning
A common mistake with SelfAttention QKV Computation from Scratch is skipping edge case testing — empty inputs, null values, and unexpected data types. Always validate boundary conditions to write robust, production-ready ai code.