Edge AI Model Evaluation and Optimization with TensorFlow, JAX, and TVM
The article demonstrates how to evaluate, compress, and convert deep‑learning models for edge devices using TensorFlow, JAX, and TVM—showing a faster iPhone‑based MNIST training benchmark, FLOPs measurement scripts, TFLite/ONNX/CoreML conversion, TVM compilation with auto‑tuning, and up to 50 % speed improvements on mobile NPU hardware.
This article introduces practical engineering for edge AI, starting with a quantitative comparison of training a MNIST model on a 2015 15‑inch MacBook Pro (i7 CPU) and an iPhone 13 Pro Max. The iPhone completes 60 000 training samples in 86 seconds versus 128 seconds on the laptop, demonstrating that modern mobile devices can handle both inference and training.
It then discusses how to evaluate and prepare models for mobile deployment. Model complexity is measured by parameters, size, and FLOPs (floating‑point operations). The A15 NPU in the iPhone 13 Pro Max delivers 15.8 TOPS, while a Qualcomm Snapdragon 855 provides only 7 TOPS, highlighting the need for model compression, pruning, quantization, or knowledge distillation before deployment.
TensorFlow FLOPs calculation
# TensorFlow 推荐计算方法
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
def get_flops(model):
concrete = tf.function(lambda inputs: model(inputs))
concrete_func = concrete.get_concrete_function(
[tf.TensorSpec([1, *inputs.shape[1:]]) for inputs in model.inputs])
frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(concrete_func)
with tf.Graph().as_default() as graph:
tf.graph_util.import_graph_def(graph_def, name='')
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
flops = tf.compat.v1.profiler.profile(graph=graph, run_meta=run_meta, cmd="op", options=opts)
return flops.total_float_ops
print("The FLOPs is:{}".format(get_flops(model)), flush=True)For PyTorch users, the thop library provides a similar utility:
# 推荐开源工具 pytorch-OpCounter
from thop import profile
input = torch.randn(1, 1, 28, 28)
macs, params = profile(model, inputs=(input, ))
print('Total macc:{}, Total params: {}'.format(macs, params))
#输出:Total macc:2307720.0, Total params: 431080.0TensorFlow model conversion to TFLite
import tensorflow as tf
# Create a model using high‑level tf.keras APIs
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(units=1, input_shape=[1]),
tf.keras.layers.Dense(units=16, activation='relu'),
tf.keras.layers.Dense(units=1)
])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(x=[-1, 0, 1], y=[-3, -1, 1], epochs=5)
# (to generate a SavedModel) tf.saved_model.save(model, "saved_model_keras_dir")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open('model.tflite', 'wb') as f:
f.write(tflite_model)Command‑line conversion of a SavedModel or an H5 file:
# 转换 SaveModel 文件
python -m tflite_convert \
--saved_model_dir=/tmp/mobilenet_saved_model \
--output_file=/tmp/mobilenet.tflite
# 转换 H5 格式模型文件
python -m tflite_convert \
--keras_model_file=/tmp/mobilenet_keras_model.h5 \
--output_file=/tmp/mobilenet.tfliteONNX can be generated from TensorFlow models using tf2onnx :
# 安装并转换 ONNX 模型的命令行工具
pip install -U tf2onnx
python -m tf2onnx.convert \
--saved-model ./output/saved_model \
--output ./output/mnist1.onnx \
--opset 7CoreML models can be converted to ONNX with onnxmltools :
import coremltools
import onnxmltools
input_coreml_model = 'model.mlmodel'
output_onnx_model = 'model.onnx'
coreml_model = coremltools.utils.load_spec(input_coreml_model)
onnx_model = onnxmltools.convert_coreml(coreml_model)
onnxmltools.utils.save_model(onnx_model, output_onnx_model)JAX example for MNIST
pip install tf-nightly --upgrade
pip install jax --upgrade
pip install jaxlib --upgrade import numpy as np
import tensorflow as tf
import functools
import time
import itertools
import numpy.random as npr
# JAX 新增部分
import jax.numpy as jnp
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax def _one_hot(x, k, dtype=np.float32):
"""Create a one‑hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.astype(np.float32)
test_images = test_images.astype(np.float32)
train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10) init_random_params, predict = stax.serial(
stax.Flatten,
stax.Dense(1024), stax.Relu,
stax.Dense(1024), stax.Relu,
stax.Dense(10), stax.LogSoftmax) serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
f.write(tflite_model)JAX accelerates training on TPU/GPU/NPU, but the exported TFLite model runs on the TFLite runtime, not JAX.
TVM compilation workflow
# Install dependencies on macOS
brew install gcc git cmake
brew install llvm
brew install [email protected] # Clone TVM source
git clone --recursive https://github.com/apache/tvm tvm
cd tvm
mkdir build
cp cmake/config.cmake build # Build TVM (enable LLVM, CUDA, etc. in config.cmake as needed)
cd build
cmake ..
make -j4 # Install Python package and other deps
pip3 install --user numpy decorator attrs tornado psutil xgboost cloudpickle
pip3 install --user onnx onnxoptimizer libomp pillow
export MACOSX_DEPLOYMENT_TARGET=10.9
cd python; python setup.py install --user; cd ..Example scripts for preprocessing input and post‑processing output:
#!python ./preprocess.py 用于生成模型输入
from tvm.contrib.download import download_testdata
from PIL import Image
import numpy as np
img_url = "https://s3.amazonaws.com/model-server/inputs/kitten.jpg"
img_path = download_testdata(img_url, "imagenet_cat.png", module="data")
resized_image = Image.open(img_path).resize((224, 224))
img_data = np.asarray(resized_image).astype("float32")
img_data = np.transpose(img_data, (2, 0, 1))
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_stddev = np.array([0.229, 0.224, 0.225])
norm_img_data = np.zeros(img_data.shape).astype("float32")
for i in range(img_data.shape[0]):
norm_img_data[i, :, :] = (img_data[i, :, :] / 255 - imagenet_mean[i]) / imagenet_stddev[i]
img_data = np.expand_dims(norm_img_data, axis=0)
np.savez("imagenet_cat", data=img_data) #!python ./postprocess.py
import os.path
import numpy as np
from scipy.special import softmax
from tvm.contrib.download import download_testdata
labels_url = "https://s3.amazonaws.com/onnx-model-zoo/synset.txt"
labels_path = download_testdata(labels_url, "synset.txt", module="data")
with open(labels_path, "r") as f:
labels = [l.rstrip() for l in f]
output_file = "predictions.npz"
if os.path.exists(output_file):
with np.load(output_file) as data:
scores = softmax(data["output_0"])
scores = np.squeeze(scores)
ranks = np.argsort(scores)[::-1]
for rank in ranks[0:5]:
print("class='%s' with probability=%f" % (labels[rank], scores[rank]))Compile an ONNX model to a TVM package:
python -m tvm.driver.tvmc compile \
--target "llvm" \
--output resnet50-v2-7-tvm.tar \
resnet50-v2-7.onnx # Extract the compiled artifacts
mkdir model
tar -xvf resnet50-v2-7-tvm.tar -C model
ls modelRun the compiled model:
python -m tvm.driver.tvmc run \
--inputs imagenet_cat.npz \
--output predictions.npz \
resnet50-v2-7-tvm.tarAuto‑tuning for better performance:
python -m tvm.driver.tvmc tune \
--target "llvm" \
--output resnet50-v2-7-autotuner_records.json \
resnet50-v2-7.onnx python -m tvm.driver.tvmc compile \
--target "llvm" \
--tuning-records resnet50-v2-7-autotuner_records.json \
--output resnet50-v2-7-tvm_autotuned.tar \
resnet50-v2-7.onnxBenchmark the tuned vs. untuned models:
# Tuned model benchmark
python -m tvm.driver.tvmc run \
--inputs imagenet_cat.npz \
--output predictions.npz \
--print-time \
--repeat 100 \
resnet50-v2-7-tvm_autotuned.tar
# Execution time summary (ms): mean 108.9, median 106.1, max 172.0, min 103.3, std 10.4
# Untuned model benchmark
python -m tvm.driver.tvmc run \
--inputs imagenet_cat.npz \
--output predictions.npz \
--print-time \
--repeat 100 \
resnet50-v2-7-tvm.tar
# Execution time summary (ms): mean 135.0, median 131.9, max 202.7, min 124.5, std 11.1The article concludes that converting and compiling models with TVM, TensorFlow, or JAX enables efficient edge AI deployment across iOS, Android, and WebGPU backends, while auto‑tuning can yield 30‑50% speedups.
DaTaobao Tech
Official account of DaTaobao Technology
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.