GEMM 优化详解
本文档详细介绍 GEMM (General Matrix Multiplication) 的 7 步优化路径。
概述
GEMM 计算: C = α * A × B + β * C
其中:
- A: M × K 矩阵
- B: K × N 矩阵
- C: M × N 矩阵
- α, β: 标量系数
Step 1: Naive Global Memory
实现思路
每个线程计算输出矩阵 C 的一个元素。
__global__ void gemm_naive_kernel(const float* A, const float* B, float* C,
int M, int N, int K) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
float sum = 0.0f;
for (int k = 0; k < K; ++k) {
sum += A[row * K + k] * B[k * N + col];
}
C[row * N + col] = sum;
}
}
性能分析
- 问题: 每个元素需要 2K 次全局内存访问
- 带宽利用率: ~5-10%
- TFLOPS: ~0.5 (FP32, RTX 4090)
内存访问模式
Thread (0,0): A[0,0], A[0,1], ..., A[0,K-1] ← 连续访问 ✓
B[0,0], B[1,0], ..., B[K-1,0] ← 跨步访问 ✗
Step 2: Shared Memory Tiling
优化思路
将 A 和 B 的子块加载到 Shared Memory,减少全局内存访问。
constexpr int TILE_SIZE = 32;
__global__ void gemm_shared_kernel(const float* A, const float* B, float* C,
int M, int N, int K) {
__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
float sum = 0.0f;
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) {
// 协作加载 Tile 到 Shared Memory
int a_col = t * TILE_SIZE + threadIdx.x;
int b_row = t * TILE_SIZE + threadIdx.y;
As[threadIdx.y][threadIdx.x] = (row < M && a_col < K) ? A[row * K + a_col] : 0.0f;
Bs[threadIdx.y][threadIdx.x] = (b_row < K && col < N) ? B[b_row * N + col] : 0.0f;
__syncthreads();
// 计算部分点积
for (int k = 0; k < TILE_SIZE; ++k) {
sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
}
__syncthreads();
}
if (row < M && col < N) {
C[row * N + col] = sum;
}
}
性能提升
- 全局内存访问减少: K → K/TILE_SIZE
- 带宽利用率: ~30-40%
- TFLOPS: ~2.0
Tiling 示意图
K K
┌───────┐ ┌───────┐
│ │ │ │
M │ A │ K │ B │
│ │ │ │
└───────┘ └───────┘
↓ ↓
┌───┬───┬───┐ ┌───┬───┬───┐
│T1 │T2 │T3 │ │T1 │T2 │T3 │
├───┼───┼───┤ ├───┼───┼───┤
│T4 │T5 │T6 │ │T4 │T5 │T6 │
└───┴───┴───┘ └───┴───┴───┘
每个 Block 处理一个 TILE_SIZE × TILE_SIZE 的输出块
Step 3: Double Buffering
优化思路
使用双缓冲技术,在计算当前 Tile 的同时预取下一个 Tile。
__global__ void gemm_double_buffer_kernel(const float* A, const float* B, float* C,
int M, int N, int K) {
__shared__ float As[2][TILE_SIZE][TILE_SIZE]; // 双缓冲
__shared__ float Bs[2][TILE_SIZE][TILE_SIZE];
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
float sum = 0.0f;
int write_stage = 0;
int read_stage = 0;
// 预取第一个 Tile
load_tile(As[write_stage], Bs[write_stage], A, B, 0, row, col, M, N, K);
__syncthreads();
for (int t = 0; t < num_tiles; ++t) {
read_stage = write_stage;
write_stage = 1 - write_stage;
// 异步加载下一个 Tile
if (t + 1 < num_tiles) {
load_tile(As[write_stage], Bs[write_stage], A, B, t + 1, row, col, M, N, K);
}
// 计算当前 Tile
for (int k = 0; k < TILE_SIZE; ++k) {
sum += As[read_stage][threadIdx.y][k] * Bs[read_stage][k][threadIdx.x];
}
__syncthreads();
}
// ...
}
性能提升
- 隐藏内存延迟: 计算与加载重叠
- TFLOPS: ~3.5
时间线对比
Without Double Buffering:
|--Load T1--|--Compute T1--|--Load T2--|--Compute T2--|
With Double Buffering:
|--Load T1--|--Compute T1--|--Compute T2--|--Compute T3--|
|--Load T2----|--Load T3----|--Load T4----|
Step 4: Register Tiling
优化思路
每个线程计算多个输出元素,减少 Shared Memory 访问。
constexpr int REG_TILE_M = 8; // 每个线程计算 8×8 个元素
constexpr int REG_TILE_N = 8;
__global__ void gemm_register_tiling_kernel(...) {
// 寄存器累加器
float reg_c[REG_TILE_M][REG_TILE_N] = {0.0f};
for (int k_tile = 0; k_tile < K; k_tile += BLK_K) {
// 加载到 Shared Memory
// ...
// 计算使用寄存器 Tiling
for (int k = 0; k < BLK_K; ++k) {
float reg_a[REG_TILE_M];
float reg_b[REG_TILE_N];
// 从 Shared Memory 加载到寄存器
for (int m = 0; m < REG_TILE_M; ++m)
reg_a[m] = As[k][thread_m * REG_TILE_M + m];
for (int n = 0; n < REG_TILE_N; ++n)
reg_b[n] = Bs[k][thread_n * REG_TILE_N + n];
// 外积计算
for (int m = 0; m < REG_TILE_M; ++m)
for (int n = 0; n < REG_TILE_N; ++n)
reg_c[m][n] += reg_a[m] * reg_b[n];
}
}
// ...
}
性能提升
- Shared Memory 访问减少: 8× (REG_TILE_M)
- 指令级并行: 更多独立计算
- TFLOPS: ~6.0
Step 5: Tensor Core (WMMA API)
优化思路
使用 NVIDIA Tensor Core 进行矩阵乘法,利用专用硬件加速。
#include <mma.h>
using namespace nvcuda;
constexpr int WMMA_M = 16;
constexpr int WMMA_N = 16;
constexpr int WMMA_K = 16;
__global__ void gemm_wmma_kernel(const __half* A, const __half* B, float* C,
int M, int N, int K) {
// 声明 Fragment
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, __half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, __half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
wmma::fill_fragment(c_frag, 0.0f);
for (int k = 0; k < K; k += WMMA_K) {
// 加载 Fragment
wmma::load_matrix_sync(a_frag, A + row * K + k, K);
wmma::load_matrix_sync(b_frag, B + k * N + col, N);
// Tensor Core MMA
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
// 存储结果
wmma::store_matrix_sync(C + row * N + col, c_frag, N, wmma::mem_row_major);
}
性能提升
- Tensor Core 吞吐: 比 CUDA Core 高 8-16×
- TFLOPS: ~50+ (FP16)
Tensor Core 架构
Tensor Core (每个 SM 有多个):
┌─────────────────────────────────────┐
│ 16×16×16 Matrix Multiply-Accumulate │
│ │
│ A (16×16, FP16) × B (16×16, FP16) │
│ ↓ │
│ C (16×16, FP32) │
└─────────────────────────────────────┘
Step 6: Tensor Core (MMA PTX)
优化思路
使用 PTX 指令直接控制 Tensor Core,获得更细粒度的控制。
__device__ __forceinline__ void mma_m16n8k16_fp16(
uint32_t* d, const uint32_t* a, const uint32_t* b, const uint32_t* c) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1}, "
"{%2, %3, %4, %5}, "
"{%6, %7}, "
"{%8, %9};\n"
: "=r"(d[0]), "=r"(d[1])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
"r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1])
);
}
性能提升
- 更精细的寄存器控制
- TFLOPS: ~60+
Step 7: Software Pipelining
优化思路
使用多阶段流水线隐藏指令延迟。
constexpr int PIPE_STAGES = 3;
__global__ void gemm_software_pipeline_kernel(...) {
__shared__ float As[PIPE_STAGES][TILE_K][TILE_M + 1];
__shared__ float Bs[PIPE_STAGES][TILE_K][TILE_N + 1];
// Prologue: 填充流水线
for (int stage = 0; stage < PIPE_STAGES - 1; ++stage) {
load_tile(As[stage], Bs[stage], ...);
}
// Main loop: 流水线执行
for (int k_tile = 0; k_tile < num_tiles; ++k_tile) {
int compute_stage = k_tile % PIPE_STAGES;
int load_stage = (k_tile + PIPE_STAGES - 1) % PIPE_STAGES;
// 异步加载下一个 Tile
if (k_tile + PIPE_STAGES - 1 < num_tiles) {
load_tile(As[load_stage], Bs[load_stage], ...);
}
// 计算当前 Tile
compute_tile(As[compute_stage], Bs[compute_stage], reg_c);
}
}
性能提升
- 隐藏指令延迟: 多阶段重叠
- TFLOPS: ~70+
流水线示意图
Stage 0: |--Load--|--Compute--|--Load--|--Compute--|
Stage 1: |--Load--|--Compute--|--Load--|--Compute--|
Stage 2: |--Load--|--Compute--|--Load--|--Compute--|
性能对比总结
| Step | 优化技术 | TFLOPS (FP32) | 相对提升 |
|---|---|---|---|
| 1 | Naive | 0.5 | 1.0× |
| 2 | Shared Memory | 2.0 | 4.0× |
| 3 | Double Buffer | 3.5 | 7.0× |
| 4 | Register Tiling | 6.0 | 12.0× |
| 5 | WMMA | 50+ | 100× |
| 6 | MMA PTX | 60+ | 120× |
| 7 | Pipeline | 70+ | 140× |