Skip to content

张量布局指南

理解 FlashAttention V1 和 V2 之间的张量布局差异。

关键差异

⚠️ V1 和 V2 使用不同的张量布局 - headsseq_len 维度互换了位置!

版本布局
V1 (flash_attention)(batch, heads, seq_len, head_dim)
V2 (flash_attention_v2)(batch, seq_len, heads, head_dim)

为什么存在差异?

FlashAttention V2 使用行级(条纹)并行,这要求 seq_len 维度在内存中连续,以在 Ampere+ GPU 上获得最优的内存访问模式。

V1 使用列并行处理,heads 连续存储更自然。

示例:正确用法

V1 (flash_attention)

python
import torch
from kernels import flash_attention

batch, heads, seq_len, head_dim = 2, 8, 512, 64

# V1 期望: (batch, heads, seq_len, head_dim)
q = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.float16)
k = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.float16)
v = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.float16)

out = flash_attention(q, k, v, causal=True)
print(out.shape)  # (2, 8, 512, 64)

V2 (flash_attention_v2)

python
import torch
from kernels import flash_attention_v2

batch, heads, seq_len, head_dim = 2, 8, 512, 64

# V2 期望: (batch, seq_len, heads, head_dim)
q = torch.randn(batch, seq_len, heads, head_dim, device="cuda", dtype=torch.float16)
k = torch.randn(batch, seq_len, heads, head_dim, device="cuda", dtype=torch.float16)
v = torch.randn(batch, seq_len, heads, head_dim, device="cuda", dtype=torch.float16)

out = flash_attention_v2(q, k, v, causal=True)
print(out.shape)  # (2, 512, 8, 64)

版本间转换

在 V1 和 V2 之间切换时,需要转置第 1 和第 2 维度:

python
# 从 V1 格式转换为 V2 格式
q_v2 = q_v1.transpose(1, 2)  # (b, h, s, d) -> (b, s, h, d)
k_v2 = k_v1.transpose(1, 2)
v_v2 = v_v1.transpose(1, 2)

# 运行 V2
out_v2 = flash_attention_v2(q_v2, k_v2, v_v2, causal=True)

# 转换回 V1 格式
out_v1 = out_v2.transpose(1, 2)  # (b, s, h, d) -> (b, h, s, d)

应该使用哪个版本?

场景推荐
Ampere+ GPU (A100, RTX 30xx, RTX 40xx)V2,性能提升 5-15%
Volta/Turing GPU (V100, RTX 20xx)V1(V2 未针对旧架构优化)
大批次 + 长序列V2
代码兼容性优先V1(标准 PyTorch 注意力布局)

BackendSelector

如果需要自动选择,可以使用 BackendSelector

python
from kernels import BackendSelector, flash_attention, flash_attention_v2

# 选择器内部处理布局差异
selector = BackendSelector()
kernel = selector.select_attention(batch=2, heads=8, seq_len=1024, head_dim=64)

# 或使用带 variant 参数的 flash_attention
from kernels import flash_attention
out = flash_attention(q, k, v, causal=True, variant="auto")

总结

  • V1: (batch, heads, seq_len, head_dim) - 标准布局,通用支持
  • V2: (batch, seq_len, heads, head_dim) - 针对 Ampere+ 优化,快 5-15%
  • 切换版本时务必检查张量形状!

Forward-only educational Triton FlashAttention project · MIT License