Skip to content

Commit f954b27

Browse files
abatterytensorflower-gardener
authored andcommitted
Add uint64 tensor support in TFLite
Even though we do not support uint64 op kernels on mobile, it is inevitable to support uint64 tensors in order to enable TF uint64 ops via flex delegate. This CL enables the uint64 tensor type in MLIR converter only. PiperOrigin-RevId: 342939673 Change-Id: I24f422040f82cad7affce4b921361f79e8a51730
1 parent 668b1d7 commit f954b27

File tree

30 files changed

+102
-10
lines changed

30 files changed

+102
-10
lines changed

tensorflow/compiler/mlir/lite/flatbuffer_export.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
167167
case 32:
168168
return tflite::TensorType_INT32;
169169
case 64:
170-
return tflite::TensorType_INT64;
170+
return itype.isUnsigned() ? tflite::TensorType_UINT64
171+
: tflite::TensorType_INT64;
171172
}
172173
} else if (auto q_uniform_type =
173174
type.dyn_cast<mlir::quant::UniformQuantizedType>()) {

tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
119119
return DT_INT32;
120120
case toco::IODataType::INT64:
121121
return DT_INT64;
122+
case toco::IODataType::UINT64:
123+
return DT_UINT64;
122124
case toco::IODataType::STRING:
123125
return DT_STRING;
124126
case toco::IODataType::BOOL:

tensorflow/compiler/mlir/lite/utils/convert_type.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
5757
return mlir::ComplexType::get(builder.getF64Type());
5858
case tflite::TensorType_INT8:
5959
return builder.getIntegerType(8);
60+
case tflite::TensorType_UINT64:
61+
return builder.getIntegerType(64, /*isSigned=*/false);
6062
}
6163
}
6264

@@ -86,6 +88,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
8688
return tensorflow::DT_STRING;
8789
case tflite::TensorType_UINT8:
8890
return tensorflow::DT_UINT8;
91+
case tflite::TensorType_UINT64:
92+
return tensorflow::DT_UINT64;
8993
}
9094
}
9195

tensorflow/lite/c/common.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ const char* TfLiteTypeGetName(TfLiteType type) {
203203
return "INT8";
204204
case kTfLiteInt64:
205205
return "INT64";
206+
case kTfLiteUInt64:
207+
return "UINT64";
206208
case kTfLiteBool:
207209
return "BOOL";
208210
case kTfLiteComplex64:

tensorflow/lite/c/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ typedef enum {
300300
kTfLiteFloat16 = 10,
301301
kTfLiteFloat64 = 11,
302302
kTfLiteComplex128 = 12,
303+
kTfLiteUInt64 = 13,
303304
} TfLiteType;
304305

305306
// Return the name of a given type, for error reporting purposes.
@@ -354,6 +355,7 @@ typedef union TfLitePtrUnion {
354355
* members are deprecated. */
355356
int32_t* i32;
356357
int64_t* i64;
358+
uint64_t* u64;
357359
float* f;
358360
TfLiteFloat16* f16;
359361
double* f64;

tensorflow/lite/c/common_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ TEST(Types, TestTypeNames) {
8484
EXPECT_EQ(type_name(kTfLiteInt16), "INT16");
8585
EXPECT_EQ(type_name(kTfLiteInt32), "INT32");
8686
EXPECT_EQ(type_name(kTfLiteUInt8), "UINT8");
87+
EXPECT_EQ(type_name(kTfLiteUInt64), "UINT64");
8788
EXPECT_EQ(type_name(kTfLiteInt8), "INT8");
8889
EXPECT_EQ(type_name(kTfLiteInt64), "INT64");
8990
EXPECT_EQ(type_name(kTfLiteBool), "BOOL");

tensorflow/lite/core/api/flatbuffer_conversions.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
859859
case TensorType_INT64:
860860
*type = kTfLiteInt64;
861861
return kTfLiteOk;
862+
case TensorType_UINT64:
863+
*type = kTfLiteUInt64;
864+
return kTfLiteOk;
862865
case TensorType_STRING:
863866
*type = kTfLiteString;
864867
return kTfLiteOk;

tensorflow/lite/delegates/flex/util.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) {
7474
return TF_INT8;
7575
case kTfLiteInt64:
7676
return TF_INT64;
77+
case kTfLiteUInt64:
78+
return TF_UINT64;
7779
case kTfLiteComplex64:
7880
return TF_COMPLEX64;
7981
case kTfLiteComplex128:
@@ -103,6 +105,8 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) {
103105
return kTfLiteInt8;
104106
case TF_INT64:
105107
return kTfLiteInt64;
108+
case TF_UINT64:
109+
return kTfLiteUInt64;
106110
case TF_COMPLEX64:
107111
return kTfLiteComplex64;
108112
case TF_COMPLEX128:

tensorflow/lite/delegates/flex/util_test.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ TEST(UtilTest, TypeConversionsFromTFLite) {
115115
EXPECT_EQ(TF_INT32, GetTensorFlowDataType(kTfLiteInt32));
116116
EXPECT_EQ(TF_UINT8, GetTensorFlowDataType(kTfLiteUInt8));
117117
EXPECT_EQ(TF_INT64, GetTensorFlowDataType(kTfLiteInt64));
118+
EXPECT_EQ(TF_UINT64, GetTensorFlowDataType(kTfLiteUInt64));
118119
EXPECT_EQ(TF_COMPLEX64, GetTensorFlowDataType(kTfLiteComplex64));
119120
EXPECT_EQ(TF_COMPLEX128, GetTensorFlowDataType(kTfLiteComplex128));
120121
EXPECT_EQ(TF_STRING, GetTensorFlowDataType(kTfLiteString));
@@ -129,6 +130,7 @@ TEST(UtilTest, TypeConversionsFromTensorFlow) {
129130
EXPECT_EQ(kTfLiteInt32, GetTensorFlowLiteType(TF_INT32));
130131
EXPECT_EQ(kTfLiteUInt8, GetTensorFlowLiteType(TF_UINT8));
131132
EXPECT_EQ(kTfLiteInt64, GetTensorFlowLiteType(TF_INT64));
133+
EXPECT_EQ(kTfLiteUInt64, GetTensorFlowLiteType(TF_UINT64));
132134
EXPECT_EQ(kTfLiteComplex64, GetTensorFlowLiteType(TF_COMPLEX64));
133135
EXPECT_EQ(kTfLiteComplex128, GetTensorFlowLiteType(TF_COMPLEX128));
134136
EXPECT_EQ(kTfLiteString, GetTensorFlowLiteType(TF_STRING));

tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,8 +421,9 @@ - (TFLTensorDataType)tensorDataTypeFromCTensorType:(TfLiteType)cTensorType {
421421
case kTfLiteString:
422422
case kTfLiteComplex64:
423423
case kTfLiteComplex128:
424-
// kTfLiteString, kTfLiteComplex64 and kTfLiteComplex128 are not supported in TensorFlow Lite
425-
// Objc API.
424+
case kTfLiteUInt64:
425+
// kTfLiteString, kTfLiteUInt64, kTfLiteComplex64 and kTfLiteComplex128 are not supported in
426+
// TensorFlow Lite Objc API.
426427
return TFLTensorDataTypeNoType;
427428
}
428429
}

0 commit comments

Comments
 (0)