🔷
真实可运行的 Triton 内核
不是玩具示例——真实的 matmul 和 FlashAttention 内核,可直接运行、基准测试、逐行研读。代码紧凑,注释详尽。
用 Triton 从零构建 FlashAttention,掌握 GPU 内核优化的核心技术
紧凑但真实:代码量控制在可完整阅读的范围内,但绝非玩具。你可以:
| 主题 | 收获 |
|---|---|
| GPU 内存层级 | 数据流动:HBM → L2 → SRAM → 寄存器 |
| Triton 编程 | 自动分块、autotune、内核优化技巧 |
| FlashAttention 算法 | 在线 softmax、因果掩码、变长序列处理 |
| 性能调优 | 块大小选择、occupancy 优化、内存分析 |
# 安装
pip install diy-flash-attention
# 或从源码安装
pip install -e ".[dev]"
# 验证安装
python -c "from kernels import flash_attention; print('✓ 安装成功')"import torch
from kernels import flash_attention
# FlashAttention — 长序列内存减少 99%
q = torch.randn(2, 8, 4096, 64, device="cuda", dtype=torch.float16)
k = torch.randn(2, 8, 4096, 64, device="cuda", dtype=torch.float16)
v = torch.randn(2, 8, 4096, 64, device="cuda", dtype=torch.float16)
out = flash_attention(q, k, v, causal=True) # GPT 风格因果掩码
print(f"输出形状: {out.shape}") # [2, 8, 4096, 64]