Skip to content

Commit 23d482e

Browse files
jdduketensorflower-gardener
authored andcommitted
Add flag for using optimized TFLite CPU kernels on Android
Add an experimental flag which allows opting in to a set of highly optimized floating point kernels provided via the XNNPACK delegate. This is offered as a preview, with the plan to enable these kernels by default in a future release. The flag can be enabled via: Interpreter.Options options = new Interpreter.Options().setUseXNNPACK(true); See tensorflow/lite/delegates/xnnpack/README.md for more details about these kernels and the associated delegate functionality. PiperOrigin-RevId: 316909226 Change-Id: Ib60cf259225b8a48a9830ccbb24ec10534b038ce
1 parent 2779d9e commit 23d482e

File tree

9 files changed

+144
-0
lines changed

9 files changed

+144
-0
lines changed

tensorflow/lite/delegates/xnnpack/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ cc_library(
2121
linkstatic = True,
2222
deps = [
2323
"//tensorflow/lite:kernel_api",
24+
"//tensorflow/lite:minimal_logging",
2425
"//tensorflow/lite:util",
2526
"//tensorflow/lite/c:common",
2627
"//tensorflow/lite/schema:schema_fbs",
@@ -47,6 +48,7 @@ cc_library(
4748
linkstatic = True,
4849
deps = [
4950
"//tensorflow/lite:kernel_api",
51+
"//tensorflow/lite:minimal_logging",
5052
"//tensorflow/lite:util",
5153
"//tensorflow/lite/c:common",
5254
"//tensorflow/lite/schema:schema_fbs",

tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ limitations under the License.
3232
#include "tensorflow/lite/builtin_ops.h"
3333
#include "tensorflow/lite/c/builtin_op_data.h"
3434
#include "tensorflow/lite/c/common.h"
35+
#include "tensorflow/lite/minimal_logging.h"
3536
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
3637

3738
namespace tflite {
@@ -52,6 +53,8 @@ class Delegate {
5253
pthreadpool_create(static_cast<size_t>(options->num_threads)));
5354
}
5455
#endif
56+
TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
57+
"Created TensorFlow Lite XNNPACK delegate for CPU.");
5558
}
5659

5760
TfLiteIntArray* PrepareOpsToDelegate(TfLiteContext* context);

tensorflow/lite/java/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ tflite_jni_binary(
408408
"//tensorflow/lite/c:c_api",
409409
"//tensorflow/lite/c:c_api_experimental",
410410
"//tensorflow/lite/delegates/nnapi/java/src/main/native",
411+
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
411412
"//tensorflow/lite/java/src/main/native",
412413
],
413414
)

tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,37 @@ public Options setAllowBufferHandleOutput(boolean allow) {
137137
return this;
138138
}
139139

140+
/**
141+
* Experimental: Enable an optimized set of floating point CPU kernels (provided by XNNPACK).
142+
*
143+
* <p>Enabling this flag will enable use of a new, highly optimized set of CPU kernels provided
144+
* via the XNNPACK delegate. Currently, this is restricted to a subset of floating point
145+
* operations. Eventually, we plan to enable this by default, as it can provide significant
146+
* peformance benefits for many classes of floating point models. See
147+
* https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md
148+
* for more details.
149+
*
150+
* <p>Things to keep in mind when enabling this flag:
151+
*
152+
* <ul>
153+
* <li>Startup time and resize time may increase.
154+
* <li>Baseline memory consumption may increase.
155+
* <li>Compatibility with other delegates (e.g., GPU) has not been fully validated.
156+
* <li>Quantized models will not see any benefit.
157+
* </ul>
158+
*
159+
* <p>WARNING: This is an experimental interface that is subject to change.
160+
*/
161+
public Options setUseXNNPACK(boolean useXNNPACK) {
162+
this.useXNNPACK = useXNNPACK;
163+
return this;
164+
}
165+
140166
int numThreads = -1;
141167
Boolean useNNAPI;
142168
Boolean allowFp16PrecisionForFp32;
143169
Boolean allowBufferHandleOutput;
170+
Boolean useXNNPACK;
144171
final List<Delegate> delegates = new ArrayList<>();
145172
}
146173

tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ private void init(long errorHandle, long modelHandle, Interpreter.Options option
8080
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
8181
}
8282
applyDelegates(options);
83+
if (options.useXNNPACK != null) {
84+
useXNNPACK(
85+
interpreterHandle, errorHandle, options.useXNNPACK.booleanValue(), options.numThreads);
86+
}
8387
allocateTensors(interpreterHandle, errorHandle);
8488
this.isMemoryAllocated = true;
8589
}
@@ -438,6 +442,9 @@ private static Delegate maybeCreateFlexDelegate(List<Delegate> delegates) {
438442

439443
private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow);
440444

445+
private static native void useXNNPACK(
446+
long interpreterHandle, long errorHandle, boolean state, int numThreads);
447+
441448
private static native long createErrorReporter(int size);
442449

443450
private static native long createModel(String modelPathOrBuffer, long errorHandle);

tensorflow/lite/java/src/main/native/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ cc_library(
3131
"//tensorflow/lite:string_util",
3232
"//tensorflow/lite:util",
3333
"//tensorflow/lite/c:common",
34+
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate_hdrs_only",
3435
"//tensorflow/lite/experimental/tflite_api_dispatcher:tflite_api_dispatcher_with_kernels",
3536
"//tensorflow/lite/java/jni",
3637
],

tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#include <dlfcn.h>
1617
#include <jni.h>
1718
#include <stdio.h>
1819
#include <time.h>
1920

2021
#include <vector>
2122

2223
#include "tensorflow/lite/c/common.h"
24+
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
2325
#include "tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h"
2426
#include "tensorflow/lite/java/src/main/native/jni_utils.h"
2527
#include "tensorflow/lite/util.h"
@@ -323,6 +325,59 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput(
323325
interpreter->SetAllowBufferHandleOutput(allow);
324326
}
325327

328+
JNIEXPORT void JNICALL
329+
Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
330+
JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jboolean state,
331+
jint num_threads) {
332+
// If not using xnnpack, simply don't apply the delegate.
333+
if (!state) {
334+
return;
335+
}
336+
337+
tflite_api_dispatcher::Interpreter* interpreter =
338+
convertLongToInterpreter(env, handle);
339+
if (interpreter == nullptr) {
340+
return;
341+
}
342+
343+
BufferErrorReporter* error_reporter =
344+
convertLongToErrorReporter(env, error_handle);
345+
if (error_reporter == nullptr) {
346+
return;
347+
}
348+
349+
// We use dynamic loading to avoid taking a hard dependency on XNNPack.
350+
// This allows clients that use trimmed builds to save on binary size.
351+
auto xnnpack_options_default =
352+
reinterpret_cast<decltype(TfLiteXNNPackDelegateOptionsDefault)*>(
353+
dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateOptionsDefault"));
354+
auto xnnpack_create =
355+
reinterpret_cast<decltype(TfLiteXNNPackDelegateCreate)*>(
356+
dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateCreate"));
357+
auto xnnpack_delete =
358+
reinterpret_cast<decltype(TfLiteXNNPackDelegateDelete)*>(
359+
dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateDelete"));
360+
361+
if (xnnpack_options_default && xnnpack_create && xnnpack_delete) {
362+
TfLiteXNNPackDelegateOptions options = xnnpack_options_default();
363+
if (num_threads > 0) {
364+
options.num_threads = num_threads;
365+
}
366+
tflite_api_dispatcher::Interpreter::TfLiteDelegatePtr delegate(
367+
xnnpack_create(&options), xnnpack_delete);
368+
if (interpreter->ModifyGraphWithDelegate(std::move(delegate)) !=
369+
kTfLiteOk) {
370+
ThrowException(env, kIllegalArgumentException,
371+
"Internal error: Failed to apply XNNPACK delegate: %s",
372+
error_reporter->CachedErrorMessage());
373+
}
374+
} else {
375+
ThrowException(env, kIllegalArgumentException,
376+
"Failed to load XNNPACK delegate from current runtime. "
377+
"Have you added the necessary dependencies?");
378+
}
379+
}
380+
326381
JNIEXPORT void JNICALL
327382
Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
328383
jclass clazz,

tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterMobileNetTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ public void testMobileNetMultithreaded() {
5454
runMobileNetFloatTest(new Interpreter.Options().setNumThreads(2));
5555
}
5656

57+
@Test
58+
public void testMobileNetEnhancedCpuKernels() {
59+
runMobileNetFloatTest(new Interpreter.Options().setUseXNNPACK(true));
60+
}
61+
62+
@Test
63+
public void testMobileNetEnhancedCpuKernelsMultithreaded() {
64+
runMobileNetFloatTest(new Interpreter.Options().setUseXNNPACK(true).setNumThreads(2));
65+
}
66+
5767
@Test
5868
public void testMobileNetQuantized() {
5969
runMobileNetQuantizedTest(new Interpreter.Options());
@@ -64,6 +74,12 @@ public void testMobileNetQuantizedMultithreaded() {
6474
runMobileNetQuantizedTest(new Interpreter.Options().setNumThreads(2));
6575
}
6676

77+
@Test
78+
public void testMobileNetQuantizedEnhancedCpu() {
79+
// The "enhanced CPU flag" should only impact float models, this is a sanity test to confirm.
80+
runMobileNetQuantizedTest(new Interpreter.Options().setUseXNNPACK(true));
81+
}
82+
6783
private static void runMobileNetFloatTest(Interpreter.Options options) {
6884
ByteBuffer img =
6985
TestUtils.getTestImageAsFloatByteBuffer(

tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,38 @@ public void testTurnOnNNAPI() throws Exception {
409409
interpreter.close();
410410
}
411411

412+
@Test
413+
public void testUseXNNPACK() throws Exception {
414+
Interpreter interpreter =
415+
new Interpreter(MODEL_BUFFER, new Interpreter.Options().setUseXNNPACK(true));
416+
float[] oneD = {1.23f, 6.54f, 7.81f};
417+
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
418+
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
419+
float[][][][] fourD = {threeD, threeD};
420+
float[][][][] parsedOutputs = new float[2][8][8][3];
421+
interpreter.run(fourD, parsedOutputs);
422+
float[] outputOneD = parsedOutputs[0][0][0];
423+
float[] expected = {3.69f, 19.62f, 23.43f};
424+
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
425+
interpreter.close();
426+
}
427+
428+
@Test
429+
public void testResizeWithEnhancedCpuKernels() throws Exception {
430+
Interpreter interpreter =
431+
new Interpreter(MODEL_BUFFER, new Interpreter.Options().setUseXNNPACK(true));
432+
float[] input = {1.f};
433+
float[] output = new float[1];
434+
interpreter.run(input, output);
435+
assertThat(output).usingTolerance(0.1f).containsExactly(new float[] {3.f}).inOrder();
436+
437+
// The new input shape should trigger a resize. Inference should still work properly.
438+
float[] input2 = {1.f, 2.f};
439+
float[] output2 = new float[2];
440+
interpreter.run(input2, output2);
441+
assertThat(output2).usingTolerance(0.1f).containsExactly(new float[] {3.f, 6.f}).inOrder();
442+
}
443+
412444
@Test
413445
public void testRedundantClose() throws Exception {
414446
Interpreter interpreter = new Interpreter(MODEL_BUFFER);

0 commit comments

Comments
 (0)