Skip to content

DIY FlashAttention Tutorial โ€‹

A comprehensive guide to understanding and implementing FlashAttention from scratch. Whether you're new to GPU programming or an experienced developer, you'll find value here.

Learning Path โ€‹

Basics โ”€โ”€โ†’ Advanced โ”€โ”€โ†’ Hands-on
    โ”‚          โ”‚           โ”‚
    โ–ผ          โ–ผ           โ–ผ
GPU Basics   FlashAttention    Performance
Triton       Implementation    Benchmarking

Part 1: GPU Programming Fundamentals โ€‹

1.1 Why GPU Acceleration? โ€‹

In the era of Large Language Models (LLMs), attention is one of the core computations:

python
# Standard Attention computation
Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d)) @ V

For sequence length N=8192:

  • Attention matrix size: 8192 ร— 8192 ร— 2 bytes = 128 MB
  • Required for backpropagation โ†’ Memory explosion!

GPU's parallel computing power is key to solving this problem.

1.2 GPU Memory Hierarchy โ€‹

Understanding GPU memory hierarchy is fundamental for optimization:

Ampere (A100, RTX 30xx)

SM80
HBM2e/GDDR6X40-80 GB2 TB/s
L2 Cache40 MB~4 TB/s
Shared Memory164 KB/SM~19 TB/s
โœ“ 3rd Gen Tensor Cores
โœ“ MIG
โœ“ NVLink 3.0
โœ“ FP16
โœ“ BF16
โœ“ TF32
ร— FP8
ร— TMA

Interactive: Click on different GPU architectures above to see their memory hierarchy and feature support.

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                   HBM (High Bandwidth Memory)               โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”   โ”‚
โ”‚  โ”‚  Capacity: 40-80 GB (A100/H100)                     โ”‚   โ”‚
โ”‚  โ”‚  Bandwidth: 1.5-3.35 TB/s                          โ”‚   โ”‚
โ”‚  โ”‚  Latency: ~500 cycles (slow!)                      โ”‚   โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜   โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                      L2 Cache (Shared)                      โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”   โ”‚
โ”‚  โ”‚  Capacity: 40-60 MB                                 โ”‚   โ”‚
โ”‚  โ”‚  Bandwidth: ~4 TB/s                                โ”‚   โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜   โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚        SRAM (Shared Memory, per SM)                         โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”   โ”‚
โ”‚  โ”‚  Capacity: 164-228 KB per SM                        โ”‚   โ”‚
โ”‚  โ”‚  Bandwidth: ~19 TB/s (fastest!)                    โ”‚   โ”‚
โ”‚  โ”‚  โšก Key optimization target for FlashAttention     โ”‚   โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜   โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                      Registers                              โ”‚
โ”‚  โ”‚  Capacity: ~256 KB per SM                         โ”‚   โ”‚
โ”‚  โ”‚  Latency: 1 cycle                                 โ”‚   โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

Key Insight: HBM has large capacity but is slow; SRAM is small but fast. FlashAttention's core idea: keep data in SRAM as much as possible.


Part 2: Getting Started with Triton โ€‹

2.1 Why Triton? โ€‹

FeatureCUDA C++Triton
Memory tilingManualAutomatic
Coalesced accessRequires careful designAuto-optimized
Shared memoryManual allocationAuto-managed
SynchronizationManual __syncthreads()Auto-handled
Learning curveSteepGentle

Conclusion: Triton lets you focus on the algorithm, not low-level optimizations.

2.2 Your First Triton Kernel โ€‹

A simple vector addition example:

python
import torch
import triton
import triton.language as tl

@triton.jit
def add_kernel(
    x_ptr,
    y_ptr,
    output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    """
    Vector addition: output = x + y

    Key concepts:
    1. tl.program_id(0): Get current block ID
    2. tl.arange(): Create index sequence
    3. mask: Handle boundary conditions
    4. tl.load/tl.store: Memory read/write
    """
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
    y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)


def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Vector addition wrapper"""
    output = torch.empty_like(x)
    n_elements = output.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

    add_kernel[grid](
        x, y, output, n_elements,
        BLOCK_SIZE=1024,
    )
    return output

Part 3: FlashAttention Principles โ€‹

3.1 The Problem with Standard Attention โ€‹

python
def standard_attention(Q, K, V):
    # Q: (batch, heads, seq_len, head_dim)

    # Step 1: Compute attention scores
    S = Q @ K.transpose(-2, -1) / sqrt(d)  # O(Nยฒ) memory

    # Step 2: Softmax
    P = softmax(S, dim=-1)  # O(Nยฒ)

    # Step 3: Weighted sum
    O = P @ V  # O(Nยฒ)

    return O

Memory complexity: O(Nยฒ ร— batch ร— heads ร— head_dim)

For LLM training, this is unacceptable!

3.2 FlashAttention's Core Innovation โ€‹

Core idea: Don't store the full attention matrix; compute tiles with online softmax.

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚              FlashAttention vs Standard                     โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                             โ”‚
โ”‚  Standard Attention:          FlashAttention:              โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”              โ”Œโ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”            โ”‚
โ”‚  โ”‚             โ”‚              โ”‚ Qโ‚โ”‚ Qโ‚‚โ”‚ Qโ‚ƒโ”‚ Qโ‚„โ”‚            โ”‚
โ”‚  โ”‚   N ร— N     โ”‚              โ”œโ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”ค            โ”‚
โ”‚  โ”‚  Attention  โ”‚    โ”€โ”€โ†’       โ”‚   โ”‚   โ”‚   โ”‚   โ”‚ Tiled      โ”‚
โ”‚  โ”‚   Matrix    โ”‚              โ”‚ K โ”‚ V โ”‚   โ”‚   โ”‚ Compute    โ”‚
โ”‚  โ”‚   Stored    โ”‚              โ”‚   โ”‚   โ”‚   โ”‚   โ”‚            โ”‚
โ”‚  โ”‚    in HBM   โ”‚              โ””โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”˜            โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                    โ†“                       โ”‚
โ”‚   O(Nยฒ) memory                O(N) memory                   โ”‚
โ”‚                                                             โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

3.3 Online Softmax Algorithm โ€‹

Standard softmax requires two passes:

  1. Find max value (numerical stability)
  2. Compute exp and normalize

Online Softmax does it in one pass:

python
def online_softmax(Q, K, V):
    """
    Online Softmax Algorithm

    Key insight: Maintain running max and running sum
    Can update incrementally without storing full matrix
    """
    m = -inf      # running max
    l = 0         # running sum of exp
    O = 0         # running output

    for K_j, V_j in blocks(K, V):
        # 1. Compute current block's attention scores
        S_j = Q @ K_j.T / sqrt(d)

        # 2. Update running max
        m_new = max(m, max(S_j, axis=1))

        # 3. Update running sum (correct previous values)
        l_new = exp(m - m_new) * l + sum(exp(S_j - m_new[:, None]), axis=1)

        # 4. Update output (also requires correction)
        O_new = (exp(m - m_new)[:, None] * O * l[:, None] +
                 exp(S_j - m_new[:, None]) @ V_j) / l_new[:, None]

        m, l, O = m_new, l_new, O_new

    return O

3.4 Memory Complexity Comparison โ€‹

MethodMemoryN=1024N=4096N=8192
StandardO(Nยฒ)8 MB128 MB512 MB
FlashAttentionO(N)0.5 MB2 MB4 MB

Up to 99% memory savings!

3.5 Causal Masking โ€‹

For autoregressive models (like GPT), position i can only see positions โ‰ค i:

Causal Mask Example (seq_len = 4):

     j=0  j=1  j=2  j=3
i=0 [ โœ“   โœ—    โœ—    โœ—  ]
i=1 [ โœ“   โœ“    โœ—    โœ—  ]
i=2 [ โœ“   โœ“    โœ“    โœ—  ]
i=3 [ โœ“   โœ“    โœ“    โœ“  ]

โœ“ = Visible (attention score kept)
โœ— = Invisible (attention score = -inf)
python
# Triton implementation
if IS_CAUSAL:
    causal_mask = offs_m[:, None] >= (start_n + offs_n[None, :])
    qk = tl.where(causal_mask, qk, float("-inf"))

Part 4: Performance Optimization โ€‹

4.1 Block Size Tuning Guide โ€‹

Block Size is the most critical parameter for Triton kernel performance.

Trade-offs:

Block SizeProsConsUse Case
Small (32ร—32)More parallel blocksMore HBM accessSmall matrices
Medium (128ร—128)BalancedBalancedGeneral purpose
Large (256ร—256)Better data reuseMay exceed SRAMLarge matrices

Recommended configuration:

Matrix SizeBLOCK_MBLOCK_NBLOCK_K
< 512323232
512-20486412832
2048-409612812864
> 409612825664

4.2 Data Type Selection โ€‹

TypeRangePrecisionPerformanceRecommended
FP32ยฑ3.4e38High1xHigh precision needs
FP16ยฑ65504Medium2xTraining/Inference
BF16ยฑ3.4e38Medium2xTraining (more stable)
FP8ยฑ448Low4xInference (Hopper+)
python
# Recommended: FP16
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)

# Training: BF16 (avoids overflow)
a = torch.randn(1024, 1024, device="cuda", dtype=torch.bfloat16)

Practice Exercises โ€‹

Exercise 1: Run Quick Demo โ€‹

bash
make demo

Exercise 2: Block Size Experiment โ€‹

bash
python examples/block_size_experiment.py

Observe how different Block Sizes affect performance.

Exercise 3: Memory Comparison โ€‹

bash
python benchmarks/bench_flash.py --memory-test

Verify FlashAttention's O(N) memory complexity.

Exercise 4: Run Benchmarks โ€‹

bash
make bench-all
make report

Next Steps โ€‹

  1. Read the Papers

  2. Experiment with Source Code

    • Try different Block Sizes
    • Add new autotune configs
    • Compare benchmark results against your own GPU
  3. Explore Advanced Topics Outside This Repo's Current Scope

    • FlashAttention Backward Pass
    • TMA (Tensor Memory Accelerator)
    • FP8 computation

References โ€‹

Forward-only educational Triton FlashAttention project ยท MIT License