Skip to content

Performance Guide โ€‹

This guide covers optimization techniques for DIY FlashAttention, including configuration tuning, best practices, and common pitfalls.

Table of Contents โ€‹


Performance Benchmarks โ€‹

Matrix Multiplication (MatMul) โ€‹

Typical performance (RTX 4090, FP16):

Matrix SizePyTorch (TFLOPS)Triton (TFLOPS)Speedup
512ร—51225281.12x
1024ร—102445481.07x
2048ร—204885951.12x
4096ร—40961201401.17x
8192ร—81921501751.17x

FlashAttention โ€‹

Typical performance (RTX 4090, FP16, batch=4, heads=8, head_dim=64):

Seq LengthPyTorch SDPA (ms)FlashAttention (ms)SpeedupMemory Saved
5120.80.71.14x94%
10242.52.01.25x97%
20489.06.51.38x98%
409635.022.01.59x99%

Memory Usage Comparison โ€‹

Seq LengthStandard AttentionFlashAttentionSavings
5122 MB0.25 MB88%
10248 MB0.5 MB94%
204832 MB1 MB97%
4096128 MB2 MB98%
8192512 MB4 MB99%

Block Size Tuning โ€‹

Block Size is the most critical parameter affecting Triton kernel performance.

Core Principles โ€‹

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    Block Size Trade-offs                    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                             โ”‚
โ”‚  Small Block Size:            Large Block Size:             โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”                โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                โ”‚
โ”‚  โ”‚   โ”‚   โ”‚   โ”‚                โ”‚           โ”‚                โ”‚
โ”‚  โ”œโ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”ค                โ”‚           โ”‚                โ”‚
โ”‚  โ”‚   โ”‚   โ”‚   โ”‚                โ”‚   Single  โ”‚                โ”‚
โ”‚  โ”œโ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”ค                โ”‚   Block   โ”‚                โ”‚
โ”‚  โ”‚   โ”‚   โ”‚   โ”‚                โ”‚           โ”‚                โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”˜                โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                โ”‚
โ”‚                                                             โ”‚
โ”‚  โœ… More parallel blocks     โœ… Better data reuse          โ”‚
โ”‚  โœ… Good for small matrices  โœ… Less HBM access            โ”‚
โ”‚  โŒ More HBM access          โŒ May exceed SRAM            โ”‚
โ”‚  โŒ Lower data reuse         โŒ Less parallelism           โ”‚
โ”‚                                                             โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
Matrix Size RangeBLOCK_MBLOCK_NBLOCK_Knum_stagesnum_warps
< 51232323244
512 - 102464643244
1024 - 2048641283244
2048 - 40961281283244
> 40961282566438

Using Autotune โ€‹

Recommended: Use built-in autotune

python
from kernels import triton_matmul

# Don't specify block size - automatic optimal selection
c = triton_matmul(a, b)

Experimental: Test different configurations manually

bash
python examples/block_size_experiment.py

Data Type Selection โ€‹

Type Comparison โ€‹

Data TypeRangeMantissaPerformanceRecommended For
FP32ยฑ3.4e3823 bits1x (baseline)High precision/debug
FP16ยฑ6550410 bits~2xTraining/Inference
BF16ยฑ3.4e387 bits~2xTraining (stable)
FP8 E4M3ยฑ4483 bits~4xInference (Hopper+)
FP8 E5M2ยฑ573442 bits~4xGradient storage

Selection Guide โ€‹

Use FP16 for:

  • Most training and inference scenarios
  • Standard LLM workloads

Use BF16 for:

  • Training with large models (avoids overflow)
  • When FP16 causes NaN issues

Use FP32 for:

  • Debugging
  • Numerical verification

Memory Optimization โ€‹

Ensure Contiguous Memory โ€‹

python
# โŒ Bad: Non-contiguous tensor triggers extra copy
a = some_tensor.transpose(0, 1)
c = triton_matmul(a, b)  # Internally calls .contiguous()

# โœ… Good: Explicitly ensure contiguous
a = some_tensor.transpose(0, 1).contiguous()
c = triton_matmul(a, b)

Monitor Memory Usage โ€‹

python
def memory_report():
    allocated = torch.cuda.memory_allocated() / 1024**3
    reserved = torch.cuda.memory_reserved() / 1024**3
    peak = torch.cuda.max_memory_allocated() / 1024**3

    print(f"Allocated: {allocated:.2f} GB")
    print(f"Reserved:  {reserved:.2f} GB")
    print(f"Peak:      {peak:.2f} GB")

# Clear cache before benchmarks
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

GPU Architecture Optimization โ€‹

Ampere (SM80) - A100, RTX 30 Series โ€‹

python
ampere_config = {
    "BLOCK_M": 128,
    "BLOCK_N": 256,
    "BLOCK_K": 64,
    "num_stages": 3,
    "num_warps": 8,
}
# SRAM: ~164 KB per SM

Ada (SM89) - RTX 40 Series โ€‹

python
ada_config = {
    "BLOCK_M": 128,
    "BLOCK_N": 256,
    "BLOCK_K": 64,
    "num_stages": 4,  # Larger SRAM
    "num_warps": 8,
}
# SRAM: ~192 KB per SM

Hopper (SM90) - H100 โ€‹

python
from kernels import check_hopper_features

features = check_hopper_features()

if features["tma_available"]:
    print("TMA available - async loading possible")

if features["fp8_available"]:
    print("FP8 available - low precision compute possible")

# SRAM: ~228 KB per SM

Profiling Tools โ€‹

PyTorch Profiler โ€‹

python
from torch.profiler import profile, ProfilerActivity

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
) as prof:
    result = triton_matmul(a, b)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
prof.export_chrome_trace("trace.json")

Triton Built-in Benchmarking โ€‹

python
from utils.benchmark import benchmark_fn

median_ms, p20_ms, p80_ms = benchmark_fn(
    triton_matmul, a, b,
    warmup=25,
    rep=100,
)
print(f"Median: {median_ms:.3f} ms, P20: {p20_ms:.3f} ms, P80: {p80_ms:.3f} ms")

Common Pitfalls โ€‹

1. Too Small Matrices โ€‹

python
# โŒ Bad: Kernel launch overhead dominates
a = torch.randn(32, 32, device="cuda")
for _ in range(1000):
    c = triton_matmul(a, a)

# โœ… Good: Use appropriately sized matrices
a = torch.randn(1024, 1024, device="cuda")
c = triton_matmul(a, a)

2. Frequent CPU-GPU Synchronization โ€‹

python
# โŒ Bad: Synchronizing every operation
for _ in range(100):
    result = triton_matmul(a, b)
    torch.cuda.synchronize()  # Blocks!

# โœ… Good: Batch operations, then sync
for _ in range(100):
    result = triton_matmul(a, b)
torch.cuda.synchronize()

3. Cold Start on First Run โ€‹

python
# โŒ Bad: First run includes compilation time
import time
start = time.time()
result = triton_matmul(a, b)  # Includes JIT compilation!
print(f"Time: {time.time() - start:.3f}s")

# โœ… Good: Warmup before timing
for _ in range(10):
    _ = triton_matmul(a, b)
torch.cuda.synchronize()

start = time.time()
for _ in range(100):
    result = triton_matmul(a, b)
torch.cuda.synchronize()
print(f"Time: {(time.time() - start) / 100 * 1000:.3f} ms")

Performance Checklist โ€‹

Before running benchmarks, ensure:

โ–ก Data Types
  โ”œโ”€ โ˜‘ Use FP16 or BF16
  โ”œโ”€ โ˜‘ Avoid FP32 (unless high precision needed)
  โ””โ”€ โ˜‘ Consistent input/output dtypes

โ–ก Memory
  โ”œโ”€ โ˜‘ Input tensors are contiguous
  โ”œโ”€ โ˜‘ Data already on GPU
  โ””โ”€ โ˜‘ Clear cache before benchmark

โ–ก Configuration
  โ”œโ”€ โ˜‘ Use autotune (don't specify block size)
  โ”œโ”€ โ˜‘ Or choose appropriate block size for matrix size
  โ””โ”€ โ˜‘ Check SRAM capacity limits

โ–ก Measurement
  โ”œโ”€ โ˜‘ Warm up (10+ iterations)
  โ”œโ”€ โ˜‘ Measure multiple times and average
  โ”œโ”€ โ˜‘ Use torch.cuda.synchronize()
  โ””โ”€ โ˜‘ Use GPU time, not CPU time

โ–ก Code
  โ”œโ”€ โ˜‘ Avoid synchronization in loops
  โ”œโ”€ โ˜‘ Avoid CPU-GPU data movement in loops
  โ””โ”€ โ˜‘ Matrix size > 512

Run Benchmarks โ€‹

bash
# Matrix multiplication benchmark
make bench-matmul

# FlashAttention benchmark
make bench-flash

# All benchmarks
make bench-all

# Generate report
make report

References โ€‹

Forward-only educational Triton FlashAttention project ยท MIT License