Build a Handwritten Digit Recognizer in Java with TensorFlow
This article walks through the complete process of creating, training, evaluating, saving, and loading a MNIST handwritten digit recognition model using TensorFlow in Java, comparing it with the equivalent Python implementation and covering required knowledge, environment setup, and code details.
Introduction: Teaching people to fish rather than giving them a fish, this guide shows how to implement a simple handwritten digit recognition model using the MNIST dataset in Java with TensorFlow, providing a bridge for Java‑oriented backend developers who find most tutorials in Python.
Goal
Train a model on the MNIST dataset to recognize hand‑written numbers.
Required knowledge
Machine‑learning basics (supervised, unsupervised, reinforcement learning)
Data processing and analysis (cleaning, feature engineering, visualization)
Programming language (Python is common, but this guide focuses on Java)
Mathematics (linear algebra, probability, calculus)
ML algorithms (linear regression, decision trees, neural networks, SVM)
Deep‑learning frameworks (TensorFlow, PyTorch)
Model evaluation and optimization (cross‑validation, hyper‑parameter tuning, metrics)
Practical experience through projects and competitions
The Hello‑World example uses the TensorFlow framework.
Main requirements
Understand the MNIST data shape (60000,28,28) and label format
Know the role of activation functions
Understand forward and backward propagation
Train and save the model
Load and use the saved model
Java vs Python code comparison
Python code for loading data (omitted for brevity) and Java code for loading data:
<code>def load_data(dpata_folder):
files = ["train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"]
paths = []
for fname in files:
paths.append(os.path.join(data_folder, fname))
# ... (gzip loading and reshaping) ...
return (train_x, train_y), (test_x, test_y)
(train_x, train_y), (test_x, test_y) = load_data("mnistDataSet/")
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s' % (train_x.shape, train_y.shape, test_x.shape, test_y.shape))
</code>Java constants for archive paths:
<code>private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz";
private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz";
private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz";
private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz";
</code>Model construction
Python (Keras) model:
<code>model = tensorflow.keras.Sequential()
model.add(tensorflow.keras.layers.Flatten(input_shape=(28, 28)))
model.add(tensorflow.keras.layers.Dense(128, activation='relu'))
model.add(tensorflow.keras.layers.Dense(10, activation='softmax'))
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
</code>Java (TensorFlow Java API) model building:
<code>Ops tf = Ops.create(graph);
Placeholder<TFloat32> images = tf.placeholder(TFloat32.class);
Placeholder<TFloat32> labels = tf.placeholder(TFloat32.class);
Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES);
Variable<TFloat32> weights = tf.variable(tf.zeros(tf.constant(weightShape), TFloat32.class));
Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES);
Variable<TFloat32> biases = tf.variable(tf.zeros(tf.constant(biasShape), TFloat32.class));
MatMul<TFloat32> matMul = tf.linalg.matMul(images, weights);
Add<TFloat32> add = tf.math.add(matMul, biases);
Softmax<TFloat32> softmax = tf.nn.softmax(add);
Mean<TFloat32> crossEntropy = tf.math.mean(
tf.math.neg(tf.reduceSum(tf.math.mul(labels, tf.math.log(softmax)), tf.array(1))),
tf.array(0));
Optimizer optimizer = new GradientDescent(graph, LEARNING_RATE);
Op minimize = optimizer.minimize(crossEntropy);
Operand<TInt64> predicted = tf.math.argMax(softmax, tf.constant(1));
Operand<TInt64> expected = tf.math.argMax(labels, tf.constant(1));
Operand<TFloat32> accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.class), tf.array(0));
</code>Training the model
Python:
<code>history = model.fit(train_x, train_y, batch_size=64, epochs=5, validation_split=0.2)
</code>Java:
<code>for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
try (TFloat32 batchImages = preprocessImages(trainingBatch.images());
TFloat32 batchLabels = preprocessLabels(trainingBatch.labels())) {
session.runner()
.addTarget(minimize)
.feed(images.asOutput(), batchImages)
.feed(labels.asOutput(), batchLabels)
.run();
}
}
</code>Model evaluation
Python:
<code>test_loss, test_acc = model.evaluate(test_x, test_y)
print('Test loss: %.3f' % test_loss)
print('Test accuracy: %.3f' % test_acc)
</code>Java:
<code>ImageBatch testBatch = dataset.testBatch();
try (TFloat32 testImages = preprocessImages(testBatch.images());
TFloat32 testLabels = preprocessLabels(testBatch.labels());
TFloat32 accuracyValue = (TFloat32) session.runner()
.fetch(accuracy)
.fetch(predicted)
.fetch(expected)
.feed(images.asOutput(), testImages)
.feed(labels.asOutput(), testLabels)
.run()
.get(0)) {
System.out.println("Accuracy: " + accuracyValue.getFloat());
}
</code>Saving the model
Python:
<code>save_model(model, 'D:\pythonProject\mnistDemo\number_model', save_format='pb')
</code>Java:
<code>SavedModelBundle.Exporter exporter = SavedModelBundle.exporter("D:\\ai\\ai-demo").withSession(session);
Signature.Builder builder = Signature.builder();
builder.input("images", images);
builder.input("labels", labels);
builder.output("accuracy", accuracy);
builder.output("expected", expected);
builder.output("predicted", predicted);
Signature signature = builder.build();
SessionFunction sessionFunction = SessionFunction.create(signature, session);
exporter.withFunction(sessionFunction);
exporter.export();
</code>Loading the model
Python:
<code>load_model = load_model('D:\pythonProject\mnistDemo\number_model')
load_model.summary()
predictValue = load_model.predict(input_data)
</code>Java:
<code>SavedModelBundle model = SavedModelBundle.load(exportDir, "serve");
printSignature(model);
Result run = model.session().runner()
.feed("Placeholder:0", testImages)
.feed("Placeholder_1:0", testLabels)
.fetch("ArgMax:0")
.fetch("ArgMax_1:0")
.fetch("Mean_1:0")
.run();
// process outputs …
</code>Full Python code (mnistTrainDemo.py)
<code>import gzip, os.path, tensorflow as tf, matplotlib.pyplot as plt, numpy as np
# load_data, model definition, compile, fit, evaluate, save (see source for details)
</code>Full Java code
Dependencies (Maven):
<code><dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.6.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-framework</artifactId>
<version>0.6.0-SNAPSHOT</version>
</dependency>
</code>Key classes:
MnistDataset – reads MNIST gzip archives, provides training/validation/test tensors.
SimpleMnist – builds the graph, trains, evaluates, saves, and loads the model.
<code>package org.example.tensorDemo.datasets.mnist;
public class MnistDataset { /* code as in source (readArchive, getOneValidationImage, etc.) */ }
</code> <code>package org.example.tensorDemo.dense;
public class SimpleMnist implements Runnable { /* full implementation from source */ }
</code>Running results
Pending improvements
Add a web service to accept image input and perform binary preprocessing.
Replace the simple linear model with convolutional neural networks for higher accuracy.
Explore deeper network architectures and hyper‑parameter tuning.
JD Cloud Developers
JD Cloud Developers (Developer of JD Technology) is a JD Technology Group platform offering technical sharing and communication for AI, cloud computing, IoT and related developers. It publishes JD product technical information, industry content, and tech event news. Embrace technology and partner with developers to envision the future.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.