Language: 简体中文 English

目录 (Table of Contents)


核心类

DeviceMemory

GPU 内存的 RAII 包装类,自动管理内存生命周期。

头文件: common.h

构造函数

1
2
DeviceMemory();                          // 默认构造,空内存
explicit DeviceMemory(size_t bytes);     // 分配指定字节数

成员函数

函数 说明
void allocate(size_t bytes) 重新分配内存
void free() 释放内存
void copy_from_host(const void* data, size_t bytes) 从主机拷贝
void copy_to_host(void* data, size_t bytes) const 拷贝到主机
void zero() 清零内存
T* get() 获取设备指针
size_t size() const 获取字节数
bool empty() const 是否为空

使用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include "common.h"

// 创建并分配内存
DeviceMemory mem(1024 * 1024);  // 1MB

// 从主机拷贝数据
std::vector<float> host_data(256, 1.0f);
mem.copy_from_host(host_data.data(), 256 * sizeof(float));

// 拷贝到主机
std::vector<float> result(256);
mem.copy_to_host(result.data(), 256 * sizeof(float));

// 清零
mem.zero();

// 获取原始指针
float* ptr = mem.get<float>();

Tensor

N 维张量类,提供 GPU 存储和常用操作。

头文件: tensor.h

构造函数

1
2
3
Tensor();                                                    // 默认构造,空张量
explicit Tensor(const std::vector<int>& shape);             // 指定形状
Tensor(const std::vector<int>& shape, const float* data);   // 用数据初始化

成员函数

函数 说明
Tensor clone() const 深拷贝
void reshape(const std::vector<int>& new_shape) 重塑形状
void fill(float value) 填充值
void zero() 清零
void copy_from_host(const float* data) 从主机拷贝
void copy_to_host(float* data) const 拷贝到主机
std::vector<float> to_host() const 转换为主机向量
MatrixDesc as_matrix() const 获取矩阵视图
const std::vector<int>& shape() const 获取形状
const std::vector<int>& strides() const 获取步长
size_t size() const 元素数量
int ndim() const 维度数
int dim(int i) const 第 i 维大小

使用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include "tensor.h"

// 创建张量
Tensor t({batch, height, width, channels});

// 用数据创建
std::vector<float> data(100);
Tensor t2({10, 10}, data.data());

// 属性查询
auto shape = t.shape();      // std::vector<int>
auto strides = t.strides();  // std::vector<int>
size_t n = t.size();         // 元素数量
int d = t.ndim();            // 维度数

// 数据操作
t.fill(1.0f);
t.zero();
Tensor t3 = t.clone();
t.reshape({100});

// 主机数据交换
t.copy_from_host(host_data.data());
auto host = t.to_host();

// 矩阵视图(2D)
MatrixDesc mat = t.as_matrix();

数学运算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// 矩阵乘法: C = A @ B
Tensor C = matmul(A, B);

// 逐元素加法
Tensor C = add(A, B);

// 加偏置: Y = X + bias(broadcast)
Tensor Y = add_bias(X, bias);

// ReLU 激活
Tensor Y = relu(X);

// Softmax
Tensor Y = softmax(X);

// 融合线性 + ReLU: Y = ReLU(X @ W + bias)
Tensor Y = fused_linear_relu(X, W, bias);

MemoryPool

GPU 内存池,通过缓存减少 cudaMalloc 开销。

头文件: memory_pool.h

单例访问

1
static MemoryPool& instance();

成员函数

函数 说明
void* allocate(size_t size) 分配内存
void deallocate(void* ptr) 释放内存(返回缓存)
void clear_cache() 清理缓存,保留活跃分配
void clear_all() 清理缓存并重置统计
void release_all() 释放所有内存(仅用于关闭时)
bool owns(void* ptr) const 是否管理该指针
bool is_cached(void* ptr) const 是否在缓存中
size_t active_block_count() const 活跃块数
size_t cached_block_count() const 缓存块数
Stats get_stats() const 获取统计信息
void print_stats() const 打印统计信息

统计信息结构体

1
2
3
4
5
6
struct Stats {
    size_t total_allocated;   // 累计分配字节数
    size_t cached_size;       // 缓存大小
    size_t cache_hits;        // 缓存命中次数
    size_t cache_misses;      // 缓存未命中次数
};

使用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include "memory_pool.h"

// 单例访问
auto& pool = MemoryPool::instance();

// 分配内存
void* ptr = pool.allocate(1024);

// 释放内存(返回缓存)
pool.deallocate(ptr);

// 查询统计
auto stats = pool.get_stats();
printf("Cache hits: %zu\n", stats.cache_hits);
printf("Cache misses: %zu\n", stats.cache_misses);
printf("Hit rate: %.1f%%\n", 
       100.0 * stats.cache_hits / (stats.cache_hits + stats.cache_misses));

// 清理缓存
pool.clear_cache();
pool.release_all();

PooledMemory 包装类

1
2
3
4
5
6
7
// RAII 包装,自动从池中分配和释放
PooledMemory mem(1024 * sizeof(float));
float* ptr = mem.get();
mem.copy_from_host(host_data.data(), bytes);
mem.copy_to_host(result.data(), bytes);
mem.zero();
// 析构时自动归还池

StreamManager

CUDA 流管理器,支持多流并发执行。

头文件: stream_manager.h

单例访问

1
static StreamManager& instance();

成员函数

函数 说明
void init(int num_streams = 4) 初始化流池
cudaStream_t get_stream() 轮询获取流
cudaStream_t get_stream(int index) 按索引获取流
void sync(int index) 同步指定流
void sync_all() 同步所有流
void cleanup() 清理所有流
int num_streams() const 流数量

使用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include "stream_manager.h"

// 单例访问
auto& sm = StreamManager::instance();

// 初始化(可选,会自动懒初始化)
sm.init(4);  // 创建 4 个流

// 获取流(轮询分配)
cudaStream_t s1 = sm.get_stream();
cudaStream_t s2 = sm.get_stream();

// 按索引获取
cudaStream_t s0 = sm.get_stream(0);

// 同步
sm.sync(0);      // 同步流 0
sm.sync_all();   // 同步所有流

// 清理
sm.cleanup();

AsyncOperation 辅助类

1
2
3
4
5
6
7
8
9
10
11
12
13
AsyncOperation op;

// 记录事件
op.record(stream);

// 等待完成
op.wait();

// 非阻塞检查
bool done = op.is_complete();

// 让另一个流等待此操作
op.wait_on(other_stream);

InferenceEngine

神经网络推理引擎,支持多层前向传播。

头文件: inference_engine.h

成员函数

函数 说明
void init(int device_id = 0) 初始化引擎
void cleanup() 清理资源
bool load_weights(const std::string& path) 加载权重
bool save_weights(const std::string& path) const 保存权重
void add_layer(int in, int out, bool bias, const float* w, const float* b) 添加层
void forward(const float* in, float* out, int batch) 前向传播
void forward_with_timing(const float* in, float* out, int batch, std::vector<float>& times) 带计时前向传播
size_t num_layers() const 层数
int input_dim() const 输入维度
int output_dim() const 输出维度

使用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include "inference_engine.h"

// 创建引擎
InferenceEngine engine;
engine.init(0);

// 加载权重
if (!engine.load_weights("model.bin")) {
    std::cerr << "Failed to load weights" << std::endl;
    return -1;
}

// 查询网络信息
size_t n_layers = engine.num_layers();
int in_dim = engine.input_dim();
int out_dim = engine.output_dim();

// 准备数据
int batch_size = 32;
DeviceMemory d_input(batch_size * in_dim * sizeof(float));
DeviceMemory d_output(batch_size * out_dim * sizeof(float));

// 前向传播
engine.forward(d_input.get(), d_output.get(), batch_size);

// 带计时的前向传播
std::vector<float> layer_times;
engine.forward_with_timing(d_input.get(), d_output.get(), 
                           batch_size, layer_times);

// 保存权重
engine.save_weights("model_copy.bin");

// 清理
engine.cleanup();

创建权重文件

1
2
3
4
5
6
7
8
9
// 定义网络结构: {输入维度, 输出维度}
std::vector<std::pair<int, int>> layer_dims = {
    {784, 256},   // Layer 0: 784 -> 256
    {256, 128},   // Layer 1: 256 -> 128
    {128, 10}     // Layer 2: 128 -> 10
};

// 创建随机权重文件(用于测试)
create_random_weights("model.bin", layer_dims, true);  // true = 包含 bias

GEMM Kernel API

基础 Kernel 启动函数

头文件: kernels.cuh

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// 基础实现
void launch_naive_matmul(const float* A, const float* B, float* C,
                         int M, int N, int K, cudaStream_t stream = 0);

void launch_tiled_gemm(const float* A, const float* B, float* C,
                       int M, int N, int K, cudaStream_t stream = 0);

void launch_coalesced_gemm(const float* A, const float* B, float* C,
                           int M, int N, int K, cudaStream_t stream = 0);

void launch_double_buffer_gemm(const float* A, const float* B, float* C,
                               int M, int N, int K, cudaStream_t stream = 0);

// 优化实现(推荐)
void launch_optimized_gemm(const float* A, const float* B, float* C,
                           int M, int N, int K, cudaStream_t stream = 0);

// 融合 GEMM + Bias + ReLU
void launch_fused_gemm(const float* A, const float* B, float* C,
                       const float* bias, int M, int N, int K,
                       bool add_bias, bool apply_relu, 
                       cudaStream_t stream = 0);

// cuBLAS 包装
void launch_cublas_gemm(cublasHandle_t handle, 
                        const float* A, const float* B, float* C,
                        int M, int N, int K, cudaStream_t stream = 0);

批量 GEMM

头文件: batch_gemm.h

1
2
3
4
5
6
7
8
9
10
11
12
// 创建批量描述符
BatchGemmDesc desc(M, N, K, batch_size);
for (int i = 0; i < batch_size; i++) {
    desc.add_matrices(A_ptrs[i], B_ptrs[i], C_ptrs[i]);
}

// 使用多流并行执行
launch_batched_gemm_streams(desc, GemmKernelType::REGISTER_BLOCKED, cublas_handle);

// Strided 批量 GEMM(矩阵连续存储)
launch_strided_batched_gemm(A, B, C, M, N, K, batch_size, 
                            GemmKernelType::REGISTER_BLOCKED, cublas_handle);

半精度 GEMM

头文件: half_gemm.cuh

1
2
3
4
5
6
7
8
9
10
11
12
13
// FP16 GEMM(FP16 输入/输出, FP32 累加)
void launch_half_gemm(const half* A, const half* B, half* C,
                      int M, int N, int K, cudaStream_t stream = 0);

// 混合精度(FP16 输入, FP32 输出)
void launch_mixed_precision_gemm(const half* A, const half* B, float* C,
                                  int M, int N, int K, cudaStream_t stream = 0);

// 类型转换
void convert_float_to_half(const float* src, half* dst, size_t n, 
                           cudaStream_t stream = 0);
void convert_half_to_float(const half* src, float* dst, size_t n,
                           cudaStream_t stream = 0);

向量化 GEMM

头文件: vectorized_gemm.cuh

1
2
3
// 使用 float4 向量化加载
void launch_vectorized_gemm(const float* A, const float* B, float* C,
                            int M, int N, int K, cudaStream_t stream = 0);

批量 GEMM 操作

头文件: batch_gemm.h

BatchGemmDesc

批量 GEMM 操作描述符,保存多个矩阵的指针。

1
2
3
4
5
6
7
8
9
10
11
12
13
struct BatchGemmDesc {
    std::vector<const float*> A_ptrs;  // A 矩阵指针数组
    std::vector<const float*> B_ptrs;  // B 矩阵指针数组
    std::vector<float*> C_ptrs;        // C 矩阵指针数组
    int M, N, K;                       // 矩阵维度(所有批次相同)
    int batch_size;                    // 矩阵对数量

    // 构造函数
    BatchGemmDesc(int m, int n, int k, int batch);

    // 添加一组矩阵到批次
    void add_matrices(const float* A, const float* B, float* C);
};

使用示例

1
2
3
4
5
6
7
// 为 4 个 128x64 的矩阵创建批次描述符
BatchGemmDesc desc(128, 64, 32, 4);

// 添加矩阵指针
for (int i = 0; i < 4; i++) {
    desc.add_matrices(d_A[i], d_B[i], d_C[i]);
}

launch_batched_gemm_streams

使用 CUDA 流并行执行多个 GEMM 操作。

1
2
3
4
5
void launch_batched_gemm_streams(
    const BatchGemmDesc& desc,
    GemmKernelType kernel_type = GemmKernelType::REGISTER_BLOCKED,
    cublasHandle_t cublas_handle = nullptr
);

参数

参数 说明
desc 包含矩阵指针的批次描述符
kernel_type 每个操作使用的 GEMM kernel
cublas_handle cuBLAS 句柄(当 kernel_type 为 CUBLAS 时需要)

launch_strided_batched_gemm

在连续存储的矩阵数组上执行批量 GEMM。

1
2
3
4
5
6
void launch_strided_batched_gemm(
    const float* A, const float* B, float* C,
    int M, int N, int K, int batch_size,
    GemmKernelType kernel_type = GemmKernelType::REGISTER_BLOCKED,
    cublasHandle_t cublas_handle = nullptr
);

矩阵步长计算:

  • A: M * K 个元素
  • B: K * N 个元素
  • C: M * N 个元素

launch_cublas_batched_gemm

优化的 cuBLAS 批量 GEMM,适用于多个小矩阵。

1
2
3
4
5
void launch_cublas_batched_gemm(
    cublasHandle_t handle,
    const float** A_array, const float** B_array, float** C_array,
    int M, int N, int K, int batch_size
);

BatchPerfStats

批量 GEMM 性能统计。

1
2
3
4
5
6
7
8
struct BatchPerfStats {
    float total_time_ms;         // 总执行时间
    float avg_time_per_gemm_ms;  // 每个 GEMM 平均时间
    float total_gflops;          // 总吞吐量
    int batch_size;              // 操作数量

    void compute(int M, int N, int K, int batch);
};

benchmark_batched_gemm

基准测试批量 GEMM 性能。

1
2
3
4
5
6
7
BatchPerfStats benchmark_batched_gemm(
    const BatchGemmDesc& desc,
    GemmKernelType kernel_type,
    int warmup_iters = 3,
    int bench_iters = 10,
    cublasHandle_t cublas_handle = nullptr
);

工具类

GpuTimer

GPU 计时器,使用 CUDA Events 精确测量 kernel 执行时间。

头文件: kernels.cuh

1
2
3
4
5
6
7
GpuTimer timer;

timer.start(stream);
// ... kernel 执行 ...
timer.stop(stream);

float ms = timer.elapsed_ms();

Profiler

性能分析器,提供详细的统计数据和 Roofline 分析。

头文件: profiler.h

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
Profiler profiler;

// 分析 kernel
auto result = profiler.profile("Optimized GEMM", M, N, K, 
    5,   // warmup iterations
    20,  // benchmark iterations
    [&]() { launch_optimized_gemm(A, B, C, M, N, K); });

printf("Time: %.3f ms\n", result.avg_time_ms);
printf("GFLOPS: %.2f\n", result.gflops);
printf("Bandwidth: %.2f GB/s\n", result.memory_bandwidth_gb);

// 打印结果
Profiler::print_result(result);

// 比较多个 kernel
std::vector<ProfileResult> results;
results.push_back(result1);
results.push_back(result2);
Profiler::print_comparison(results, "cuBLAS");

ProfileResult 结构体:

1
2
3
4
5
6
7
8
9
10
11
struct ProfileResult {
    std::string name;
    float avg_time_ms;           // 平均时间
    float min_time_ms;           // 最小时间
    float max_time_ms;           // 最大时间
    float std_dev_ms;            // 标准差
    float gflops;                // GFLOPS
    float memory_bandwidth_gb;   // 内存带宽
    float arithmetic_intensity;  // 计算强度
    int iterations;              // 迭代次数
};

RooflineAnalyzer

Roofline 模型分析器,用于识别性能瓶颈。

头文件: profiler.h

1
2
3
4
5
6
7
8
9
10
11
12
// 创建分析器,指定峰值性能指标
RooflineAnalyzer analyzer(
    10000.0f,  // 峰值 GFLOPS(设备特定)
    900.0f     // 峰值带宽 GB/s
);

// 添加分析结果
analyzer.add_point(result1);
analyzer.add_point(result2);

// 打印 Roofline 分析
analyzer.analyze();

输出包括:

  • 计算强度 (FLOPs/Byte)
  • 实际 vs. 理论 GFLOPS
  • 效率百分比
  • 内存受限 vs. 计算受限分类

AutoTuner

自动调优器,为给定矩阵尺寸选择最优 kernel。

头文件: autotuner.h

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
AutoTuner tuner;

// 调优并缓存结果
auto result = tuner.tune(M, N, K, cublas_handle);
printf("Best kernel: %s\n", kernel_type_name(result.config.kernel_type));
printf("Performance: %.2f GFLOPS\n", result.gflops);

// 获取缓存结果(不重新调优)
auto cached = tuner.get_best(M, N, K, cublas_handle);

// 使用最优 kernel 执行
tuner.execute_best(A, B, C, M, N, K, stream);

// 打印缓存
tuner.print_cache();

TuningResult 结构体:

1
2
3
4
5
6
7
8
9
struct TuningResult {
    TuningConfig config;   // 包含 kernel_type
    float time_ms;         // 执行时间
    float gflops;          // 性能指标
};

struct TuningConfig {
    GemmKernelType kernel_type;  // 选中的 kernel 类型
};

配置与日志

Config

配置管理器,支持文件和环境变量。

头文件: config.h

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// 单例访问
auto& config = Config::instance();

// 从文件加载
config.load_from_file("config/default.ini");

// 获取值(带默认值)
int device = config.get_int("CUDA_DEVICE", 0);
float rate = config.get_float("LEARNING_RATE", 0.001f);
bool enabled = config.get_bool("ENABLE_TENSOR_CORES", true);
std::string preset = config.get("GEMM_PRESET", "medium");

// 设置值
config.set("LOG_LEVEL", "DEBUG");

// 检查键是否存在
if (config.has("CUDA_DEVICE")) { ... }

// 获取所有键
auto keys = config.keys();

// 清空
config.clear();

GEMM 预设: get_gemm_preset("medium")

可用预设: "small", "medium", "large", "volta", "ampere"

设备配置: get_device_config()


Logger

线程安全的日志系统。

头文件: logger.h

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// 设置日志级别
Logger::instance().set_level(LogLevel::DEBUG);

// 输出到文件
Logger::instance().set_file("app.log");

// 控制台输出开关
Logger::instance().set_console(true);

// 颜色开关
Logger::instance().set_colors(true);

// 日志宏
LOG_TRACE("Trace message");
LOG_DEBUG("Debug: value = %d", value);
LOG_INFO("Info message");
LOG_WARN("Warning: %s", msg);
LOG_ERROR("Error occurred");
LOG_FATAL("Fatal error");

日志级别:

级别 说明
TRACE 详细跟踪
DEBUG 调试信息
INFO 一般信息
WARN 警告
ERROR 错误
FATAL 致命错误
OFF 关闭日志

量化支持

QuantizationParams

量化参数。

头文件: quantization.h

1
2
3
4
5
6
7
8
9
10
11
// 计算量化参数(对称量化)
QuantizationParams params = compute_quant_params(data, n);
// params.scale, zero_point, min_val, max_val

// 量化
std::vector<int8_t> quantized(n);
quantize_tensor(data, quantized.data(), n, params);

// 反量化
std::vector<float> dequantized(n);
dequantize_tensor(quantized.data(), dequantized.data(), n, params);

QuantizedWeight

量化权重存储。

1
2
3
4
5
6
7
8
// 从浮点权重创建
QuantizedWeight qw(float_weights, rows, cols);

// 反量化
auto restored = qw.dequantize();

// 压缩比
float ratio = qw.compression_ratio();  // 4.0 (FP32 -> INT8)

Per-Channel 量化

1
2
3
4
5
6
7
8
// 计算每通道参数
auto params = compute_per_channel_params(data, rows, cols);

// 量化
quantize_per_channel(data, quantized.data(), rows, cols, params);

// 反量化
dequantize_per_channel(quantized.data(), restored.data(), rows, cols, params);

QuantizationCalibrator

量化校准器,用于动态量化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
QuantizationCalibrator calibrator;

// 观察多个批次
calibrator.observe(batch1, n1);
calibrator.observe(batch2, n2);

// 获取最终参数
auto params = calibrator.get_params();

// 统计信息
float mean = calibrator.mean();
float var = calibrator.variance();

// 重置
calibrator.reset();

错误处理

CUDA_CHECK 宏

1
2
3
4
5
6
// 检查 CUDA 调用
CUDA_CHECK(cudaMalloc(&ptr, size));
CUDA_CHECK(cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToDevice));
CUDA_CHECK(cudaDeviceSynchronize());

// 失败时抛出 CudaException

CUBLAS_CHECK 宏

1
2
3
4
// 检查 cuBLAS 调用
CUBLAS_CHECK(cublasSgemm(handle, ...));

// 失败时抛出 std::runtime_error

异常处理示例

1
2
3
4
5
6
7
8
9
10
11
try {
    // CUDA 操作
} catch (const CudaException& e) {
    // CUDA 错误
    std::cerr << e.what() << std::endl;
    std::cerr << "Error: " << cudaGetErrorName(e.error()) << std::endl;
} catch (const std::invalid_argument& e) {
    // 参数验证错误
} catch (const std::runtime_error& e) {
    // 其他运行时错误
}

性能测量

工具函数

头文件: common.h

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// 用随机值初始化数组(范围 [-1.0, 1.0])
void random_init(float* data, size_t n);

// 用零初始化数组
void zero_init(float* data, size_t n);

// CPU 参考实现(GEMM)
void cpu_matmul(const float* A, const float* B, float* C, int M, int N, int K);

// 带偏置和 ReLU 的 CPU 参考
void cpu_matmul_bias_relu(const float* A, const float* B, float* C,
                          const float* bias, int M, int N, int K,
                          bool add_bias, bool apply_relu);

// 比较矩阵,返回最大绝对误差
float compare_matrices(const float* a, const float* b, size_t n);

benchmark_kernel 函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include "kernels.cuh"

PerfStats stats = benchmark_kernel(
    GemmKernelType::REGISTER_BLOCKED,  // kernel 类型
    A, B, C,                            // 矩阵指针
    M, N, K,                            // 维度
    5,    // warmup iterations
    20,   // benchmark iterations
    cublas_handle,  // cuBLAS 句柄
    stream           // CUDA 流
);

printf("Time: %.3f ms\n", stats.kernel_time_ms);
printf("GFLOPS: %.2f\n", stats.gflops);
printf("Bandwidth: %.2f GB/s\n", stats.memory_bandwidth_gb);

PerfStats 结构体

1
2
3
4
5
6
struct PerfStats {
    float kernel_time_ms;       // kernel 时间
    float gflops;               // GFLOPS
    float memory_bandwidth_gb;  // 内存带宽
    float cublas_ratio;         // 相对 cuBLAS 的比例
};

完整示例

端到端推理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include "inference_engine.h"
#include "tensor.h"
#include <iostream>

int main() {
    // 初始化
    InferenceEngine engine;
    engine.init(0);
    
    // 加载权重
    if (!engine.load_weights("mnist_model.bin")) {
        std::cerr << "Failed to load model" << std::endl;
        return -1;
    }
    
    // 准备输入
    int batch_size = 32;
    Tensor input({batch_size, engine.input_dim()});
    input.fill(0.5f);
    
    // 推理
    Tensor output({batch_size, engine.output_dim()});
    engine.forward(input.data(), output.data(), batch_size);
    
    // 获取结果
    auto result = output.to_host();
    
    // 打印前 5 个预测
    for (int i = 0; i < 5; i++) {
        float* logits = &result[i * engine.output_dim()];
        int pred = std::max_element(logits, logits + engine.output_dim()) - logits;
        std::cout << "Sample " << i << ": predicted " << pred << std::endl;
    }
    
    engine.cleanup();
    return 0;
}

相关链接


*最后更新:2025-04-16 文档版本:v1.1.0*

Back to top

MIT License | A learning project for the CUDA community