Full Transformer Block — Encoder and Decoder
The Transformer has two halves: an Encoder (reads and encodes input into contextual representations — bidirectional) and a Decoder (generates output one token at a time — causal). BERT uses encoder-only. GPT uses decoder-only. T5/BART use both (full seq2seq).
Complete Transformer Encoder-Decoder
import torch
import torch.nn as nn
class TransformerEncoderBlock(nn.Module):
"""
Single Transformer encoder block (as in BERT):
LayerNorm → MultiHeadAttention → Dropout → Residual
LayerNorm → FFN (Linear → GELU → Linear) → Dropout → Residual
"""
def __init__(self, d_model: int = 768, n_heads: int = 12,
ffn_dim: int = 3072, dropout: float = 0.1):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(d_model)
# FFN: expand 4x → GELU → compress back (standard ratio)
self.ffn = nn.Sequential(
nn.Linear(d_model, ffn_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(ffn_dim, d_model), nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor, src_key_padding_mask: torch.Tensor = None) -> torch.Tensor:
# Pre-norm architecture (more stable training than post-norm)
normed = self.norm1(x)
attn_out, _ = self.attn(normed, normed, normed, key_padding_mask=src_key_padding_mask)
x = x + attn_out # residual connection
x = x + self.ffn(self.norm2(x)) # FFN + residual
return x
class TransformerDecoderBlock(nn.Module):
"""
Single Transformer decoder block (as in GPT, but with cross-attention for encoder-decoder):
1. Causal self-attention over decoder input (can only see past)
2. Cross-attention over encoder outputs (reads the source)
3. FFN
"""
def __init__(self, d_model: int = 768, n_heads: int = 12,
ffn_dim: int = 3072, dropout: float = 0.1):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(d_model)
self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.norm3 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, ffn_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(ffn_dim, d_model), nn.Dropout(dropout),
)
def forward(self, tgt: torch.Tensor, memory: torch.Tensor,
tgt_mask: torch.Tensor = None) -> torch.Tensor:
# 1. Causal self-attention (decoder can only see its own past tokens)
normed = self.norm1(tgt)
self_attn_out, _ = self.self_attn(normed, normed, normed, attn_mask=tgt_mask)
tgt = tgt + self_attn_out
# 2. Cross-attention: Q from decoder, K&V from encoder
normed = self.norm2(tgt)
cross_out, _ = self.cross_attn(normed, memory, memory)
tgt = tgt + cross_out
# 3. FFN
tgt = tgt + self.ffn(self.norm3(tgt))
return tgt
# Full encoder + decoder Transformer (like original "Attention Is All You Need")
class Transformer(nn.Module):
def __init__(self, vocab_size=30000, d_model=512, n_heads=8, n_enc=6, n_dec=6, ffn_dim=2048):
super().__init__()
self.src_embed = nn.Embedding(vocab_size, d_model, padding_idx=0)
self.tgt_embed = nn.Embedding(vocab_size, d_model, padding_idx=0)
self.encoder = nn.ModuleList([TransformerEncoderBlock(d_model, n_heads, ffn_dim) for _ in range(n_enc)])
self.decoder = nn.ModuleList([TransformerDecoderBlock(d_model, n_heads, ffn_dim) for _ in range(n_dec)])
self.norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def encode(self, src: torch.Tensor) -> torch.Tensor:
x = self.src_embed(src)
for block in self.encoder:
x = block(x)
return self.norm(x)
def decode(self, tgt: torch.Tensor, memory: torch.Tensor) -> torch.Tensor:
T = tgt.size(1)
causal_mask = torch.triu(torch.ones(T, T, device=tgt.device), diagonal=1).bool()
x = self.tgt_embed(tgt)
for block in self.decoder:
x = block(x, memory, tgt_mask=causal_mask.float() * -1e9)
return self.lm_head(self.norm(x))
model = Transformer()
src = torch.randint(1, 30000, (2, 50)) # source sequence
tgt = torch.randint(1, 30000, (2, 30)) # target sequence (what decoder generates)
memory = model.encode(src)
logits = model.decode(tgt, memory)
print(f"Decoder output logits: {logits.shape}") # [2, 30, 30000]Tip
Tip
Practice Full Transformer Block Encoder and Decoder in small, isolated examples before integrating into larger projects. Breaking concepts into small experiments builds genuine understanding faster than reading alone.
Better prompts = better AI output. Structure, examples, and constraints matter.
Practice Task
Note
Practice Task — (1) Write a working example of Full Transformer Block Encoder and Decoder from scratch without looking at notes. (2) Modify it to handle an edge case (empty input, null value, or error state). (3) Share your solution in the Priygop community for feedback.
Quick Quiz
Common Mistake
Warning
A common mistake with Full Transformer Block Encoder and Decoder is skipping edge case testing — empty inputs, null values, and unexpected data types. Always validate boundary conditions to write robust, production-ready ai code.