RFC: Mini-Inference Engine Architecture Design
Status
Status: Accepted
Created: 2024
Last Updated: 2024
Overview
Mini-Inference Engine is a lightweight neural network inference engine focused on GEMM (General Matrix Multiply) optimization. Through progressive optimization strategies, it achieves performance close to 70%-80% of cuBLAS.
Tech Stack
Component
Version
CUDA C++
Compute Capability 7.0+
CMake
3.18+
CUDA Toolkit
11.0+
cuBLAS
Used for performance comparison
Architecture
System Architecture
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
┌─────────────────────────────────────────────────────────────────┐
│ Application Layer │
│ Benchmark │ MNIST Demo │ Tests │ User Application │
└─────────────────────────────────────────────────────────────────┘
│
┌─────────────────────────────────────────────────────────────────┐
│ Engine Layer │
│ InferenceEngine │ Tensor │ AutoTuner │ Profiler │
└─────────────────────────────────────────────────────────────────┘
│
┌─────────────────────────────────────────────────────────────────┐
│ Kernel Layer │
│ Naive │ Tiled │ Coalesced │ DoubleBuffer │ Optimized │ Fused │
│ Vectorized │ Half-Precision │ Batched │ cuBLAS wrapper │
└─────────────────────────────────────────────────────────────────┘
│
┌─────────────────────────────────────────────────────────────────┐
│ Infrastructure Layer │
│ MemoryPool │ StreamManager │ Logger │ Config │ Quantization │
└─────────────────────────────────────────────────────────────────┘
Memory Hierarchy Optimization Strategy
Optimization Technique
Target Memory
Bandwidth
Optimization Effect
Tiling
Shared Memory
~10 TB/s
Reduces 32x global memory accesses
Coalescing
Global Memory
~500 GB/s
Improves bandwidth utilization
Double Buffer
Shared Memory
~10 TB/s
Hides memory latency
Register Blocking
Registers
~100 TB/s
Maximizes compute density
Components and Interfaces
Core Data Structures
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// Matrix descriptor
struct MatrixDesc {
float * data ; // Device pointer
int rows ; // Row count M
int cols ; // Column count N
int ld ; // Leading dimension
bool is_transposed ; // Whether transposed
};
// GEMM configuration
struct GemmConfig {
int BLOCK_M ; // Tile row size
int BLOCK_N ; // Tile column size
int BLOCK_K ; // K dimension block size
int WARP_M ; // Warp-level M blocking
int WARP_N ; // Warp-level N blocking
bool use_double_buffer ;
bool use_vectorized_load ;
};
// Fusion operation configuration
struct FusionConfig {
bool add_bias ;
bool apply_relu ;
float * bias ;
};
// Performance statistics
struct PerfStats {
float kernel_time_ms ;
float gflops ;
float memory_bandwidth_gb ;
float cublas_ratio ;
};
Kernel Interfaces
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// Naive MatMul: Each thread computes one output element
__global__ void naive_matmul (
const float * A , const float * B , float * C ,
int M , int N , int K
);
// Tiled GEMM: Uses shared memory tiling
__global__ void tiled_gemm (
const float * A , const float * B , float * C ,
int M , int N , int K
);
// Optimized GEMM: Full optimization version
template < int BM , int BN , int BK , int TM , int TN >
__global__ void optimized_gemm (
const float * A , const float * B , float * C ,
int M , int N , int K
);
// Fused Kernel: MatMul + Bias + ReLU
template < int BM , int BN , int BK , int TM , int TN ,
bool ADD_BIAS , bool APPLY_RELU >
__global__ void fused_gemm_bias_relu (
const float * A , const float * B , float * C ,
const float * bias , int M , int N , int K
);
Engine Interface
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class InferenceEngine {
public:
void init ( int device_id = 0 );
bool load_weights ( const std :: string & path );
void forward ( const float * input , float * output , int batch_size );
void forward_with_timing ( const float * input , float * output ,
int batch_size , std :: vector < float >& times );
void cleanup ();
size_t num_layers () const ;
int input_dim () const ;
int output_dim () const ;
private:
std :: vector < LayerWeights > layers_ ;
cublasHandle_t cublas_handle_ ;
cudaStream_t stream_ ;
};
GEMM Optimization Details
Optimization Levels
Level
Technique
Performance (vs cuBLAS)
Key Optimization
1
Naive
~10%
Baseline implementation
2
Tiled
~20%
Shared memory tiling
3
Coalesced
~30%
Memory access coalescing
4
Double Buffer
~40%
Double buffering to hide latency
5
Register Blocked
~70%
Register blocking
6
Fused
~80%
Operator fusion
7
Vectorized
~85%
Vectorized loads
Parameter Constraints
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
Constraints:
─────────────────────────────────────────────────────────────
1. Thread count: (BM / TM) × (BN / TN) ≤ 1024
2. Shared memory: (BM × BK + BK × BN) × 4 ≤ 48KB
3. Registers: TM × TN + TM + TN + overhead ≤ 255
Recommended Configurations:
─────────────────────────────────────────────────────────────
Config BM BN BK TM TN Threads Shared
─────────────────────────────────────────────────────────────
Small 64 64 8 4 4 256 4KB
Medium 128 128 8 8 8 256 8KB
Large 128 256 16 8 8 512 24KB
─────────────────────────────────────────────────────────────
Data Models
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
+------------------+
| Header (32 bytes)|
| - magic: 4B | (0x4D494E49 = "MINI")
| - version: 4B |
| - num_layers: 4B|
| - reserved: 20B |
+------------------+
| Layer Meta |
| - type: 4B |
| - in_features |
| - out_features |
| - has_bias |
+------------------+
| Layer Weights |
| - W: float[] |
| - bias: float[] |
+------------------+
Network Architecture (MNIST)
1
2
3
4
5
6
7
8
9
Input: 784 (28x28)
↓
Linear(784, 256) + ReLU
↓
Linear(256, 128) + ReLU
↓
Linear(128, 10)
↓
Output: 10 (logits)
Error Handling
CUDA Error Handling
1
2
3
4
5
6
7
8
9
10
11
12
13
#define CUDA_CHECK(call) do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
throw CudaException(err, __FILE__, __LINE__); \
} \
} while(0)
class CudaException : public std :: exception {
public:
CudaException ( cudaError_t err , const char * file , int line );
const char * what () const noexcept override ;
cudaError_t error () const ;
};
RAII Resource Management
1
2
3
4
5
6
7
8
9
10
11
12
class DeviceMemory {
public:
explicit DeviceMemory ( size_t bytes );
~ DeviceMemory ();
DeviceMemory ( const DeviceMemory & ) = delete ;
DeviceMemory & operator = ( const DeviceMemory & ) = delete ;
DeviceMemory ( DeviceMemory && ) noexcept ;
float * get ();
size_t size () const ;
};
Testing Strategy
Test Framework
Type
Tool
Unit Testing
Google Test
Property Testing
Custom random matrix generator
Performance Testing
Custom benchmark framework
Test Coverage
Test File
Coverage
test_gemm.cpp
All GEMM kernels
test_tensor.cpp
Tensor operations
test_inference.cpp
InferenceEngine
test_memory_pool.cpp
MemoryPool
test_stream_manager.cpp
StreamManager
test_config.cpp
Config
test_logger.cpp
Logger
test_quantization.cpp
INT8 quantization
test_fusion.cpp
Fusion kernels
References