DIY FlashAttention Tutorial โ
A comprehensive guide to understanding and implementing FlashAttention from scratch. Whether you're new to GPU programming or an experienced developer, you'll find value here.
Learning Path โ
Basics โโโ Advanced โโโ Hands-on
โ โ โ
โผ โผ โผ
GPU Basics FlashAttention Performance
Triton Implementation BenchmarkingPart 1: GPU Programming Fundamentals โ
1.1 Why GPU Acceleration? โ
In the era of Large Language Models (LLMs), attention is one of the core computations:
# Standard Attention computation
Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d)) @ VFor sequence length N=8192:
- Attention matrix size: 8192 ร 8192 ร 2 bytes = 128 MB
- Required for backpropagation โ Memory explosion!
GPU's parallel computing power is key to solving this problem.
1.2 GPU Memory Hierarchy โ
Understanding GPU memory hierarchy is fundamental for optimization:
Ampere (A100, RTX 30xx)
SM80Interactive: Click on different GPU architectures above to see their memory hierarchy and feature support.
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ HBM (High Bandwidth Memory) โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ Capacity: 40-80 GB (A100/H100) โ โ
โ โ Bandwidth: 1.5-3.35 TB/s โ โ
โ โ Latency: ~500 cycles (slow!) โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ L2 Cache (Shared) โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ Capacity: 40-60 MB โ โ
โ โ Bandwidth: ~4 TB/s โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ SRAM (Shared Memory, per SM) โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ Capacity: 164-228 KB per SM โ โ
โ โ Bandwidth: ~19 TB/s (fastest!) โ โ
โ โ โก Key optimization target for FlashAttention โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ Registers โ
โ โ Capacity: ~256 KB per SM โ โ
โ โ Latency: 1 cycle โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโKey Insight: HBM has large capacity but is slow; SRAM is small but fast. FlashAttention's core idea: keep data in SRAM as much as possible.
Part 2: Getting Started with Triton โ
2.1 Why Triton? โ
| Feature | CUDA C++ | Triton |
|---|---|---|
| Memory tiling | Manual | Automatic |
| Coalesced access | Requires careful design | Auto-optimized |
| Shared memory | Manual allocation | Auto-managed |
| Synchronization | Manual __syncthreads() | Auto-handled |
| Learning curve | Steep | Gentle |
Conclusion: Triton lets you focus on the algorithm, not low-level optimizations.
2.2 Your First Triton Kernel โ
A simple vector addition example:
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
"""
Vector addition: output = x + y
Key concepts:
1. tl.program_id(0): Get current block ID
2. tl.arange(): Create index sequence
3. mask: Handle boundary conditions
4. tl.load/tl.store: Memory read/write
"""
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Vector addition wrapper"""
output = torch.empty_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
add_kernel[grid](
x, y, output, n_elements,
BLOCK_SIZE=1024,
)
return outputPart 3: FlashAttention Principles โ
3.1 The Problem with Standard Attention โ
def standard_attention(Q, K, V):
# Q: (batch, heads, seq_len, head_dim)
# Step 1: Compute attention scores
S = Q @ K.transpose(-2, -1) / sqrt(d) # O(Nยฒ) memory
# Step 2: Softmax
P = softmax(S, dim=-1) # O(Nยฒ)
# Step 3: Weighted sum
O = P @ V # O(Nยฒ)
return OMemory complexity: O(Nยฒ ร batch ร heads ร head_dim)
For LLM training, this is unacceptable!
3.2 FlashAttention's Core Innovation โ
Core idea: Don't store the full attention matrix; compute tiles with online softmax.
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ FlashAttention vs Standard โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ Standard Attention: FlashAttention: โ
โ โโโโโโโโโโโโโโโ โโโโโฌโโโโฌโโโโฌโโโโ โ
โ โ โ โ Qโโ Qโโ Qโโ Qโโ โ
โ โ N ร N โ โโโโโผโโโโผโโโโผโโโโค โ
โ โ Attention โ โโโ โ โ โ โ โ Tiled โ
โ โ Matrix โ โ K โ V โ โ โ Compute โ
โ โ Stored โ โ โ โ โ โ โ
โ โ in HBM โ โโโโโดโโโโดโโโโดโโโโ โ
โ โโโโโโโโโโโโโโโ โ โ
โ O(Nยฒ) memory O(N) memory โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ3.3 Online Softmax Algorithm โ
Standard softmax requires two passes:
- Find max value (numerical stability)
- Compute exp and normalize
Online Softmax does it in one pass:
def online_softmax(Q, K, V):
"""
Online Softmax Algorithm
Key insight: Maintain running max and running sum
Can update incrementally without storing full matrix
"""
m = -inf # running max
l = 0 # running sum of exp
O = 0 # running output
for K_j, V_j in blocks(K, V):
# 1. Compute current block's attention scores
S_j = Q @ K_j.T / sqrt(d)
# 2. Update running max
m_new = max(m, max(S_j, axis=1))
# 3. Update running sum (correct previous values)
l_new = exp(m - m_new) * l + sum(exp(S_j - m_new[:, None]), axis=1)
# 4. Update output (also requires correction)
O_new = (exp(m - m_new)[:, None] * O * l[:, None] +
exp(S_j - m_new[:, None]) @ V_j) / l_new[:, None]
m, l, O = m_new, l_new, O_new
return O3.4 Memory Complexity Comparison โ
| Method | Memory | N=1024 | N=4096 | N=8192 |
|---|---|---|---|---|
| Standard | O(Nยฒ) | 8 MB | 128 MB | 512 MB |
| FlashAttention | O(N) | 0.5 MB | 2 MB | 4 MB |
Up to 99% memory savings!
3.5 Causal Masking โ
For autoregressive models (like GPT), position i can only see positions โค i:
Causal Mask Example (seq_len = 4):
j=0 j=1 j=2 j=3
i=0 [ โ โ โ โ ]
i=1 [ โ โ โ โ ]
i=2 [ โ โ โ โ ]
i=3 [ โ โ โ โ ]
โ = Visible (attention score kept)
โ = Invisible (attention score = -inf)# Triton implementation
if IS_CAUSAL:
causal_mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = tl.where(causal_mask, qk, float("-inf"))Part 4: Performance Optimization โ
4.1 Block Size Tuning Guide โ
Block Size is the most critical parameter for Triton kernel performance.
Trade-offs:
| Block Size | Pros | Cons | Use Case |
|---|---|---|---|
| Small (32ร32) | More parallel blocks | More HBM access | Small matrices |
| Medium (128ร128) | Balanced | Balanced | General purpose |
| Large (256ร256) | Better data reuse | May exceed SRAM | Large matrices |
Recommended configuration:
| Matrix Size | BLOCK_M | BLOCK_N | BLOCK_K |
|---|---|---|---|
| < 512 | 32 | 32 | 32 |
| 512-2048 | 64 | 128 | 32 |
| 2048-4096 | 128 | 128 | 64 |
| > 4096 | 128 | 256 | 64 |
4.2 Data Type Selection โ
| Type | Range | Precision | Performance | Recommended |
|---|---|---|---|---|
| FP32 | ยฑ3.4e38 | High | 1x | High precision needs |
| FP16 | ยฑ65504 | Medium | 2x | Training/Inference |
| BF16 | ยฑ3.4e38 | Medium | 2x | Training (more stable) |
| FP8 | ยฑ448 | Low | 4x | Inference (Hopper+) |
# Recommended: FP16
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
# Training: BF16 (avoids overflow)
a = torch.randn(1024, 1024, device="cuda", dtype=torch.bfloat16)Practice Exercises โ
Exercise 1: Run Quick Demo โ
make demoExercise 2: Block Size Experiment โ
python examples/block_size_experiment.pyObserve how different Block Sizes affect performance.
Exercise 3: Memory Comparison โ
python benchmarks/bench_flash.py --memory-testVerify FlashAttention's O(N) memory complexity.
Exercise 4: Run Benchmarks โ
make bench-all
make reportNext Steps โ
Read the Papers
Experiment with Source Code
- Try different Block Sizes
- Add new autotune configs
- Compare benchmark results against your own GPU
Explore Advanced Topics Outside This Repo's Current Scope
- FlashAttention Backward Pass
- TMA (Tensor Memory Accelerator)
- FP8 computation