Skip to content

Tensor Layout Guide

Understanding tensor layout differences between FlashAttention V1 and V2.

The Critical Difference

⚠️ V1 and V2 use different tensor layouts - the heads and seq_len dimensions are swapped!

VersionLayout
V1 (flash_attention)(batch, heads, seq_len, head_dim)
V2 (flash_attention_v2)(batch, seq_len, heads, head_dim)

Why the Difference?

FlashAttention V2 uses row-wise (striped) parallelism, which requires the seq_len dimension to be contiguous in memory for optimal memory access patterns on Ampere+ GPUs.

V1 uses column-parallel processing, where heads being contiguous is more natural.

Example: Correct Usage

V1 (flash_attention)

python
import torch
from kernels import flash_attention

batch, heads, seq_len, head_dim = 2, 8, 512, 64

# V1 expects: (batch, heads, seq_len, head_dim)
q = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.float16)
k = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.float16)
v = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.float16)

out = flash_attention(q, k, v, causal=True)
print(out.shape)  # (2, 8, 512, 64)

V2 (flash_attention_v2)

python
import torch
from kernels import flash_attention_v2

batch, heads, seq_len, head_dim = 2, 8, 512, 64

# V2 expects: (batch, seq_len, heads, head_dim)
q = torch.randn(batch, seq_len, heads, head_dim, device="cuda", dtype=torch.float16)
k = torch.randn(batch, seq_len, heads, head_dim, device="cuda", dtype=torch.float16)
v = torch.randn(batch, seq_len, heads, head_dim, device="cuda", dtype=torch.float16)

out = flash_attention_v2(q, k, v, causal=True)
print(out.shape)  # (2, 512, 8, 64)

Converting Between Versions

To switch between V1 and V2, you need to transpose dimensions 1 and 2:

python
# From V1 to V2 format
q_v2 = q_v1.transpose(1, 2)  # (b, h, s, d) -> (b, s, h, d)
k_v2 = k_v1.transpose(1, 2)
v_v2 = v_v1.transpose(1, 2)

# Run V2
out_v2 = flash_attention_v2(q_v2, k_v2, v_v2, causal=True)

# Convert back to V1 format
out_v1 = out_v2.transpose(1, 2)  # (b, s, h, d) -> (b, h, s, d)

Which Should You Use?

ScenarioRecommendation
Ampere+ GPU (A100, RTX 30xx, RTX 40xx)V2 for 5-15% better performance
Volta/Turing GPU (V100, RTX 20xx)V1 (V2 not optimized for older architectures)
Large batch + long sequencesV2
Code compatibility priorityV1 (standard PyTorch attention layout)

BackendSelector

If you want automatic selection, use BackendSelector:

python
from kernels import BackendSelector, flash_attention, flash_attention_v2

# The selector handles layout differences internally
selector = BackendSelector()
kernel = selector.select_attention(batch=2, heads=8, seq_len=1024, head_dim=64)

# Or use flash_attention with variant parameter
from kernels import flash_attention
out = flash_attention(q, k, v, causal=True, variant="auto")

Summary

  • V1: (batch, heads, seq_len, head_dim) - standard layout, universal support
  • V2: (batch, seq_len, heads, head_dim) - optimized for Ampere+, 5-15% faster
  • Always check tensor shapes when switching versions!

Forward-only educational Triton FlashAttention project · MIT License