自定义算子
为 Tiny-DL-Inference 创建新的 WebGPU 神经网络算子。
概述
虽然 Tiny-DL-Inference 提供了常用算子,但你可能需要创建自定义算子来实现:
- 新的激活函数(Sigmoid、Tanh、GELU)
- 专用层(注意力机制、归一化)
- 模型特定操作
算子结构
每个算子继承自 Operator 基类:
typescript
import { Operator, Tensor, TensorShape, OperatorParams } from 'tiny-dl-inference';
class SigmoidOperator extends Operator {
private initialized = false;
protected compileShader(): string {
return /* wgsl */`
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let idx = id.x;
let total = arrayLength(&input.data);
if (idx >= total) { return; }
let x = input.data[idx];
output.data[idx] = 1.0 / (1.0 + exp(-x));
}
`;
}
computeOutputShape(inputShape: TensorShape, params?: OperatorParams): TensorShape {
return inputShape; // Sigmoid 不改变形状
}
async forward(inputs: Tensor[], params?: OperatorParams): Promise<Tensor> {
this.ensureInitialized();
const input = inputs[0];
const outputShape = this.computeOutputShape(input.shape);
const output = new Tensor(this.context, outputShape);
// 创建 uniform buffer
const paramsBuffer = this.context.createBuffer({
size: 8,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
});
// 写入参数
this.context.getQueue().writeBuffer(paramsBuffer, 0, new Uint32Array([
input.shape[0], // batchSize
input.size / input.shape[0] // features
]));
// 创建绑定组
const bindGroup = this.context.getDevice().createBindGroup({
layout: this.bindGroupLayout!,
entries: [
{ binding: 0, resource: { buffer: input.buffer } },
{ binding: 1, resource: { buffer: output.buffer } },
{ binding: 2, resource: { buffer: paramsBuffer } }
]
});
// 分发计算
const workgroupCount = Math.ceil(input.size / 256);
const encoder = this.context.getDevice().createCommandEncoder();
const pass = encoder.computePass();
pass.setPipeline(this.pipeline!);
pass.setBindGroup(0, bindGroup);
pass.dispatchWorkgroups(workgroupCount);
pass.end();
this.context.getQueue().submit([encoder.finish()]);
this.context.deferDestroy(paramsBuffer);
this.context.deferDestroy(input.buffer);
return output;
}
private ensureInitialized(): void {
if (this.initialized) return;
this.initializePipeline(this.compileShader(), []);
this.initialized = true;
}
}1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
WGSL 编写指南
基本结构
wgsl
@group(0) @binding(0)
var<storage, read> input: array<f32>;
@group(0) @binding(1)
var<storage, read_write> output: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
struct Params {
batchSize: u32,
features: u32
}
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let idx = id.x;
let total = arrayLength(&input);
if (idx >= total) { return; }
// 计算逻辑
output[idx] = /* 计算结果 */;
}1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
常用 WGSL 函数
wgsl
// 数学函数
exp(x) // e^x
log(x) // 自然对数
sqrt(x) // 平方根
abs(x) // 绝对值
max(a, b) // 最大值
min(a, b) // 最小值
clamp(x, a, b) // 限制范围
// 激活函数
fn sigmoid(x: f32) -> f32 {
return 1.0 / (1.0 + exp(-x));
}
fn tanh(x: f32) -> f32 {
let exp2x = exp(2.0 * x);
return (exp2x - 1.0) / (exp2x + 1.0);
}
fn relu(x: f32) -> f32 {
return max(0.0, x);
}1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
测试算子
单元测试
typescript
import { describe, it, expect } from 'vitest';
import { createMockContext } from '../helpers/mockContext';
describe('SigmoidOperator', () => {
it('should compute sigmoid function', async () => {
const context = createMockContext();
const op = new SigmoidOperator(context);
const input = Tensor.fromArray(context,
new Float32Array([0.0, 1.0, -1.0, 2.0]),
[1, 4]
);
const output = await op.forward([input]);
const result = await output.download();
expect(result[0]).toBeCloseTo(0.5, 3);
expect(result[1]).toBeCloseTo(0.731, 3);
expect(result[2]).toBeCloseTo(0.269, 3);
expect(result[3]).toBeCloseTo(0.881, 3);
input.destroy();
output.destroy();
});
});1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
属性测试
typescript
import * as fc from 'fast-check';
it('should preserve shape', async () => {
await fc.assert(fc.asyncProperty(
fc.array(fc.float({ noDefaultInfinity: true, noNaN: true })),
async (values) => {
const context = createMockContext();
const op = new SigmoidOperator(context);
const input = Tensor.fromArray(context, new Float32Array(values), [1, values.length]);
const output = await op.forward([input]);
expect(output.shape).toEqual(input.shape);
input.destroy();
output.destroy();
return true;
}
));
});1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
最佳实践
- 继承 Operator:使用基类管理管线
- 延迟初始化:使用 ensureInitialized() 模式
- 资源清理:使用 deferDestroy() 管理临时缓冲区
- 形状验证:在 computeOutputShape() 中验证参数
- 错误处理:快速失败并提供描述性错误消息
- WGSL 优化:使用适当的 workgroup 大小