@@ -427,3 +427,67 @@ func @unsupported_op() -> tensor<i32> {
427427
428428 let constructor = "TFDevice::CreateMarkOpsForOutsideCompilationPass()";
429429}
430+
431+ def FunctionalControlFlowToRegionsPass : Pass<"tf-functional-control-flow-to-regions", "ModuleOp"> {
432+ let summary = "Transforms functional control flow operations to their region-based counterparts";
433+
434+ let description = [{
435+ This pass transforms functional control flow operations in the TensorFlow
436+ dialect to their region-based counterparts, i.e., `tf.If` is transformed to
437+ `tf.IfRegion` and `tf.While` is transformed to `tf.WhileRegion`.
438+
439+ For example, this functional operation
440+
441+ ```mlir
442+ %0 = "tf.If"(%arg0, %arg1) {
443+ then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false
444+ } : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
445+ ```
446+
447+ will be transformed into this region-based operation
448+
449+ ```mlir
450+ %0 = "tf.IfRegion"(%arg0) ( {
451+ %1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
452+ "tf.Yield"(%1) : (tensor<*xf32>) -> ()
453+ }, {
454+ %1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
455+ "tf.Yield"(%1) : (tensor<*xf32>) -> ()
456+ }) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>
457+ ```
458+ }];
459+
460+ let constructor = "TF::CreateTFFunctionalControlFlowToRegions()";
461+ }
462+
463+ def RegionControlFlowToFunctionalPass : Pass<"tf-region-control-flow-to-functional", "ModuleOp"> {
464+ let summary = "Transforms region-based control flow operations to their functional counterparts";
465+
466+ let description = [{
467+ This pass transforms region-based control flow operations in the TensorFlow
468+ dialect to their functional counterparts, i.e., `tf.IfRegion` is transformed to
469+ `tf.If` and `tf.WhileRegion` is transformed to `tf.While`.
470+
471+ For example, this region-based operation
472+
473+ ```mlir
474+ %0 = "tf.IfRegion"(%arg0) ( {
475+ %1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
476+ "tf.Yield"(%1) : (tensor<*xf32>) -> ()
477+ }, {
478+ %1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
479+ "tf.Yield"(%1) : (tensor<*xf32>) -> ()
480+ }) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>
481+ ```
482+
483+ will be transformed into this functional operation
484+
485+ ```mlir
486+ %0 = "tf.If"(%arg0, %arg1) {
487+ then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false
488+ } : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
489+ ```
490+ }];
491+
492+ let constructor = "TF::CreateTFRegionControlFlowToFunctional()";
493+ }
0 commit comments