Language: 简体中文 English
目录 (Table of Contents)
核心类
DeviceMemory
GPU 内存的 RAII 包装类,自动管理内存生命周期。
头文件: common.h
构造函数
1
2
DeviceMemory(); // 默认构造,空内存
explicit DeviceMemory(size_t bytes); // 分配指定字节数
成员函数
| 函数 | 说明 |
|---|---|
void allocate(size_t bytes) |
重新分配内存 |
void free() |
释放内存 |
void copy_from_host(const void* data, size_t bytes) |
从主机拷贝 |
void copy_to_host(void* data, size_t bytes) const |
拷贝到主机 |
void zero() |
清零内存 |
T* get() |
获取设备指针 |
size_t size() const |
获取字节数 |
bool empty() const |
是否为空 |
使用示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include "common.h"
// 创建并分配内存
DeviceMemory mem(1024 * 1024); // 1MB
// 从主机拷贝数据
std::vector<float> host_data(256, 1.0f);
mem.copy_from_host(host_data.data(), 256 * sizeof(float));
// 拷贝到主机
std::vector<float> result(256);
mem.copy_to_host(result.data(), 256 * sizeof(float));
// 清零
mem.zero();
// 获取原始指针
float* ptr = mem.get<float>();
Tensor
N 维张量类,提供 GPU 存储和常用操作。
头文件: tensor.h
构造函数
1
2
3
Tensor(); // 默认构造,空张量
explicit Tensor(const std::vector<int>& shape); // 指定形状
Tensor(const std::vector<int>& shape, const float* data); // 用数据初始化
成员函数
| 函数 | 说明 |
|---|---|
Tensor clone() const |
深拷贝 |
void reshape(const std::vector<int>& new_shape) |
重塑形状 |
void fill(float value) |
填充值 |
void zero() |
清零 |
void copy_from_host(const float* data) |
从主机拷贝 |
void copy_to_host(float* data) const |
拷贝到主机 |
std::vector<float> to_host() const |
转换为主机向量 |
MatrixDesc as_matrix() const |
获取矩阵视图 |
const std::vector<int>& shape() const |
获取形状 |
const std::vector<int>& strides() const |
获取步长 |
size_t size() const |
元素数量 |
int ndim() const |
维度数 |
int dim(int i) const |
第 i 维大小 |
使用示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include "tensor.h"
// 创建张量
Tensor t({batch, height, width, channels});
// 用数据创建
std::vector<float> data(100);
Tensor t2({10, 10}, data.data());
// 属性查询
auto shape = t.shape(); // std::vector<int>
auto strides = t.strides(); // std::vector<int>
size_t n = t.size(); // 元素数量
int d = t.ndim(); // 维度数
// 数据操作
t.fill(1.0f);
t.zero();
Tensor t3 = t.clone();
t.reshape({100});
// 主机数据交换
t.copy_from_host(host_data.data());
auto host = t.to_host();
// 矩阵视图(2D)
MatrixDesc mat = t.as_matrix();
数学运算
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// 矩阵乘法: C = A @ B
Tensor C = matmul(A, B);
// 逐元素加法
Tensor C = add(A, B);
// 加偏置: Y = X + bias(broadcast)
Tensor Y = add_bias(X, bias);
// ReLU 激活
Tensor Y = relu(X);
// Softmax
Tensor Y = softmax(X);
// 融合线性 + ReLU: Y = ReLU(X @ W + bias)
Tensor Y = fused_linear_relu(X, W, bias);
MemoryPool
GPU 内存池,通过缓存减少 cudaMalloc 开销。
头文件: memory_pool.h
单例访问
1
static MemoryPool& instance();
成员函数
| 函数 | 说明 |
|---|---|
void* allocate(size_t size) |
分配内存 |
void deallocate(void* ptr) |
释放内存(返回缓存) |
void clear_cache() |
清理缓存,保留活跃分配 |
void clear_all() |
清理缓存并重置统计 |
void release_all() |
释放所有内存(仅用于关闭时) |
bool owns(void* ptr) const |
是否管理该指针 |
bool is_cached(void* ptr) const |
是否在缓存中 |
size_t active_block_count() const |
活跃块数 |
size_t cached_block_count() const |
缓存块数 |
Stats get_stats() const |
获取统计信息 |
void print_stats() const |
打印统计信息 |
统计信息结构体
1
2
3
4
5
6
struct Stats {
size_t total_allocated; // 累计分配字节数
size_t cached_size; // 缓存大小
size_t cache_hits; // 缓存命中次数
size_t cache_misses; // 缓存未命中次数
};
使用示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include "memory_pool.h"
// 单例访问
auto& pool = MemoryPool::instance();
// 分配内存
void* ptr = pool.allocate(1024);
// 释放内存(返回缓存)
pool.deallocate(ptr);
// 查询统计
auto stats = pool.get_stats();
printf("Cache hits: %zu\n", stats.cache_hits);
printf("Cache misses: %zu\n", stats.cache_misses);
printf("Hit rate: %.1f%%\n",
100.0 * stats.cache_hits / (stats.cache_hits + stats.cache_misses));
// 清理缓存
pool.clear_cache();
pool.release_all();
PooledMemory 包装类
1
2
3
4
5
6
7
// RAII 包装,自动从池中分配和释放
PooledMemory mem(1024 * sizeof(float));
float* ptr = mem.get();
mem.copy_from_host(host_data.data(), bytes);
mem.copy_to_host(result.data(), bytes);
mem.zero();
// 析构时自动归还池
StreamManager
CUDA 流管理器,支持多流并发执行。
头文件: stream_manager.h
单例访问
1
static StreamManager& instance();
成员函数
| 函数 | 说明 |
|---|---|
void init(int num_streams = 4) |
初始化流池 |
cudaStream_t get_stream() |
轮询获取流 |
cudaStream_t get_stream(int index) |
按索引获取流 |
void sync(int index) |
同步指定流 |
void sync_all() |
同步所有流 |
void cleanup() |
清理所有流 |
int num_streams() const |
流数量 |
使用示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include "stream_manager.h"
// 单例访问
auto& sm = StreamManager::instance();
// 初始化(可选,会自动懒初始化)
sm.init(4); // 创建 4 个流
// 获取流(轮询分配)
cudaStream_t s1 = sm.get_stream();
cudaStream_t s2 = sm.get_stream();
// 按索引获取
cudaStream_t s0 = sm.get_stream(0);
// 同步
sm.sync(0); // 同步流 0
sm.sync_all(); // 同步所有流
// 清理
sm.cleanup();
AsyncOperation 辅助类
1
2
3
4
5
6
7
8
9
10
11
12
13
AsyncOperation op;
// 记录事件
op.record(stream);
// 等待完成
op.wait();
// 非阻塞检查
bool done = op.is_complete();
// 让另一个流等待此操作
op.wait_on(other_stream);
InferenceEngine
神经网络推理引擎,支持多层前向传播。
头文件: inference_engine.h
成员函数
| 函数 | 说明 |
|---|---|
void init(int device_id = 0) |
初始化引擎 |
void cleanup() |
清理资源 |
bool load_weights(const std::string& path) |
加载权重 |
bool save_weights(const std::string& path) const |
保存权重 |
void add_layer(int in, int out, bool bias, const float* w, const float* b) |
添加层 |
void forward(const float* in, float* out, int batch) |
前向传播 |
void forward_with_timing(const float* in, float* out, int batch, std::vector<float>& times) |
带计时前向传播 |
size_t num_layers() const |
层数 |
int input_dim() const |
输入维度 |
int output_dim() const |
输出维度 |
使用示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include "inference_engine.h"
// 创建引擎
InferenceEngine engine;
engine.init(0);
// 加载权重
if (!engine.load_weights("model.bin")) {
std::cerr << "Failed to load weights" << std::endl;
return -1;
}
// 查询网络信息
size_t n_layers = engine.num_layers();
int in_dim = engine.input_dim();
int out_dim = engine.output_dim();
// 准备数据
int batch_size = 32;
DeviceMemory d_input(batch_size * in_dim * sizeof(float));
DeviceMemory d_output(batch_size * out_dim * sizeof(float));
// 前向传播
engine.forward(d_input.get(), d_output.get(), batch_size);
// 带计时的前向传播
std::vector<float> layer_times;
engine.forward_with_timing(d_input.get(), d_output.get(),
batch_size, layer_times);
// 保存权重
engine.save_weights("model_copy.bin");
// 清理
engine.cleanup();
创建权重文件
1
2
3
4
5
6
7
8
9
// 定义网络结构: {输入维度, 输出维度}
std::vector<std::pair<int, int>> layer_dims = {
{784, 256}, // Layer 0: 784 -> 256
{256, 128}, // Layer 1: 256 -> 128
{128, 10} // Layer 2: 128 -> 10
};
// 创建随机权重文件(用于测试)
create_random_weights("model.bin", layer_dims, true); // true = 包含 bias
GEMM Kernel API
基础 Kernel 启动函数
头文件: kernels.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// 基础实现
void launch_naive_matmul(const float* A, const float* B, float* C,
int M, int N, int K, cudaStream_t stream = 0);
void launch_tiled_gemm(const float* A, const float* B, float* C,
int M, int N, int K, cudaStream_t stream = 0);
void launch_coalesced_gemm(const float* A, const float* B, float* C,
int M, int N, int K, cudaStream_t stream = 0);
void launch_double_buffer_gemm(const float* A, const float* B, float* C,
int M, int N, int K, cudaStream_t stream = 0);
// 优化实现(推荐)
void launch_optimized_gemm(const float* A, const float* B, float* C,
int M, int N, int K, cudaStream_t stream = 0);
// 融合 GEMM + Bias + ReLU
void launch_fused_gemm(const float* A, const float* B, float* C,
const float* bias, int M, int N, int K,
bool add_bias, bool apply_relu,
cudaStream_t stream = 0);
// cuBLAS 包装
void launch_cublas_gemm(cublasHandle_t handle,
const float* A, const float* B, float* C,
int M, int N, int K, cudaStream_t stream = 0);
批量 GEMM
头文件: batch_gemm.h
1
2
3
4
5
6
7
8
9
10
11
12
// 创建批量描述符
BatchGemmDesc desc(M, N, K, batch_size);
for (int i = 0; i < batch_size; i++) {
desc.add_matrices(A_ptrs[i], B_ptrs[i], C_ptrs[i]);
}
// 使用多流并行执行
launch_batched_gemm_streams(desc, GemmKernelType::REGISTER_BLOCKED, cublas_handle);
// Strided 批量 GEMM(矩阵连续存储)
launch_strided_batched_gemm(A, B, C, M, N, K, batch_size,
GemmKernelType::REGISTER_BLOCKED, cublas_handle);
半精度 GEMM
头文件: half_gemm.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
// FP16 GEMM(FP16 输入/输出, FP32 累加)
void launch_half_gemm(const half* A, const half* B, half* C,
int M, int N, int K, cudaStream_t stream = 0);
// 混合精度(FP16 输入, FP32 输出)
void launch_mixed_precision_gemm(const half* A, const half* B, float* C,
int M, int N, int K, cudaStream_t stream = 0);
// 类型转换
void convert_float_to_half(const float* src, half* dst, size_t n,
cudaStream_t stream = 0);
void convert_half_to_float(const half* src, float* dst, size_t n,
cudaStream_t stream = 0);
向量化 GEMM
头文件: vectorized_gemm.cuh
1
2
3
// 使用 float4 向量化加载
void launch_vectorized_gemm(const float* A, const float* B, float* C,
int M, int N, int K, cudaStream_t stream = 0);
批量 GEMM 操作
头文件: batch_gemm.h
BatchGemmDesc
批量 GEMM 操作描述符,保存多个矩阵的指针。
1
2
3
4
5
6
7
8
9
10
11
12
13
struct BatchGemmDesc {
std::vector<const float*> A_ptrs; // A 矩阵指针数组
std::vector<const float*> B_ptrs; // B 矩阵指针数组
std::vector<float*> C_ptrs; // C 矩阵指针数组
int M, N, K; // 矩阵维度(所有批次相同)
int batch_size; // 矩阵对数量
// 构造函数
BatchGemmDesc(int m, int n, int k, int batch);
// 添加一组矩阵到批次
void add_matrices(const float* A, const float* B, float* C);
};
使用示例
1
2
3
4
5
6
7
// 为 4 个 128x64 的矩阵创建批次描述符
BatchGemmDesc desc(128, 64, 32, 4);
// 添加矩阵指针
for (int i = 0; i < 4; i++) {
desc.add_matrices(d_A[i], d_B[i], d_C[i]);
}
launch_batched_gemm_streams
使用 CUDA 流并行执行多个 GEMM 操作。
1
2
3
4
5
void launch_batched_gemm_streams(
const BatchGemmDesc& desc,
GemmKernelType kernel_type = GemmKernelType::REGISTER_BLOCKED,
cublasHandle_t cublas_handle = nullptr
);
参数
| 参数 | 说明 |
|---|---|
desc |
包含矩阵指针的批次描述符 |
kernel_type |
每个操作使用的 GEMM kernel |
cublas_handle |
cuBLAS 句柄(当 kernel_type 为 CUBLAS 时需要) |
launch_strided_batched_gemm
在连续存储的矩阵数组上执行批量 GEMM。
1
2
3
4
5
6
void launch_strided_batched_gemm(
const float* A, const float* B, float* C,
int M, int N, int K, int batch_size,
GemmKernelType kernel_type = GemmKernelType::REGISTER_BLOCKED,
cublasHandle_t cublas_handle = nullptr
);
矩阵步长计算:
- A:
M * K个元素 - B:
K * N个元素 - C:
M * N个元素
launch_cublas_batched_gemm
优化的 cuBLAS 批量 GEMM,适用于多个小矩阵。
1
2
3
4
5
void launch_cublas_batched_gemm(
cublasHandle_t handle,
const float** A_array, const float** B_array, float** C_array,
int M, int N, int K, int batch_size
);
BatchPerfStats
批量 GEMM 性能统计。
1
2
3
4
5
6
7
8
struct BatchPerfStats {
float total_time_ms; // 总执行时间
float avg_time_per_gemm_ms; // 每个 GEMM 平均时间
float total_gflops; // 总吞吐量
int batch_size; // 操作数量
void compute(int M, int N, int K, int batch);
};
benchmark_batched_gemm
基准测试批量 GEMM 性能。
1
2
3
4
5
6
7
BatchPerfStats benchmark_batched_gemm(
const BatchGemmDesc& desc,
GemmKernelType kernel_type,
int warmup_iters = 3,
int bench_iters = 10,
cublasHandle_t cublas_handle = nullptr
);
工具类
GpuTimer
GPU 计时器,使用 CUDA Events 精确测量 kernel 执行时间。
头文件: kernels.cuh
1
2
3
4
5
6
7
GpuTimer timer;
timer.start(stream);
// ... kernel 执行 ...
timer.stop(stream);
float ms = timer.elapsed_ms();
Profiler
性能分析器,提供详细的统计数据和 Roofline 分析。
头文件: profiler.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
Profiler profiler;
// 分析 kernel
auto result = profiler.profile("Optimized GEMM", M, N, K,
5, // warmup iterations
20, // benchmark iterations
[&]() { launch_optimized_gemm(A, B, C, M, N, K); });
printf("Time: %.3f ms\n", result.avg_time_ms);
printf("GFLOPS: %.2f\n", result.gflops);
printf("Bandwidth: %.2f GB/s\n", result.memory_bandwidth_gb);
// 打印结果
Profiler::print_result(result);
// 比较多个 kernel
std::vector<ProfileResult> results;
results.push_back(result1);
results.push_back(result2);
Profiler::print_comparison(results, "cuBLAS");
ProfileResult 结构体:
1
2
3
4
5
6
7
8
9
10
11
struct ProfileResult {
std::string name;
float avg_time_ms; // 平均时间
float min_time_ms; // 最小时间
float max_time_ms; // 最大时间
float std_dev_ms; // 标准差
float gflops; // GFLOPS
float memory_bandwidth_gb; // 内存带宽
float arithmetic_intensity; // 计算强度
int iterations; // 迭代次数
};
RooflineAnalyzer
Roofline 模型分析器,用于识别性能瓶颈。
头文件: profiler.h
1
2
3
4
5
6
7
8
9
10
11
12
// 创建分析器,指定峰值性能指标
RooflineAnalyzer analyzer(
10000.0f, // 峰值 GFLOPS(设备特定)
900.0f // 峰值带宽 GB/s
);
// 添加分析结果
analyzer.add_point(result1);
analyzer.add_point(result2);
// 打印 Roofline 分析
analyzer.analyze();
输出包括:
- 计算强度 (FLOPs/Byte)
- 实际 vs. 理论 GFLOPS
- 效率百分比
- 内存受限 vs. 计算受限分类
AutoTuner
自动调优器,为给定矩阵尺寸选择最优 kernel。
头文件: autotuner.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
AutoTuner tuner;
// 调优并缓存结果
auto result = tuner.tune(M, N, K, cublas_handle);
printf("Best kernel: %s\n", kernel_type_name(result.config.kernel_type));
printf("Performance: %.2f GFLOPS\n", result.gflops);
// 获取缓存结果(不重新调优)
auto cached = tuner.get_best(M, N, K, cublas_handle);
// 使用最优 kernel 执行
tuner.execute_best(A, B, C, M, N, K, stream);
// 打印缓存
tuner.print_cache();
TuningResult 结构体:
1
2
3
4
5
6
7
8
9
struct TuningResult {
TuningConfig config; // 包含 kernel_type
float time_ms; // 执行时间
float gflops; // 性能指标
};
struct TuningConfig {
GemmKernelType kernel_type; // 选中的 kernel 类型
};
配置与日志
Config
配置管理器,支持文件和环境变量。
头文件: config.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// 单例访问
auto& config = Config::instance();
// 从文件加载
config.load_from_file("config/default.ini");
// 获取值(带默认值)
int device = config.get_int("CUDA_DEVICE", 0);
float rate = config.get_float("LEARNING_RATE", 0.001f);
bool enabled = config.get_bool("ENABLE_TENSOR_CORES", true);
std::string preset = config.get("GEMM_PRESET", "medium");
// 设置值
config.set("LOG_LEVEL", "DEBUG");
// 检查键是否存在
if (config.has("CUDA_DEVICE")) { ... }
// 获取所有键
auto keys = config.keys();
// 清空
config.clear();
GEMM 预设: get_gemm_preset("medium")
可用预设: "small", "medium", "large", "volta", "ampere"
设备配置: get_device_config()
Logger
线程安全的日志系统。
头文件: logger.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// 设置日志级别
Logger::instance().set_level(LogLevel::DEBUG);
// 输出到文件
Logger::instance().set_file("app.log");
// 控制台输出开关
Logger::instance().set_console(true);
// 颜色开关
Logger::instance().set_colors(true);
// 日志宏
LOG_TRACE("Trace message");
LOG_DEBUG("Debug: value = %d", value);
LOG_INFO("Info message");
LOG_WARN("Warning: %s", msg);
LOG_ERROR("Error occurred");
LOG_FATAL("Fatal error");
日志级别:
| 级别 | 说明 |
|---|---|
TRACE |
详细跟踪 |
DEBUG |
调试信息 |
INFO |
一般信息 |
WARN |
警告 |
ERROR |
错误 |
FATAL |
致命错误 |
OFF |
关闭日志 |
量化支持
QuantizationParams
量化参数。
头文件: quantization.h
1
2
3
4
5
6
7
8
9
10
11
// 计算量化参数(对称量化)
QuantizationParams params = compute_quant_params(data, n);
// params.scale, zero_point, min_val, max_val
// 量化
std::vector<int8_t> quantized(n);
quantize_tensor(data, quantized.data(), n, params);
// 反量化
std::vector<float> dequantized(n);
dequantize_tensor(quantized.data(), dequantized.data(), n, params);
QuantizedWeight
量化权重存储。
1
2
3
4
5
6
7
8
// 从浮点权重创建
QuantizedWeight qw(float_weights, rows, cols);
// 反量化
auto restored = qw.dequantize();
// 压缩比
float ratio = qw.compression_ratio(); // 4.0 (FP32 -> INT8)
Per-Channel 量化
1
2
3
4
5
6
7
8
// 计算每通道参数
auto params = compute_per_channel_params(data, rows, cols);
// 量化
quantize_per_channel(data, quantized.data(), rows, cols, params);
// 反量化
dequantize_per_channel(quantized.data(), restored.data(), rows, cols, params);
QuantizationCalibrator
量化校准器,用于动态量化。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
QuantizationCalibrator calibrator;
// 观察多个批次
calibrator.observe(batch1, n1);
calibrator.observe(batch2, n2);
// 获取最终参数
auto params = calibrator.get_params();
// 统计信息
float mean = calibrator.mean();
float var = calibrator.variance();
// 重置
calibrator.reset();
错误处理
CUDA_CHECK 宏
1
2
3
4
5
6
// 检查 CUDA 调用
CUDA_CHECK(cudaMalloc(&ptr, size));
CUDA_CHECK(cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToDevice));
CUDA_CHECK(cudaDeviceSynchronize());
// 失败时抛出 CudaException
CUBLAS_CHECK 宏
1
2
3
4
// 检查 cuBLAS 调用
CUBLAS_CHECK(cublasSgemm(handle, ...));
// 失败时抛出 std::runtime_error
异常处理示例
1
2
3
4
5
6
7
8
9
10
11
try {
// CUDA 操作
} catch (const CudaException& e) {
// CUDA 错误
std::cerr << e.what() << std::endl;
std::cerr << "Error: " << cudaGetErrorName(e.error()) << std::endl;
} catch (const std::invalid_argument& e) {
// 参数验证错误
} catch (const std::runtime_error& e) {
// 其他运行时错误
}
性能测量
工具函数
头文件: common.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// 用随机值初始化数组(范围 [-1.0, 1.0])
void random_init(float* data, size_t n);
// 用零初始化数组
void zero_init(float* data, size_t n);
// CPU 参考实现(GEMM)
void cpu_matmul(const float* A, const float* B, float* C, int M, int N, int K);
// 带偏置和 ReLU 的 CPU 参考
void cpu_matmul_bias_relu(const float* A, const float* B, float* C,
const float* bias, int M, int N, int K,
bool add_bias, bool apply_relu);
// 比较矩阵,返回最大绝对误差
float compare_matrices(const float* a, const float* b, size_t n);
benchmark_kernel 函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include "kernels.cuh"
PerfStats stats = benchmark_kernel(
GemmKernelType::REGISTER_BLOCKED, // kernel 类型
A, B, C, // 矩阵指针
M, N, K, // 维度
5, // warmup iterations
20, // benchmark iterations
cublas_handle, // cuBLAS 句柄
stream // CUDA 流
);
printf("Time: %.3f ms\n", stats.kernel_time_ms);
printf("GFLOPS: %.2f\n", stats.gflops);
printf("Bandwidth: %.2f GB/s\n", stats.memory_bandwidth_gb);
PerfStats 结构体
1
2
3
4
5
6
struct PerfStats {
float kernel_time_ms; // kernel 时间
float gflops; // GFLOPS
float memory_bandwidth_gb; // 内存带宽
float cublas_ratio; // 相对 cuBLAS 的比例
};
完整示例
端到端推理
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include "inference_engine.h"
#include "tensor.h"
#include <iostream>
int main() {
// 初始化
InferenceEngine engine;
engine.init(0);
// 加载权重
if (!engine.load_weights("mnist_model.bin")) {
std::cerr << "Failed to load model" << std::endl;
return -1;
}
// 准备输入
int batch_size = 32;
Tensor input({batch_size, engine.input_dim()});
input.fill(0.5f);
// 推理
Tensor output({batch_size, engine.output_dim()});
engine.forward(input.data(), output.data(), batch_size);
// 获取结果
auto result = output.to_host();
// 打印前 5 个预测
for (int i = 0; i < 5; i++) {
float* logits = &result[i * engine.output_dim()];
int pred = std::max_element(logits, logits + engine.output_dim()) - logits;
std::cout << "Sample " << i << ": predicted " << pred << std::endl;
}
engine.cleanup();
return 0;
}
相关链接
| *最后更新:2025-04-16 | 文档版本:v1.1.0* |