自定义模型加载
学习如何使用 Tiny-DL-Inference 定义、创建和加载自定义神经网络模型。
概述
本指南展示如何从零开始构建模型,包括:
- 定义模型架构
- 创建具有正确形状的权重张量
- 将模型加载到推理引擎中
- 理解层连接和参数
模型定义结构
Tiny-DL-Inference 中的每个模型都定义为包含两个主要部分的 JavaScript 对象:layers 和 weights。
typescript
const modelDef = {
name: 'MyModel', // 可选:模型标识符
layers: [ /* 层定义 */ ],
weights: { /* 权重张量 */ }
};1
2
3
4
5
2
3
4
5
步骤 1:定义层
每个层是一个包含以下属性的对象:
| 属性 | 类型 | 描述 |
|---|---|---|
name | string | 该层的唯一标识符 |
type | string | 算子类型('conv2d'、'relu'、'dense' 等) |
inputs | string[] | 输入张量/层的名称 |
params | object | 算子特定参数 |
层定义示例
typescript
{
name: 'conv1',
type: 'conv2d',
inputs: ['input', 'conv1_weight', 'conv1_bias'],
params: {
kernelSize: [3, 3],
stride: [1, 1],
padding: [1, 1],
useBias: true
}
}1
2
3
4
5
6
7
8
9
10
11
2
3
4
5
6
7
8
9
10
11
特殊输入名称
'input'- 保留名称,表示模型的主输入张量- 前一层的层名称 - 引用更早层的输出
步骤 2:创建权重
权重定义为从张量名称到数据和形状的映射:
typescript
weights: {
conv1_weight: {
data: new Float32Array(/* 权重值 */),
shape: [输出通道数, 输入通道数, 卷积核高度, 卷积核宽度]
},
conv1_bias: {
data: new Float32Array(/* 偏置值 */),
shape: [输出通道数]
}
}1
2
3
4
5
6
7
8
9
10
2
3
4
5
6
7
8
9
10
权重形状约定
| 层类型 | 权重形状 | 偏置形状 |
|---|---|---|
| Conv2d | [outChannels, inChannels, kH, kW] | [outChannels] |
| Dense | [units, inFeatures] | [units] |
创建随机权重
typescript
function randomWeights(shape: number[], scale = 0.01): Float32Array {
const size = shape.reduce((a, b) => a * b, 1);
return new Float32Array(size).map(() =>
(Math.random() - 0.5) * 2 * scale
);
}
// 使用示例
const conv1Weights = randomWeights([32, 1, 3, 3], 0.1);
const fcWeights = randomWeights([10, 3136], 0.01);1
2
3
4
5
6
7
8
9
10
2
3
4
5
6
7
8
9
10
步骤 3:完整示例
以下是一个完整的图像分类自定义模型:
typescript
import { InferenceEngine } from 'tiny-dl-inference';
async function createAndLoadModel() {
const engine = new InferenceEngine();
await engine.initialize();
// 定义模型
const modelDef = {
name: 'CustomClassifier',
layers: [
// 卷积块
{
name: 'conv1',
type: 'conv2d',
inputs: ['input', 'conv1_weight', 'conv1_bias'],
params: {
kernelSize: [3, 3],
stride: [1, 1],
padding: [1, 1],
useBias: true
}
},
{
name: 'relu1',
type: 'relu',
inputs: ['conv1'],
params: {}
},
{
name: 'pool1',
type: 'maxpool',
inputs: ['relu1'],
params: {
poolSize: [2, 2],
stride: [2, 2]
}
},
// 全连接分类器
{
name: 'flatten',
type: 'flatten',
inputs: ['pool1'],
params: {}
},
{
name: 'fc',
type: 'dense',
inputs: ['flatten', 'fc_weight', 'fc_bias'],
params: {
units: 10,
useBias: true
}
},
{
name: 'output',
type: 'softmax',
inputs: ['fc'],
params: { axis: -1 }
}
],
weights: {
conv1_weight: {
data: randomWeights([32, 3, 3, 3]),
shape: [32, 3, 3, 3]
},
conv1_bias: {
data: new Float32Array(32).fill(0),
shape: [32]
},
fc_weight: {
data: randomWeights([10, 32 * 16 * 16]),
shape: [10, 32 * 16 * 16]
},
fc_bias: {
data: new Float32Array(10).fill(0),
shape: [10]
}
}
};
// 加载模型
await engine.loadModel(modelDef);
console.log('模型加载成功');
return engine;
}
function randomWeights(shape: number[], scale = 0.01): Float32Array {
const size = shape.reduce((a, b) => a * b, 1);
return new Float32Array(size).map(() => (Math.random() - 0.5) * 2 * scale);
}1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
步骤 4:运行推理
typescript
async function runInference(engine: InferenceEngine) {
// 创建输入张量 [batch=1, channels=3, height=32, width=32]
const inputData = new Float32Array(1 * 3 * 32 * 32);
// 填充实际数据...
const input = engine.tensorFromArray(inputData, [1, 3, 32, 32]);
// 运行推理
const output = await engine.infer(input);
const predictions = await output.download();
console.log('预测结果:', predictions);
// 清理
input.destroy();
output.destroy();
}1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
支持的层类型
| 类型 | 描述 | 关键参数 |
|---|---|---|
conv2d | 2D 卷积 | kernelSize、stride、padding、useBias |
relu | ReLU 激活函数 | 无 |
maxpool | 最大池化 | poolSize、stride |
flatten | 展平为一维 | 无 |
dense | 全连接层 | units、useBias |
softmax | Softmax 激活函数 | axis |
从文件加载权重
在生产环境中,权重通常从文件加载,而不是随机生成:
typescript
async function loadWeightsFromFile(url: string): Promise<Record<string, { data: Float32Array; shape: number[] }>> {
const response = await fetch(url);
const weightData = await response.json();
const weights: Record<string, { data: Float32Array; shape: number[] }> = {};
for (const [name, entry] of Object.entries(weightData)) {
weights[name] = {
data: new Float32Array(entry.data),
shape: entry.shape
};
}
return weights;
}
// 使用示例
const externalWeights = await loadWeightsFromFile('model-weights.json');
const modelDef = {
layers: [ /* ... */ ],
weights: externalWeights
};1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
常见错误
形状不匹配
typescript
// 错误:Dense 输入特征与展平大小不匹配
{
name: 'fc',
type: 'dense',
inputs: ['flatten', 'fc_weight', 'fc_bias'],
params: { units: 10 }
}
// 权重形状为 [10, 1000],但 flatten 产生 8192 个特征
// 正确:匹配实际的展平大小
fc_weight: {
data: randomWeights([10, 8192]),
shape: [10, 8192]
}1
2
3
4
5
6
7
8
9
10
11
12
13
14
2
3
4
5
6
7
8
9
10
11
12
13
14
缺少偏置
typescript
// 如果 useBias: true,必须提供偏置权重
{
name: 'conv1',
type: 'conv2d',
inputs: ['input', 'conv1_weight', 'conv1_bias'],
params: { useBias: true }
}1
2
3
4
5
6
7
2
3
4
5
6
7
不正确的输入引用
typescript
// 错误:引用不存在的层
{
name: 'relu1',
type: 'relu',
inputs: ['nonexistent_layer'], // 错误!
params: {}
}
// 正确:引用实际的前一层名称
{
name: 'relu1',
type: 'relu',
inputs: ['conv1'],
params: {}
}1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
2
3
4
5
6
7
8
9
10
11
12
13
14
15