AI

Exporting AI models on Android with XNNPACK and ExecuTorch

Wojtek JasinskiAug 14, 202412 min read

Bringing Native AI to Your Mobile Apps with ExecuTorch — part III

ExecuTorch is a new PyTorch-based framework that allows you to export your models to formats suitable for local deployments on devices such as smartphones or microcontrollers. This article will guide you through converting your PyTorch models into a format optimized for running on Android devices with good performance.

Check out our previous blog posts on this topic: [part I — Running models on IOS with CoreML using ExecuTorch] [part II — Running models on Android using ExecuTorch] 🤖

Environment setup 🛠️

Follow the official ExecuTorch setup guide. We use ExecuTorch v0.3.0 with Python 3.10. Unless specified otherwise, run all code and commands from the ExecuTorch repository root.

⚠️ Keep in mind that the project is under heavy development and many interfaces will change in future releases.

Lowering the Model ⬇️

The process of converting a model involves several steps, referred to as “lowering” since each step brings us closer to the hardware level.

TLDR We can divide lowering the model into 4 steps: 1. Export: Convert PyTorch code to an execution graph. 2. To Edge: Simplify the graph to core ATen operators. 3. Delegate: Offload some operations to optimized backends. 4. Serialize: Save the model as an ExecuTorch program.

In our specific case, we want to work with a style transfer model from the PyTorch examples repository. The model is a simple convolutional autoencoder. It has its quirks that we might notice during the export process — it is very often the case when working with models you will find on GitHub etc.

Start the process in Python with our model as torch.nn.Module. Load the weights and set them to eval mode.

import torch
from fast_neural_style.neural_style.transformer_net import TransformerNet

state_dict = torch.load("saved_models/candy.pth")
model = TransformerNet()
model.load_state_dict(state_dict)
model.eval()

1. Export

The first step is to convert the PyTorch program into an execution graph. This graph represents the model through operations like addition, multiplication, and convolution, rather than plain Python code.

To export the model, prepare example_inputs— a tuple of tensor inputs — and pass it to the export function.

from torch.export import export, ExportedProgram

example_inputs = (torch.randn(1, 3, 640, 640),)
exported_program: ExportedProgram = export(model, example_inputs)

print(exported_program)

The output shows the new graph representation. Each node corresponds to an operation, and its inputs, and includes a link back to the original Python code line that generated it. This is incredibly useful for debugging.

A few example lines from the output:

...

# File: /Users/woj/export_tut/examples/fast_neural_style/neural_style/transformer_net.py:53 in forward, code: out = self.conv2d(out)
conv2d_11: "f32[1, 128, 160, 160]" = torch.ops.aten.conv2d.default(pad_11, p_res5_conv1_conv2d_weight, p_res5_conv1_conv2d_bias);  pad_11 = p_res5_conv1_conv2d_weight = p_res5_conv1_conv2d_bias = None

# File: /Users/woj/export_tut/examples/fast_neural_style/neural_style/transformer_net.py:73 in forward, code: out = self.relu(self.in1(self.conv1(x)))
instance_norm_11: "f32[1, 128, 160, 160]" = torch.ops.aten.instance_norm.default(conv2d_11, p_res5_in1_weight, p_res5_in1_bias, None, None, True, 0.1, 1e-05, True);  conv2d_11 = p_res5_in1_weight = p_res5_in1_bias = None
relu_7: "f32[1, 128, 160, 160]" = torch.ops.aten.relu.default(instance_norm_11);  instance_norm_11 = None

...

2. To edge

Next, we convert the graph to a representation proper for edge devices — an Edge Dialect. During this step, operators are decomposed to core ATen operators. Each Input and output of the graph, each node, and all Scalar types are converted to Tensors.

from executorch.exir import EdgeProgramManager, to_edge

edge: EdgeProgramManager = to_edge(exported_program)

print(edge.exported_program())

A few example lines from the output referring to the same part of the execution graph as above:

...

# File: /Users/woj/export_tut/examples/fast_neural_style/neural_style/transformer_net.py:53 in forward, code: out = self.conv2d(out)
aten_convolution_default_3: "f32[1, 128, 160, 160]" = executorch_exir_dialects_edge__ops_aten_convolution_default(aten_index_tensor_7, p_res1_conv1_conv2d_weight, p_res1_conv1_conv2d_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  aten_index_tensor_7 = p_res1_conv1_conv2d_weight = p_res1_conv1_conv2d_bias = None

# File: /Users/woj/export_tut/examples/fast_neural_style/neural_style/transformer_net.py:73 in forward, code: out = self.relu(self.in1(self.conv1(x)))
aten_repeat_default_6: "f32[128]" = executorch_exir_dialects_edge__ops_aten_repeat_default(p_res1_in1_weight, [1]);  p_res1_in1_weight = None
aten_repeat_default_7: "f32[128]" = executorch_exir_dialects_edge__ops_aten_repeat_default(p_res1_in1_bias, [1]);  p_res1_in1_bias = None
aten_view_copy_default_6: "f32[1, 128, 160, 160]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_convolution_default_3, [1, 128, 160, 160]);  aten_convolution_default_3 = None
aten__native_batch_norm_legit_no_stats_3 = executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_stats(aten_view_copy_default_6, aten_repeat_default_6, aten_repeat_default_7, True, 0.1, 1e-05);  aten_view_copy_default_6 = aten_repeat_default_6 = aten_repeat_default_7 = None
getitem_3: "f32[1, 128, 160, 160]" = aten__native_batch_norm_legit_no_stats_3[0];  aten__native_batch_norm_legit_no_stats_3 = None
aten_view_copy_default_7: "f32[1, 128, 160, 160]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem_3, [1, 128, 160, 160]);  getitem_3 = None
aten_relu_default_3: "f32[1, 128, 160, 160]" = executorch_exir_dialects_edge__ops_aten_relu_default(aten_view_copy_default_7);  aten_view_copy_default_7 = None
...

3. Delegate

Parts of a model can be offloaded to optimized backends, allowing the rest of the program to run on the basic ExecuTorch runtime. This delegation leverages the performance and efficiency benefits of specialized hardware and libraries.

We’ll use XNNPACK, a highly optimized library for neural network inference targeting ARM, WebAssembly, and x86 platforms. It enhances neural network performance on mobile and edge devices by optimizing key operations. For more details, check out XNNPACK.

ℹ️ For Apple devices, CoreML is an excellent choice as it can leverage GPU and Neural Engine chips and is generally well-documented. On Android, Vulkan can be used for GPU leverage, though it is not yet well-supported. Meta developers are actively working on it!

To make it work, we call the to_backend() API with the XnnpackPartitioner. The partitioner identifies the suitable subgraphs, which will be serialized with the XNNPACK Delegate flatbuffer schema. At runtime, each subgraph will be replaced with a call to the XNNPACK Delegate.

from executorch.exir.backend.backend_api import to_backend
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

edge = edge.to_backend(XnnpackPartitioner()) 

Output:

INFO:executorch.backends.xnnpack.partition.xnnpack_partitioner:Found 31 subgraphs to be partitioned. `INFO:executorch.backends.xnnpack.partition.xnnpack_partitioner:Found 31 subgraphs to be partitioned.

The execution graph has been split into 31 subgraphs, each of which can be delegated to XNNPACK. The reason it wasn’t processed in one piece is due to unsupported operations in the XNNPACK backend. The model should still work correctly, though generally the more subgraphs there are, the slower the inference might be. We’ll address optimizing this issue shortly.

4. Serialize

After lowering to the XNNPACK Program, we save the model as a .pte file. It is a binary format that stores the serialized ExecuTorch graph, suitable for the runtime environment.

from executorch.exir import ExecutorchProgramManager

exec_prog: ExecutorchProgramManager = edge.to_executorch()

with open("style_transfer_xnnpack.pte", "wb") as file:
    exec_prog.write_to_file(file)

Your model should now be optimized and ready for deployment!

Troubleshooting and optimization 🔬

In our case, with no changes to the model architecture, it was not very fast. Let’s try to make a few patches and lower the model again.

TLDR We made several changes to the model architecture to fix issues and enhance performance: - We swapped reflection padding with zero padding - Replaced nearest neighbor interpolation with bilinear interpolation.

Identify and replace weird operations

Look for uncommon operations in the model architecture and, if possible –replace them with more standard ones. Looking through our model’s architecture code, you can notice torch.nn.ReflectionPad2d and torch.nn.InstanceNorm2d standing out as operators that might cause issues.

We further confirm suspicions by looking at what the ReflectionPad2d operator decomposed to (graph after to_edge() step).

...
# File: /Users/woj/export_tut/executorch/fast_neural_style/neural_style/transformer_net.py:52 in forward, code: out = self.reflection_pad(x)
aten_arange_start_step_34: "i64[648]" = executorch_exir_dialects_edge__ops_aten_arange_start_step(-4, 644, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
aten_abs_default_60: "i64[648]" = executorch_exir_dialects_edge__ops_aten_abs_default(aten_arange_start_step_34);  aten_arange_start_step_34 = None
aten_sub_tensor_74: "i64[648]" = executorch_exir_dialects_edge__ops_aten_sub_Tensor(_lifted_tensor_constant76, aten_abs_default_60);  _lifted_tensor_constant76 = aten_abs_default_60 = None
aten_abs_default_61: "i64[648]" = executorch_exir_dialects_edge__ops_aten_abs_default(aten_sub_tensor_74);  aten_sub_tensor_74 = None
aten_sub_tensor_75: "i64[648]" = executorch_exir_dialects_edge__ops_aten_sub_Tensor(_lifted_tensor_constant77, aten_abs_default_61);  _lifted_tensor_constant77 = aten_abs_default_61 = None
aten_index_tensor_38: "f32[1, 32, 648, 640]" = executorch_exir_dialects_edge__ops_aten_index_Tensor(aten_relu_default_9, [None, None, aten_sub_tensor_75, None]);  aten_relu_default_9 = aten_sub_tensor_75 = None
aten_arange_start_step_35: "i64[648]" = executorch_exir_dialects_edge__ops_aten_arange_start_step(-4, 644, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
aten_abs_default_62: "i64[648]" = executorch_exir_dialects_edge__ops_aten_abs_default(aten_arange_start_step_35);  aten_arange_start_step_35 = None
aten_sub_tensor_76: "i64[648]" = executorch_exir_dialects_edge__ops_aten_sub_Tensor(_lifted_tensor_constant78, aten_abs_default_62);  _lifted_tensor_constant78 = aten_abs_default_62 = None
aten_abs_default_63: "i64[648]" = executorch_exir_dialects_edge__ops_aten_abs_default(aten_sub_tensor_76);  aten_sub_tensor_76 = None
aten_sub_tensor_77: "i64[648]" = executorch_exir_dialects_edge__ops_aten_sub_Tensor(_lifted_tensor_constant79, aten_abs_default_63);  _lifted_tensor_constant79 = aten_abs_default_63 = None
aten_index_tensor_39: "f32[1, 32, 648, 648]" = executorch_exir_dialects_edge__ops_aten_index_Tensor(aten_index_tensor_38, [None, None, None, aten_sub_tensor_77]);  aten_index_tensor_38 = aten_sub_tensor_77 = None
...

If we replace reflection padding with zero padding, it gets much simpler.

# File: /Users/woj/export_tut/executorch/fast_neural_style/neural_style/transformer_net.py:52 in forward, code: out = self.zero_pad(x)
aten_constant_pad_nd_default: "f32[1, 3, 648, 648]" = executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default(x, [4, 4, 4, 4], 0.0);  x = None

We can get away with that substitution without retraining the model, with negligible visual output difference. This made the model run much faster!

Use operators supported by the backend.

While it was not obvious at first, nearest neighbor interpolation was slowing our model down. It is not covered by XNNPACK operator set. We changed it to bilinear interpolation, which can be delegated to XNNPACK. Again, we could apply this change without retraining the model.

We’ve successfully addressed the issues with our model, reducing the inference time to just over one second. 🏎️💨

Running the .pte model 🏃

To check if the lowered model works, use the xnn_executor_runner, a sample wrapper for the ExecuTorch Runtime and XNNPACK Backend.

cd executorch

rm -rf cmake-out && mkdir cmake-out
cmake \
    -DCMAKE_INSTALL_PREFIX=cmake-out \
    -DCMAKE_BUILD_TYPE=Release \
    -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
    -DEXECUTORCH_BUILD_XNNPACK=ON \
    -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
    -DEXECUTORCH_ENABLE_LOGGING=ON \
    -DPYTHON_EXECUTABLE=python \
    -Bcmake-out .
    
cmake --build cmake-out -j16 --target install --config Release

The executable should now be built at ./cmake-out/backends/xnnpack/xnn_executor_runner. Run it with the model you just lowered:

./cmake-out/backends/xnnpack/xnn_executor_runner --model_path=style_transfer_xnnpack.pte

It should print out the outputs (or throw an error if the lowered model doesn’t work 🙃).

Example outputs:

I 00:00:00.000477 executorch:executor_runner.cpp:73] Model file ../style_transfer_xnnpack.pte is loaded.
I 00:00:00.000486 executorch:executor_runner.cpp:82] Using method forward
I 00:00:00.000489 executorch:executor_runner.cpp:129] Setting up planned buffer 0, size 9830400.
I 00:00:00.009933 executorch:executor_runner.cpp:152] Method loaded.
I 00:00:00.010493 executorch:executor_runner.cpp:162] Inputs prepared.
I 00:00:00.416360 executorch:executor_runner.cpp:171] Model executed successfully.
I 00:00:00.416385 executorch:executor_runner.cpp:175] 1 outputs: 
Output 0: tensor(sizes=[1, 3, 640, 640], [
  ...
])

Building the runtime 🚀

Once we’ve confirmed that our model works, we can build the proper runtime for Android. Specify your Android NDK path (e.g., /Users/woj/Library/Android/sdk/ndk/27.0.11902837) and target architecture (Application Binary Interface) - arm64-v8a.

The script below handles the entire process of setting up the environment for the ExecuTorch library, building necessary components, and constructing the libexecutorch library. It includes all required components from the XNNPACK backend and additional modules, all wrapped up with a JNI layer to enable communication between Java applications.

export ANDROID_NDK=<path-to-android-ndk>
export ANDROID_ABI=arm64-v8a

rm -rf cmake-android-out && mkdir cmake-android-out

# Build the core executorch library
cmake . -DCMAKE_INSTALL_PREFIX=cmake-android-out \
  -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \
  -DANDROID_ABI="${ANDROID_ABI}" \
  -DEXECUTORCH_BUILD_XNNPACK=ON \
  -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
  -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
  -Bcmake-android-out

cmake --build cmake-android-out -j16 --target install

# Build the android extension
cmake extension/android \
  -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \
  -DANDROID_ABI="${ANDROID_ABI}" \
  -DCMAKE_INSTALL_PREFIX=cmake-android-out \
  -Bcmake-android-out/extension/android

cmake --build cmake-android-out/extension/android -j16

If successful, the library will be built as libexecutorch_jni.so. Copy it to your app as libexecutorch.so.

cp cmake-android-out/extension/android/libexecutorch_jni.so \
	~/executorch-style-transfer/android/app/src/main/jniLibs/arm64-v8a/libexecutorch.so

Integrating It Within Your App

For seamless integration with the runtime, refer to extension/android/src/main/java/org/pytorch/executorch in the ExecuTorch repo for Java wrapper classes for ExecuTorch components.

Fortunately, we’ve already covered the next steps in our previous articles. Check them out for detailed instructions on integrating and using your exported models on Android or iOS:

Bringing native AI to your mobile apps with ExecuTorch — part II — Android ExecuTorch is a brand new PyTorch-based framework that allows you to export your models to formats suitable for local…blog.swmansion.com

Bringing native AI to your mobile apps with ExecuTorch— part I — iOS What is ExecuTorch?blog.swmansion.com

We’re Software Mansion: software development consultants, AI explorers, multimedia experts, React Native core contributors, and community builders. Hire us: [email protected].

Contact - Software Mansion Software Mansion S.A., a polish public joint stock company with its principal place of business at ul. Zabłocie 43b…swmansion.com