Skip to content

API 参考文档

本文档提供 DIY FlashAttention 的完整 API 参考。

目录


Kernels

triton_matmul

高性能 Triton 矩阵乘法,支持 autotune 和多种数据类型。

python
from kernels import triton_matmul

def triton_matmul(
    a: torch.Tensor,
    b: torch.Tensor,
    block_m: Optional[int] = None,
    block_n: Optional[int] = None,
    block_k: Optional[int] = None,
    use_autotune: bool = True,
) -> torch.Tensor:
    """
    使用 Triton 执行矩阵乘法 C = A @ B。

    参数:
        a (torch.Tensor): 输入矩阵 A,形状 (M, K)。
            支持数据类型: float16, float32, bfloat16。
            必须是 2D CUDA tensor。

        b (torch.Tensor): 输入矩阵 B,形状 (K, N)。
            支持数据类型: float16, float32, bfloat16。
            必须是 2D CUDA tensor,且与 a 同 dtype、同 device。

        block_m (int, optional): M 维度的 block size。
            如果指定,必须同时指定 block_n 和 block_k。
            默认使用 autotune 自动选择。

        block_n (int, optional): N 维度的 block size。

        block_k (int, optional): K 维度的 block size。

        use_autotune (bool): 是否使用 autotune。
            默认 True。仅在未指定 block size 时生效。

    返回:
        torch.Tensor: 输出矩阵 C,形状 (M, N)。
            - float16/bfloat16 输入 → 保持原 dtype
            - float32 输入 → 输出 float16(内部转换为 float16 计算)

    抛出:
        ValueError:
            - 输入不是 2D tensor
            - 输入不是 CUDA tensor
            - 输入不在同一 device
            - 维度不兼容 (A.shape[1] != B.shape[0])
            - block size 非法 (非正数或超过矩阵维度)

        TypeError:
            - 输入 dtype 不受支持
            - 输入 dtype 不一致

    示例:
        基本用法::

            import torch
            from kernels import triton_matmul

            # 创建输入矩阵
            a = torch.randn(1024, 512, device="cuda", dtype=torch.float16)
            b = torch.randn(512, 2048, device="cuda", dtype=torch.float16)

            # 使用 autotune (推荐)
            c = triton_matmul(a, b)

        指定 block size::

            # 手动指定 block size
            c = triton_matmul(a, b, block_m=128, block_n=256, block_k=64)

        不同数据类型::

            # BF16 支持
            a_bf16 = torch.randn(1024, 512, device="cuda", dtype=torch.bfloat16)
            b_bf16 = torch.randn(512, 2048, device="cuda", dtype=torch.bfloat16)
            c_bf16 = triton_matmul(a_bf16, b_bf16)  # 输出也是 bfloat16

    注意:
        - 输入会被自动转换为连续内存布局
        - float32 输入会转换为 float16 计算,输出为 float16
        - 对于小矩阵 (< 64),使用 PyTorch 可能更快

    性能提示:
        - 使用 autotune 可获得最佳性能
        - 大矩阵 (> 2048) 性能优势更明显
        - 确保输入已在 GPU 上,避免 CPU-GPU 数据传输
    """

flash_attention

FlashAttention 前向实现,O(N) 内存复杂度。

python
from kernels import flash_attention

def flash_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    causal: bool = False,
    sm_scale: Optional[float] = None,
    seq_lens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    FlashAttention forward pass,实现高效的注意力计算。

    计算: softmax(Q @ K^T / sqrt(d)) @ V

    参数:
        q (torch.Tensor): Query tensor。
            形状: (batch, heads, seq_len, head_dim) 或 (batch*heads, seq_len, head_dim)。
            支持数据类型: float16, float32, bfloat16。
            必须是 CUDA tensor。

        k (torch.Tensor): Key tensor。形状和 dtype 必须与 q 相同。

        v (torch.Tensor): Value tensor。形状和 dtype 必须与 q 相同。

        causal (bool): 是否使用因果 masking(用于自回归模型)。
            默认 False。
            当 True 时,位置 i 只能关注位置 <= i 的 token。

        sm_scale (float, optional): Softmax 缩放因子。
            默认 1 / sqrt(head_dim)。
            对于标准 attention,建议使用默认值。

        seq_lens (torch.Tensor, optional): 每个样本的有效序列长度。
            形状: (batch,)。
            数据类型: int32。
            用于处理变长序列,超过有效长度的位置输出为 0。
            必须满足: 0 < seq_lens[i] <= seq_len。

    返回:
        torch.Tensor: Attention 输出,形状与输入 q 相同。
            内部使用 float16 计算。

    抛出:
        ValueError:
            - 输入不是 3D 或 4D tensor
            - Q, K, V 形状不一致
            - 输入不是 CUDA tensor
            - 输入不在同一 device
            - head_dim 不是 32 或 64
            - seq_lens 形状或值非法

        TypeError:
            - 输入 dtype 不受支持
            - 输入 dtype 不一致

    示例:
        基本用法::

            import torch
            from kernels import flash_attention

            # 创建输入 (batch=2, heads=8, seq_len=512, head_dim=64)
            q = torch.randn(2, 8, 512, 64, device="cuda", dtype=torch.float16)
            k = torch.randn(2, 8, 512, 64, device="cuda", dtype=torch.float16)
            v = torch.randn(2, 8, 512, 64, device="cuda", dtype=torch.float16)

            # 非因果 attention
            out = flash_attention(q, k, v)

        因果 attention::

            # 用于自回归模型 (如 GPT)
            out = flash_attention(q, k, v, causal=True)

        变长序列::

            # 处理不同长度的序列
            seq_lens = torch.tensor([256, 512], device="cuda", dtype=torch.int32)
            out = flash_attention(q, k, v, seq_lens=seq_lens)

        3D 输入::

            # 3D 输入: (batch*heads, seq_len, head_dim)
            q_3d = torch.randn(16, 512, 64, device="cuda", dtype=torch.float16)
            k_3d = torch.randn(16, 512, 64, device="cuda", dtype=torch.float16)
            v_3d = torch.randn(16, 512, 64, device="cuda", dtype=torch.float16)
            out_3d = flash_attention(q_3d, k_3d, v_3d)

    注意:
        - 当前仅支持 head_dim = 32 或 64
        - 非 float16 输入会转换为 float16 计算
        - 内存复杂度 O(N),相比标准 attention 的 O(N²)

    性能提示:
        - 长序列 (> 512) 内存节省更明显
        - 启用 causal 可减少约一半计算量
        - 建议使用 batch > 1 以充分利用 GPU
    """

reference_attention

PyTorch 参考实现,用于验证正确性。

python
from kernels.flash_attn import reference_attention

def reference_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    causal: bool = False,
) -> torch.Tensor:
    """
    标准 attention 实现,用于验证 FlashAttention 的正确性。

    参数:
        q, k, v: 输入 tensor,任意形状兼容 torch.matmul
        causal: 是否使用因果 masking

    返回:
        torch.Tensor: Attention 输出

    示例:
        验证 FlashAttention 正确性::

            import torch
            from kernels import flash_attention
            from kernels.flash_attn import reference_attention

            q = torch.randn(2, 8, 128, 64, device="cuda", dtype=torch.float16)
            k = torch.randn(2, 8, 128, 64, device="cuda", dtype=torch.float16)
            v = torch.randn(2, 8, 128, 64, device="cuda", dtype=torch.float16)

            flash_out = flash_attention(q, k, v, causal=True)
            ref_out = reference_attention(q, k, v, causal=True)

            assert torch.allclose(flash_out, ref_out, rtol=1e-2, atol=1e-2)
    """

GPU 检测

detect_gpu

检测 GPU 能力和特性。

python
from utils import detect_gpu

def detect_gpu(device_id: int = 0) -> GPUCapabilities:
    """
    检测指定 GPU 的能力。

    参数:
        device_id (int): CUDA 设备 ID。默认 0。

    返回:
        GPUCapabilities: GPU 能力信息对象。

    抛出:
        RuntimeError: CUDA 不可用。

    示例:
        ::

            from utils import detect_gpu, print_gpu_info

            caps = detect_gpu()
            print(f"GPU: {caps.name}")
            print(f"架构: {caps.arch.value}")
            print(f"计算能力: {caps.compute_capability}")
            print(f"显存: {caps.total_memory_gb:.1f} GB")
            print(f"TMA 支持: {caps.has_tma}")
            print(f"FP8 支持: {caps.has_fp8}")
    """

GPUCapabilities

GPU 能力信息数据类。

python
from utils import GPUCapabilities
from dataclasses import dataclass

@dataclass
class GPUCapabilities:
    """GPU 能力信息。"""

    name: str
        # GPU 名称,如 "NVIDIA GeForce RTX 4090"

    arch: GPUArch
        # GPU 架构枚举值

    compute_capability: tuple[int, int]
        # 计算能力,如 (8, 9) 表示 SM 89

    has_tma: bool
        # 是否支持 Tensor Memory Accelerator (Hopper+)

    has_fp8: bool
        # 是否支持 FP8 数据类型 (Hopper+)

    has_warpgroup_mma: bool
        # 是否支持 Warpgroup MMA (Hopper+)

    sram_per_sm: int
        # 每个 SM 的共享内存大小 (bytes)

    num_sms: int
        # SM 数量

    total_memory_gb: float
        # 总显存大小 (GB)

GPUArch

GPU 架构枚举。

python
from utils import GPUArch

class GPUArch(Enum):
    """GPU 架构枚举。"""

    VOLTA = "sm_70"      # V100
    TURING = "sm_75"     # RTX 20xx
    AMPERE = "sm_80"     # A100, RTX 30xx
    ADA = "sm_89"        # RTX 40xx
    HOPPER = "sm_90"     # H100
    BLACKWELL = "sm_100" # B100/B200
    UNKNOWN = "unknown"

Benchmark 工具

BenchmarkRunner

Benchmark 运行器。

python
from utils import BenchmarkRunner

class BenchmarkRunner:
    """
    运行和管理 benchmark。

    参数:
        device (str): 运行设备。默认 "cuda"。
        warmup (int): 预热迭代次数。默认 25。
        rep (int): 重复测量次数。默认 100。

    示例:
        矩阵乘法 benchmark::

            from utils import BenchmarkRunner
            from kernels import triton_matmul

            runner = BenchmarkRunner(warmup=10, rep=50)
            results = runner.benchmark_matmul(
                triton_matmul,
                sizes=[(1024, 1024, 1024), (2048, 2048, 2048)],
            )
            runner.print_comparison_table(results)

        FlashAttention benchmark::

            results = runner.benchmark_attention(
                flash_attention,
                seq_lengths=[512, 1024, 2048],
                batch_size=4,
                num_heads=8,
                head_dim=64,
            )
    """

    def benchmark_matmul(
        self,
        triton_fn: Callable,
        sizes: list[tuple[int, int, int]],
        block_configs: Optional[list[dict]] = None,
        dtype: torch.dtype = torch.float16,
    ) -> list[BenchmarkResult]:
        """
        运行矩阵乘法 benchmark。

        参数:
            triton_fn: Triton matmul 函数
            sizes: 矩阵大小列表 [(M, N, K), ...]
            block_configs: 可选的 block size 配置列表
            dtype: 数据类型

        返回:
            BenchmarkResult 列表
        """

    def benchmark_attention(
        self,
        flash_fn: Callable,
        seq_lengths: list[int],
        batch_size: int = 4,
        num_heads: int = 8,
        head_dim: int = 64,
        causal: bool = False,
        dtype: torch.dtype = torch.float16,
    ) -> list[BenchmarkResult]:
        """
        运行 FlashAttention benchmark。
        """

    def print_comparison_table(
        self,
        results: Optional[list[BenchmarkResult]] = None,
        title: str = "Benchmark Results",
    ) -> None:
        """打印格式化的比较表格。"""

    def clear_results(self) -> None:
        """清除已存储的结果。"""

BenchmarkResult

Benchmark 结果数据类。

python
from utils import BenchmarkResult
from dataclasses import dataclass

@dataclass
class BenchmarkResult:
    """单个 benchmark 结果。"""

    name: str
        # 实现名称,如 "PyTorch", "Triton", "FlashAttention"

    size: tuple
        # 问题大小

    time_ms: float
        # 执行时间 (毫秒)

    tflops: float
        # 计算吞吐量 (TFLOPS)

    memory_mb: float = 0.0
        # 峰值内存使用 (MB)

    block_config: Optional[dict] = None
        # Block size 配置 (如果手动指定)

benchmark_fn

单函数 benchmark 工具。

python
from utils.benchmark import benchmark_fn

def benchmark_fn(
    fn: Callable,
    *args,
    warmup: int = 25,
    rep: int = 100,
    quantiles: Optional[list[float]] = None,
    **kwargs,
) -> tuple[float, float, float]:
    """
    对单个函数进行 benchmark。

    参数:
        fn: 要测试的函数
        *args: 传递给 fn 的位置参数
        warmup: 预热次数
        rep: 测量次数
        quantiles: 分位数列表
        **kwargs: 传递给 fn 的关键字参数

    返回:
        tuple[float, float, float]: (median_ms, p20_ms, p80_ms)

    示例:
        ::

            from utils.benchmark import benchmark_fn
            from kernels import triton_matmul

            a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
            b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)

            median_ms, p20_ms, p80_ms = benchmark_fn(triton_matmul, a, b)
            print(f"Median: {median_ms:.3f} ms")
    """

验证工具

validate_matmul

验证矩阵乘法正确性。

python
from utils import validate_matmul

def validate_matmul(
    triton_fn: Callable,
    m: int,
    n: int,
    k: int,
    rtol: float = 1e-3,
    atol: float = 1e-3,
    dtype: torch.dtype = torch.float16,
    device: str = "cuda",
    verbose: bool = False,
) -> tuple[bool, float]:
    """
    验证 Triton matmul 与 PyTorch 参考实现的一致性。

    参数:
        triton_fn: Triton matmul 函数
        m, n, k: 矩阵维度
        rtol: 相对容差
        atol: 绝对容差
        dtype: 数据类型
        device: 设备
        verbose: 是否打印详细信息

    返回:
        tuple[bool, float]: (是否通过验证, 最大差异)

    示例:
        ::

            from utils import validate_matmul
            from kernels import triton_matmul

            is_valid, max_diff = validate_matmul(
                triton_matmul, m=1024, n=1024, k=1024, verbose=True
            )
            print(f"验证{'通过' if is_valid else '失败'}, 最大差异: {max_diff:.2e}")
    """

validate_attention

验证 FlashAttention 正确性。

python
from utils import validate_attention

def validate_attention(
    flash_fn: Callable,
    batch: int,
    heads: int,
    seq_len: int,
    head_dim: int,
    causal: bool = False,
    rtol: float = 1e-3,
    atol: float = 1e-3,
    dtype: torch.dtype = torch.float16,
    device: str = "cuda",
    verbose: bool = False,
) -> tuple[bool, float]:
    """
    验证 FlashAttention 与 PyTorch SDPA 的一致性。

    参数:
        flash_fn: FlashAttention 函数
        batch: 批次大小
        heads: 注意力头数
        seq_len: 序列长度
        head_dim: 每个头的维度
        causal: 是否使用因果 masking
        rtol: 相对容差
        atol: 绝对容差
        dtype: 数据类型
        device: 设备
        verbose: 是否打印详细信息

    返回:
        tuple[bool, float]: (是否通过验证, 最大差异)

    示例:
        ::

            from utils import validate_attention
            from kernels import flash_attention

            # 验证非因果 attention
            is_valid, max_diff = validate_attention(
                flash_attention, batch=2, heads=8, seq_len=512, head_dim=64
            )

            # 验证因果 attention
            is_valid, max_diff = validate_attention(
                flash_attention, batch=2, heads=8, seq_len=512, head_dim=64, causal=True
            )
    """

现代特性检测

check_hopper_features

检测 Hopper+ GPU 特性。

python
from kernels import check_hopper_features

def check_hopper_features() -> dict[str, Any]:
    """
    检测当前 GPU 的 Hopper+ 特性支持情况。

    返回:
        dict: 包含以下键:
            - tma_available (bool): TMA 支持
            - fp8_available (bool): FP8 支持
            - wgmma_available (bool): Warpgroup MMA 支持
            - arch (str): 架构名称
            - compute_capability (tuple): 计算能力

    示例:
        ::

            from kernels import check_hopper_features

            features = check_hopper_features()
            print(f"TMA: {features['tma_available']}")
            print(f"FP8: {features['fp8_available']}")
            print(f"架构: {features['arch']}")
    """

AdaptiveKernelSelector

自适应 kernel 选择器。

python
from kernels import AdaptiveKernelSelector

class AdaptiveKernelSelector:
    """
    根据 GPU 架构选择最优 kernel 实现。

    示例:
        ::

            from kernels import AdaptiveKernelSelector

            selector = AdaptiveKernelSelector()

            # 获取最优配置
            matmul_config = selector.get_matmul_config()
            attention_config = selector.get_attention_config()

            # 获取最优 kernel
            matmul_fn = selector.select_matmul_kernel()
            attention_fn = selector.select_attention_kernel()
    """

    def get_matmul_config(self) -> dict[str, Any]:
        """获取最优 matmul 配置。"""

    def get_attention_config(self) -> dict[str, Any]:
        """获取最优 attention 配置。"""

    def select_matmul_kernel(self) -> Callable:
        """选择最优 matmul kernel。"""

    def select_attention_kernel(self) -> Callable:
        """选择最优 attention kernel。"""

快捷函数

python
from kernels import (
    get_matmul_config,      # 获取最优 matmul 配置
    get_attention_config,   # 获取最优 attention 配置
    get_optimal_matmul,     # 获取最优 matmul 函数
    get_optimal_attention,  # 获取最优 attention 函数
    supports_fp8,           # 检测 FP8 支持
)

错误处理

常见错误及解决方案

错误类型错误信息原因解决方案
ValueErrorExpected 2D tensorsmatmul 输入不是 2D使用 .view().reshape()
ValueErrorIncompatible dimensions矩阵维度不匹配检查 A.shape[1] == B.shape[0]
ValueErrorCUDA tensors required输入在 CPU 上使用 .to("cuda").cuda()
ValueErrorExpected 3D or 4D tensorsattention 输入维度错误检查输入形状
ValueErrorUnsupported head_dimhead_dim 不是 32/64使用 head_dim=32 或 64
TypeErrorUnsupported dtype使用了不支持的 dtype使用 float16/bfloat16/float32
TypeErrordtypes must matchQ, K, V dtype 不一致确保相同 dtype
ModuleNotFoundErrortriton is requiredTriton 未安装pip install triton

Forward-only educational Triton FlashAttention project · MIT License