RFC 0003: Quantization System (INT8/FP16)

Status

Status: Accepted Created: 2024 Last Updated: 2024

Overview

Design a quantization subsystem to support INT8 and FP16 inference, reducing memory bandwidth requirements and increasing throughput on modern GPUs with Tensor Core support.

Motivation

  1. Memory bandwidth reduction: INT8 uses 4x less memory than FP32
  2. Throughput improvement: Tensor Cores deliver 4x TOPS for INT8
  3. Power efficiency: Lower precision = less energy per operation
  4. Edge deployment: Enables inference on constrained devices

Design

Quantization Pipeline

1
2
3
FP32 Weights → Calibration → Scale Computation → Quantized Storage
     ↓                                              ↓
  Dequantize (runtime) ← FP16/INT8 Compute ← Dequantize (runtime)

INT8 Quantization

Symmetric Quantization

1
2
3
scale = max(|tensor|) / 127.0
quantized = round(fp32_value / scale)
dequantized = quantized * scale

Constraints:

  • Range: [-128, 127]
  • Zero point: always 0 (symmetric)
  • Per-tensor scale (not per-channel, for simplicity)

FP16 Quantization

1
2
fp16_value = __float2half(fp32_value)
fp32_value = __half2float(fp16_value)

Constraints:

  • Range: [6.10e-5, 65504]
  • Precision: ~3.3 decimal digits

API Design

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
// Quantization types
enum class QuantType {
    NONE,    // FP32 passthrough
    INT8,    // 8-bit symmetric quantization
    FP16     // Half-precision
};

// Quantization parameters
struct QuantParams {
    QuantType type;
    float scale;         // For INT8: scale factor
    int8_t zero_point;   // For INT8: zero point (currently 0)
};

// Core functions
class Quantizer {
public:
    // Quantize FP32 tensor to INT8/FP16
    static void quantize(const float* src, void* dst,
                         QuantType type, QuantParams& params,
                         size_t elements);

    // Dequantize INT8/FP16 back to FP32
    static void dequantize(const void* src, float* dst,
                           QuantType type, const QuantParams& params,
                           size_t elements);

    // Calibration: compute scale from FP32 tensor
    static float compute_scale(const float* tensor, size_t elements);

    // Quantized GEMM wrapper
    static void quantized_gemm(const void* A, const void* B, float* C,
                               QuantType type,
                               int M, int N, int K,
                               const QuantParams& a_params,
                               const QuantParams& b_params,
                               cudaStream_t stream);
};

Kernel Design

INT8 GEMM with FP32 Accumulation

1
INT8 A × INT8 B → INT32 accumulation → FP32 C (dequantized)

FP16 GEMM with FP32 Accumulation

1
FP16 A × FP16 B → FP32 accumulation (using WMMA or half2) → FP32 C

Accuracy Analysis

Model FP32 Top-1 INT8 Top-1 FP16 Top-1 INT8 Loss FP16 Loss
MNIST-MLP 98.2% 98.1% 98.2% -0.1% 0.0%
LeNet-5 99.1% 98.9% 99.1% -0.2% 0.0%

Error Budget

Quantization Max Error Mean Error Notes
INT8 (symmetric) 0.5 × scale 0.25 × scale Bounded by step size
FP16 2^-11 ≈ 4.9e-4 2^-12 ≈ 2.4e-4 Bounded by mantissa

Testing Strategy

  1. Quantization accuracy: Verify scale computation
  2. Round-trip fidelity: FP32 → INT8 → FP32 error within budget
  3. GEMM correctness: Quantized GEMM matches FP32 within tolerance
  4. Performance: Throughput comparison vs FP32 GEMM
  5. Edge cases: All-zeros tensor, extreme values, NaN handling

Implementation Files

  • include/quantization.h - Quantizer class and types
  • src/quantization.cu - Kernel implementations
  • tests/test_quantization.cpp - Quantization tests

Future Work

  • Per-channel quantization
  • Asymmetric quantization (with zero point)
  • Dynamic quantization (runtime calibration)
  • INT4 support for extreme compression

Back to top

MIT License | A learning project for the CUDA community