目录 (Table of Contents)
GEMM 基础理论
数学定义
GEMM(General Matrix Multiply)计算公式:
简化版本(α=1, β=0):
1
2
3
K
C(m,n) = Σ A(m,k) × B(k,n)
k=1
其中:
A : M × K 矩阵
B : K × N 矩阵
C : M × N 矩阵
计算复杂度分析
指标
公式
说明
浮点运算
2 × M × N × K
乘法 + 加法
内存读取
M×K + K×N + M×N floats
读取 A、B,写入 C
计算强度
2×M×N×K / (M×K + K×N + M×N) / 4
FLOP/byte
为什么 GEMM 重要?
神经网络中的核心计算:
神经网络层
GEMM 形式
全连接层
Y = X × W + b
卷积层(im2col)
转换为 GEMM
注意力机制
Q × K^T, Score × V
LSTM/GRU
多个 GEMM 组合
Level 1: Naive 实现
算法描述
每个线程计算输出矩阵的一个元素。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
__global__ void naive_matmul_kernel ( const float * A , const float * B , float * C ,
int M , int N , int K ) {
int row = blockIdx . y * blockDim . y + threadIdx . y ;
int col = blockIdx . x * blockDim . x + threadIdx . x ;
if ( row < M && col < N ) {
float sum = 0.0 f ;
for ( int k = 0 ; k < K ; k ++ ) {
sum += A [ row * K + k ] * B [ k * N + col ];
}
C [ row * N + col ] = sum ;
}
}
void launch_naive_matmul ( const float * A , const float * B , float * C ,
int M , int N , int K , cudaStream_t stream ) {
dim3 block ( 16 , 16 );
dim3 grid (( N + 15 ) / 16 , ( M + 15 ) / 16 );
naive_matmul_kernel <<< grid , block , 0 , stream >>> ( A , B , C , M , N , K );
}
内存访问模式
1
2
3
4
5
6
7
线程 (row, col) 的访问模式:
═════════════════════════════════════════════════════════════
A: A[row, 0], A[row, 1], ..., A[row, K-1] ← K 次读取(连续)
B: B[0, col], B[1, col], ..., B[K-1, col] ← K 次读取(跨步)
C: C[row, col] ← 1 次写入
═════════════════════════════════════════════════════════════
总计: 2K 次读取 + 1 次写入 / 每个输出元素
性能瓶颈
问题
影响
解决方案
全局内存访问过多
带宽成为瓶颈
使用共享内存
B 矩阵列优先访问
非合并访问,效率低
转置读取或调整布局
无数据重用
每个元素从全局内存读取 K 次
分块加载到共享内存
性能 : ~5-10% of cuBLAS
Level 2: Tiled GEMM
核心思想
将矩阵分成小块(tiles),加载到共享内存中重用。
1
2
3
4
5
6
全局内存访问减少原理:
═════════════════════════════════════════════════════════════
原始: 每个元素从全局内存读取 K 次
分块: 每个元素从全局内存读取 K/TILE_SIZE 次
═════════════════════════════════════════════════════════════
减少倍数: TILE_SIZE 倍
算法实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
constexpr int TILE_SIZE = 32 ;
__global__ void tiled_gemm_kernel ( const float * A , const float * B , float * C ,
int M , int N , int K ) {
__shared__ float As [ TILE_SIZE ][ TILE_SIZE ];
__shared__ float Bs [ TILE_SIZE ][ TILE_SIZE ];
int row = blockIdx . y * TILE_SIZE + threadIdx . y ;
int col = blockIdx . x * TILE_SIZE + threadIdx . x ;
float sum = 0.0 f ;
int num_tiles = ( K + TILE_SIZE - 1 ) / TILE_SIZE ;
for ( int t = 0 ; t < num_tiles ; t ++ ) {
// 协作加载 A tile
int a_col = t * TILE_SIZE + threadIdx . x ;
if ( row < M && a_col < K ) {
As [ threadIdx . y ][ threadIdx . x ] = A [ row * K + a_col ];
} else {
As [ threadIdx . y ][ threadIdx . x ] = 0.0 f ;
}
// 协作加载 B tile
int b_row = t * TILE_SIZE + threadIdx . y ;
if ( b_row < K && col < N ) {
Bs [ threadIdx . y ][ threadIdx . x ] = B [ b_row * N + col ];
} else {
Bs [ threadIdx . y ][ threadIdx . x ] = 0.0 f ;
}
__syncthreads ();
// 计算部分和
for ( int k = 0 ; k < TILE_SIZE ; k ++ ) {
sum += As [ threadIdx . y ][ k ] * Bs [ k ][ threadIdx . x ];
}
__syncthreads ();
}
if ( row < M && col < N ) {
C [ row * N + col ] = sum ;
}
}
矩阵分块可视化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
K K K
┌──────┐ ┌──────┐ ┌──────┐
│ │ │┌────┐│ │ │
M │ A │ M ││A_t ││ M │ │
│ │ │└────┘│ │ │
└──────┘ └──────┘ └──────┘
× × ×
┌──────┐ ┌──────┐ ┌──────┐
│ │ │┌────┐│ │ │
K │ B │ K ││B_t ││ K │ │
│ │ │└────┘│ │ │
└──────┘ └──────┘ └──────┘
= = =
┌──────┐ ┌──────┐ ┌──────┐
│ │ │┌────┐│ │ │
M │ C │ ←── M ││C_t ││ 累加多个 │ │
│ │ │└────┘│ A_t × B_t │ │
└──────┘ └──────┘ └──────┘
N N N
共享内存使用: 2 × 32 × 32 × 4 = 8 KB
性能分析
优化点
效果
共享内存重用
全局内存访问减少 32 倍
协作加载
每个线程加载 1 个元素到 tile
块内同步
确保数据正确性
性能 : ~20-30% of cuBLAS
Level 3: Memory Coalescing
核心思想
确保同一 warp 中的线程访问连续的内存地址,最大化内存带宽利用率。
合并访问 vs 非合并访问
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
合并访问(Coalesced)- 高效:
═════════════════════════════════════════════════════════════
Thread 0 1 2 3 ... 31
│ │ │ │ │
▼ ▼ ▼ ▼ ▼
地址 0x1000 0x1004 0x1008 0x100C ... 0x107C
═════════════════════════════════════════════════════════════
→ 1 次 128 字节内存事务
非合并访问(Strided)- 低效:
═════════════════════════════════════════════════════════════
Thread 0 1 2 3 ... 31
│ │ │ │ │
▼ ▼ ▼ ▼ ▼
地址 0x1000 0x2000 0x3000 0x4000 ... 0x20000
═════════════════════════════════════════════════════════════
→ 32 次独立内存事务
Bank Conflict 避免
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
共享内存 Bank 布局(32 个 bank):
地址 Bank 0 Bank 1 Bank 2 ... Bank 31
┌────┐ ┌────┐ ┌────┐ ┌────┐
偏移 0 │ 0 │ │ 1 │ │ 2 │ │ 31 │
偏移 32 │ 32 │ │ 33 │ │ 34 │ │ 63 │
偏移 64 │ 64 │ │ 65 │ │ 66 │ │ 95 │
└────┘ └────┘ └────┘ └────┘
无 Padding(有 Bank Conflict):
As[0][0], As[1][0], As[2][0]... 都在 Bank 0
→ 同一 warp 访问同一 bank = 串行化
有 Padding(无 Bank Conflict):
As[0][0] 在 Bank 0, As[1][0] 在 Bank 1, ...
→ 不同 bank = 并行访问
优化技巧
1
2
3
4
5
6
// 优化共享内存访问,避免 Bank Conflict
__shared__ float As [ TILE_SIZE ][ TILE_SIZE + 1 ]; // +1 padding
__shared__ float Bs [ TILE_SIZE ][ TILE_SIZE + 1 ];
// 这样同一列的线程访问不同 bank
float val = As [ threadIdx . y ][ k ]; // 无 bank conflict
性能 : ~30-40% of cuBLAS
Level 4: Double Buffering
核心思想
使用两组共享内存缓冲区,在计算当前 tile 时预取下一个 tile。
时间线对比
1
2
3
4
5
6
7
8
9
10
11
12
无 Double Buffering:
═══════════════════════��══════════════════════════════════════
├─Load 0─┼─Comp 0─┼─Load 1─┼─Comp 1─┼─Load 2─┼─Comp 2─┤
计算等待加载 计算等待加载
有 Double Buffering:
═════════════════════════════════════════════════════════════
├─Load 0─┼─────────────────────────────────────────────┤
├─Comp 0─┼─Comp 1─┼─Comp 2─┤
├─Load 1─┤
├─Load 2─┤
加载与计算重叠
算法实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
__global__ void double_buffer_gemm_kernel ( const float * A , const float * B ,
float * C , int M , int N , int K ) {
__shared__ float As [ 2 ][ TILE_SIZE ][ TILE_SIZE + 1 ];
__shared__ float Bs [ 2 ][ TILE_SIZE ][ TILE_SIZE + 1 ];
int row = blockIdx . y * TILE_SIZE + threadIdx . y ;
int col = blockIdx . x * TILE_SIZE + threadIdx . x ;
float sum = 0.0 f ;
// 预加载第一个 tile
load_tile ( As [ 0 ], Bs [ 0 ], 0 );
__syncthreads ();
for ( int t = 0 ; t < num_tiles ; t ++ ) {
int next = ( t + 1 ) % 2 ;
int curr = t % 2 ;
// 异步预取下一个 tile
if ( t + 1 < num_tiles ) {
load_tile ( As [ next ], Bs [ next ], t + 1 );
}
// 计算当前 tile
for ( int k = 0 ; k < TILE_SIZE ; k ++ ) {
sum += As [ curr ][ threadIdx . y ][ k ] * Bs [ curr ][ k ][ threadIdx . x ];
}
__syncthreads ();
}
if ( row < M && col < N ) {
C [ row * N + col ] = sum ;
}
}
共享内存使用
1
2
3
4
单缓冲: 2 × 32 × 32 × 4 = 8 KB
双缓冲: 2 × 2 × 32 × 32 × 4 = 16 KB
注意: 需要确保共享内存足够
性能 : ~40-50% of cuBLAS
Level 5: Register Blocking
核心思想
每个线程计算多个输出元素,增加计算密度,将数据保持在寄存器中。
参数配置
1
2
3
4
5
6
7
8
template <
int BM , // Block 处理的 M 维度 (128)
int BN , // Block 处理的 N 维度 (128)
int BK , // 每次迭代的 K 维度 (8)
int TM , // 每线程处理的 M 维度 (8)
int TN // 每线程处理的 N 维度 (8)
>
__global__ void optimized_gemm (...);
参数约束
1
2
3
4
5
6
7
8
9
10
11
约束条件:
═════════════════════════════════════════════════════════════
线程数: (BM / TM) × (BN / TN) ≤ 1024
= (128/8) × (128/8) = 16 × 16 = 256 ✓
共享内存: (BM × BK + BK × BN) × sizeof(float)
= (128 × 8 + 8 × 128) × 4 = 8 KB ✓
寄存器: TM × TN + TM + TN + overhead ≤ 255
= 64 + 8 + 8 + 20 = 100 ✓
═════════════════════════════════════════════════════════════
算法实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
template < int BM , int BN , int BK , int TM , int TN >
__global__ void optimized_gemm ( const float * A , const float * B , float * C ,
int M , int N , int K ) {
// 共享内存(转置存储 A 以实现合并访问)
__shared__ float As [ BK ][ BM + 1 ];
__shared__ float Bs [ BK ][ BN + 1 ];
// 线程索引
const int tx = threadIdx . x , ty = threadIdx . y ;
const int row_start = blockIdx . y * BM + ty * TM ;
const int col_start = blockIdx . x * BN + tx * TN ;
// 寄存器存储
float regC [ TM ][ TN ] = { 0.0 f };
float regA [ TM ];
float regB [ TN ];
// 迭代计算
for ( int tile = 0 ; tile < ( K + BK - 1 ) / BK ; tile ++ ) {
// 加载到共享内存
// ... (省略加载代码)
__syncthreads ();
// 计算当前 tile
for ( int k = 0 ; k < BK ; k ++ ) {
// 加载到寄存器
for ( int m = 0 ; m < TM ; m ++ ) {
regA [ m ] = As [ k ][ ty * TM + m ];
}
for ( int n = 0 ; n < TN ; n ++ ) {
regB [ n ] = Bs [ k ][ tx * TN + n ];
}
// 外积累加
for ( int m = 0 ; m < TM ; m ++ ) {
for ( int n = 0 ; n < TN ; n ++ ) {
regC [ m ][ n ] += regA [ m ] * regB [ n ];
}
}
}
__syncthreads ();
}
// 写回结果
for ( int m = 0 ; m < TM ; m ++ ) {
for ( int n = 0 ; n < TN ; n ++ ) {
int out_row = row_start + m ;
int out_col = col_start + n ;
if ( out_row < M && out_col < N ) {
C [ out_row * N + out_col ] = regC [ m ][ n ];
}
}
}
}
数据流图
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
全局内存 → 共享内存 → 寄存器 → 计算 → 写回
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ 全局内存 │────▶│ 共享内存 │────▶│ 寄存器 │
│ A, B │ │ As, Bs │ │ regA, regB │
└─────────────┘ └─────────────┘ └─────────────┘
│ │ │
│ 合并访问 │ 高速访问 │ 最高速访问
│ (慢) │ (中) │ (快)
▼ ▼ ▼
带宽瓶颈 Bank Conflict 无冲突
需优化
计算流程:
═════════════════════════════════════════════════════════════
regC[m][n] += regA[m] × regB[n]
= 外积 (Outer Product)
= TM × TN 次乘加
═════════════════════════════════════════════════════════════
每线程计算: 8 × 8 = 64 个输出元素
性能 : ~70-80% of cuBLAS
Level 6: Kernel Fusion
核心思想
将多个操作合并到一个 kernel 中,消除中间结果的内存读写。
融合效果对比
1
2
3
4
5
6
7
8
9
10
11
12
13
14
分离执行:
═════════════════════════════════════════════════════════════
GEMM: C = A × B → 读 A, B;写 C
Bias: C = C + bias → 读 C, bias;写 C
ReLU: C = max(0, C) → 读 C;写 C
═════════════════════════════════════════════════════════════
总计: 3 次读 C + 3 次写 C = 6 次 C 的内存访问
融合执行:
═════════════════════════════════════════════════════════════
Fused: C = ReLU(A × B + bias)
═════════════════════════════════════════════════════════════
总计: 0 次中间内存访问
节省: 2 × M × N × sizeof(float) bytes
算法实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
template < int BM , int BN , int BK , int TM , int TN ,
bool ADD_BIAS , bool APPLY_RELU >
__global__ void fused_gemm_bias_relu ( const float * A , const float * B ,
float * C , const float * bias ,
int M , int N , int K ) {
// ... GEMM 计算代码 ...
// 写回时应用融合操作
for ( int m = 0 ; m < TM ; m ++ ) {
for ( int n = 0 ; n < TN ; n ++ ) {
float val = regC [ m ][ n ];
// 编译时分支
if constexpr ( ADD_BIAS ) {
val += bias [ col_start + n ];
}
if constexpr ( APPLY_RELU ) {
val = fmaxf ( 0.0 f , val );
}
C [ out_row * N + out_col ] = val ;
}
}
}
融合选项
1
2
3
4
5
6
7
8
9
10
11
// 完整融合: GEMM + Bias + ReLU
launch_fused_gemm ( A , B , C , bias , M , N , K , true , true );
// 仅加 bias
launch_fused_gemm ( A , B , C , bias , M , N , K , true , false );
// 仅 ReLU
launch_fused_gemm ( A , B , C , nullptr , M , N , K , false , true );
// 无融合 (纯 GEMM)
launch_fused_gemm ( A , B , C , nullptr , M , N , K , false , false );
内存带宽节省
1
2
3
4
5
6
7
8
矩阵大小 1024 × 1024:
═════════════════════════════════════════════════════════════
分离执行: 3 × 4 MB = 12 MB 内存带宽
融合执行: 0 MB 额外带宽
节省: 12 MB 内存带宽
═════════════════════════════════════════════════════════════
对内存受限的神经网络,融合可提升 20-40% 性能
性能 : ~80-85% of cuBLAS
Level 7: Vectorized Loads
核心思想
使用 128 位向量加载(float4),减少内存事务数量。
向量加载原理
1
2
3
4
5
6
7
8
9
// 标量加载: 4 次 32 位事务
float a = A [ idx ];
float b = A [ idx + 1 ];
float c = A [ idx + 2 ];
float d = A [ idx + 3 ];
// 向量加载: 1 次 128 位事务
float4 vec = * reinterpret_cast < const float4 *> ( & A [ idx ]);
// vec.x, vec.y, vec.z, vec.w
实现代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// 向量加载辅助函数
__device__ __forceinline__ float4 load_float4 ( const float * ptr ) {
return * reinterpret_cast < const float4 *> ( ptr );
}
__device__ __forceinline__ void store_float4 ( float * ptr , float4 val ) {
* reinterpret_cast < float4 *> ( ptr ) = val ;
}
// 在 kernel 中使用
// 写回时使用向量存储
for ( int n = 0 ; n < TN ; n += 4 ) {
int out_col = col_start + n ;
if ( out_col + 3 < N ) {
float4 result = make_float4 (
regC [ m ][ n ], regC [ m ][ n + 1 ],
regC [ m ][ n + 2 ], regC [ m ][ n + 3 ]);
store_float4 ( & C [ out_row * N + out_col ], result );
}
}
对齐要求
1
2
3
4
5
6
7
地址必须 16 字节对齐:
═════════════════════════════════════════════════════════════
A[idx] 的地址: (A + idx) % 16 == 0
对于 float 数组:
idx % 4 == 0(每 4 个元素对齐)
═════════════════════════════════════════════════════════════
性能 : ~85-90% of cuBLAS
性能对比总结
理论性能分析
Level
全局内存访问
共享内存访问
寄存器使用
关键优化
1. Naive
2MNK
0
低
基准
2. Tiled
2MNK/T
2MNK
低
共享内存
3. Coalesced
2MNK/T
2MNK
低
合并访问
4. Double Buffer
2MNK/T
4MNK
低
延迟隐藏
5. Register Blocked
2MNK/T
2MNK/(TM×TN)
高
寄存器分块
6. Fused
减少
2MNK/(TM×TN)
高
算子融合
7. Vectorized
减少
减少
高
向量化
实测性能(RTX 3080,1024×1024×1024)
Kernel
时间 (ms)
GFLOPS
vs cuBLAS
cuBLAS
0.31
6920
100%
Naive
3.10
694
10%
Tiled
1.55
1388
20%
Coalesced
1.03
2088
30%
Double Buffer
0.78
2768
40%
Optimized
0.44
4870
70%
Fused
0.38
5630
81%
Vectorized
0.35
6130
89%
矩阵大小选择策略
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
═════════════════════════════════════════════════════════════
小矩阵 (< 512):
- 使用小 block size (64 × 64)
- 避免使用 double buffer
- 考虑 batched GEMM
中等矩阵 (512 - 2048):
- 标准配置 (128 × 128)
- 启用所有优化
- 使用 AutoTuner 选择最佳配置
大矩阵 (> 2048):
- 大 block size (128 × 256)
- 向量化加载
- 使用异步拷贝 (Ampere+)
═════════════════════════════════════════════════════════════
进一步优化方向
1. Tensor Core(WMMA)
使用 FP16/BF16 精度
利用专用矩阵计算单元
可达 2-4x FP32 性能
2. 异步拷贝(Ampere+)
1
2
3
4
// Ampere 异步拷贝
__pipeline_memcpy_async ( & smem [ idx ], & gmem [ idx ], sizeof ( float4 ));
__pipeline_commit ();
__pipeline_wait_prior ( 0 );
3. Warp 级优化
使用 warp shuffle 指令
warp 级矩阵乘法
4. 多 GPU
优化路径总结
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
GEMM 优化路径:
═════════════════════════════════════════════════════════════
Naive (10%)
│
▼ 共享内存分块
Tiled (20%)
│
▼ 内存合并
Coalesced (30%)
│
▼ 双缓冲隐藏延迟
Double Buffer (40%)
│
▼ 寄存器分块增加计算密度
Register Blocked (70%)
│
▼ 算子融合减少内存访问
Fused (80%)
│
▼ 向量化加载优化带宽
Vectorized (85-90%)
│
▼ Tensor Core / 异步拷贝 / ...
接近 cuBLAS
═════════════════════════════════════════════════════════════
关键要点:
1. 内存带宽是 GEMM 的主要瓶颈
2. 分块是减少全局内存访问的核心技术
3. 寄存器分块最大化计算密度
4. 融合减少中间结果的内存访问
5. 向量化提升内存带宽利用率
相关链接
*最后更新:2025-04-16
文档版本:v1.1.0*