Skip to content
v0.1.0

FlashAttention Algorithm Deep Dive

FlashAttention is an IO-aware algorithm for computing exact attention with reduced memory complexity from O(N²) to O(N), while achieving significant speedup in practice.


Table of Contents


Standard Attention Bottleneck

Standard self-attention computation:

S = Q × K^T           # [N, N] — Attention score matrix
P = softmax(S)        # [N, N] — Attention weight matrix
O = P × V             # [N, d] — Output

Core Problem: Intermediate matrices S and P have O(N²) size, must be stored in HBM (device memory). For large sequence lengths N:

IssueImpact
Memory UsageN=4096, 32 heads → ~2 GB just for attention matrices
Bandwidth BottleneckGPU computation is much faster than HBM bandwidth; time dominated by data movement
IO OperationsS and P each require write-to and read-from HBM: 4 O(N²) operations total

Core FlashAttention Concepts

1. Tiling

Divide Q, K, V into blocks that fit in SRAM (shared memory):

Q = [Q_1, Q_2, ..., Q_Tr]    Each block [B_r, d]
K = [K_1, K_2, ..., K_Tc]    Each block [B_c, d]
V = [V_1, V_2, ..., V_Tc]    Each block [B_c, d]

Block Size Selection:

GPU ArchitectureSRAM SizeTypical B_r × B_c
Volta (V100)96 KB64 × 64
Ampere (A100)164 KB128 × 64
Hopper (H100)228 KB128 × 128

Why Tiling Works:

  • Each block fits in fast SRAM (L1/shared memory)
  • Avoids repeated HBM accesses for intermediate results
  • Enables parallel processing of independent blocks

2. Online Softmax

Standard softmax requires two passes (find max → compute exp sum → normalize). FlashAttention uses online softmax to update incrementally in a single pass:

python
for each KV block j:
    S_ij = Q_i × K_j^T                 # Local attention scores
    m_new = max(m_old, rowmax(S_ij))   # Update global maximum
    P = exp(S_ij - m_new)              # Local softmax numerator
    l_new = exp(m_old - m_new) × l_old + rowsum(P)  # Update normalizer
    O_i = (exp(m_old - m_new) × O_i + P × V_j) / l_new  # Update output

Key Insight: When processing a new KV block, previous outputs must be corrected by exp(m_old - m_new) because the global maximum may have changed.

Numerical Stability: Tracking running maximum ensures no exp() overflow even for large attention scores.

3. Recomputation

Standard backward pass stores O(N²) attention matrix P for gradient computation. FlashAttention's strategy:

PhaseStorageMemory
ForwardOutput O and logsumexp L onlyO(N)
BackwardRecompute attention weights from Q, K, V, O, LO(N)

Trade-off: Increases computation (~33% more FLOPs) but significantly reduces HBM IO, resulting in overall speedup.


Forward Pass Algorithm

Input: Q, K, V ∈ R^(N×d), scale
Output: O ∈ R^(N×d), L ∈ R^N

Initialize: O = 0, m = -∞, l = 0

For each Q block i (parallel):
    Load Q_i to SRAM
    For each KV block j:
        Load K_j, V_j to SRAM
        S_ij = scale × Q_i × K_j^T           # Compute in SRAM
        m_new = max(m_i, rowmax(S_ij))
        P = exp(S_ij - m_new)
        l_new = exp(m_i - m_new) × l_i + rowsum(P)
        O_i = (l_i × exp(m_i - m_new) × O_i + P × V_j) / l_new
        m_i = m_new, l_i = l_new
    L_i = m_i + log(l_i)                      # Store logsumexp

Key Operations:

  1. Parallel over Q blocks: Each output block computed independently
  2. Sequential over KV blocks: Accumulate attention across all keys
  3. Output correction: Adjust running sum when new maximum found

Backward Pass Algorithm

Input: Q, K, V, O, L, dO
Output: dQ, dK, dV

For each KV block j:
    Load K_j, V_j to SRAM
    Initialize dK_j = 0, dV_j = 0
    For each Q block i:
        Load Q_i, O_i, dO_i, L_i to SRAM
        S_ij = scale × Q_i × K_j^T
        P_ij = exp(S_ij - L_i)               # Recompute attention weights
        D_i = rowsum(dO_i ⊙ O_i)             # Diagonal term
        dV_j += P_ij^T × dO_i                # V gradient
        dP_ij = dO_i × V_j^T
        dS_ij = P_ij ⊙ (dP_ij - D_i)         # Softmax gradient
        dQ_i += scale × dS_ij × K_j          # Q gradient
        dK_j += scale × dS_ij^T × Q_i        # K gradient

Gradient Flow:

  1. dV: Weighted sum of gradients using attention weights
  2. dQ, dK: Through softmax Jacobian using recomputed P
  3. Memory efficient: No O(N²) storage needed

Causal Masking

For autoregressive models (like GPT), position i can only attend to positions ≤ i. FlashAttention's block structure enables efficient causal masking:

CaseHandling
Full skipKV block start column > Q block end row → skip entire block
Partial maskApply mask within block (set to -∞)

Efficiency Gain: Approximately 50% of blocks can be skipped entirely, reducing computation by half.

Implementation:

for Q block i:
    for KV block j:
        if block_start_j > block_end_i:
            continue  # Entire block masked, skip
        elif block needs partial masking:
            apply mask during softmax computation

FP16 Implementation

This implementation fully supports FP16 (half precision) for both forward and backward passes.

Implementation Strategy

FP16 inputs are converted to FP32 internally for computation, then converted back to FP16 for output:

Input: half* Q, K, V
Internal: float (FP32) computation
Output: half* O, L

Numerical Precision

OperationPrecision
Matrix multiplication (Q × K^T)FP32
Softmax computationFP32
AccumulationFP32
Final outputFP16

Benefits:

  • Numerical stability comparable to FP32
  • Reduced memory bandwidth (2× smaller tensors)
  • Supported on all modern GPUs (compute capability ≥ 5.3)

Memory Complexity Analysis

MethodForward MemoryBackward MemoryHBM IO
Standard AttentionO(N²)O(N²)O(N² + Nd)
FlashAttentionO(N)O(N)O(N²d / M)

Where M is SRAM size. When M = Θ(Nd), IO complexity approaches O(Nd), which is optimal.

Real Memory Savings

Sequence LengthStandard AttentionFlashAttentionSavings
1,0244 MB8 KB99.8%
4,09664 MB32 KB99.95%
16,3841 GB128 KB99.99%

Implementation Highlights

Block Configuration

head_dimBLOCK_MBLOCK_NNotes
326464Standard configuration
646464Standard configuration
1283232Reduced for larger shared memory needs

Optimization Techniques

TechniqueBenefit
Vectorized Memory Accessfloat4 loads/stores for better bandwidth
Launch Bounds__launch_bounds__(128) controls register pressure
Dynamic Shared MemoryRuntime allocation based on head_dim
Stream SafetyExplicit workspace lifetime management
Warp-level Primitives__shfl_sync for reduction operations

Data Type Support

Data TypeForwardBackward
FP32 (float)
FP16 (half)

References

  1. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

    • Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
    • NeurIPS 2022
    • arXiv:2205.14135
  2. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

  3. Online normalizer calculation for softmax

  4. NVIDIA CUDA Programming Guide - Shared Memory

Stable v0.3.0 baseline • OpenSpec-driven CUDA FlashAttention reference.

Contributors