Skip to content

[mlir][tosa] Stop support the custom simplified form of COND_IF #139576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

tatwaichong
Copy link
Contributor

Since the tensor_list_shape for input_list, output_list, then_graph, and else_graph is required to be equal according to the spec, this information must be explicitly provided during operation construction. The current custom simplified form does not meet this requirement. For example, the input_list and output_list can be empty in the simplified form. A new compatible simplified form will be introduced in the future if necessary.

Since the tensor_list_shape for input_list, output_list, then_graph,
and else_graph is required to be equal according to the spec, this
information must be explicitly provided during operation construction.
The current custom simplified form does not meet this requirement.
For example, the input_list and output_list can be empty in the
simplified form. A new compatible simplified form will be introduced
in the future if necessary.
@llvmbot
Copy link
Member

llvmbot commented May 12, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: TatWai Chong (tatwaichong)

Changes

Since the tensor_list_shape for input_list, output_list, then_graph, and else_graph is required to be equal according to the spec, this information must be explicitly provided during operation construction. The current custom simplified form does not meet this requirement. For example, the input_list and output_list can be empty in the simplified form. A new compatible simplified form will be introduced in the future if necessary.


Full diff: https://github.com/llvm/llvm-project/pull/139576.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (-59)
  • (modified) mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir (+9-14)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+7-5)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+5-3)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+73-26)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+5-3)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+32-24)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+26-17)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 52bb0eb992b69..70aecfcfa3ec7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2558,7 +2558,6 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
     SizedRegion<1>:$else_graph
   );
 
-  let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 371c6dc27b428..2d7c80cbf7848 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3518,65 +3518,6 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
   return std::nullopt;
 }
 
-// parse and print of IfOp refer to the implementation of SCF dialect.
-ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
-  // Create the regions for 'then'.
-  result.regions.reserve(2);
-  Region *thenRegion = result.addRegion();
-  Region *elseRegion = result.addRegion();
-
-  auto &builder = parser.getBuilder();
-  OpAsmParser::UnresolvedOperand cond;
-  // Create a i1 tensor type for the boolean condition.
-  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
-  if (parser.parseOperand(cond) ||
-      parser.resolveOperand(cond, i1Type, result.operands))
-    return failure();
-  // Parse optional results type list.
-  if (parser.parseOptionalArrowTypeList(result.types))
-    return failure();
-  // Parse the 'then' region.
-  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
-    return failure();
-
-  // If we find an 'else' keyword then parse the 'else' region.
-  if (!parser.parseOptionalKeyword("else")) {
-    if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
-      return failure();
-  }
-
-  // Parse the optional attribute list.
-  if (parser.parseOptionalAttrDict(result.attributes))
-    return failure();
-  return success();
-}
-
-void IfOp::print(OpAsmPrinter &p) {
-  bool printBlockTerminators = false;
-
-  p << " " << getCondition();
-  if (!getResults().empty()) {
-    p << " -> (" << getResultTypes() << ")";
-    // Print yield explicitly if the op defines values.
-    printBlockTerminators = true;
-  }
-  p << ' ';
-  p.printRegion(getThenGraph(),
-                /*printEntryBlockArgs=*/false,
-                /*printBlockTerminators=*/printBlockTerminators);
-
-  // Print the 'else' regions if it exists and has a block.
-  auto &elseRegion = getElseGraph();
-  if (!elseRegion.empty()) {
-    p << " else ";
-    p.printRegion(elseRegion,
-                  /*printEntryBlockArgs=*/false,
-                  /*printBlockTerminators=*/printBlockTerminators);
-  }
-
-  p.printOptionalAttrDict((*this)->getAttrs());
-}
-
 LogicalResult IfOp::verify() {
   if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
                                  "'then_graph' arguments", getInputList(),
diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
index fa7a91cda0a47..78f5040eab97a 100644
--- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
+++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
@@ -36,20 +36,15 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) {
 func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) {
   // CHECK: [[EX:%.+]] = tensor.extract [[ARG2]]
   // CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) {
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
-
-  // CHECK:   scf.yield [[ARG0]]
-    tosa.yield %arg0 : tensor<f32>
-
-  // CHECK: } else {
-  } else {
-
-  // CHECK:   scf.yield [[ARG1]]
-    tosa.yield %arg1 : tensor<f32>
-
-  // CHECK: }
-  // CHECK: return [[IF]]
-  }
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    // CHECK:   scf.yield [[ARG0]]
+    tosa.yield %arg3 : tensor<f32>
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    // CHECK:   scf.yield [[ARG1]]
+    tosa.yield %arg4 : tensor<f32>
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
 
   return %0 : tensor<f32>
 }
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 75126a11ac504..5381d6c533d01 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -645,13 +645,15 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // CHECK: tosa.cond_if profiles: [ ]
   // CHECK: tosa.cond_if extensions: [ [controlflow] ]
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
-    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  } else {
-    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return %0 : tensor<f32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 2364985442e43..c688b6592ed9f 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -337,13 +337,15 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
 // -----
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  } else {
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return %0 : tensor<f32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d24c1fa57883d..5b11aa782637a 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1503,40 +1503,87 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1:
 
 // -----
 
-func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
-    %1 = tosa.cond_if %arg3 -> (tensor<f32>) {
-      %2 = tosa.cond_if %arg2 -> (tensor<f32>) {
-        %3 = tosa.cond_if %arg3 -> (tensor<f32>) {
-          %4 = tosa.cond_if %arg2 -> (tensor<f32>) {
+func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>,  %arg3: tensor<i1>) -> tensor<f32> {
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+
+  // COM: then graph of IF-1
+  ^bb1(%a1: tensor<f32>, %b1: tensor<f32>):
+    %cond1 = tosa.equal %a1, %b1 : (tensor<f32>, tensor<f32>) -> tensor<i1>
+    %1 = "tosa.cond_if"(%cond1, %a1, %b1) ({
+
+    // COM: then graph of IF-2
+    ^bb2(%a2: tensor<f32>, %b2: tensor<f32>):
+      %cond2 = tosa.equal %a2, %b2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
+      %2 = "tosa.cond_if"(%cond2, %a2, %b2) ({
+
+      // COM: then graph of IF-3
+      ^bb3(%a3: tensor<f32>, %b3: tensor<f32>):
+        %cond3 = tosa.equal %a3, %b3 : (tensor<f32>, tensor<f32>) -> tensor<i1>
+        %3 = "tosa.cond_if"(%cond3, %a3, %b3) ({
+
+        // COM: then graph of IF-4
+        ^bb4(%a4: tensor<f32>, %b4: tensor<f32>):
+          %cond4 = tosa.equal %a4, %b4 : (tensor<f32>, tensor<f32>) -> tensor<i1>
+          %4 = "tosa.cond_if"(%cond4, %a4, %b4) ({
+
+          // COM: then graph of IF-5
+          ^bb5(%a5: tensor<f32>, %b5: tensor<f32>):
+            %cond5 = tosa.equal %a5, %b5 : (tensor<f32>, tensor<f32>) -> tensor<i1>
             // expected-error@+1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}}
-            %5 = tosa.cond_if %arg3 -> (tensor<f32>) {
+            %5 = "tosa.cond_if"(%cond5, %a5, %b5) ({
+
+            // COM: then graph of IF-6
+            ^bb6(%a6: tensor<f32>, %b6: tensor<f32>):
               %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
               tosa.yield %res : tensor<f32>
-            } else {
+            },  {
+
+            // COM: else graph of IF-6
+            ^bb6(%a6: tensor<f32>, %b6: tensor<f32>):
               tosa.yield %arg0 : tensor<f32>
-            }
+            }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
             tosa.yield %5 : tensor<f32>
-          } else {
+          },  {
+
+            // COM: else graph of IF-5
+            ^bb5(%a5: tensor<f32>, %b5: tensor<f32>):
+              tosa.yield %arg0 : tensor<f32>
+          }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
+            tosa.yield %4 : tensor<f32>
+        },  {
+
+          // COM: else graph of IF-4
+          ^bb4(%a4: tensor<f32>, %b4: tensor<f32>):
             tosa.yield %arg0 : tensor<f32>
-          }
-          tosa.yield %4 : tensor<f32>
-        } else {
-          tosa.yield %arg0 : tensor<f32>
-        }
-        tosa.yield %3 : tensor<f32>
-      } else {
+        }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
+          tosa.yield %3 : tensor<f32>
+      },  {
+
+      // COM: else graph of IF-3
+      ^bb3(%a3: tensor<f32>, %b3: tensor<f32>):
         tosa.yield %arg0 : tensor<f32>
-      }
-      tosa.yield %2 : tensor<f32>
-    } else {
-      tosa.yield %arg0 : tensor<f32>
-    }
-    tosa.yield %1 : tensor<f32>
-  } else {
-    %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+      }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
+        tosa.yield %2 : tensor<f32>
+    },  {
+
+      // COM: else graph of IF-2
+      ^bb2(%a2: tensor<f32>, %b2: tensor<f32>):
+        tosa.yield %arg0 : tensor<f32>
+    }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
+      tosa.yield %1 : tensor<f32>
+  },  {
+
+  // COM: else graph of IF-1
+  ^bb1(%a1: tensor<f32>, %b1: tensor<f32>):
+    %res = tosa.sub %a1, %b1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %res : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
   return %0 : tensor<f32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e327ed900f45f..e3036cf07171f 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -781,13 +781,15 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
 // -----
 // CHECK-LABEL: cond_if
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  } else {
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return %0 : tensor<f32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 1ad1e6c76c294..981e3cc7fc129 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1121,12 +1121,14 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
   %b = tosa.log %arg1 : (tensor<f32>) -> tensor<f32>
 
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<f32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
-    tosa.yield %a : tensor<f32>
-  } else {
-    tosa.yield %b : tensor<f32>
-  }
+  // CHECK: -> tensor<f32>
+  %0 = "tosa.cond_if"(%arg2, %a, %b) ({
+  ^bb0(%a1: tensor<f32>, %b1: tensor<f32>):
+    tosa.yield %a1 : tensor<f32>
+  },  {
+  ^bb0(%a1: tensor<f32>, %b1: tensor<f32>):
+    tosa.yield %b1 : tensor<f32>
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return
 }
 
@@ -1135,12 +1137,14 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
 // CHECK-LABEL: @if_test_dynamic
 func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<?xf32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<?xf32>) {
-    tosa.yield %arg0 : tensor<2xf32>
-  } else {
-    tosa.yield %arg1 : tensor<3xf32>
-  }
+  // CHECK: -> tensor<?xf32>
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%a: tensor<2xf32>, %b: tensor<3xf32>):
+    tosa.yield %a : tensor<2xf32>
+  },  {
+  ^bb0(%a: tensor<2xf32>, %b: tensor<3xf32>):
+    tosa.yield %b : tensor<3xf32>
+  }) : (tensor<i1>, tensor<2xf32>, tensor<3xf32>) -> tensor<?xf32>
   return
 }
 
@@ -1149,12 +1153,14 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 :
 // CHECK-LABEL: @if_test_unranked
 func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<*xf32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<*xf32>) {
-    tosa.yield %arg0 : tensor<f32>
-  } else {
-    tosa.yield %arg1 : tensor<3xf32>
-  }
+  // CHECK: -> tensor<*xf32>
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%a: tensor<f32>, %b: tensor<3xf32>):
+    tosa.yield %a : tensor<f32>
+  },  {
+  ^bb0(%a: tensor<f32>, %b: tensor<3xf32>):
+    tosa.yield %b : tensor<3xf32>
+  }) : (tensor<i1>, tensor<f32>, tensor<3xf32>) -> tensor<*xf32>
   return
 }
 
@@ -1163,14 +1169,16 @@ func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 :
 // CHECK-LABEL: @if_test_propagate
 func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<f32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
-    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK: -> tensor<f32>
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%a: tensor<f32>, %b: tensor<f32>):
+    %1 = tosa.add %a, %b : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  } else {
-    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  },  {
+  ^bb0(%a: tensor<f32>, %b: tensor<f32>):
+    %1 = tosa.sub %a, %b : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return
 }
 
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 990e0d954f54e..e99608dfbeff4 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -502,14 +502,17 @@ func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor<f32>, %ar
 
 func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}}
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1, %2 : tensor<f32>, tensor<f32>
-  } else {
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
   return %0 : tensor<f32>
 }
 
@@ -517,13 +520,15 @@ func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg
 
 func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}}
-  %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+  %0, %2 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  } else {
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
   return %0 : tensor<f32>
 }
 
@@ -531,14 +536,16 @@ func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %a
 
 func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}}
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  } else {
-    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
-    %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %2 = tosa.add %1, %arg3 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1, %2 : tensor<f32>, tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return %0 : tensor<f32>
 }
 
@@ -546,14 +553,16 @@ func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg
 
 func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}}
-  %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
-    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
-    %2 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %0, %2 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %2 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1, %2 : tensor<f32>, tensor<f32>
-  } else {
-    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
   return %0 : tensor<f32>
 }
 

@tatwaichong
Copy link
Contributor Author

@lhutton1 @Jerry-Ge Could you guys review this?

Copy link
Member

@Jerry-Ge Jerry-Ge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants