架构设计
LLM-Speed 的深度技术文档。
目录
项目概述
LLM-Speed 是专为 LLM 推理优化的高性能 CUDA 算子库。采用渐进式优化策略:
Naive → Tiled → FlashAttention → Tensor Core
核心目标
| 目标 | 指标 |
|---|---|
| GEMM 性能 | ≥cuBLAS 的 90% |
| FlashAttention 显存 | O(N) 复杂度 |
| 流水线改进 | ≥20% 性能提升 |
| 精度支持 | FP32/FP16/INT8 |
优化哲学
我们遵循先正确再优化的原则:
- 正确性: 与 PyTorch 参考实现对比验证基准实现
- 优化: 可测量的渐进式改进
- 硬件利用: 利用 Tensor Core 和内存层次结构
- 生产就绪: 全面的错误处理和输入验证
系统架构
三层架构
┌─────────────────────────────────────────────────────────────────┐
│ Python 接口层 │
│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │
│ │ flash_attention │ │ gemm_kernel │ │ profiler │ │
│ └────────┬────────┘ └────────┬────────┘ └────────┬────────┘ │
└───────────┼─────────────────────┼─────────────────────┼──────────┘
│ │ │
┌───────────┼─────────────────────┼─────────────────────┼──────────┐
│ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ CUDA Kernel 层 │ │
│ │ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ │ │
│ │ │ Attention │ │ GEMM │ │ Warp │ │ │
│ │ │ Kernels │ │ Kernels │ │ Primitives │ │ │
│ │ └───────┬───────┘ └───────┬───────┘ └───────┬───────┘ │ │
│ └──────────┼──────────────────┼──────────────────┼───────────┘ │
│ │ │ │ │
│ ┌──────────┼──────────────────┼──────────────────┼───────────┐ │
│ │ ▼ ▼ ▼ │ │
│ │ ┌─────────────────────────────────────────────────────┐ │ │
│ │ │ 优化组件 │ │ │
│ │ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ │
│ │ │ │ Tiling │ │ Tensor Core │ │ Pipeline │ │ │ │
│ │ │ │ Manager │ │ Accelerator │ │ Scheduler │ │ │ │
│ │ │ └─────────────┘ └─────────────┘ └─────────────┘ │ │ │
│ │ └─────────────────────────────────────────────────────┘ │ │
│ └────────────────────────────────────────────────────────────┘ │
│ CUDA Runtime │
└─────────────────────────────────────────────────────────────────┘
优化路线图
┌─────────────────┐
│ Naive Kernel │
│ O(N²) 显存 │
└────────┬────────┘
│ 共享内存分块
┌────────▼────────┐
│ Tiled Kernel │
│ 减少全局访存 │
└────────┬────────┘
│ 在线 Softmax
┌────────▼────────┐
│ FlashAttention │
│ O(N) 显存 │
└────────┬────────┘
│ 双缓冲流水线
┌────────▼────────┐
│ Optimized Flash │
│ 计算/访存重叠 │
└─────────────────┘
Attention 内核
1. Naive Attention
用于正确性验证和性能对比的基准实现。
算法:
Attention(Q,K,V) = softmax(QK^T / √d_k)V
计算流程:
S = Q @ K^T * scale→[seq_len, seq_len]P = softmax(S, dim=-1)→[seq_len, seq_len]O = P @ V→[seq_len, head_dim]
关键实现细节:
// 每个 block 处理一个 (batch, head, row)
__global__ void naive_attention_simple_kernel(
const T* Q, const T* K, const T* V, T* O,
int batch_size, int num_heads, int seq_len, int head_dim, float scale
) {
// 共享内存存储 Attention 分数
extern __shared__ float shared_mem[];
float* scores = shared_mem;
// Warp 归约计算 Softmax
float reduced_max = block_reduce_max<float, 256>(local_max, reduce_smem);
float reduced_sum = block_reduce_sum<float, 256>(local_sum, reduce_smem);
}
复杂度分析:
- 时间: O(N² × d)
- 显存: O(N²)
使用场景:
- 与参考实现对比验证正确性
- 短序列(N <= 64)
- 理解基准行为
2. Tiled Attention
共享内存分块减少全局内存访问。
分块配置:
BLOCK_M = 32 // Q 行分块大小
BLOCK_N = 32 // K/V 行分块大小
共享内存布局:
┌────────────────────────────────────────────┐
│ smem_Q [BLOCK_M × (head_dim+1)] │ ← 带填充的 Q 分块
├────────────────────────────────────────────┤
│ smem_K [BLOCK_N × (head_dim+1)] │ ← 带填充的 K 分块
├────────────────────────────────────────────┤
│ smem_V [BLOCK_N × (head_dim+1)] │ ← 带填充的 V 分块
├────────────────────────────────────────────┤
│ smem_S [BLOCK_M × (BLOCK_N+1)] │ ← Attention 分数
├────────────────────────────────────────────┤
│ output [BLOCK_M × head_dim] │ ← 输出累加器
└────────────────────────────────────────────┘
注: +1 填充消除 Bank 冲突
性能提升:
- 全局内存流量减少约 75%
- 更好的缓存利用率
- 适用于序列长度 128-2048
3. FlashAttention
核心创新: 避免存储 N×N Attention 矩阵,实现 O(N) 显存复杂度。
在线 Softmax 公式:
对于每个分块 t:
S_t = Q_tile @ K_tile^T * scale
m_t = max(m_{t-1}, row_max(S_t))
// 重缩放
scale_factor = exp(m_{t-1} - m_t)
l_t = l_{t-1} * scale_factor + sum(exp(S_t - m_t))
// 输出更新
O_t = O_{t-1} * scale_factor + exp(S_t - m_t) @ V_tile
最终: O = O_T / l_T
状态维护:
float row_max[BLOCK_M]; // 当前行最大值 m_i
float row_sum[BLOCK_M]; // 当前行指数和 l_i
float rescale[BLOCK_M]; // 每行重缩放因子
双缓冲实现:
// 共享内存布局(K/V 双缓冲)
smem_Q [BLOCK_M × (head_dim+1)] — Q 分块
smem_K[2] [2 × BLOCK_N × (head_dim+1)] — K 双缓冲
smem_V[2] [2 × BLOCK_N × (head_dim+1)] — V 双缓冲
smem_S [BLOCK_M × (BLOCK_N+1)] — Attention 分数
output [BLOCK_M × head_dim] — 输出累加器
// 流水线流程
// Prologue: 加载第一个 K/V 分块到缓冲 0
// 主循环: 计算当前缓冲,预取下一个到交替缓冲
// Causal 早退: 当下一个分块超出因果窗口时跳过预取
两阶段计算:
// 阶段 1: 每行单线程计算 Softmax 状态(轻量)
if (tid < BLOCK_M) {
// 计算 rowmax(scores)
// 更新 max/sum 状态
// 计算重缩放因子
}
// 阶段 2: 全线程协作更新输出(重量)
for (int i = tid; i < BLOCK_M * head_dim; i += blockDim.x) {
// 重缩放旧输出
// 计算新贡献: exp_scores @ V
// 更新输出
}
因果掩码:
if (is_causal && global_col > global_row) {
score = -FLT_MAX; // 掩码未来位置
}
// 早退优化: 当 col_start >= row_start + BLOCK_M 时退出
if (is_causal && col_start >= row_start + BLOCK_M) break;
性能:
- 显存: O(N) vs O(N²) naive
- 吞吐量: 长序列快 2-4 倍
- 可扩展到 100K+ 序列长度
GEMM 内核
1. Tensor Core GEMM
使用 WMMA API 利用 Tensor Core 硬件加速。
WMMA 片段:
#include <mma.h>
using namespace nvcuda;
// 16×16×16 矩阵分块
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag;
// 加载 → 计算 → 存储
wmma::load_matrix_sync(a_frag, A + offset, K);
wmma::load_matrix_sync(b_frag, B + offset, N);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
wmma::store_matrix_sync(C + offset, c_frag, N, wmma::mem_row_major);
分块版本:
// 带填充的共享内存分块
__shared__ half smem_A[BLOCK_M][BLOCK_K + 8]; // +8 half 填充
__shared__ half smem_B[BLOCK_K][BLOCK_N + 8];
// 多 warp 协作
constexpr int WARPS_M = BLOCK_M / WMMA_M; // 4 warps
constexpr int WARPS_N = BLOCK_N / WMMA_N; // 4 warps
INT8 支持(Turing+ SM≥7.2):
// INT8 WMMA 维度: 8×32×16
wmma::fragment<wmma::matrix_a, 8, 32, 16, int8_t, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 8, 32, 16, int8_t, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 8, 32, 16, int32_t> c_frag;
2. 高性能 GEMM(寄存器分块)
三级分块策略:
Block 级: BLOCK_M=128, BLOCK_N=128, BLOCK_K=32
Warp 级: WARP_M=32, WARP_N=64
Thread级: THREAD_M=8, THREAD_N=8
寄存器分块:
// 每个线程持有 THREAD_M × THREAD_N 输出分块
float reg_C[THREAD_M][THREAD_N] = {0};
float reg_A[THREAD_M];
float reg_B[THREAD_N];
// 外积算法
for (int k = 0; k < BLOCK_K; k++) {
// 加载 A/B 元素到寄存器
for (int m = 0; m < THREAD_M; m++)
reg_A[m] = smem_A[warp_row + thread_row + m][k];
for (int n = 0; n < THREAD_N; n++)
reg_B[n] = smem_B[k][warp_col + thread_col + n];
// 寄存器内矩阵乘
for (int m = 0; m < THREAD_M; m++)
for (int n = 0; n < THREAD_N; n++)
reg_C[m][n] += reg_A[m] * reg_B[n];
}
双缓冲:
__shared__ float smem_A[2][BLOCK_M][BLOCK_K + 1];
__shared__ float smem_B[2][BLOCK_K][BLOCK_N + 1];
// 主循环: 计算当前缓冲,预取下一个到交替缓冲
for (int tile = 0; tile < num_k_tiles; tile++) {
int cur_buf = tile % 2;
int next_buf = 1 - cur_buf;
// 预取下一块
if (tile + 1 < num_k_tiles) {
LOAD_TILE_A(next_buf, next_k_tile);
LOAD_TILE_B(next_buf, next_k_tile);
}
// 计算当前块
COMPUTE_TILE(cur_buf);
__syncthreads();
}
性能目标: ≥cuBLAS 的 90%,用于矩阵 ≥1024×1024
头文件原语库
common.cuh
核心类型:
enum class Precision { FP32, FP16, BF16, INT8 };
enum class MatrixLayout { RowMajor, ColMajor, RowMajorPadded };
struct AttentionConfig {
int batch_size, num_heads, seq_len, head_dim;
float scale;
bool is_causal;
int block_m, block_n;
Precision precision;
};
struct GemmConfig {
int M, N, K;
float alpha, beta;
MatrixLayout layout_a, layout_b;
int block_m, block_n, block_k;
int warp_m, warp_n;
int thread_m, thread_n;
Precision precision;
};
工具宏:
#define CUDA_CHECK(call) do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
throw std::runtime_error(std::string("CUDA 错误: ") + \
cudaGetErrorString(err) + " 在 " + __FILE__ + ":" + std::to_string(__LINE__)); \
} \
} while(0)
inline int div_ceil(int a, int b) { return (a + b - 1) / b; }
inline bool is_tensor_core_aligned(int dim, int alignment = 16) { return (dim % alignment) == 0; }
warp_primitives.cuh
Warp 级归约:
template<typename T>
__device__ T warp_reduce_sum(T val) {
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2)
val += __shfl_down_sync(0xffffffff, val, offset);
return val;
}
template<typename T>
__device__ T warp_reduce_max(T val) {
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2)
val = max(val, __shfl_down_sync(0xffffffff, val, offset));
return val;
}
Block 级归约:
template<typename T, int BLOCK_SIZE>
__device__ T block_reduce_sum(T val, T* smem) {
int lane = threadIdx.x % 32;
int warp_id = threadIdx.x / 32;
// Warp 内归约
val = warp_reduce_sum(val);
// 写入共享内存
if (lane == 0) smem[warp_id] = val;
__syncthreads();
// 第一个 warp 完成最终归约
constexpr int num_warps = BLOCK_SIZE / 32;
if (warp_id == 0) {
val = (lane < num_warps) ? smem[lane] : T(0);
val = warp_reduce_sum(val);
}
return val;
}
online_softmax.cuh
在线 Softmax 状态:
struct OnlineSoftmaxState {
float max_val; // 当前最大值 m_i
float sum_exp; // 当前指数和 l_i
};
状态更新:
__device__ void online_softmax_update(
OnlineSoftmaxState& state, float new_val
) {
float new_max = fmaxf(state.max_val, new_val);
float old_scale = expf(state.max_val - new_max);
float new_scale = expf(new_val - new_max);
state.sum_exp = state.sum_exp * old_scale + new_scale;
state.max_val = new_max;
}
Python 绑定
接口设计
// cuda_llm_ops/bindings.cpp
PYBIND11_MODULE(cuda_llm_ops, m) {
m.doc() = "LLM-Speed";
// Attention 函数
m.def("naive_attention", &naive_attention,
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("scale") = 0.0f,
"Naive attention (baseline)");
m.def("tiled_attention", &tiled_attention,
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("scale") = 0.0f,
"Tiled attention with shared memory");
m.def("flash_attention", &flash_attention,
py::arg("q"), py::arg("k"), py::arg("v"),
py::arg("scale") = 0.0f, py::arg("is_causal") = false,
"FlashAttention with online softmax");
// GEMM 函数
m.def("gemm", &gemm,
py::arg("a"), py::arg("b"),
py::arg("alpha") = 1.0f, py::arg("beta") = 0.0f,
py::arg("trans_a") = false, py::arg("trans_b") = false,
"High-performance GEMM");
m.def("tensor_core_gemm", &tensor_core_gemm,
py::arg("a"), py::arg("b"),
py::arg("alpha") = 1.0f, py::arg("beta") = 0.0f,
"Tensor Core GEMM (FP16 in, FP32 out)");
m.def("tensor_core_gemm_int8", &tensor_core_gemm_int8_wrapper,
py::arg("a"), py::arg("b"),
"INT8 Tensor Core GEMM (SM>=7.2 required)");
}
输入验证
void validate_attention_inputs(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v) {
TORCH_CHECK(q.dim() == 4, "Q 必须是 4D 张量 [batch, heads, seq_len, head_dim]");
TORCH_CHECK(q.sizes() == k.sizes(), "Q 和 K 必须具有相同形状");
TORCH_CHECK(q.is_cuda(), "Q 必须在 CUDA 设备上");
TORCH_CHECK(q.is_contiguous(), "Q 必须是连续的");
TORCH_CHECK(q.scalar_type() == torch::kFloat32 || q.scalar_type() == torch::kFloat16,
"仅支持 float32 和 float16");
TORCH_CHECK(q.size(0) > 0 && q.size(1) > 0 && q.size(2) > 0 && q.size(3) > 0,
"张量维度必须为正");
}
性能优化技术
技术总结
| 技术 | 目标 | 实现 |
|---|---|---|
| 共享内存分块 | 减少全局内存访存 | tiled_attention, hgemm |
| Bank 冲突避免 | +1 填充 | shared_memory.cuh |
| 在线 Softmax | O(N) 显存 | flash_attention |
| Warp Shuffle | 快速归约 | warp_primitives.cuh |
| 寄存器分块 | 数据重用 | hgemm_kernel |
| Tensor Core | 硬件加速 | tensor_core_gemm |
| 双缓冲 | 隐藏延迟 | pipeline.cuh |
| 异步拷贝 | 计算/传输重叠 | pipeline.cuh (Ampere+) |
瓶颈检测
compute_intensity = flops / memory_bytes # FLOPs/byte
if compute_intensity > 100:
bottleneck = "COMPUTE_BOUND"
else:
bottleneck = "MEMORY_BOUND"
优化检查清单
- 对齐维度(Tensor Core 的 16 倍数)
- Bank 冲突自由的共享内存布局
- Warp shuffle 用于归约操作
- 双缓冲用于流水线优化
- 循环展开(编译器提示)
- Ampere+ 异步拷贝(可选)
测试策略
基于属性的测试
使用 Hypothesis 进行全面的正确性验证:
@pytest.mark.cuda
@pytest.mark.property
@settings(max_examples=100, deadline=None)
@given(
batch=st.integers(1, 4),
heads=st.integers(1, 8),
seq_len=st.integers(16, 256),
head_dim=st.sampled_from([32, 64, 128])
)
def test_flash_attention_correctness(batch, heads, seq_len, head_dim, device):
q = torch.randn(batch, heads, seq_len, head_dim, device=device)
k = torch.randn_like(q)
v = torch.randn_like(q)
output = flash_attention(q, k, v)
reference = torch.nn.functional.scaled_dot_product_attention(q, k, v)
assert_close(output, reference, rtol=1e-3, atol=1e-3)
测试覆盖
| 类别 | 内容 |
|---|---|
| 正确性 | 与 PyTorch 参考实现对比 |
| 数值稳定性 | FP16/FP32 精度验证 |
| 边界条件 | 最小维度、大序列、未对齐 |
| 布局等价 | NN/NT/TN/TT 矩阵布局 |
| 错误处理 | 维度不匹配、数据类型错误、空张量 |
参考资料
- FlashAttention: Dao et al., “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”, NeurIPS 2022
- FlashAttention-2: Dao, “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning”, 2023
- CUTLASS: NVIDIA CUTLASS - CUDA Templates for Linear Algebra Subroutines
- cuBLAS: NVIDIA cuBLAS Library Documentation
- CUDA Programming Guide: NVIDIA CUDA C++ Programming Guide