Skip to content

DIY FlashAttention 速查表

快速查找常用 API、命令和配置。

🚀 快速开始

bash
# 安装
pip install -e ".[dev]"

# 运行演示
make demo

# 运行测试
make test

📦 核心 API

矩阵乘法

python
from kernels import triton_matmul

# 基本用法 (自动选择最优配置)
c = triton_matmul(a, b)

# 指定 block size
c = triton_matmul(a, b, block_m=128, block_n=256, block_k=64)

# 支持的数据类型
a = torch.randn(..., dtype=torch.float16)   # ✅ 推荐
a = torch.randn(..., dtype=torch.bfloat16)  # ✅ 支持
a = torch.randn(..., dtype=torch.float32)   # ⚠️ 内部转 float16

FlashAttention

python
from kernels import flash_attention

# 基本用法
out = flash_attention(q, k, v)

# 因果注意力 (用于自回归模型)
out = flash_attention(q, k, v, causal=True)

# 变长序列
seq_lens = torch.tensor([256, 512], device="cuda", dtype=torch.int32)
out = flash_attention(q, k, v, seq_lens=seq_lens)

# 3D 输入: (batch*heads, seq_len, head_dim)
q_3d = torch.randn(16, 512, 64, device="cuda", dtype=torch.float16)
out = flash_attention(q_3d, k_3d, v_3d)

GPU 检测

python
from utils import detect_gpu, print_gpu_info

caps = detect_gpu()
print_gpu_info(caps)

# 检查特性
print(f"TMA: {caps.has_tma}")
print(f"FP8: {caps.has_fp8}")

Benchmark

python
from utils import BenchmarkRunner

runner = BenchmarkRunner(warmup=10, rep=50)

# 矩阵乘法
results = runner.benchmark_matmul(
    triton_matmul,
    sizes=[(1024, 1024, 1024), (2048, 2048, 2048)],
)
runner.print_comparison_table(results)

# FlashAttention
results = runner.benchmark_attention(
    flash_attention,
    seq_lengths=[512, 1024, 2048],
)

验证

python
from utils import validate_matmul, validate_attention

# 验证矩阵乘法
is_valid, max_diff = validate_matmul(
    triton_matmul, m=1024, n=1024, k=1024
)

# 验证 FlashAttention
is_valid, max_diff = validate_attention(
    flash_attention, batch=2, heads=8, seq_len=512, head_dim=64
)

🔧 常用命令

命令说明
make demo快速演示
make test运行所有测试
make bench-all运行所有 benchmark
make gpu-info显示 GPU 信息
make experimentBlock Size 实验
make lint代码检查
make format代码格式化
make clean清理缓存

📐 输入形状

矩阵乘法

A: (M, K)  ×  B: (K, N)  →  C: (M, N)

FlashAttention

4D 输入: (batch, heads, seq_len, head_dim)
3D 输入: (batch*heads, seq_len, head_dim)

seq_lens: (batch,) 指定每个样本的有效长度
head_dim: 支持 32 或 64

📊 Block Size 推荐

矩阵大小BLOCK_MBLOCK_NBLOCK_K
< 512323232
512-1024646432
1024-20486412832
2048-409612812832
> 409612825664

提示: 使用 autotune (不指定 block size) 可自动选择最优配置。

🎨 数据类型

类型精度性能推荐场景
float162x训练/推理 (推荐)
bfloat162x训练 (更稳定)
float321x调试/高精度
float84x推理 (Hopper+)

🖥️ GPU 架构

架构SMGPU特性
Ampere80A100, RTX 30xx完整支持
Ada89RTX 40xx完整支持
Hopper90H100TMA, FP8
Blackwell100B100最新

💾 内存复杂度

方法内存说明
标准 AttentionO(N²)存储完整 attention matrix
FlashAttentionO(N)分块计算,不存储完整矩阵

内存节省: 长序列可节省 99% 内存!

⚠️ 常见错误

错误原因解决方案
Expected 2D tensorsmatmul 输入不是 2D使用 .view().reshape()
Incompatible dimensionsA.shape[1] != B.shape[0]检查矩阵维度
CUDA tensors required输入在 CPU 上使用 .cuda().to("cuda")
Q, K, V shapes must match形状不一致确保 Q, K, V 形状相同
Expected 3D or 4D tensorsattention 输入维度错误检查输入形状
Unsupported head_dimhead_dim 不是 32/64使用 32 或 64
Unsupported dtypedtype 不支持使用 float16/bfloat16/float32
dtypes must matchdtype 不一致统一 dtype

✅ 性能检查清单

□ 使用 FP16 或 BF16
□ 使用 autotune (不指定 block size)
□ 预热 kernel (运行几次后计时)
□ 确保输入是连续的 (.is_contiguous())
□ 数据保持在 GPU 上
□ 矩阵足够大 (> 512)
□ 避免循环内同步
□ 避免循环内 CPU-GPU 数据移动

📁 文件结构

kernels/
├── matmul.py          # 矩阵乘法
├── flash_attn.py      # FlashAttention
└── modern_features.py # 现代特性

utils/
├── benchmark.py       # Benchmark 工具
├── validation.py      # 验证工具
└── gpu_detect.py      # GPU 检测

tests/
├── test_matmul.py     # 矩阵乘法测试
├── test_flash.py      # FlashAttention 测试
├── test_properties.py # 属性测试
└── test_error_handling.py # 错误处理测试

🔗 链接

Forward-only educational Triton FlashAttention project · MIT License