Machine learning in the browser with TensorFlow.js

Hubert Legęć
Pragmatists
Published in
8 min readMar 13, 2019

--

Recently, more and more attention has been paid to artificial intelligence and machine learning. It would seem that these concepts are completely unrelated to web development and JavaScript technologies. They are usually associated with Python/R environment or even C++ libraries. One of the most popular frameworks used by an ever-increasing number of developers is TensorFlow. It was developed in Google in 2011. It’s built in C++ and has bindings to different languages like Python, R or Java. But what about JavaScript?

Time for JavaScript

For a long time, in JavaScript, machine learning was performed server-side. A trained model was deployed at the server and made available via the HTTP protocol, for example. The web app sent a request with the necessary data using JS to get a result from the server.

In March 2018, TensorFlow.js appeared and allows you to do ML/DL in JavaScript, without having to use server-side applications. You can use it to define, train and run machine learning models entirely in the browser and a high-level layers API. From the user’s perspective it’s very simple and convenient; there is no need to install any libraries or drivers. Just open the web page and your program is ready to run.

Speed is crucial

So far, everything looks good, but you’re probably wondering what performance will look like. As we all know, JS is single-threaded and runs only on the CPU, which is designed for task-switching and high-latency jobs, not for high-throughput. On the other hand, a GPU is designed for heavy workload and throughput. What distinguishes neural networks is that the dot product between the weights and input at every neuron can be run in parallel, and the thousands of cores inside a GPU are perfect candidates to do this job. Thanks to this, calculations can be performed many times faster than when a CPU is used.

Fortunately, there is no reason to worry. TensorFlow.js automatically supports WebGL and will accelerate your code when a GPU is available. WebGL is a browser interface to OpenGL, and enables the execution of JavaScript code on a GPU.

What can it be used for?

You can consider three workflows when working with TensorFlow.js:

  • The existing, pretrained model can be imported for inference
  • The imported model can be retrained
  • The brand-new model can be designed, trained and run in the browser

There are two types of API — low and high-level. Low-level API (called also Core API) can be used for linear algebra and automatic differentiation. TensorFlow SavedModels can also be imported with use of this API layer. High-level API (called also Layers API) will seem familiar to those who have ever used Keras. Importing trained Keras models can be done with Layers API, as well as designing, training and running new models directly in JavaScript.

TensorFlow.js layers

Handwritten digit recognition

The best way to find out how something works is to use it in practice. So let’s get our hands dirty and create a sample application. I’ve decided to use MNIST handwriting dataset and prepare an application for recognizing handwritten numbers. It’s a browser application based on ReactJS.

The source code is available on GitHub and a working application is deployed here.

Getting started

To get TensorFlow.js to your project, you have to execute only one simple Yarn or NPM command:

yarn add @tensorflow/tfjsnpm install @tensorflow/tfjs

And then to import it in your JS/TS file:

import * as tf from '@tensorflow/tfjs';

Building the model

The most important element of the application we are creating is the neural network model. We’ll build a convolutional image classifier model based on Layers API. To do so, we’ll use a Sequential model (the simplest type of model), in which tensors are consecutively passed from one layer to the next. Let’s instantiate our Sequential model with tf.sequential and add layers to it:

Building MNIST model

The first layer we’ll add is a two-dimensional convolutional layer (more about convolutions). In its configuration object there is a inputShape property. It’s the shape of the data that will flow into the first layer of the model (MNIST examples are 28x28 black and white images). The second maxPooling2d layer will downsample the result from the convolution by computing the maximum value for each sliding window of size 2x2. Layers repeating is a common pattern in neural networks. Let’s add two more convolutional layers to our model. Next is a flatten layer, which flattens the output of the previous layer to a vector. At the end, we have a dense layer (also known as a fully connected layer) that will perform the final classification.

Preparation for training

To drive training of the model, we’ll need an optimizer, loss function and an evaluation metric to measure how our model performs on the data.

We’ll use a stochastic gradient descent (SGD) optimizer with a learning rate of 0.15 as an optimizer. For our loss function, we’ll use cross-entropy (categoricalCrossentropy), which is commonly used to optimize classification tasks. We’ll use accuracy as an evaluation metric, which measures the percentage of correct predictions out of all predictions.

Before we begin training, we need to define a few more parameters:

static BATCH_SIZE = 128;
static VALIDATION_SPLIT = 0.15;
trainEpochs = 5;
async loadData() {
await this._data.load();
this._trainData = this._data.trainData;
this._testData = this._data.testData;
this._totalNumBatches = Math.ceil(
this._trainData.xs.shape[0] *
(1 - MnistModel.VALIDATION_SPLIT) /
MnistModel.BATCH_SIZE
) * this.trainEpochs;
}

BATCH_SIZE tells us how many simultaneous images the model sees before updating the parameters. We want to batch multiple inputs together and feed them through the network using a single feed-forward call to take advantage of the GPU’s ability to parallelize computation. The second reason to batch inputs is that we want to update internal parameters (taking a step) only after averaging gradients from several examples to avoid taking a step in the wrong direction based on a single example. VALIDATION_SPLIT determines what part of the image set will be used for testing.

To preprocess MNIST data, I’ve used data.js, script prepared by the TenorFlow team which contains the class MnistData that fetches random batches of MNIST images from a hosted version of the MNIST dataset.

Training

Training can be started by calling the fit method on the compiled model object. Its first two parameters are training images and labels. The third one is a configuration object, which contains parameters such as batchSize, the aforementioned validationSplit, the number of training epochs, and callbacks. onBatchEnd is called after each batch is processed with values such as current loss and accuracy. onEpochEnd is executed after the end of each epoch, with accuracy and loss achieved by the model on the validation dataset.

Training MNIST model

When training is finished, it’s worth checking the quality of the model on the test dataset with the evaluate method.

In a prepared application to start training, we can select the number of training epochs and press the Train button. During training, the user can watch the progress live on charts.

Training view of the prepared application

It makes no sense to train the model each time you want to use it, unless it has to be adjusted every time new data appears. In most cases, you want to train your model once and then simply use it. TensorFlow.js comes with a bunch of methods to store and load the trained model. It can be Local Storage, IndexedDB, an HTTP request or a local file system. You can find out more about this here.

Recognition

To make a prediction, we need a trained model and an image with a handwritten digit. The first part we already have, so let’s prepare some images. In order to have some area for writing, I’ve used the react-sketch library. It’s easy to use and has all features we need. As its output, we receive a base64-encoded PNG image with the size of canvas. We need to resize it and convert it to an ImageData object with monochromatic images to finally create a Tensor object from its data.

The easiest way to convert a base64-encoded image to the ImageData object and resize it at the same time, is to use a native HTML canvas object:

const image = new Image();
image.onload = () => {
const canvas = document.createElement('canvas');
canvas.width = outputSize;
canvas.height = outputSize;
const context = canvas.getContext('2d');
context.drawImage(image, 0, 0, outputSize, outputSize);

const imageData = context.getImageData(0, 0, outputSize, outputSize);
};
image.src = base64Image;

To create a Tensor that will flow into the first layer of the model, we can use the tf.fromPixels method:

const inputTensor = fromPixels(imageData, 1)
.reshape([1, 28, 28, 1])
.cast('float32')
.div(scalar(255));

Passing 1 as the second argument of the fromPixels function, we say that we only want to use the first channel from the RGBA image. The model allows us to feed it with multiple images at the same time, so our input has the shape [number_of_images, image_width, image_height, number_of_channels]. In our case, we only have one image, so we have to reshape it to match the expected dimensions. Each pixel should be between 0 and 1, so we have to divide the values of all pixels by 255. And that’s it: we are ready to make a prediction:

const predictionResult =  this._model.predict(inputTensor).dataSync();
const recognizedDigit = predictionResult.indexOf(Math.max(...predictionResult));

As a prediction result, we get an array of probabilities for each digit, so we have to pick the index with the highest probability and here we are ;)

In the prepared application, you have to draw a digit with the mouse cursor on the left side, press the Recognize button, and the result should appear on the right side of the screen.

Digit recognition

Conclusions

As you can see, starting an adventure with machine learning is not too complicated. You don’t have to know Python or C++, nor do you need to be an expert in Math. It’s possible to create a neural network model, train it and make predictions as a frontend programmer.

If you’re interested in the details, I encourage you to look into the TensorFlow.js documentation, e.g. here.

--

--