Design a batch GEMM system that efficiently processes multiple matrix multiplications in a single kernel launch, optimizing for inference workloads with repeated matrix operations across different layers or batches.
Motivation
Multi-layer inference: Neural networks have multiple layers, each with GEMM
Batch processing: Serve multiple requests with same model efficiently
Kernel launch overhead: Amortize launch cost across many small matrices
Memory efficiency: Coalesce memory accesses for small matrices
classBatchGemmRunner{public:// Add a GEMM to the batchvoidadd_gemm(constfloat*A,constfloat*B,float*C,intM,intN,intK);// Execute all queued GEMMsvoidexecute(cudaStream_tstream=0);// Clear the batch without executingvoidclear();// Statisticsintgemm_count()const;floatestimated_time_ms()const;// Configurationvoidset_kernel_variant(intvariant);voidset_max_batch_size(intsize);private:structGemmTask{constfloat*A;constfloat*B;float*C;intM,N,K;};std::vector<GemmTask>tasks_;intmax_batch_size_;intkernel_variant_;};
Execution Strategy
Strategy 1: Sequential Execution
Launch one kernel per GEMM
Pros: Simple, uses optimal kernel per task
Cons: High launch overhead for many small GEMMs
Strategy 2: Batched Kernel Launch
Single kernel processes all GEMMs
Pros: One launch, coalesced memory access
Cons: All GEMMs must be same size
Strategy 3: Grouped Execution
Group same-size GEMMs together
Execute each group with batched kernel
Pros: Balance of efficiency and flexibility
Cons: More complex scheduling
Decision: Implement Strategy 3 (Grouped Execution) for production.
Kernel Design
1
2
3
4
5
6
// Batched GEMM kernel: processes N GEMMs of same dimensions__global__voidbatched_gemm(constfloat**A_array,constfloat**B_array,float**C_array,intM,intN,intK,intbatch_count);