Skip to content

Commit fe0dcca

Browse files
mgestertensorflower-gardener
authored andcommitted
Migrate control flow conversion passes to TableGen
Now FunctionalControlFlowToRegions and RegionControlFlowToFunctional passes are defined and documented in TableGen. PiperOrigin-RevId: 348863824 Change-Id: I94dfabd2582dc0e00252d808794ee19f8ae87932
1 parent d4cc37b commit fe0dcca

File tree

3 files changed

+70
-12
lines changed

3 files changed

+70
-12
lines changed

tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ limitations under the License.
3333
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
3434
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
3535
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
36+
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
3637
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
3738

3839
#define DEBUG_TYPE "tf-functional-cf-to-region"
@@ -43,8 +44,8 @@ namespace TF {
4344
namespace {
4445

4546
struct FunctionalControlFlowToRegions
46-
: public PassWrapper<FunctionalControlFlowToRegions,
47-
OperationPass<ModuleOp>> {
47+
: public TF::FunctionalControlFlowToRegionsPassBase<
48+
FunctionalControlFlowToRegions> {
4849
void runOnOperation() override;
4950
};
5051

@@ -157,9 +158,5 @@ CreateTFFunctionalControlFlowToRegions() {
157158
return std::make_unique<FunctionalControlFlowToRegions>();
158159
}
159160

160-
static PassRegistration<FunctionalControlFlowToRegions> pass(
161-
"tf-functional-control-flow-to-regions",
162-
"Transform functional control flow Ops to Region based counterparts");
163-
164161
} // namespace TF
165162
} // namespace mlir

tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ limitations under the License.
3636
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
3737
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
3838
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
39+
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
3940
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
4041

4142
#define DEBUG_TYPE "tf-region-cf-to-functional"
@@ -46,8 +47,8 @@ namespace TF {
4647
namespace {
4748

4849
struct RegionControlFlowToFunctional
49-
: public PassWrapper<RegionControlFlowToFunctional,
50-
OperationPass<ModuleOp>> {
50+
: public TF::RegionControlFlowToFunctionalPassBase<
51+
RegionControlFlowToFunctional> {
5152
void runOnOperation() override;
5253

5354
private:
@@ -445,9 +446,5 @@ CreateTFRegionControlFlowToFunctional() {
445446
return std::make_unique<RegionControlFlowToFunctional>();
446447
}
447448

448-
static PassRegistration<RegionControlFlowToFunctional> pass(
449-
"tf-region-control-flow-to-functional",
450-
"Transform region bases control flow Ops to functional counterparts");
451-
452449
} // namespace TF
453450
} // namespace mlir

tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,67 @@ func @unsupported_op() -> tensor<i32> {
427427

428428
let constructor = "TFDevice::CreateMarkOpsForOutsideCompilationPass()";
429429
}
430+
431+
def FunctionalControlFlowToRegionsPass : Pass<"tf-functional-control-flow-to-regions", "ModuleOp"> {
432+
let summary = "Transforms functional control flow operations to their region-based counterparts";
433+
434+
let description = [{
435+
This pass transforms functional control flow operations in the TensorFlow
436+
dialect to their region-based counterparts, i.e., `tf.If` is transformed to
437+
`tf.IfRegion` and `tf.While` is transformed to `tf.WhileRegion`.
438+
439+
For example, this functional operation
440+
441+
```mlir
442+
%0 = "tf.If"(%arg0, %arg1) {
443+
then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false
444+
} : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
445+
```
446+
447+
will be transformed into this region-based operation
448+
449+
```mlir
450+
%0 = "tf.IfRegion"(%arg0) ( {
451+
%1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
452+
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
453+
}, {
454+
%1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
455+
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
456+
}) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>
457+
```
458+
}];
459+
460+
let constructor = "TF::CreateTFFunctionalControlFlowToRegions()";
461+
}
462+
463+
def RegionControlFlowToFunctionalPass : Pass<"tf-region-control-flow-to-functional", "ModuleOp"> {
464+
let summary = "Transforms region-based control flow operations to their functional counterparts";
465+
466+
let description = [{
467+
This pass transforms region-based control flow operations in the TensorFlow
468+
dialect to their functional counterparts, i.e., `tf.IfRegion` is transformed to
469+
`tf.If` and `tf.WhileRegion` is transformed to `tf.While`.
470+
471+
For example, this region-based operation
472+
473+
```mlir
474+
%0 = "tf.IfRegion"(%arg0) ( {
475+
%1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
476+
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
477+
}, {
478+
%1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
479+
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
480+
}) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>
481+
```
482+
483+
will be transformed into this functional operation
484+
485+
```mlir
486+
%0 = "tf.If"(%arg0, %arg1) {
487+
then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false
488+
} : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
489+
```
490+
}];
491+
492+
let constructor = "TF::CreateTFRegionControlFlowToFunctional()";
493+
}

0 commit comments

Comments
 (0)