Skip to content

DIYFlashAttention

用 Triton 从零构建 FlashAttention,掌握 GPU 内核优化的核心技术

⚡ 内存减少 99% 🚀 速度提升 1.6x 📖 生产级代码质量

为什么选择这个项目?

紧凑但真实:代码量控制在可完整阅读的范围内,但绝非玩具。你可以:

  • ✅ 在你的 GPU 上运行真实基准测试
  • ✅ 对比 PyTorch SDPA 的性能差异
  • ✅ 理解每一行代码背后的设计决策

你将学到什么

主题收获
GPU 内存层级数据流动:HBM → L2 → SRAM → 寄存器
Triton 编程自动分块、autotune、内核优化技巧
FlashAttention 算法在线 softmax、因果掩码、变长序列处理
性能调优块大小选择、occupancy 优化、内存分析

项目数据

2+
核心 Triton 内核
O(N)
注意力内存复杂度
6
支持的 GPU 架构
99%
长序列内存节省

快速开始

bash
# 安装
pip install diy-flash-attention

# 或从源码安装
pip install -e ".[dev]"

# 验证安装
python -c "from kernels import flash_attention; print('✓ 安装成功')"

运行示例

python
import torch
from kernels import flash_attention

# FlashAttention — 长序列内存减少 99%
q = torch.randn(2, 8, 4096, 64, device="cuda", dtype=torch.float16)
k = torch.randn(2, 8, 4096, 64, device="cuda", dtype=torch.float16)
v = torch.randn(2, 8, 4096, 64, device="cuda", dtype=torch.float16)

out = flash_attention(q, k, v, causal=True)  # GPT 风格因果掩码
print(f"输出形状: {out.shape}")  # [2, 8, 4096, 64]

学习路径

🧑‍💻
内核开发者
从教程开始,逐行理解 FlashAttention 实现
路径:教程 → API → 性能指南
🔬
研究人员
快速查阅 API 契约,复现和修改内核
路径:API 参考 → 源码
🚀
性能工程师
深入性能调优,理解块大小和架构适配
路径:性能指南 → 基准测试
📚
学习者
系统学习 GPU 编程和注意力优化
路径:教程 → 速查表 → FAQ
开始你的 FlashAttention 学习之旅
教程帮助你理解实现,API 参考确认契约,性能指南提供证据。

语言切换

Forward-only educational Triton FlashAttention project · MIT License