Skip to content

Commit bc295ed

Browse files
renjie-liutensorflower-gardener
authored andcommitted
Promote rfft2d to builtin op and add mlir conversion support.
PiperOrigin-RevId: 344732711 Change-Id: I811c45a03d7c204120f2c9bc491e39b32cc5222b
1 parent 4d1142b commit bc295ed

File tree

19 files changed

+689
-568
lines changed

19 files changed

+689
-568
lines changed

RELEASE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
* Added support for saved model's session initializer through
5151
`TFLiteConverter.from_saved_model`.
5252
* Added dynamic range quantization support for the BatchMatMul op.
53+
* Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently
54+
only supports float32 input.
5355

5456
* TF Core:
5557
* Corrected higher-order gradients of control flow constructs (`tf.cond`,

tensorflow/compiler/mlir/lite/ir/tfl_ops.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def TFL_FpTensor : TFL_TensorOf<[F32]>;
172172
def TFL_I32OrI64Tensor : TFL_TensorOf<[TFL_Int32Or64]>;
173173
def TFL_I32Tensor : TFL_TensorOf<[I32]>;
174174
def TFL_I64Tensor : TFL_TensorOf<[I64]>;
175+
def TFL_Complex64Tensor : TFL_TensorOf<[Complex<F<32>>]>;
176+
175177
// TODO(jpienaar): Expand to all int types.
176178
def TFL_IntTensor : TypeAlias<TFL_I32Tensor, "tensor of any integer type">;
177179

@@ -4481,4 +4483,31 @@ subsequent operation and then be optimized away, however.)
44814483
);
44824484
}
44834485

4486+
def TFL_RFFT2dOp : TFL_Op<"RFFT2D", [NoSideEffect, NoQuantizableResult]> {
4487+
let summary = "2D real-valued fast Fourier transform.";
4488+
4489+
let description = [{
4490+
Computes the 2-dimensional discrete Fourier transform of a real-valued signal
4491+
over the inner-most 2 dimensions of `input`.
4492+
4493+
Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the
4494+
`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension
4495+
of `output`: the zero-frequency term, followed by the `fft_length / 2`
4496+
positive-frequency terms.
4497+
4498+
Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the
4499+
corresponding dimension of `input`, the dimension is cropped. If it is larger,
4500+
the dimension is padded with zeros.
4501+
}];
4502+
4503+
let arguments = (ins
4504+
TFL_FpTensor:$input,
4505+
TFL_I32Tensor:$fft_length
4506+
);
4507+
4508+
let results = (outs
4509+
TFL_Complex64Tensor:$output
4510+
);
4511+
}
4512+
44844513
#endif // TFL_OPS

tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1966,3 +1966,17 @@ func @segmentsum_i64(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<*xf32
19661966
// CHECK: "tfl.cast"
19671967
// CHECK: "tfl.segment_sum"
19681968
}
1969+
1970+
func @rfft2d(%arg0: tensor<10x20x10x30xf32>, %arg1: tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>> {
1971+
%0 = "tf.RFFT2D"(%arg0, %arg1) : (tensor<10x20x10x30xf32>, tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>>
1972+
return %0 : tensor<10x20x10x30xcomplex<f32>>
1973+
// CHECK-LABEL: rfft2d
1974+
// CHECK: "tfl.RFFT2D"(%arg0, %arg1) : (tensor<10x20x10x30xf32>, tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>>
1975+
}
1976+
1977+
func @rfft2d_invalid(%arg0: tensor<10x20x10x30xf64>, %arg1: tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f64>> {
1978+
%0 = "tf.RFFT2D"(%arg0, %arg1) : (tensor<10x20x10x30xf64>, tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f64>>
1979+
return %0 : tensor<10x20x10x30xcomplex<f64>>
1980+
// CHECK-LABEL: rfft2d_invalid
1981+
// CHECK-NOT: "tfl.RFFT2D"
1982+
}

tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,3 +491,7 @@ def LegalizeStridedSlice : Pat<
491491
(convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask),
492492
(convertIntAttrTo32Bit $new_axis_mask),
493493
(convertIntAttrTo32Bit $shrink_axis_mask))>;
494+
495+
def LegalizeRfft2d : Pat<
496+
(TF_RFFT2DOp $input, $fft_length),
497+
(TFL_RFFT2dOp $input, $fft_length)>;

tensorflow/lite/build_def.bzl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,6 @@ def generated_test_models():
359359
"resolve_constant_strided_slice",
360360
"reverse_sequence",
361361
"reverse_v2",
362-
"rfft2d",
363362
"round",
364363
"rsqrt",
365364
"scatter_nd",

tensorflow/lite/builtin_ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ typedef enum {
158158
kTfLiteBuiltinCumsum = 128,
159159
kTfLiteBuiltinCallOnce = 129,
160160
kTfLiteBuiltinBroadcastTo = 130,
161+
kTfLiteBuiltinRfft2d = 131,
161162
} TfLiteBuiltinOperator;
162163

163164
#ifdef __cplusplus

tensorflow/lite/core/api/flatbuffer_conversions.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
823823
case BuiltinOperator_DENSIFY:
824824
case BuiltinOperator_SEGMENT_SUM:
825825
case BuiltinOperator_BROADCAST_TO:
826+
case BuiltinOperator_RFFT2D:
826827
return kTfLiteOk;
827828
case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES:
828829
return kTfLiteError;

tensorflow/lite/kernels/BUILD

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ BUILTIN_KERNEL_SRCS = [
621621
"where.cc",
622622
"while.cc",
623623
"zeros_like.cc",
624+
"rfft2d.cc",
624625
]
625626

626627
BUILTIN_KERNEL_DEPS = [
@@ -669,10 +670,12 @@ cc_library(
669670
copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
670671
visibility = ["//visibility:private"],
671672
deps = BUILTIN_KERNEL_DEPS + [
673+
"@fft2d",
672674
"@ruy//ruy/profiler:instrumentation",
673675
"//tensorflow/lite/kernels/internal:cppmath",
674676
"//tensorflow/lite:string",
675677
"@farmhash_archive//:farmhash",
678+
"//third_party/fft2d:fft2d_headers",
676679
],
677680
)
678681

@@ -713,7 +716,6 @@ cc_library(
713716
"complex_support.cc",
714717
"multinomial.cc",
715718
"random_standard_normal.cc",
716-
"rfft2d.cc",
717719
],
718720
hdrs = ["custom_ops_register.h"],
719721
copts = tflite_copts(),
@@ -722,8 +724,6 @@ cc_library(
722724
"//tensorflow/lite/c:common",
723725
"//tensorflow/lite/kernels/internal:tensor",
724726
"//tensorflow/lite/kernels/internal:types",
725-
"//third_party/fft2d:fft2d_headers",
726-
"@fft2d",
727727
"@ruy//ruy/profiler:instrumentation",
728728
],
729729
)
@@ -2187,13 +2187,9 @@ cc_test(
21872187
size = "small",
21882188
srcs = ["rfft2d_test.cc"],
21892189
deps = [
2190-
":custom_ops",
21912190
":test_main",
21922191
":test_util",
2193-
"//tensorflow/lite:framework",
2194-
"//tensorflow/lite/c:common",
21952192
"//tensorflow/lite/schema:schema_fbs",
2196-
"//tensorflow/lite/testing:util",
21972193
"@com_google_googletest//:gtest",
21982194
],
21992195
)

tensorflow/lite/kernels/builtin_op_kernels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ TfLiteRegistration* Register_RESIZE_BILINEAR();
118118
TfLiteRegistration* Register_RESIZE_NEAREST_NEIGHBOR();
119119
TfLiteRegistration* Register_REVERSE_SEQUENCE();
120120
TfLiteRegistration* Register_REVERSE_V2();
121+
TfLiteRegistration* Register_RFFT2D();
121122
TfLiteRegistration* Register_RNN();
122123
TfLiteRegistration* Register_ROUND();
123124
TfLiteRegistration* Register_RSQRT();

tensorflow/lite/kernels/custom_ops_register.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ TfLiteRegistration* Register_IMAG();
2929
TfLiteRegistration* Register_MULTINOMIAL();
3030
TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL();
3131
TfLiteRegistration* Register_REAL();
32-
TfLiteRegistration* Register_RFFT2D();
3332

3433
} // namespace custom
3534
} // namespace ops

0 commit comments

Comments
 (0)