Skip to the content.

CUDA LLM 内核优化项目 - DeepWiki

项目概述

高性能 CUDA 算子库,专为 LLM 推理优化设计。实现渐进式优化策略:Naive → Tiled → FlashAttention,提供完整的优化技术栈。


1. CUDA 内核实现

1.1 Naive Attention (src/naive_attention.cu)

算法: 标准 Scaled Dot-Product Attention

Attention(Q,K,V) = softmax(QK^T / √d_k)V

优化技术:

复杂度: 时间 O(N²×d),显存 O(N²)


1.2 Tiled Attention (src/tiled_attention.cu)

算法: 分块注意力计算,按块加载 Q/K/V 到共享内存

分块配置:

BLOCK_M=32, BLOCK_N=32, BLOCK_K=64
smem_Q[BLOCK_M][BLOCK_K + 1]  // +1 填充避免 bank 冲突

优化技术:


1.3 FlashAttention (src/flash_attention.cu)

核心创新: 避免存储 N×N 注意力矩阵,O(N) 显存复杂度

在线 Softmax 更新公式:

O_new = (l_old × exp(m_old - m_new) × O_old + exp(S - m_new) @ V) / l_new

状态维护:

float row_max[BLOCK_M];   // 当前行最大值 m_i
float row_sum[BLOCK_M];   // 当前行指数和 l_i

特性:


1.4 Tensor Core GEMM (src/tensor_core_gemm.cu)

硬件加速: 使用 WMMA API 调用 Tensor Core

wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag;

wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);

支持精度: FP16/INT8 输入,FP32 累加


1.5 高性能 GEMM (src/hgemm_kernel.cu)

三级分块策略:

Block 级: BLOCK_M=128, BLOCK_N=128, BLOCK_K=32
Warp 级:  WARP_M=32,   WARP_N=64
Thread级: THREAD_M=8,  THREAD_N=8

外积算法:

float reg_C[THREAD_M][THREAD_N];  // 寄存器累积
for (m : THREAD_M)
    for (n : THREAD_N)
        reg_C[m][n] += reg_A[m] * reg_B[n];

2. 头文件原语

2.1 common.cuh

核心类型:

enum class Precision { FP32, FP16, BF16, INT8 };
struct AttentionConfig { batch_size, num_heads, seq_len, head_dim, scale, is_causal, ... };
struct GemmConfig { M, N, K, alpha, beta, layout_a, layout_b, ... };

工具宏:

#define CUDA_CHECK(call) ...  // 错误检查
inline int div_ceil(int a, int b);
inline bool is_tensor_core_aligned(int dim, int alignment = 16);

2.2 warp_primitives.cuh

Warp 级归约:

template<typename T>
__device__ T warp_reduce_sum(T val) {
    for (int offset = 16; offset > 0; offset /= 2)
        val += __shfl_down_sync(0xffffffff, val, offset);
    return val;
}

块级归约: Warp 归约 → 共享内存 → 最终归约


2.3 shared_memory.cuh

带填充的共享内存瓦片:

template<typename T, int ROWS, int COLS, int PAD = 1>
struct SharedMemoryTile {
    T data[ROWS][COLS + PAD];  // 填充避免 bank 冲突
};

功能: 合并访问、边界检查、向量化加载、转置优化


2.4 online_softmax.cuh

FlashAttention 核心:

struct OnlineSoftmaxState {
    float max_val;  // m_i
    float sum_exp;  // l_i
};

void online_softmax_update(state, new_val) {
    new_max = fmaxf(state.max_val, new_val);
    old_scale = expf(state.max_val - new_max);
    state.sum_exp = state.sum_exp * old_scale + expf(new_val - new_max);
    state.max_val = new_max;
}

2.5 pipeline.cuh

双缓冲:

template<typename T, int BLOCK_SIZE>
struct DoubleBuffer {
    T* buffer[2];
    int current;
    T* get_load_buffer()    { return buffer[current]; }
    T* get_compute_buffer() { return buffer[1 - current]; }
    void swap()             { current = 1 - current; }
};

异步拷贝 (Ampere+):

asm volatile("cp.async.ca.shared.global [%0], [%1], 16;" :: "r"(dst), "l"(src));

3. Python 绑定

3.1 bindings.cpp

暴露接口:

PYBIND11_MODULE(cuda_llm_ops, m) {
    m.def("flash_attention", &flash_attention,
          py::arg("q"), py::arg("k"), py::arg("v"),
          py::arg("scale") = 0.0f, py::arg("is_causal") = false);

    m.def("gemm", &gemm,
          py::arg("a"), py::arg("b"),
          py::arg("alpha") = 1.0f, py::arg("beta") = 0.0f,
          py::arg("trans_a") = false, py::arg("trans_b") = false);
}

输入验证: 4D 张量、CUDA 设备、连续内存、dtype 检查


3.2 profiler.py

FLOPs 计算:

def compute_attention_flops(batch, heads, seq_len, head_dim):
    return batch * heads * (
        4 * seq_len * seq_len * head_dim +  # Q@K^T + P@V
        5 * seq_len * seq_len               # softmax
    )

瓶颈检测: 计算密集型 vs 内存带宽限制


4. 测试策略

4.1 测试框架 (conftest.py)

Fixtures:

自定义断言: assert_close(actual, expected, rtol, atol)


4.2 测试类型

属性测试 (Hypothesis):

@pytest.mark.property
@settings(max_examples=100)
@given(batch=st.integers(1, 4), seq_len=st.integers(16, 256), ...)
def test_attention_correctness(batch, seq_len, ...):
    ...

测试覆盖: | 类别 | 内容 | |——|——| | 正确性 | 与 PyTorch 参考实现对比 | | 数值稳定性 | FP16/FP32 精度验证 | | 边界条件 | 最小维度、大序列 | | 布局等价性 | NN/NT/TN/TT 矩阵布局 | | 错误处理 | 维度不匹配、dtype 错误 |


5. 基准测试

5.1 Attention Benchmark

测试配置:

输出示例:

Seq Len | PyTorch(ms) | Naive(ms) | Tiled(ms) | Flash(ms) | Speedup
   2048 |      45.123 |    62.345 |    30.456 |    25.789 |   1.75x

5.2 GEMM Benchmark

测试矩阵: 1024², 2048², 4096², 8192²

性能目标: 达到 cuBLAS 90%+ 性能

输出示例:

Size           | cuBLAS(ms) | Custom(ms) | TC GEMM(ms) | Custom % | TC %
4096x4096x4096 |    125.678 |    139.234 |     127.890 |    90.2% | 98.2%

6. 支持的 GPU 架构

架构 SM 版本 Tensor Core
Volta SM 7.0 FP16
Turing SM 7.5 FP16, INT8
Ampere SM 8.0, 8.6 FP16, BF16, INT8, TF32
Ada Lovelace SM 8.9 FP16, BF16, INT8, FP8
Hopper SM 9.0 FP16, BF16, INT8, FP8

7. 性能优化总结

优化技术 目标 实现位置
共享内存分块 减少全局内存访问 tiled_attention, hgemm
Bank 冲突避免 +1 填充 shared_memory.cuh
在线 Softmax O(N) 显存 flash_attention
Warp Shuffle 快速归约 warp_primitives.cuh
寄存器平铺 最大化数据重用 hgemm_kernel
Tensor Core 硬件加速 tensor_core_gemm
双缓冲流水线 隐藏内存延迟 pipeline.cuh
异步拷贝 计算/传输重叠 pipeline.cuh (Ampere+)