FlashAttention — Efficient Attention for Long Contexts
Standard attention memory complexity is O(T²) — for a 32K token context, the attention matrix alone is 32K² = 1 billion floats (~4GB just for attention!). FlashAttention (Dao et al., 2022) computes attention with O(T) memory by fusing operations and using tiling, making long-context LLMs (128K+ tokens) practical.
FlashAttention in Practice
import torch
import torch.nn as nn
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Standard vs FlashAttention memory comparison
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def attention_memory_gb(seq_len: int, n_heads: int = 32,
dtype_bytes: int = 2) -> dict:
"""Estimate memory usage for attention matrix [B=1, H, T, T]."""
attn_matrix_elements = n_heads * seq_len * seq_len
attn_memory_gb = attn_matrix_elements * dtype_bytes / 1e9
return {
"seq_len": seq_len,
"attention_matrix_GB": round(attn_memory_gb, 2),
"flashattn_memory_GB": "O(T) — ~constant at ~0.9GB",
}
for context_len in [2048, 8192, 32768, 131072]:
stats = attention_memory_gb(context_len)
print(f" Context {stats['seq_len']:7d}: Standard={stats['attention_matrix_GB']:6.2f}GB | Flash≈{stats['flashattn_memory_GB']}")
# Context 2048: Standard= 0.54GB | Flash≈O(T) ~constant
# Context 8192: Standard= 8.59GB | Flash≈O(T) ~constant
# Context 32768: Standard=137.44GB | Flash≈O(T) ~constant ← impossible standard, trivial FA
# Context 131072: IMPOSSIBLE with standard attention
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Using FlashAttention in PyTorch
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# PyTorch 2.0+ includes native FlashAttention via scaled_dot_product_attention
# It automatically selects FlashAttention when inputs are on CUDA
B, H, T, D = 2, 8, 2048, 64
q = torch.randn(B, H, T, D, device="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.float16)
k = torch.randn(B, H, T, D, device=q.device, dtype=torch.float16)
v = torch.randn(B, H, T, D, device=q.device, dtype=torch.float16)
# This call uses FlashAttention underneath (PyTorch 2.0+, CUDA, fp16/bf16)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
print(f"FlashAttention output: {out.shape}") # [2, 8, 2048, 64]
# In HuggingFace Transformers:
# model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B",
# attn_implementation="flash_attention_2",
# torch_dtype=torch.bfloat16,
# )
# pip install flash-attn --no-build-isolation (for custom FlashAttention-2 kernels)
# 2-4x speedup + 10x memory reduction for long sequences
print("\nFlashAttention advantages:")
print(" Memory: O(T) vs O(T²) — can handle 128K+ context on A100")
print(" Speed: 2-4x faster than standard attention on GPU")
print(" Exact: mathematically equivalent to standard attention, not approximate")
print(" Used by: LLaMA 3, Mistral, all modern LLM training runs")Tip
Tip
Practice FlashAttention Efficient Attention for Long Contexts in small, isolated examples before integrating into larger projects. Breaking concepts into small experiments builds genuine understanding faster than reading alone.
Technical diagram.
Practice Task
Note
Practice Task — (1) Write a working example of FlashAttention Efficient Attention for Long Contexts from scratch without looking at notes. (2) Modify it to handle an edge case (empty input, null value, or error state). (3) Share your solution in the Priygop community for feedback.
Quick Quiz
Common Mistake
Warning
A common mistake with FlashAttention Efficient Attention for Long Contexts is skipping edge case testing — empty inputs, null values, and unexpected data types. Always validate boundary conditions to write robust, production-ready ai code.