MNIST 数字识别
一个完整的示例,展示如何使用在 WebGPU 上运行的卷积神经网络对手写数字进行分类。
概述
MNIST 演示展示了核心推理工作流:
- 初始化推理引擎
- 定义 CNN 模型架构
- 加载模型权重
- 准备输入数据
- 运行推理并解释结果
源代码位于 /examples/mnist-demo.ts。
模型架构
MNIST 模型使用标准 CNN 架构:
输入 (1x28x28)
|
v
Conv2d (32 个滤波器, 3x3, 填充=1)
|
v
ReLU
|
v
MaxPool (2x2, 步长=2)
|
v
Conv2d (64 个滤波器, 3x3, 填充=1)
|
v
ReLU
|
v
MaxPool (2x2, 步长=2)
|
v
Flatten
|
v
Dense (10 个单元)
|
v
Softmax
|
v
输出 (10 个概率值)1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
逐步指南
步骤 1:初始化引擎
typescript
import { InferenceEngine } from 'tiny-dl-inference';
const engine = new InferenceEngine();
await engine.initialize();
console.log('引擎已初始化');1
2
3
4
5
2
3
4
5
步骤 2:定义模型
typescript
const modelDef = {
name: 'MNIST-CNN',
layers: [
{
name: 'conv1',
type: 'conv2d',
inputs: ['input', 'conv1_weight', 'conv1_bias'],
params: {
kernelSize: [3, 3],
stride: [1, 1],
padding: [1, 1],
useBias: true
}
},
{
name: 'relu1',
type: 'relu',
inputs: ['conv1'],
params: {}
},
{
name: 'pool1',
type: 'maxpool',
inputs: ['relu1'],
params: {
poolSize: [2, 2],
stride: [2, 2]
}
},
{
name: 'conv2',
type: 'conv2d',
inputs: ['pool1', 'conv2_weight', 'conv2_bias'],
params: {
kernelSize: [3, 3],
stride: [1, 1],
padding: [1, 1],
useBias: true
}
},
{
name: 'relu2',
type: 'relu',
inputs: ['conv2'],
params: {}
},
{
name: 'pool2',
type: 'maxpool',
inputs: ['relu2'],
params: {
poolSize: [2, 2],
stride: [2, 2]
}
},
{
name: 'flatten',
type: 'flatten',
inputs: ['pool2'],
params: {}
},
{
name: 'fc',
type: 'dense',
inputs: ['flatten', 'fc_weight', 'fc_bias'],
params: {
units: 10,
useBias: true
}
},
{
name: 'output',
type: 'softmax',
inputs: ['fc'],
params: {
axis: -1
}
}
],
weights: {
conv1_weight: { data: new Float32Array(32 * 1 * 3 * 3).fill(0.01), shape: [32, 1, 3, 3] },
conv1_bias: { data: new Float32Array(32).fill(0), shape: [32] },
conv2_weight: { data: new Float32Array(64 * 32 * 3 * 3).fill(0.01), shape: [64, 32, 3, 3] },
conv2_bias: { data: new Float32Array(64).fill(0), shape: [64] },
fc_weight: { data: new Float32Array(10 * 3136).fill(0.01), shape: [10, 3136] },
fc_bias: { data: new Float32Array(10).fill(0), shape: [10] }
}
};
await engine.loadModel(modelDef);1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
提示
在实际应用中,权重应从训练好的模型文件中加载,而不是使用占位值初始化。
步骤 3:准备输入
typescript
// 创建 28x28 灰度图像张量 [batch=1, channels=1, height=28, width=28]
const inputData = new Float32Array(1 * 1 * 28 * 28);
// 填充实际像素数据(归一化到 [0, 1])
for (let i = 0; i < inputData.length; i++) {
inputData[i] = pixelData[i] / 255.0;
}
const input = engine.tensorFromArray(inputData, [1, 1, 28, 28]);1
2
3
4
5
6
7
8
9
2
3
4
5
6
7
8
9
步骤 4:运行推理
typescript
const startTime = performance.now();
const output = await engine.infer(input);
const endTime = performance.now();
const result = await output.download();
console.log(`执行时间:${(endTime - startTime).toFixed(2)}ms`);1
2
3
4
5
6
2
3
4
5
6
步骤 5:解释结果
typescript
console.log('\n预测结果:');
for (let i = 0; i < 10; i++) {
const probability = (result[i] * 100).toFixed(2);
const bar = '\u2588'.repeat(Math.floor(result[i] * 50));
console.log(`数字 ${i}: ${bar} ${probability}%`);
}
const predictedDigit = result.indexOf(Math.max(...Array.from(result)));
console.log(`\n预测数字:${predictedDigit}`);1
2
3
4
5
6
7
8
9
2
3
4
5
6
7
8
9
步骤 6:清理资源
typescript
input.destroy();
output.destroy();
engine.destroy();1
2
3
2
3
核心概念
张量形状
MNIST 图像是 28x28 的灰度像素。张量形状 [1, 1, 28, 28] 表示:
1- 批次大小(单张图像)1- 通道数(灰度)28- 高度28- 宽度
层连接
每层的 inputs 数组引用特殊标记 'input'(用于第一层)或前一层的名称。引擎从这些连接构建计算图。
内存管理
所有张量和引擎必须显式销毁以释放 GPU 内存。在生产代码中使用 try/finally 块,确保即使发生错误也能清理。
运行演示
bash
npx ts-node examples/mnist-demo.ts1
演示将输出每个数字(0-9)的预测概率,并显示预测的数字。