Language: 简体中文 English

目录 (Table of Contents)


GEMM 基础理论

数学定义

GEMM(General Matrix Multiply)计算公式:

1
C = α × A × B + β × C

简化版本(α=1, β=0):

1
2
3
        K
C(m,n) = Σ A(m,k) × B(k,n)
       k=1

其中:

  • A: M × K 矩阵
  • B: K × N 矩阵
  • C: M × N 矩阵

计算复杂度分析

指标 公式 说明
浮点运算 2 × M × N × K 乘法 + 加法
内存读取 M×K + K×N + M×N floats 读取 A、B,写入 C
计算强度 2×M×N×K / (M×K + K×N + M×N) / 4 FLOP/byte

为什么 GEMM 重要?

神经网络中的核心计算:

神经网络层 GEMM 形式
全连接层 Y = X × W + b
卷积层(im2col) 转换为 GEMM
注意力机制 Q × K^T, Score × V
LSTM/GRU 多个 GEMM 组合

Level 1: Naive 实现

算法描述

每个线程计算输出矩阵的一个元素。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
__global__ void naive_matmul_kernel(const float* A, const float* B, float* C,
                                     int M, int N, int K) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (row < M && col < N) {
        float sum = 0.0f;
        for (int k = 0; k < K; k++) {
            sum += A[row * K + k] * B[k * N + col];
        }
        C[row * N + col] = sum;
    }
}

void launch_naive_matmul(const float* A, const float* B, float* C,
                         int M, int N, int K, cudaStream_t stream) {
    dim3 block(16, 16);
    dim3 grid((N + 15) / 16, (M + 15) / 16);
    naive_matmul_kernel<<<grid, block, 0, stream>>>(A, B, C, M, N, K);
}

内存访问模式

1
2
3
4
5
6
7
线程 (row, col) 的访问模式:
═════════════════════════════════════════════════════════════
A: A[row, 0], A[row, 1], ..., A[row, K-1]   ← K 次读取(连续)
B: B[0, col], B[1, col], ..., B[K-1, col]   ← K 次读取(跨步)
C: C[row, col]                               ← 1 次写入
═════════════════════════════════════════════════════════════
总计: 2K 次读取 + 1 次写入 / 每个输出元素

性能瓶颈

问题 影响 解决方案
全局内存访问过多 带宽成为瓶颈 使用共享内存
B 矩阵列优先访问 非合并访问,效率低 转置读取或调整布局
无数据重用 每个元素从全局内存读取 K 次 分块加载到共享内存

性能: ~5-10% of cuBLAS


Level 2: Tiled GEMM

核心思想

将矩阵分成小块(tiles),加载到共享内存中重用。

1
2
3
4
5
6
全局内存访问减少原理:
═════════════════════════════════════════════════════════════
原始: 每个元素从全局内存读取 K 次
分块: 每个元素从全局内存读取 K/TILE_SIZE 次
═════════════════════════════════════════════════════════════
减少倍数: TILE_SIZE 倍

算法实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
constexpr int TILE_SIZE = 32;

__global__ void tiled_gemm_kernel(const float* A, const float* B, float* C,
                                   int M, int N, int K) {
    __shared__ float As[TILE_SIZE][TILE_SIZE];
    __shared__ float Bs[TILE_SIZE][TILE_SIZE];
    
    int row = blockIdx.y * TILE_SIZE + threadIdx.y;
    int col = blockIdx.x * TILE_SIZE + threadIdx.x;
    
    float sum = 0.0f;
    int num_tiles = (K + TILE_SIZE - 1) / TILE_SIZE;
    
    for (int t = 0; t < num_tiles; t++) {
        // 协作加载 A tile
        int a_col = t * TILE_SIZE + threadIdx.x;
        if (row < M && a_col < K) {
            As[threadIdx.y][threadIdx.x] = A[row * K + a_col];
        } else {
            As[threadIdx.y][threadIdx.x] = 0.0f;
        }
        
        // 协作加载 B tile
        int b_row = t * TILE_SIZE + threadIdx.y;
        if (b_row < K && col < N) {
            Bs[threadIdx.y][threadIdx.x] = B[b_row * N + col];
        } else {
            Bs[threadIdx.y][threadIdx.x] = 0.0f;
        }
        
        __syncthreads();
        
        // 计算部分和
        for (int k = 0; k < TILE_SIZE; k++) {
            sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
        }
        
        __syncthreads();
    }
    
    if (row < M && col < N) {
        C[row * N + col] = sum;
    }
}

矩阵分块可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
         K                    K                    K
      ┌──────┐            ┌──────┐            ┌──────┐
      │      │            │┌────┐│            │      │
    M │  A   │          M ││A_t ││          M │      │
      │      │            │└────┘│            │      │
      └──────┘            └──────┘            └──────┘
           ×                   ×                   ×
      ┌──────┐            ┌──────┐            ┌──────┐
      │      │            │┌────┐│            │      │
    K │  B   │          K ││B_t ││          K │      │
      │      │            │└────┘│            │      │
      └──────┘            └──────┘            └──────┘
           =                   =                   =
      ┌──────┐            ┌──────┐            ┌──────┐
      │      │            │┌────┐│            │      │
    M │  C   │    ←──    M ││C_t ││  累加多个    │      │
      │      │            │└────┘│  A_t × B_t   │      │
      └──────┘            └──────┘            └──────┘
         N                    N                    N

共享内存使用: 2 × 32 × 32 × 4 = 8 KB

性能分析

优化点 效果
共享内存重用 全局内存访问减少 32 倍
协作加载 每个线程加载 1 个元素到 tile
块内同步 确保数据正确性

性能: ~20-30% of cuBLAS


Level 3: Memory Coalescing

核心思想

确保同一 warp 中的线程访问连续的内存地址,最大化内存带宽利用率。

合并访问 vs 非合并访问

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
合并访问(Coalesced)- 高效:
═════════════════════════════════════════════════════════════
Thread  0  1  2  3  ...  31
        │  │  │  │       │
        ▼  ▼  ▼  ▼       ▼
地址  0x1000 0x1004 0x1008 0x100C ... 0x107C
═════════════════════════════════════════════════════════════
→ 1 次 128 字节内存事务

非合并访问(Strided)- 低效:
═════════════════════════════════════════════════════════════
Thread  0     1     2     3     ...  31
        │     │     │     │          │
        ▼     ▼     ▼     ▼          ▼
地址  0x1000 0x2000 0x3000 0x4000 ... 0x20000
═════════════════════════════════════════════════════════════
→ 32 次独立内存事务

Bank Conflict 避免

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
共享内存 Bank 布局(32 个 bank):

地址      Bank 0  Bank 1  Bank 2  ... Bank 31
          ┌────┐  ┌────┐  ┌────┐      ┌────┐
偏移 0    │ 0  │  │ 1  │  │ 2  │      │ 31 │
偏移 32   │ 32 │  │ 33 │  │ 34 │      │ 63 │
偏移 64   │ 64 │  │ 65 │  │ 66 │      │ 95 │
          └────┘  └────┘  └────┘      └────┘

无 Padding(有 Bank Conflict):
As[0][0], As[1][0], As[2][0]... 都在 Bank 0
→ 同一 warp 访问同一 bank = 串行化

有 Padding(无 Bank Conflict):
As[0][0] 在 Bank 0, As[1][0] 在 Bank 1, ...
→ 不同 bank = 并行访问

优化技巧

1
2
3
4
5
6
// 优化共享内存访问,避免 Bank Conflict
__shared__ float As[TILE_SIZE][TILE_SIZE + 1];  // +1 padding
__shared__ float Bs[TILE_SIZE][TILE_SIZE + 1];

// 这样同一列的线程访问不同 bank
float val = As[threadIdx.y][k];  // 无 bank conflict

性能: ~30-40% of cuBLAS


Level 4: Double Buffering

核心思想

使用两组共享内存缓冲区,在计算当前 tile 时预取下一个 tile。

时间线对比

1
2
3
4
5
6
7
8
9
10
11
12
无 Double Buffering:
═══════════════════════��══════════════════════════════════════
├─Load 0─┼─Comp 0─┼─Load 1─┼─Comp 1─┼─Load 2─┼─Comp 2─┤
          计算等待加载           计算等待加载

有 Double Buffering:
═════════════════════════════════════════════════════════════
├─Load 0─┼─────────────────────────────────────────────┤
          ├─Comp 0─┼─Comp 1─┼─Comp 2─┤
                    ├─Load 1─┤
                              ├─Load 2─┤
          加载与计算重叠

算法实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
__global__ void double_buffer_gemm_kernel(const float* A, const float* B, 
                                           float* C, int M, int N, int K) {
    __shared__ float As[2][TILE_SIZE][TILE_SIZE + 1];
    __shared__ float Bs[2][TILE_SIZE][TILE_SIZE + 1];
    
    int row = blockIdx.y * TILE_SIZE + threadIdx.y;
    int col = blockIdx.x * TILE_SIZE + threadIdx.x;
    float sum = 0.0f;
    
    // 预加载第一个 tile
    load_tile(As[0], Bs[0], 0);
    __syncthreads();
    
    for (int t = 0; t < num_tiles; t++) {
        int next = (t + 1) % 2;
        int curr = t % 2;
        
        // 异步预取下一个 tile
        if (t + 1 < num_tiles) {
            load_tile(As[next], Bs[next], t + 1);
        }
        
        // 计算当前 tile
        for (int k = 0; k < TILE_SIZE; k++) {
            sum += As[curr][threadIdx.y][k] * Bs[curr][k][threadIdx.x];
        }
        
        __syncthreads();
    }
    
    if (row < M && col < N) {
        C[row * N + col] = sum;
    }
}

共享内存使用

1
2
3
4
单缓冲: 2 × 32 × 32 × 4 = 8 KB
双缓冲: 2 × 2 × 32 × 32 × 4 = 16 KB

注意: 需要确保共享内存足够

性能: ~40-50% of cuBLAS


Level 5: Register Blocking

核心思想

每个线程计算多个输出元素,增加计算密度,将数据保持在寄存器中。

参数配置

1
2
3
4
5
6
7
8
template<
    int BM,    // Block 处理的 M 维度 (128)
    int BN,    // Block 处理的 N 维度 (128)
    int BK,    // 每次迭代的 K 维度 (8)
    int TM,    // 每线程处理的 M 维度 (8)
    int TN     // 每线程处理的 N 维度 (8)
>
__global__ void optimized_gemm(...);

参数约束

1
2
3
4
5
6
7
8
9
10
11
约束条件:
═════════════════════════════════════════════════════════════
线程数:    (BM / TM) × (BN / TN) ≤ 1024
          = (128/8) × (128/8) = 16 × 16 = 256 ✓

共享内存:  (BM × BK + BK × BN) × sizeof(float)
          = (128 × 8 + 8 × 128) × 4 = 8 KB ✓

寄存器:   TM × TN + TM + TN + overhead ≤ 255
          = 64 + 8 + 8 + 20 = 100 ✓
═════════════════════════════════════════════════════════════

算法实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
template<int BM, int BN, int BK, int TM, int TN>
__global__ void optimized_gemm(const float* A, const float* B, float* C,
                                int M, int N, int K) {
    // 共享内存(转置存储 A 以实现合并访问)
    __shared__ float As[BK][BM + 1];
    __shared__ float Bs[BK][BN + 1];
    
    // 线程索引
    const int tx = threadIdx.x, ty = threadIdx.y;
    const int row_start = blockIdx.y * BM + ty * TM;
    const int col_start = blockIdx.x * BN + tx * TN;
    
    // 寄存器存储
    float regC[TM][TN] = {0.0f};
    float regA[TM];
    float regB[TN];
    
    // 迭代计算
    for (int tile = 0; tile < (K + BK - 1) / BK; tile++) {
        // 加载到共享内存
        // ... (省略加载代码)
        
        __syncthreads();
        
        // 计算当前 tile
        for (int k = 0; k < BK; k++) {
            // 加载到寄存器
            for (int m = 0; m < TM; m++) {
                regA[m] = As[k][ty * TM + m];
            }
            for (int n = 0; n < TN; n++) {
                regB[n] = Bs[k][tx * TN + n];
            }
            
            // 外积累加
            for (int m = 0; m < TM; m++) {
                for (int n = 0; n < TN; n++) {
                    regC[m][n] += regA[m] * regB[n];
                }
            }
        }
        
        __syncthreads();
    }
    
    // 写回结果
    for (int m = 0; m < TM; m++) {
        for (int n = 0; n < TN; n++) {
            int out_row = row_start + m;
            int out_col = col_start + n;
            if (out_row < M && out_col < N) {
                C[out_row * N + out_col] = regC[m][n];
            }
        }
    }
}

数据流图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
全局内存 → 共享内存 → 寄存器 → 计算 → 写回

┌─────────────┐     ┌─────────────┐     ┌─────────────┐
│  全局内存    │────▶│  共享内存    │────▶│   寄存器     │
│  A, B       │     │  As, Bs     │     │  regA, regB  │
└─────────────┘     └─────────────┘     └─────────────┘
       │                   │                   │
       │ 合并访问          │ 高速访问           │ 最高速访问
       │ (慢)              │ (中)               │ (快)
       ▼                   ▼                   ▼
    带宽瓶颈           Bank Conflict        无冲突
                      需优化

计算流程:
═════════════════════════════════════════════════════════════
regC[m][n] += regA[m] × regB[n]
           = 外积 (Outer Product)
           = TM × TN 次乘加
═════════════════════════════════════════════════════════════
每线程计算: 8 × 8 = 64 个输出元素

性能: ~70-80% of cuBLAS


Level 6: Kernel Fusion

核心思想

将多个操作合并到一个 kernel 中,消除中间结果的内存读写。

融合效果对比

1
2
3
4
5
6
7
8
9
10
11
12
13
14
分离执行:
═════════════════════════════════════════════════════════════
GEMM:  C = A × B           → 读 A, B;写 C
Bias:  C = C + bias        → 读 C, bias;写 C
ReLU:  C = max(0, C)       → 读 C;写 C
═════════════════════════════════════════════════════════════
总计: 3 次读 C + 3 次写 C = 6 次 C 的内存访问

融合执行:
═════════════════════════════════════════════════════════════
Fused: C = ReLU(A × B + bias)
═════════════════════════════════════════════════════════════
总计: 0 次中间内存访问
节省: 2 × M × N × sizeof(float) bytes

算法实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
template<int BM, int BN, int BK, int TM, int TN, 
         bool ADD_BIAS, bool APPLY_RELU>
__global__ void fused_gemm_bias_relu(const float* A, const float* B, 
                                      float* C, const float* bias,
                                      int M, int N, int K) {
    // ... GEMM 计算代码 ...
    
    // 写回时应用融合操作
    for (int m = 0; m < TM; m++) {
        for (int n = 0; n < TN; n++) {
            float val = regC[m][n];
            
            // 编译时分支
            if constexpr (ADD_BIAS) {
                val += bias[col_start + n];
            }
            
            if constexpr (APPLY_RELU) {
                val = fmaxf(0.0f, val);
            }
            
            C[out_row * N + out_col] = val;
        }
    }
}

融合选项

1
2
3
4
5
6
7
8
9
10
11
// 完整融合: GEMM + Bias + ReLU
launch_fused_gemm(A, B, C, bias, M, N, K, true, true);

// 仅加 bias
launch_fused_gemm(A, B, C, bias, M, N, K, true, false);

// 仅 ReLU
launch_fused_gemm(A, B, C, nullptr, M, N, K, false, true);

// 无融合 (纯 GEMM)
launch_fused_gemm(A, B, C, nullptr, M, N, K, false, false);

内存带宽节省

1
2
3
4
5
6
7
8
矩阵大小 1024 × 1024:
═════════════════════════════════════════════════════════════
分离执行: 3 × 4 MB = 12 MB 内存带宽
融合执行: 0 MB 额外带宽
节省: 12 MB 内存带宽
═════════════════════════════════════════════════════════════

对内存受限的神经网络,融合可提升 20-40% 性能

性能: ~80-85% of cuBLAS


Level 7: Vectorized Loads

核心思想

使用 128 位向量加载(float4),减少内存事务数量。

向量加载原理

1
2
3
4
5
6
7
8
9
// 标量加载: 4 次 32 位事务
float a = A[idx];
float b = A[idx + 1];
float c = A[idx + 2];
float d = A[idx + 3];

// 向量加载: 1 次 128 位事务
float4 vec = *reinterpret_cast<const float4*>(&A[idx]);
// vec.x, vec.y, vec.z, vec.w

实现代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// 向量加载辅助函数
__device__ __forceinline__ float4 load_float4(const float* ptr) {
    return *reinterpret_cast<const float4*>(ptr);
}

__device__ __forceinline__ void store_float4(float* ptr, float4 val) {
    *reinterpret_cast<float4*>(ptr) = val;
}

// 在 kernel 中使用
// 写回时使用向量存储
for (int n = 0; n < TN; n += 4) {
    int out_col = col_start + n;
    if (out_col + 3 < N) {
        float4 result = make_float4(
            regC[m][n], regC[m][n+1], 
            regC[m][n+2], regC[m][n+3]);
        store_float4(&C[out_row * N + out_col], result);
    }
}

对齐要求

1
2
3
4
5
6
7
地址必须 16 字节对齐:
═════════════════════════════════════════════════════════════
A[idx] 的地址: (A + idx) % 16 == 0

对于 float 数组:
idx % 4 == 0(每 4 个元素对齐)
═════════════════════════════════════════════════════════════

性能: ~85-90% of cuBLAS


性能对比总结

理论性能分析

Level 全局内存访问 共享内存访问 寄存器使用 关键优化
1. Naive 2MNK 0 基准
2. Tiled 2MNK/T 2MNK 共享内存
3. Coalesced 2MNK/T 2MNK 合并访问
4. Double Buffer 2MNK/T 4MNK 延迟隐藏
5. Register Blocked 2MNK/T 2MNK/(TM×TN) 寄存器分块
6. Fused 减少 2MNK/(TM×TN) 算子融合
7. Vectorized 减少 减少 向量化

实测性能(RTX 3080,1024×1024×1024)

Kernel 时间 (ms) GFLOPS vs cuBLAS
cuBLAS 0.31 6920 100%
Naive 3.10 694 10%
Tiled 1.55 1388 20%
Coalesced 1.03 2088 30%
Double Buffer 0.78 2768 40%
Optimized 0.44 4870 70%
Fused 0.38 5630 81%
Vectorized 0.35 6130 89%

矩阵大小选择策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
═════════════════════════════════════════════════════════════
小矩阵 (< 512):
  - 使用小 block size (64 × 64)
  - 避免使用 double buffer
  - 考虑 batched GEMM

中等矩阵 (512 - 2048):
  - 标准配置 (128 × 128)
  - 启用所有优化
  - 使用 AutoTuner 选择最佳配置

大矩阵 (> 2048):
  - 大 block size (128 × 256)
  - 向量化加载
  - 使用异步拷贝 (Ampere+)
═════════════════════════════════════════════════════════════

进一步优化方向

1. Tensor Core(WMMA)

  • 使用 FP16/BF16 精度
  • 利用专用矩阵计算单元
  • 可达 2-4x FP32 性能

2. 异步拷贝(Ampere+)

1
2
3
4
// Ampere 异步拷贝
__pipeline_memcpy_async(&smem[idx], &gmem[idx], sizeof(float4));
__pipeline_commit();
__pipeline_wait_prior(0);

3. Warp 级优化

  • 使用 warp shuffle 指令
  • warp 级矩阵乘法

4. 多 GPU

  • 数据并行
  • 模型并行

优化路径总结

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
GEMM 优化路径:
═════════════════════════════════════════════════════════════
Naive (10%)
    │
    ▼ 共享内存分块
Tiled (20%)
    │
    ▼ 内存合并
Coalesced (30%)
    │
    ▼ 双缓冲隐藏延迟
Double Buffer (40%)
    │
    ▼ 寄存器分块增加计算密度
Register Blocked (70%)
    │
    ▼ 算子融合减少内存访问
Fused (80%)
    │
    ▼ 向量化加载优化带宽
Vectorized (85-90%)
    │
    ▼ Tensor Core / 异步拷贝 / ...
接近 cuBLAS
═════════════════════════════════════════════════════════════

关键要点:
1. 内存带宽是 GEMM 的主要瓶颈
2. 分块是减少全局内存访问的核心技术
3. 寄存器分块最大化计算密度
4. 融合减少中间结果的内存访问
5. 向量化提升内存带宽利用率

相关链接


*最后更新:2025-04-16 文档版本:v1.1.0*

Back to top

MIT License | A learning project for the CUDA community