Skip to content

DIY FlashAttention Cheatsheet โ€‹

Quick reference for common APIs, commands, and configurations.

Quick Start โ€‹

bash
# Installation
pip install -e ".[dev]"

# Run demo
make demo

# Run tests
make test

Core APIs โ€‹

Matrix Multiplication โ€‹

python
from kernels import triton_matmul

# Basic usage (auto-selects optimal config)
c = triton_matmul(a, b)

# Specify block size
c = triton_matmul(a, b, block_m=128, block_n=256, block_k=64)

# Supported dtypes
a = torch.randn(..., dtype=torch.float16)   # โœ… Recommended
a = torch.randn(..., dtype=torch.bfloat16)  # โœ… Supported
a = torch.randn(..., dtype=torch.float32)   # โš ๏ธ Converts to float16 internally

FlashAttention โ€‹

python
from kernels import flash_attention

# Basic usage
out = flash_attention(q, k, v)

# Causal attention (for autoregressive models)
out = flash_attention(q, k, v, causal=True)

# Variable-length sequences
seq_lens = torch.tensor([256, 512], device="cuda", dtype=torch.int32)
out = flash_attention(q, k, v, seq_lens=seq_lens)

# 3D input: (batch*heads, seq_len, head_dim)
q_3d = torch.randn(16, 512, 64, device="cuda", dtype=torch.float16)
out = flash_attention(q_3d, k_3d, v_3d)

GPU Detection โ€‹

python
from utils import detect_gpu, print_gpu_info

caps = detect_gpu()
print_gpu_info(caps)

# Check features
print(f"TMA: {caps.has_tma}")
print(f"FP8: {caps.has_fp8}")

Benchmark โ€‹

python
from utils import BenchmarkRunner

runner = BenchmarkRunner(warmup=10, rep=50)

# Matrix multiplication
results = runner.benchmark_matmul(
    triton_matmul,
    sizes=[(1024, 1024, 1024), (2048, 2048, 2048)],
)
runner.print_comparison_table(results)

# FlashAttention
results = runner.benchmark_attention(
    flash_attention,
    seq_lengths=[512, 1024, 2048],
)

Validation โ€‹

python
from utils import validate_matmul, validate_attention

# Validate matrix multiplication
is_valid, max_diff = validate_matmul(
    triton_matmul, m=1024, n=1024, k=1024
)

# Validate FlashAttention
is_valid, max_diff = validate_attention(
    flash_attention, batch=2, heads=8, seq_len=512, head_dim=64
)

Useful Commands โ€‹

CommandDescription
make demoQuick demo
make testRun all tests
make bench-allRun all benchmarks
make gpu-infoShow GPU info
make experimentBlock size experiment
make lintCode linting
make formatCode formatting
make cleanClean caches

Input Shapes โ€‹

Matrix Multiplication โ€‹

A: (M, K) ร— B: (K, N) โ†’ C: (M, N)

FlashAttention โ€‹

4D Input: (batch, heads, seq_len, head_dim)
3D Input: (batch*heads, seq_len, head_dim)

seq_lens: (batch,) - Per-sample effective length
head_dim: Support for 32 or 64

Block Size Recommendations โ€‹

Matrix SizeBLOCK_MBLOCK_NBLOCK_K
< 512323232
512-1024646432
1024-20486412832
2048-409612812832
> 409612825664

Tip: Use autotune (don't specify block size) for automatic optimal selection.


Data Types โ€‹

TypePrecisionPerformanceRecommended For
float16Medium2xTraining/Inference (Recommended)
bfloat16Medium2xTraining (more stable)
float32High1xDebugging/High precision
float8Low4xInference (Hopper+)

GPU Architectures โ€‹

ArchitectureSMGPUsFeatures
Ampere80A100, RTX 30xxFull support
Ada89RTX 40xxFull support
Hopper90H100TMA, FP8
Blackwell100B100Latest

Memory Complexity โ€‹

MethodMemoryDescription
Standard AttentionO(Nยฒ)Stores full attention matrix
FlashAttentionO(N)Tiled computation, no full matrix

Memory Savings: Up to 99% for long sequences!


Common Errors โ€‹

ErrorCauseSolution
Expected 2D tensorsNon-2D matmul inputUse .view() or .reshape()
Incompatible dimensionsA.shape[1] != B.shape[0]Check matrix dimensions
CUDA tensors requiredInput on CPUUse .cuda() or .to("cuda")
Q, K, V shapes must matchShape mismatchEnsure identical shapes
Expected 3D or 4D tensorsWrong attention dimsCheck input shapes
Unsupported head_dimhead_dim not 32/64Use 32 or 64
Unsupported dtypeWrong dtypeUse float16/bfloat16/float32
dtypes must matchDtype mismatchUse consistent dtype

Performance Checklist โ€‹

โ˜‘ Use FP16 or BF16
โ˜‘ Use autotune (don't specify block size)
โ˜‘ Warm up kernel (run a few times before timing)
โ˜‘ Ensure input is contiguous (.is_contiguous())
โ˜‘ Keep data on GPU
โ˜‘ Matrix size > 512
โ˜‘ Avoid synchronization in loops
โ˜‘ Avoid CPU-GPU data movement in loops

File Structure โ€‹

kernels/
โ”œโ”€โ”€ matmul.py          # Matrix multiplication
โ”œโ”€โ”€ flash_attn.py      # FlashAttention
โ””โ”€โ”€ modern_features.py # Modern GPU features

utils/
โ”œโ”€โ”€ benchmark.py       # Benchmark tools
โ”œโ”€โ”€ validation.py      # Validation tools
โ””โ”€โ”€ gpu_detect.py      # GPU detection

tests/
โ”œโ”€โ”€ test_matmul.py     # Matrix multiplication tests
โ”œโ”€โ”€ test_flash.py      # FlashAttention tests
โ”œโ”€โ”€ test_properties.py # Property-based tests
โ””โ”€โ”€ test_error_handling.py # Error handling tests

Forward-only educational Triton FlashAttention project ยท MIT License