Attention Mechanism — The Precursor to Transformers
Attention was originally introduced to help LSTM-based Seq2Seq models look back at the encoder's outputs when decoding. It allows the decoder to focus on relevant input positions for each output word — like how a human translator focuses on corresponding source words when translating. This mechanism became the foundation of Transformers.
Attention from Scratch
import torch
import torch.nn as nn
import torch.nn.functional as F
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# BAHDANAU ATTENTION (original, 2014)
# Used in: RNN encoder-decoder with alignment
# The foundation idea that led to Transformers
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
class BahdanauAttention(nn.Module):
"""
Learns to align decoder hidden state with encoder outputs.
Scores each encoder position → softmax → weighted sum of encoder outputs.
"""
def __init__(self, encoder_dim: int, decoder_dim: int, attn_dim: int = 256):
super().__init__()
self.W_encoder = nn.Linear(encoder_dim, attn_dim) # project encoder states
self.W_decoder = nn.Linear(decoder_dim, attn_dim) # project decoder state
self.v = nn.Linear(attn_dim, 1, bias=False) # score per position
def forward(self, encoder_outputs: torch.Tensor, decoder_hidden: torch.Tensor) -> tuple:
# encoder_outputs: [B, src_len, enc_dim]
# decoder_hidden: [B, dec_dim]
energy = torch.tanh(
self.W_encoder(encoder_outputs) + # [B, src_len, attn_dim]
self.W_decoder(decoder_hidden).unsqueeze(1) # [B, 1, attn_dim] → broadcast
) # [B, src_len, attn_dim]
scores = self.v(energy).squeeze(-1) # [B, src_len] — one score per position
weights = F.softmax(scores, dim=-1) # [B, src_len] — sum to 1
# Weighted sum: what does the decoder attend to?
context = (weights.unsqueeze(-1) * encoder_outputs).sum(dim=1) # [B, enc_dim]
return context, weights
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# SCALED DOT-PRODUCT ATTENTION (Transformer, 2017)
# The evolution: replaces learned alignment with QKV dot products
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor = None) -> tuple:
"""
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
Q (Query): what we're looking for
K (Key): what each position offers to search against
V (Value): what we actually retrieve for matching positions
"""
d_k = q.shape[-1]
# Q·K^T / sqrt(d_k) — division prevents softmax saturation for large d_k
scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5) # [B, H, T_q, T_k]
if mask is not None:
# Causal mask: future positions = -inf → softmax → 0 attention weight
scores = scores.masked_fill(mask == 0, float("-inf"))
weights = F.softmax(scores, dim=-1) # [B, H, T_q, T_k] — attention distribution
output = torch.matmul(weights, v) # [B, H, T_q, d_v]
return output, weights
# Demonstration:
B, H, T, d_k = 2, 8, 10, 64 # batch, heads, seq_len, key_dim
q = k = v = torch.randn(B, H, T, d_k)
output, attn_weights = scaled_dot_product_attention(q, k, v)
print(f"Attention output: {output.shape}") # [2, 8, 10, 64]
print(f"Attention weights: {attn_weights.shape}") # [2, 8, 10, 10]
# Attention weights: attn_weights[b, h, i, j] =
# "How much does position i in the output attend to position j in the input?"
# Sum over j for any i = 1.0 (probability distribution)Tip
Tip
Practice Attention Mechanism The Precursor to Transformers in small, isolated examples before integrating into larger projects. Breaking concepts into small experiments builds genuine understanding faster than reading alone.
Technical diagram.
Practice Task
Note
Practice Task — (1) Write a working example of Attention Mechanism The Precursor to Transformers from scratch without looking at notes. (2) Modify it to handle an edge case (empty input, null value, or error state). (3) Share your solution in the Priygop community for feedback.
Quick Quiz
Common Mistake
Warning
A common mistake with Attention Mechanism The Precursor to Transformers is skipping edge case testing — empty inputs, null values, and unexpected data types. Always validate boundary conditions to write robust, production-ready ai code.