Skip to content

Commit e290ea6

Browse files
andylytensorflower-gardener
authored andcommitted
Migrate TF MLIR shape inference pass to use declarative pass registration instead of manually defined pass registration (NFC).
PiperOrigin-RevId: 347430171 Change-Id: Iff3e9c3a6c2ddcca1ef7351adadc2c9ba75e0d4a
1 parent 89ac5d4 commit e290ea6

File tree

6 files changed

+8
-4
lines changed

6 files changed

+8
-4
lines changed

tensorflow/compiler/mlir/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ cc_library(
7979
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
8080
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
8181
"//tensorflow/compiler/mlir/tensorflow",
82+
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
8283
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
8384
"//tensorflow/core:lib",
8485
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",

tensorflow/compiler/mlir/tensorflow/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ cc_library(
639639
":tensorflow_tfrt_ops_inc_gen",
640640
":tensorflow_traits",
641641
":tensorflow_types",
642+
":tf_pass_inc_gen",
642643
":tf_saved_model_inc_gen",
643644
"//tensorflow/compiler/mlir/lite:validators",
644645
"//tensorflow/core:framework",

tensorflow/compiler/mlir/tensorflow/transforms/passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,9 @@ CreateTPUCompileOpReplicationPass();
401401

402402
} // namespace TFTPU
403403

404+
#define GEN_PASS_REGISTRATION
405+
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
406+
404407
} // namespace mlir
405408

406409
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_

tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ class ShapeInference : public TensorFlowShapeInferencePassBase<ShapeInference> {
3636
}
3737
};
3838

39-
PassRegistration<ShapeInference> pass(
40-
"tf-shape-inference", "Simple Shape Inference on TensorFlow Dialect");
41-
4239
} // namespace
4340

4441
std::unique_ptr<OperationPass<ModuleOp>> CreateTFShapeInferencePass() {

tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def TensorFlowShapeInferencePass : Pass<"tf-shape-inference", "ModuleOp"> {
2121
let summary = "Simple Shape Inference on TensorFlow Dialect";
2222
// TODO(jpienaar): Write `description`.
2323

24-
let constructor = "CreateTFShapeInferencePass()";
24+
let constructor = "TF::CreateTFShapeInferencePass()";
2525

2626
let options = [
2727
Option<"max_iterations_", "max-iterations", "int64_t", /*default=*/"10",

tensorflow/compiler/mlir/tf_mlir_opt_main.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ limitations under the License.
2222
#include "tensorflow/compiler/mlir/init_mlir.h"
2323
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
2424
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
25+
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
2526
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
2627
#include "tensorflow/core/platform/init_main.h"
2728

2829
int main(int argc, char **argv) {
2930
tensorflow::InitMlir y(&argc, &argv);
3031

3132
mlir::registerAllPasses();
33+
mlir::registerTensorFlowPasses();
3234
mlir::mhlo::registerAllMhloPasses();
3335
mlir::lmhlo::registerAllLmhloPasses();
3436

0 commit comments

Comments
 (0)