Custom Operators
Creating your own operators for Tiny-DL-Inference.
Overview
You can extend Tiny-DL-Inference by creating custom operators. This guide shows you how to implement a new operator from scratch.
Operator Structure
Base Class
All operators extend the Operator base class:
typescript
import { Operator, OperatorParams, Tensor, TensorShape, GPUContext } from 'tiny-dl-inference';
export class MyOperator extends Operator {
constructor(context: GPUContext) {
super(context);
}
// 1. Define the WGSL shader
protected compileShader(): string {
return /* wgsl */`
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
// Your shader implementation
}
`;
}
// 2. Compute output shape
computeOutputShape(inputShape: TensorShape, params?: OperatorParams): TensorShape {
// Return the shape of the output tensor
return inputShape;
}
// 3. Execute the operation
async forward(inputs: Tensor[], params?: OperatorParams): Promise<Tensor> {
this.ensureInitialized();
// Implementation
}
}1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
Complete Example: Sigmoid
Let's implement a Sigmoid activation operator:
1. Create the Operator Class
typescript
// src/operators/Sigmoid.ts
import { Operator, OperatorParams, Tensor, TensorShape, GPUContext } from '../index';
export class Sigmoid extends Operator {
constructor(context: GPUContext) {
super(context);
}
protected compileShader(): string {
return /* wgsl */`
@group(0) @binding(0) var<storage, read_write> output: array<f32>;
@group(0) @binding(1) var<storage, read> input: array<f32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let idx = id.x;
let x = input[idx];
output[idx] = 1.0 / (1.0 + exp(-x));
}
`;
}
computeOutputShape(inputShape: TensorShape): TensorShape {
return inputShape;
}
async forward(inputs: Tensor[]): Promise<Tensor> {
this.ensureInitialized();
const [input] = inputs;
const outputShape = this.computeOutputShape(input.shape);
const output = Tensor.zeros(this.context, outputShape);
// Create bind group and dispatch
// ... implementation details
return output;
}
}1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
2. Register with Engine
typescript
// In your engine initialization
import { Sigmoid } from './operators/Sigmoid';
const sigmoid = new Sigmoid(context);
engine.registerOperator('sigmoid', sigmoid);1
2
3
4
5
2
3
4
5
WGSL Shader Guide
Compute Shader Structure
wgsl
// Bindings
@group(0) @binding(0) var<storage, read_write> output: array<f32>;
@group(0) @binding(1) var<storage, read> input: array<f32>;
// Entry point
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
// Check bounds
if (id.x >= arrayLength(&output)) { return; }
// Compute
output[id.x] = process(input[id.x]);
}1
2
3
4
5
6
7
8
9
10
11
12
13
2
3
4
5
6
7
8
9
10
11
12
13
Common Patterns
Element-wise Operations
wgsl
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let idx = id.x;
if (idx >= arrayLength(&output)) { return; }
output[idx] = activation(input[idx]);
}1
2
3
4
5
6
2
3
4
5
6
Reduction Operations
wgsl
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) lid: vec3<u32>) {
// Use shared memory for reduction
var<workgroup> shared: array<f32, 256>;
// ... reduction logic
}1
2
3
4
5
6
2
3
4
5
6
Testing Custom Operators
Property-Based Tests
typescript
import { describe, it, expect } from 'vitest';
import * as fc from 'fast-check';
describe('Sigmoid', () => {
it('should output values in [0, 1]', async () => {
await fc.assert(
fc.asyncProperty(
fc.float({ min: -100, max: 100 }),
async (x) => {
const output = await sigmoid.forward([tensorFromValue(x)]);
const result = await output.download();
return result[0] >= 0 && result[0] <= 1;
}
),
{ numRuns: 100 }
);
});
});1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
CPU Reference
Always implement a CPU reference for validation:
typescript
function sigmoidCPU(x: number): number {
return 1.0 / (1.0 + Math.exp(-x));
}1
2
3
2
3
Best Practices
- Validate inputs - Check shapes and types early
- Use workgroup size 256 - Good balance for most GPUs
- Check bounds - Prevent out-of-bounds memory access
- Test thoroughly - Property tests catch edge cases
- Profile - Measure actual GPU performance
API Reference
See the Operators API Reference for the complete operator API.