RFC: Mini-Inference Engine Architecture Design

Status

Status: Accepted
Created: 2024
Last Updated: 2024

Overview

Mini-Inference Engine is a lightweight neural network inference engine focused on GEMM (General Matrix Multiply) optimization. Through progressive optimization strategies, it achieves performance close to 70%-80% of cuBLAS.

Tech Stack

Component Version
CUDA C++ Compute Capability 7.0+
CMake 3.18+
CUDA Toolkit 11.0+
cuBLAS Used for performance comparison

Architecture

System Architecture

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
┌─────────────────────────────────────────────────────────────────┐
│                      Application Layer                          │
│   Benchmark  │  MNIST Demo  │  Tests  │  User Application      │
└─────────────────────────────────────────────────────────────────┘
                              │
┌─────────────────────────────────────────────────────────────────┐
│                       Engine Layer                              │
│   InferenceEngine  │  Tensor  │  AutoTuner  │  Profiler         │
└─────────────────────────────────────────────────────────────────┘
                              │
┌─────────────────────────────────────────────────────────────────┐
│                       Kernel Layer                              │
│  Naive │ Tiled │ Coalesced │ DoubleBuffer │ Optimized │ Fused  │
│  Vectorized │ Half-Precision │ Batched │ cuBLAS wrapper        │
└─────────────────────────────────────────────────────────────────┘
                              │
┌─────────────────────────────────────────────────────────────────┐
│                   Infrastructure Layer                          │
│  MemoryPool │ StreamManager │ Logger │ Config │ Quantization   │
└─────────────────────────────────────────────────────────────────┘

Memory Hierarchy Optimization Strategy

Optimization Technique Target Memory Bandwidth Optimization Effect
Tiling Shared Memory ~10 TB/s Reduces 32x global memory accesses
Coalescing Global Memory ~500 GB/s Improves bandwidth utilization
Double Buffer Shared Memory ~10 TB/s Hides memory latency
Register Blocking Registers ~100 TB/s Maximizes compute density

Components and Interfaces

Core Data Structures

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// Matrix descriptor
struct MatrixDesc {
    float* data;        // Device pointer
    int rows;           // Row count M
    int cols;           // Column count N
    int ld;             // Leading dimension
    bool is_transposed; // Whether transposed
};

// GEMM configuration
struct GemmConfig {
    int BLOCK_M;        // Tile row size
    int BLOCK_N;        // Tile column size
    int BLOCK_K;        // K dimension block size
    int WARP_M;         // Warp-level M blocking
    int WARP_N;         // Warp-level N blocking
    bool use_double_buffer;
    bool use_vectorized_load;
};

// Fusion operation configuration
struct FusionConfig {
    bool add_bias;
    bool apply_relu;
    float* bias;
};

// Performance statistics
struct PerfStats {
    float kernel_time_ms;
    float gflops;
    float memory_bandwidth_gb;
    float cublas_ratio;
};

Kernel Interfaces

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// Naive MatMul: Each thread computes one output element
__global__ void naive_matmul(
    const float* A, const float* B, float* C,
    int M, int N, int K
);

// Tiled GEMM: Uses shared memory tiling
__global__ void tiled_gemm(
    const float* A, const float* B, float* C,
    int M, int N, int K
);

// Optimized GEMM: Full optimization version
template<int BM, int BN, int BK, int TM, int TN>
__global__ void optimized_gemm(
    const float* A, const float* B, float* C,
    int M, int N, int K
);

// Fused Kernel: MatMul + Bias + ReLU
template<int BM, int BN, int BK, int TM, int TN,
         bool ADD_BIAS, bool APPLY_RELU>
__global__ void fused_gemm_bias_relu(
    const float* A, const float* B, float* C,
    const float* bias, int M, int N, int K
);

Engine Interface

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class InferenceEngine {
public:
    void init(int device_id = 0);
    bool load_weights(const std::string& path);
    void forward(const float* input, float* output, int batch_size);
    void forward_with_timing(const float* input, float* output,
                             int batch_size, std::vector<float>& times);
    void cleanup();

    size_t num_layers() const;
    int input_dim() const;
    int output_dim() const;

private:
    std::vector<LayerWeights> layers_;
    cublasHandle_t cublas_handle_;
    cudaStream_t stream_;
};

GEMM Optimization Details

Optimization Levels

Level Technique Performance (vs cuBLAS) Key Optimization
1 Naive ~10% Baseline implementation
2 Tiled ~20% Shared memory tiling
3 Coalesced ~30% Memory access coalescing
4 Double Buffer ~40% Double buffering to hide latency
5 Register Blocked ~70% Register blocking
6 Fused ~80% Operator fusion
7 Vectorized ~85% Vectorized loads

Parameter Constraints

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
Constraints:
─────────────────────────────────────────────────────────────
1. Thread count: (BM / TM) × (BN / TN) ≤ 1024

2. Shared memory: (BM × BK + BK × BN) × 4 ≤ 48KB

3. Registers: TM × TN + TM + TN + overhead ≤ 255

Recommended Configurations:
─────────────────────────────────────────────────────────────
Config        BM   BN   BK   TM   TN   Threads   Shared
─────────────────────────────────────────────────────────────
Small         64   64    8    4    4     256     4KB
Medium       128  128    8    8    8     256     8KB
Large        128  256   16    8    8     512    24KB
─────────────────────────────────────────────────────────────

Data Models

Weight File Format

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
+------------------+
| Header (32 bytes)|
|  - magic: 4B     |  (0x4D494E49 = "MINI")
|  - version: 4B   |
|  - num_layers: 4B|
|  - reserved: 20B |
+------------------+
| Layer Meta       |
|  - type: 4B      |
|  - in_features   |
|  - out_features  |
|  - has_bias      |
+------------------+
| Layer Weights    |
|  - W: float[]    |
|  - bias: float[] |
+------------------+

Network Architecture (MNIST)

1
2
3
4
5
6
7
8
9
Input: 784 (28x28)
    ↓
Linear(784, 256) + ReLU
    ↓
Linear(256, 128) + ReLU
    ↓
Linear(128, 10)
    ↓
Output: 10 (logits)

Error Handling

CUDA Error Handling

1
2
3
4
5
6
7
8
9
10
11
12
13
#define CUDA_CHECK(call) do { \
    cudaError_t err = call; \
    if (err != cudaSuccess) { \
        throw CudaException(err, __FILE__, __LINE__); \
    } \
} while(0)

class CudaException : public std::exception {
public:
    CudaException(cudaError_t err, const char* file, int line);
    const char* what() const noexcept override;
    cudaError_t error() const;
};

RAII Resource Management

1
2
3
4
5
6
7
8
9
10
11
12
class DeviceMemory {
public:
    explicit DeviceMemory(size_t bytes);
    ~DeviceMemory();

    DeviceMemory(const DeviceMemory&) = delete;
    DeviceMemory& operator=(const DeviceMemory&) = delete;
    DeviceMemory(DeviceMemory&&) noexcept;

    float* get();
    size_t size() const;
};

Testing Strategy

Test Framework

Type Tool
Unit Testing Google Test
Property Testing Custom random matrix generator
Performance Testing Custom benchmark framework

Test Coverage

Test File Coverage
test_gemm.cpp All GEMM kernels
test_tensor.cpp Tensor operations
test_inference.cpp InferenceEngine
test_memory_pool.cpp MemoryPool
test_stream_manager.cpp StreamManager
test_config.cpp Config
test_logger.cpp Logger
test_quantization.cpp INT8 quantization
test_fusion.cpp Fusion kernels

References


Back to top

MIT License | A learning project for the CUDA community