⚡
线性内存
通过 FlashAttention 分块处理 16K+ token 序列,内存复杂度 O(N) —— 比标准注意力节省 99.9%。
适用场景
你想理解 FlashAttention 内部原理,实验注意力机制,或在没有重型框架依赖的情况下集成到项目中。
| 特性 | CuFlash-Attn | PyTorch SDPA | FlashAttention-2 |
|---|---|---|---|
| 教育性代码 | ✅ | ❌ | ⚠️ |
| 无依赖 | ✅ | ❌ PyTorch | ❌ |
| Python 绑定 | ✅ ctypes | ✅ 原生 | ✅ |
| 训练支持 | ✅ | ✅ | ✅ |
| 可定制 | ✅ 简单 | ⚠️ 困难 | ⚠️ |
5 分钟内运行:
git clone https://github.com/LessUp/cuflash-attn.git
cd cuflash-attn
cmake --preset release
cmake --build --preset release
ctest --preset release --output-on-failure#include "cuflash/flash_attention.h"
auto err = cuflash::flash_attention_forward(
d_Q, d_K, d_V, d_O, d_L,
batch_size, num_heads, seq_len, head_dim,
scale, true, stream
);import ctypes
lib = ctypes.CDLL("./build/release/libcuflash_attn.so")
# 通过 C ABI 调用
lib.cuflash_attention_forward_f32(
q_ptr, k_ptr, v_ptr, o_ptr, l_ptr,
B, H, N, D, scale, True, None
)| 序列长度 | 标准注意力 | FlashAttention | 节省 |
|---|---|---|---|
| 1,024 | 4 MB | 8 KB | 99.8% |
| 4,096 | 64 MB | 32 KB | 99.95% |
| 16,384 | 1 GB | 128 KB | 99.99% |
| 资源 | 描述 |
|---|---|
| 快速开始指南 | Preset 构建路径 |
| 从源码构建 | 平台、presets、覆盖参数 |
| API 参考 | 完整 C++ 和 C ABI 文档 |
| 算法详解 | 分块、online softmax、重计算 |
| 故障排除 | 常见问题与解决方案 |
稳定的 v0.3.0 基线 —— 可归档级参考实现。当前重点:文档质量、工作流简化、Bug 修复。
详见 项目状态 了解维护姿态与治理规则。
本项目遵循 OpenSpec 规范驱动方法。权威需求定义: