Skip to content
v0.1.0

API 参考

CuFlash-Attn 提供简洁的 C++ API,所有函数和类型定义在 cuflash 命名空间中。


目录


头文件

cpp
#include "cuflash/flash_attention.h"

所有公共 API 均通过此单一头文件暴露。


前向传播

flash_attention_forward (FP32)

计算 FP32 精度的 FlashAttention 前向传播。

cpp
FlashAttentionError flash_attention_forward(
    const float* Q,          // 查询张量 [B, H, N, D]
    const float* K,          // 键张量 [B, H, N, D]
    const float* V,          // 值张量 [B, H, N, D]
    float* O,                // 输出张量 [B, H, N, D]
    float* L,                // logsumexp [B, H, N](反向传播需要)
    int batch_size,          // 批大小 B
    int num_heads,           // 注意力头数 H
    int seq_len,             // 序列长度 N
    int head_dim,            // 头维度 D(32、64 或 128)
    float scale,             // 缩放因子,通常 1.0f / sqrt(D)
    bool causal,             // 是否启用因果掩码
    cudaStream_t stream = 0  // CUDA 流(0 表示默认流)
);

参数说明:

参数类型说明
Qconst float*设备内存中的查询张量
Kconst float*设备内存中的键张量
Vconst float*设备内存中的值张量
Ofloat*设备内存中的输出张量
Lfloat*设备内存中的 logsumexp 值
batch_sizeint批次中的序列数量
num_headsint注意力头数
seq_lenint输入序列长度
head_dimint每个头的维度(32、64 或 128)
scalefloat注意力缩放因子
causalbool是否应用因果(自回归)掩码
streamcudaStream_t异步执行的 CUDA 流

返回值: 成功时返回 FlashAttentionError::SUCCESS,否则返回错误代码。


flash_attention_forward (FP16)

计算 FP16 精度的 FlashAttention 前向传播。内部计算使用 FP32 以确保数值稳定性,输出转换回 FP16。

cpp
FlashAttentionError flash_attention_forward(
    const half* Q,           // 查询张量 [B, H, N, D]
    const half* K,           // 键张量 [B, H, N, D]
    const half* V,           // 值张量 [B, H, N, D]
    half* O,                 // 输出张量 [B, H, N, D]
    half* L,                 // logsumexp [B, H, N]
    int batch_size,
    int num_heads,
    int seq_len,
    int head_dim,
    float scale,
    bool causal,
    cudaStream_t stream = 0
);

精度处理:

  • 输入/输出:FP16(16 位半精度)
  • 内部计算:FP32(32 位单精度)
  • 最终结果:FP16

此方法在减少内存带宽需求的同时,提供了与 FP32 相当的数值稳定性。


反向传播

flash_attention_backward (FP32)

计算 FP32 精度的 FlashAttention 反向传播梯度。

cpp
FlashAttentionError flash_attention_backward(
    const float* Q,          // 前向传播的查询张量
    const float* K,          // 前向传播的键张量
    const float* V,          // 前向传播的值张量
    const float* O,          // 前向传播的输出张量
    const float* L,          // 前向传播的 logsumexp
    const float* dO,         // 上游梯度 [B, H, N, D]
    float* dQ,               // Q 的梯度(输出)
    float* dK,               // K 的梯度(输出)
    float* dV,               // V 的梯度(输出)
    int batch_size,
    int num_heads,
    int seq_len,
    int head_dim,
    float scale,
    bool causal,
    cudaStream_t stream = 0
);

梯度计算:

  • 使用重计算策略(反向传播期间重新计算注意力权重)
  • 不存储 O(N²) 的注意力矩阵
  • 内存复杂度:O(N) 而非 O(N²)

要求:

  • OL 必须来自相应的前向传播调用
  • dQdKdV 必须在设备内存中预先分配

flash_attention_backward (FP16)

计算 FP16 精度的 FlashAttention 反向传播梯度。

cpp
FlashAttentionError flash_attention_backward(
    const half* Q,
    const half* K,
    const half* V,
    const half* O,
    const half* L,
    const half* dO,
    half* dQ,
    half* dK,
    half* dV,
    int batch_size,
    int num_heads,
    int seq_len,
    int head_dim,
    float scale,
    bool causal,
    cudaStream_t stream = 0
);

实现说明:

  • 内部累加使用 FP32 以防止溢出
  • 最终梯度转换为 FP16
  • 数值稳定性与 FP32 反向传播相当

C ABI 接口

用于通过 ctypes 或其他语言调用的 C 兼容函数。

FP32 接口

c
// 前向传播 - C ABI
int cuflash_attention_forward_f32(
    const float* Q, const float* K, const float* V,
    float* O, float* L,
    int batch_size, int num_heads, int seq_len, int head_dim,
    float scale, bool causal, cudaStream_t stream
);

// 反向传播 - C ABI
int cuflash_attention_backward_f32(
    const float* Q, const float* K, const float* V,
    const float* O, const float* L, const float* dO,
    float* dQ, float* dK, float* dV,
    int batch_size, int num_heads, int seq_len, int head_dim,
    float scale, bool causal, cudaStream_t stream
);

FP16 接口

c
// 前向传播 - C ABI
int cuflash_attention_forward_f16(
    const half* Q, const half* K, const half* V,
    half* O, half* L,
    int batch_size, int num_heads, int seq_len, int head_dim,
    float scale, bool causal, cudaStream_t stream
);

// 反向传播 - C ABI
int cuflash_attention_backward_f16(
    const half* Q, const half* K, const half* V,
    const half* O, const half* L, const half* dO,
    half* dQ, half* dK, half* dV,
    int batch_size, int num_heads, int seq_len, int head_dim,
    float scale, bool causal, cudaStream_t stream
);

返回值: FlashAttentionError 枚举的整数值。


张量布局

所有张量使用**行优先(C 风格)**内存布局。

张量形状

张量形状说明
QKVO[batch_size, num_heads, seq_len, head_dim]输入/输出张量
dQdKdVdO[batch_size, num_heads, seq_len, head_dim]梯度张量
L[batch_size, num_heads, seq_len]logsumexp 值

内存偏移计算

cpp
// 访问 Q[b][h][s][d]
size_t offset = ((b * num_heads + h) * seq_len + s) * head_dim + d;

// 访问 L[b][h][s]
size_t offset = (b * num_heads + h) * seq_len + s;

数据类型详情

  • float:32 位 IEEE 754 单精度浮点数
  • half:16 位 IEEE 754 半精度浮点数(CUDA 原生)
  • 所有指针必须指向连续的设备内存

错误处理

FlashAttentionError 枚举

cpp
enum class FlashAttentionError {
    SUCCESS = 0,                   // 操作成功完成
    INVALID_DIMENSION,             // 维度参数无效(≤ 0)
    DIMENSION_MISMATCH,            // 预留,将来使用
    NULL_POINTER,                  // 输入或输出指针为空
    CUDA_ERROR,                    // CUDA 运行时错误
    OUT_OF_MEMORY,                 // GPU 显存不足
    UNSUPPORTED_HEAD_DIM,          // head_dim 不在 {32, 64, 128} 中
    UNSUPPORTED_DTYPE              // 不支持的数据类型
};

get_error_string

cpp
const char* get_error_string(FlashAttentionError error);

返回错误代码对应的人类可读字符串。

错误处理示例

cpp
#include "cuflash/flash_attention.h"
#include <iostream>

int main() {
    // ... 为 d_Q、d_K、d_V、d_O、d_L 分配设备内存 ...
    
    float scale = 1.0f / std::sqrt(static_cast<float>(head_dim));
    
    auto err = cuflash::flash_attention_forward(
        d_Q, d_K, d_V, d_O, d_L,
        batch_size, num_heads, seq_len, head_dim,
        scale,
        /*causal=*/true
    );
    
    if (err != cuflash::FlashAttentionError::SUCCESS) {
        std::cerr << "FlashAttention 错误: "
                  << cuflash::get_error_string(err) << std::endl;
        return 1;
    }
    
    // 反向传播
    err = cuflash::flash_attention_backward(
        d_Q, d_K, d_V, d_O, d_L, d_dO,
        d_dQ, d_dK, d_dV,
        batch_size, num_heads, seq_len, head_dim,
        scale, true
    );
    
    if (err != cuflash::FlashAttentionError::SUCCESS) {
        std::cerr << "反向传播错误: "
                  << cuflash::get_error_string(err) << std::endl;
        return 1;
    }
    
    return 0;
}

类型支持

支持的配置

参数支持的值
head_dim32、64、128
数据类型float (FP32)、half (FP16)
因果掩码可选(bool causal
批大小≥ 1
注意力头数≥ 1
序列长度≥ 1

数据类型支持矩阵

数据类型前向传播反向传播
float (FP32)✅ 完全支持✅ 完全支持
half (FP16)✅ 完全支持✅ 完全支持

构建选项

CMake 选项默认值说明
BUILD_TESTSON构建 GoogleTest 测试套件
ENABLE_RAPIDCHECKOFF启用 RapidCheck 基于属性的测试
BUILD_SHARED_LIBSON构建为共享库(*.so/.dll/.dylib
BUILD_EXAMPLESON构建示例程序
ENABLE_FAST_MATHOFF启用 --use_fast_math 编译器标志

示例配置

bash
# 高性能发布版本构建
cmake --preset release-fast-math \
      -DBUILD_SHARED_LIBS=OFF
cmake --build --preset release-fast-math

# 带所有测试的调试版本
cmake --preset default \
      -DENABLE_RAPIDCHECK=ON
cmake --build --preset default

# 仅静态库
cmake --preset minimal \
      -DBUILD_SHARED_LIBS=OFF
cmake --build --preset minimal

GPU 架构支持

支持的 CUDA 架构

架构计算能力代表 GPU
Voltasm_70V100
Turingsm_75RTX 2080 Ti
Amperesm_80、sm_86A100、RTX 3090
Ada Lovelacesm_89RTX 4090
Hoppersm_90H100

架构特定调优

默认构建支持所有架构。针对特定部署:

bash
# 仅支持 RTX 3090 / A100
cmake --preset release -DCMAKE_CUDA_ARCHITECTURES=86

# 支持多个架构
cmake --preset release -DCMAKE_CUDA_ARCHITECTURES="80;86;89"

共享内存需求

head_dimSRAM 需求典型块大小
32~32 KB64 × 64
64~64 KB64 × 64
128~128 KB32 × 32

注意:head_dim=128 需要支持扩展共享内存的 GPU。


线程安全

  • 使用不同流调用时,所有函数都是线程安全的
  • 支持多个并发调用(使用不同流)
  • 当流共享资源时,同步由调用者负责

内存管理

  • 所有张量分配由调用者负责
  • 内核执行期间不进行动态内存分配
  • 工作空间内存由内部使用流安全的分配管理

Stable v0.3.0 baseline • OpenSpec-driven CUDA FlashAttention reference.

Contributors