Build a Browser‑Based MNIST Classifier with TensorFlow.js: A Step‑by‑Step Guide
Learn how to create a browser‑compatible MNIST image classification model using TensorFlow.js, covering data preprocessing with sprite images, model construction, training, and evaluation, while providing complete JavaScript code examples and practical tips for handling ArrayBuffer, DataView, and visualization.
1. Introduction
In 2017, to reduce the heavy workload of image moderation during the Double‑11 shopping festival, we launched an intelligent image‑review project that used deep learning to successfully audit tens of millions of pictures. Later we attempted a JavaScript version (tensjs) but faced many challenges; the official TensorFlow.js release, especially version 2.0, solved most of them. Recent advances in deep learning have far surpassed human performance in many tasks, and with TensorFlow.js the learning curve is lower, making engineering‑level deep learning more accessible.
2. Overview
This series requires no prior deep‑learning background, avoids mathematical notation, and explains concepts through JavaScript code examples using TensorFlow.js in both browser and Node.js environments.
3. Hello World
Most deep‑learning tutorials start with the MNIST dataset, a classic set of 60,000 training and 10,000 test handwritten digit images. By training on these images we can build a simple image‑recognition model and learn the basics of TensorFlow.js.
3.1 Data Preprocessing
3.1.1 Image Data Preprocessing
In the browser we store the MNIST images as a single sprite sheet and the labels as binary data, loading them directly from Google Cloud URLs.
<code>const MNIST_IMAGES_SPRITE_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
const MNIST_LABELS_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';</code>To convert the sprite into binary data we create an
Image, draw it onto a
canvas, and extract the pixel buffer.
<code>const img = new Image();
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
img.crossOrigin = '';
img.onload = () => {
img.width = img.naturalWidth;
img.height = img.naturalHeight;
ctx.drawImage(img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize);
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
};
img.src = MNIST_IMAGES_SPRITE_PATH;</code>3.1.2 ArrayBuffer & DataView
TensorFlow.js frequently works with
ArrayBufferobjects, for example when handling canvas data, fetching binary resources, or reading files.
<code>// Canvas Uint8ClampedArray
const canvas = document.getElementById('myCanvas');
const ctx = canvas.getContext('2d');
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
const uint8ClampedArray = imageData.data;
// Fetch ArrayBuffer
fetch(url)
.then(function(response) { return response.arrayBuffer(); })
.then(function(arrayBuffer) { /* ... */ });
// File ArrayBuffer
const fileInput = document.getElementById('fileInput');
const file = fileInput.files[0];
const reader = new FileReader();
reader.readAsArrayBuffer(file);
reader.onload = function () {
const arrayBuffer = reader.result;
/* ... */
};</code>3.1.3 Test and Validation Data
Training on the full dataset can cause over‑fitting; therefore we split the data into training and validation subsets.
<code>this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);</code>3.2 Build Model
Deep learning constructs a multi‑layer network; TensorFlow.js lets us define such a network with just a few lines of code.
<code>function createDenseModel() {
const model = tf.sequential();
model.add(tf.layers.flatten({inputShape: [IMAGE_H, IMAGE_W, 1]}));
model.add(tf.layers.dense({units: 42, activation: 'relu'}));
model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
return model;
}</code>3.3 Train Model
After compiling, the model can be trained; the
tfjs‑vislibrary helps visualize the training progress.
<code>model.compile({
optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
await model.fit(trainData.xs, trainData.labels, {
batchSize,
validationSplit,
epochs: trainEpochs
});</code>4. Online Demo
Live demo
Official example code
5. Next Steps
Future articles will dive deeper into core TensorFlow.js concepts.
6. References
tfjs‑examples
Python Deep Learning (mentioned for further reading)
Taobao Frontend Technology
The frontend landscape is constantly evolving, with rapid innovations across familiar languages. Like us, your understanding of the frontend is continually refreshed. Join us on Taobao, a vibrant, all‑encompassing platform, to uncover limitless potential.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.