RFC 0008: Batch GEMM System

Status

Status: Accepted Created: 2024 Last Updated: 2024

Overview

Design a batch GEMM system that efficiently processes multiple matrix multiplications in a single kernel launch, optimizing for inference workloads with repeated matrix operations across different layers or batches.

Motivation

  1. Multi-layer inference: Neural networks have multiple layers, each with GEMM
  2. Batch processing: Serve multiple requests with same model efficiently
  3. Kernel launch overhead: Amortize launch cost across many small matrices
  4. Memory efficiency: Coalesce memory accesses for small matrices

Design

Batch GEMM API

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class BatchGemmRunner {
public:
    // Add a GEMM to the batch
    void add_gemm(const float* A, const float* B, float* C,
                  int M, int N, int K);

    // Execute all queued GEMMs
    void execute(cudaStream_t stream = 0);

    // Clear the batch without executing
    void clear();

    // Statistics
    int gemm_count() const;
    float estimated_time_ms() const;

    // Configuration
    void set_kernel_variant(int variant);
    void set_max_batch_size(int size);

private:
    struct GemmTask {
        const float* A;
        const float* B;
        float* C;
        int M, N, K;
    };

    std::vector<GemmTask> tasks_;
    int max_batch_size_;
    int kernel_variant_;
};

Execution Strategy

Strategy 1: Sequential Execution

  • Launch one kernel per GEMM
  • Pros: Simple, uses optimal kernel per task
  • Cons: High launch overhead for many small GEMMs

Strategy 2: Batched Kernel Launch

  • Single kernel processes all GEMMs
  • Pros: One launch, coalesced memory access
  • Cons: All GEMMs must be same size

Strategy 3: Grouped Execution

  • Group same-size GEMMs together
  • Execute each group with batched kernel
  • Pros: Balance of efficiency and flexibility
  • Cons: More complex scheduling

Decision: Implement Strategy 3 (Grouped Execution) for production.

Kernel Design

1
2
3
4
5
6
// Batched GEMM kernel: processes N GEMMs of same dimensions
__global__ void batched_gemm(
    const float** A_array, const float** B_array, float** C_array,
    int M, int N, int K,
    int batch_count
);

Memory Layout for Batched GEMM

1
2
3
4
5
6
7
8
9
10
11
12
GPU Memory:
┌─────────────────────────────────────────────┐
│ A_pointers[N]  → [A_0, A_1, ..., A_N-1]     │
│ B_pointers[N]  → [B_0, B_1, ..., B_N-1]     │
│ C_pointers[N]  → [C_0, C_1, ..., C_N-1]     │
│                                            │
│ A_0: [M×K matrix]                          │
│ A_1: [M×K matrix]                          │
│ ...                                        │
│ B_0: [K×N matrix]                          │
│ ...                                        │
└─────────────────────────────────────────────┘

Performance Targets

Scenario Sequential Batched Speedup
10× 64×64 GEMMs 0.5ms 0.15ms 3.3x
100× 64×64 GEMMs 5.0ms 0.4ms 12.5x
10× 512×512 GEMMs 1.0ms 0.8ms 1.25x

Error Handling

Condition Behavior
Empty batch No-op (no kernel launch)
Exceed max_batch_size Execute current batch, queue remaining
Mixed dimensions Group by dimensions, execute separately
Invalid pointer Throw std::invalid_argument

Testing Strategy

  1. Correctness: Each GEMM produces correct output independently
  2. Grouping: Same-size GEMMs grouped correctly
  3. Performance: Meets speedup targets vs sequential
  4. Edge cases: Single GEMM batch, very large batch count
  5. Memory safety: No out-of-bounds accesses

Implementation Files

  • include/batch_gemm.h - BatchGemmRunner class
  • src/batch_gemm.cu - Batch GEMM kernel and runner
  • tests/test_batch_gemm.cpp - Unit tests

Future Work

  • Strided batched GEMM (cuBLAS-style)
  • Dynamic grouping based on runtime statistics
  • Integration with StreamManager for concurrent batches

Back to top

MIT License | A learning project for the CUDA community