Cross-Attention & Encoder-Decoder Communication
Cross-attention is the bridge in encoder-decoder models: the decoder's queries look into the encoder's keys and values to extract relevant source information for each decoding step. This is how a translation model maps source context to target words. It is also the fundamental mechanism behind attention-based image captioning and multimodal models.
Cross-Attention Explained
import torch
import torch.nn as nn
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# CROSS-ATTENTION: Q from decoder, K&V from encoder
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# In cross-attention:
# Q = decoder hidden states (what am I generating now?)
# K = encoder output states (index into source information)
# V = encoder output states (the actual source information to retrieve)
class CrossAttentionBlock(nn.Module):
def __init__(self, d_model: int = 512, n_heads: int = 8):
super().__init__()
self.norm_dec = nn.LayerNorm(d_model)
self.norm_enc = nn.LayerNorm(d_model)
self.cross_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.norm_out = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4), nn.GELU(),
nn.Linear(d_model * 4, d_model)
)
def forward(self, decoder_states: torch.Tensor,
encoder_output: torch.Tensor,
encoder_padding_mask: torch.Tensor = None) -> torch.Tensor:
"""
decoder_states: [B, T_dec, D] — what the decoder has generated so far
encoder_output: [B, T_enc, D] — encoded source sequence (fixed)
"""
# Normalize both before cross-attention (stable training)
dec_norm = self.norm_dec(decoder_states)
enc_norm = self.norm_enc(encoder_output)
# Q = decoder; K,V = encoder
cross_out, attn_weights = self.cross_attn(
query=dec_norm, # decoder asks the question
key=enc_norm, # encoder provides the index
value=enc_norm, # encoder provides the content
key_padding_mask=encoder_padding_mask # ignore encoder padding
)
decoder_states = decoder_states + cross_out # residual
# FFN
decoder_states = decoder_states + self.ffn(self.norm_out(decoder_states))
return decoder_states, attn_weights
# Example: English→French translation
B, T_enc, T_dec, D = 2, 30, 20, 512
encoder_output = torch.randn(B, T_enc, D) # encoded English tokens
decoder_states = torch.randn(B, T_dec, D) # French tokens decoded so far
cross_block = CrossAttentionBlock(d_model=D, n_heads=8)
output, attn_w = cross_block(decoder_states, encoder_output)
print(f"Cross-attention output: {output.shape}") # [2, 20, 512]
print(f"Attention weights: {attn_w.shape}") # [2, 20, 30]
# attn_w[0, 3, 7] = 0.45 means:
# "When generating French token position 3, the decoder attended
# 45% to English token position 7 in the source"
# This is essentially a SOFT ALIGNMENT between source and target!
# WHERE CROSS-ATTENTION IS USED:
applications = {
"Machine Translation": "Decoder attends to source language encoder states",
"Image Captioning": "Text decoder attends to CNN-encoded image features",
"Stable Diffusion": "UNet decoder attends to CLIP text encoder embeddings",
"ASR (Speech-to-Text)": "Text decoder attends to audio encoder states",
"Document Summarization": "Summary decoder attends to document encoder states",
}
for app, desc in applications.items():
print(f" {app:30s}: {desc}")Tip
Tip
Practice CrossAttention EncoderDecoder Communication in small, isolated examples before integrating into larger projects. Breaking concepts into small experiments builds genuine understanding faster than reading alone.
Technical diagram.
Practice Task
Note
Practice Task — (1) Write a working example of CrossAttention EncoderDecoder Communication from scratch without looking at notes. (2) Modify it to handle an edge case (empty input, null value, or error state). (3) Share your solution in the Priygop community for feedback.
Quick Quiz
Common Mistake
Warning
A common mistake with CrossAttention EncoderDecoder Communication is skipping edge case testing — empty inputs, null values, and unexpected data types. Always validate boundary conditions to write robust, production-ready ai code.