MNIST Digit Recognition
A complete example showing how to classify handwritten digits using a convolutional neural network running on WebGPU.
Overview
The MNIST demo demonstrates the core inference workflow:
- Initialize the inference engine
- Define a CNN model architecture
- Load model weights
- Prepare input data
- Run inference and interpret results
The source code is located at /examples/mnist-demo.ts.
Model Architecture
The MNIST model uses a standard CNN architecture:
Input (1x28x28)
|
v
Conv2d (32 filters, 3x3, padding=1)
|
v
ReLU
|
v
MaxPool (2x2, stride=2)
|
v
Conv2d (64 filters, 3x3, padding=1)
|
v
ReLU
|
v
MaxPool (2x2, stride=2)
|
v
Flatten
|
v
Dense (10 units)
|
v
Softmax
|
v
Output (10 probabilities)2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
Step-by-Step Guide
Step 1: Initialize the Engine
import { InferenceEngine } from 'tiny-dl-inference';
const engine = new InferenceEngine();
await engine.initialize();
console.log('Engine initialized');2
3
4
5
Step 2: Define the Model
const modelDef = {
name: 'MNIST-CNN',
layers: [
{
name: 'conv1',
type: 'conv2d',
inputs: ['input', 'conv1_weight', 'conv1_bias'],
params: {
kernelSize: [3, 3],
stride: [1, 1],
padding: [1, 1],
useBias: true
}
},
{
name: 'relu1',
type: 'relu',
inputs: ['conv1'],
params: {}
},
{
name: 'pool1',
type: 'maxpool',
inputs: ['relu1'],
params: {
poolSize: [2, 2],
stride: [2, 2]
}
},
{
name: 'conv2',
type: 'conv2d',
inputs: ['pool1', 'conv2_weight', 'conv2_bias'],
params: {
kernelSize: [3, 3],
stride: [1, 1],
padding: [1, 1],
useBias: true
}
},
{
name: 'relu2',
type: 'relu',
inputs: ['conv2'],
params: {}
},
{
name: 'pool2',
type: 'maxpool',
inputs: ['relu2'],
params: {
poolSize: [2, 2],
stride: [2, 2]
}
},
{
name: 'flatten',
type: 'flatten',
inputs: ['pool2'],
params: {}
},
{
name: 'fc',
type: 'dense',
inputs: ['flatten', 'fc_weight', 'fc_bias'],
params: {
units: 10,
useBias: true
}
},
{
name: 'output',
type: 'softmax',
inputs: ['fc'],
params: {
axis: -1
}
}
],
weights: {
conv1_weight: { data: new Float32Array(32 * 1 * 3 * 3).fill(0.01), shape: [32, 1, 3, 3] },
conv1_bias: { data: new Float32Array(32).fill(0), shape: [32] },
conv2_weight: { data: new Float32Array(64 * 32 * 3 * 3).fill(0.01), shape: [64, 32, 3, 3] },
conv2_bias: { data: new Float32Array(64).fill(0), shape: [64] },
fc_weight: { data: new Float32Array(10 * 3136).fill(0.01), shape: [10, 3136] },
fc_bias: { data: new Float32Array(10).fill(0), shape: [10] }
}
};
await engine.loadModel(modelDef);2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
Note
In a real application, weights would be loaded from a trained model file rather than initialized with placeholder values.
Step 3: Prepare Input
// Create a 28x28 grayscale image tensor [batch=1, channels=1, height=28, width=28]
const inputData = new Float32Array(1 * 1 * 28 * 28);
// Fill with actual pixel data (normalized to [0, 1])
for (let i = 0; i < inputData.length; i++) {
inputData[i] = pixelData[i] / 255.0;
}
const input = engine.tensorFromArray(inputData, [1, 1, 28, 28]);2
3
4
5
6
7
8
9
Step 4: Run Inference
const startTime = performance.now();
const output = await engine.infer(input);
const endTime = performance.now();
const result = await output.download();
console.log(`Execution time: ${(endTime - startTime).toFixed(2)}ms`);2
3
4
5
6
Step 5: Interpret Results
console.log('\nPredictions:');
for (let i = 0; i < 10; i++) {
const probability = (result[i] * 100).toFixed(2);
const bar = '\u2588'.repeat(Math.floor(result[i] * 50));
console.log(`Digit ${i}: ${bar} ${probability}%`);
}
const predictedDigit = result.indexOf(Math.max(...Array.from(result)));
console.log(`\nPredicted digit: ${predictedDigit}`);2
3
4
5
6
7
8
9
Step 6: Cleanup
input.destroy();
output.destroy();
engine.destroy();2
3
Key Concepts
Tensor Shapes
MNIST images are 28x28 grayscale pixels. The tensor shape [1, 1, 28, 28] represents:
1- Batch size (single image)1- Channels (grayscale)28- Height28- Width
Layer Connections
Each layer's inputs array references either the special 'input' token (for the first layer) or the name of a previous layer. The engine builds a computation graph from these connections.
Memory Management
All tensors and the engine must be explicitly destroyed to free GPU memory. Use try/finally blocks in production code to ensure cleanup even on errors.
Running the Demo
npx ts-node examples/mnist-demo.tsThe demo will output prediction probabilities for each digit (0-9) and display the predicted digit.
Next Steps
- See Custom Model Loading for building models from scratch
- Read about Memory Layout for performance optimization
- Check the API Reference for detailed method documentation