-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][tosa] Stop support the custom simplified form of COND_IF #139576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Since the tensor_list_shape for input_list, output_list, then_graph, and else_graph is required to be equal according to the spec, this information must be explicitly provided during operation construction. The current custom simplified form does not meet this requirement. For example, the input_list and output_list can be empty in the simplified form. A new compatible simplified form will be introduced in the future if necessary.
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: TatWai Chong (tatwaichong) ChangesSince the tensor_list_shape for input_list, output_list, then_graph, and else_graph is required to be equal according to the spec, this information must be explicitly provided during operation construction. The current custom simplified form does not meet this requirement. For example, the input_list and output_list can be empty in the simplified form. A new compatible simplified form will be introduced in the future if necessary. Full diff: https://github.com/llvm/llvm-project/pull/139576.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 52bb0eb992b69..70aecfcfa3ec7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2558,7 +2558,6 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
SizedRegion<1>:$else_graph
);
- let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 371c6dc27b428..2d7c80cbf7848 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3518,65 +3518,6 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
return std::nullopt;
}
-// parse and print of IfOp refer to the implementation of SCF dialect.
-ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
- // Create the regions for 'then'.
- result.regions.reserve(2);
- Region *thenRegion = result.addRegion();
- Region *elseRegion = result.addRegion();
-
- auto &builder = parser.getBuilder();
- OpAsmParser::UnresolvedOperand cond;
- // Create a i1 tensor type for the boolean condition.
- Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
- if (parser.parseOperand(cond) ||
- parser.resolveOperand(cond, i1Type, result.operands))
- return failure();
- // Parse optional results type list.
- if (parser.parseOptionalArrowTypeList(result.types))
- return failure();
- // Parse the 'then' region.
- if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
- return failure();
-
- // If we find an 'else' keyword then parse the 'else' region.
- if (!parser.parseOptionalKeyword("else")) {
- if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
- return failure();
- }
-
- // Parse the optional attribute list.
- if (parser.parseOptionalAttrDict(result.attributes))
- return failure();
- return success();
-}
-
-void IfOp::print(OpAsmPrinter &p) {
- bool printBlockTerminators = false;
-
- p << " " << getCondition();
- if (!getResults().empty()) {
- p << " -> (" << getResultTypes() << ")";
- // Print yield explicitly if the op defines values.
- printBlockTerminators = true;
- }
- p << ' ';
- p.printRegion(getThenGraph(),
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/printBlockTerminators);
-
- // Print the 'else' regions if it exists and has a block.
- auto &elseRegion = getElseGraph();
- if (!elseRegion.empty()) {
- p << " else ";
- p.printRegion(elseRegion,
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/printBlockTerminators);
- }
-
- p.printOptionalAttrDict((*this)->getAttrs());
-}
-
LogicalResult IfOp::verify() {
if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
"'then_graph' arguments", getInputList(),
diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
index fa7a91cda0a47..78f5040eab97a 100644
--- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
+++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
@@ -36,20 +36,15 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) {
func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) {
// CHECK: [[EX:%.+]] = tensor.extract [[ARG2]]
// CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) {
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
-
- // CHECK: scf.yield [[ARG0]]
- tosa.yield %arg0 : tensor<f32>
-
- // CHECK: } else {
- } else {
-
- // CHECK: scf.yield [[ARG1]]
- tosa.yield %arg1 : tensor<f32>
-
- // CHECK: }
- // CHECK: return [[IF]]
- }
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ // CHECK: scf.yield [[ARG0]]
+ tosa.yield %arg3 : tensor<f32>
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ // CHECK: scf.yield [[ARG1]]
+ tosa.yield %arg4 : tensor<f32>
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 75126a11ac504..5381d6c533d01 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -645,13 +645,15 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// CHECK: tosa.cond_if profiles: [ ]
// CHECK: tosa.cond_if extensions: [ [controlflow] ]
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
- %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
- } else {
- %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
- }
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 2364985442e43..c688b6592ed9f 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -337,13 +337,15 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
// -----
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
- } else {
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
- }
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d24c1fa57883d..5b11aa782637a 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1503,40 +1503,87 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1:
// -----
-func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
- %1 = tosa.cond_if %arg3 -> (tensor<f32>) {
- %2 = tosa.cond_if %arg2 -> (tensor<f32>) {
- %3 = tosa.cond_if %arg3 -> (tensor<f32>) {
- %4 = tosa.cond_if %arg2 -> (tensor<f32>) {
+func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+
+ // COM: then graph of IF-1
+ ^bb1(%a1: tensor<f32>, %b1: tensor<f32>):
+ %cond1 = tosa.equal %a1, %b1 : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ %1 = "tosa.cond_if"(%cond1, %a1, %b1) ({
+
+ // COM: then graph of IF-2
+ ^bb2(%a2: tensor<f32>, %b2: tensor<f32>):
+ %cond2 = tosa.equal %a2, %b2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ %2 = "tosa.cond_if"(%cond2, %a2, %b2) ({
+
+ // COM: then graph of IF-3
+ ^bb3(%a3: tensor<f32>, %b3: tensor<f32>):
+ %cond3 = tosa.equal %a3, %b3 : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ %3 = "tosa.cond_if"(%cond3, %a3, %b3) ({
+
+ // COM: then graph of IF-4
+ ^bb4(%a4: tensor<f32>, %b4: tensor<f32>):
+ %cond4 = tosa.equal %a4, %b4 : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ %4 = "tosa.cond_if"(%cond4, %a4, %b4) ({
+
+ // COM: then graph of IF-5
+ ^bb5(%a5: tensor<f32>, %b5: tensor<f32>):
+ %cond5 = tosa.equal %a5, %b5 : (tensor<f32>, tensor<f32>) -> tensor<i1>
// expected-error@+1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}}
- %5 = tosa.cond_if %arg3 -> (tensor<f32>) {
+ %5 = "tosa.cond_if"(%cond5, %a5, %b5) ({
+
+ // COM: then graph of IF-6
+ ^bb6(%a6: tensor<f32>, %b6: tensor<f32>):
%res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %res : tensor<f32>
- } else {
+ }, {
+
+ // COM: else graph of IF-6
+ ^bb6(%a6: tensor<f32>, %b6: tensor<f32>):
tosa.yield %arg0 : tensor<f32>
- }
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
tosa.yield %5 : tensor<f32>
- } else {
+ }, {
+
+ // COM: else graph of IF-5
+ ^bb5(%a5: tensor<f32>, %b5: tensor<f32>):
+ tosa.yield %arg0 : tensor<f32>
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
+ tosa.yield %4 : tensor<f32>
+ }, {
+
+ // COM: else graph of IF-4
+ ^bb4(%a4: tensor<f32>, %b4: tensor<f32>):
tosa.yield %arg0 : tensor<f32>
- }
- tosa.yield %4 : tensor<f32>
- } else {
- tosa.yield %arg0 : tensor<f32>
- }
- tosa.yield %3 : tensor<f32>
- } else {
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
+ tosa.yield %3 : tensor<f32>
+ }, {
+
+ // COM: else graph of IF-3
+ ^bb3(%a3: tensor<f32>, %b3: tensor<f32>):
tosa.yield %arg0 : tensor<f32>
- }
- tosa.yield %2 : tensor<f32>
- } else {
- tosa.yield %arg0 : tensor<f32>
- }
- tosa.yield %1 : tensor<f32>
- } else {
- %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
+ tosa.yield %2 : tensor<f32>
+ }, {
+
+ // COM: else graph of IF-2
+ ^bb2(%a2: tensor<f32>, %b2: tensor<f32>):
+ tosa.yield %arg0 : tensor<f32>
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
+ tosa.yield %1 : tensor<f32>
+ }, {
+
+ // COM: else graph of IF-1
+ ^bb1(%a1: tensor<f32>, %b1: tensor<f32>):
+ %res = tosa.sub %a1, %b1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %res : tensor<f32>
- }
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
return %0 : tensor<f32>
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e327ed900f45f..e3036cf07171f 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -781,13 +781,15 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
// -----
// CHECK-LABEL: cond_if
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
- } else {
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
- }
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 1ad1e6c76c294..981e3cc7fc129 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1121,12 +1121,14 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
%b = tosa.log %arg1 : (tensor<f32>) -> tensor<f32>
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<f32>)
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
- tosa.yield %a : tensor<f32>
- } else {
- tosa.yield %b : tensor<f32>
- }
+ // CHECK: -> tensor<f32>
+ %0 = "tosa.cond_if"(%arg2, %a, %b) ({
+ ^bb0(%a1: tensor<f32>, %b1: tensor<f32>):
+ tosa.yield %a1 : tensor<f32>
+ }, {
+ ^bb0(%a1: tensor<f32>, %b1: tensor<f32>):
+ tosa.yield %b1 : tensor<f32>
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return
}
@@ -1135,12 +1137,14 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
// CHECK-LABEL: @if_test_dynamic
func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<?xf32>)
- %0 = tosa.cond_if %arg2 -> (tensor<?xf32>) {
- tosa.yield %arg0 : tensor<2xf32>
- } else {
- tosa.yield %arg1 : tensor<3xf32>
- }
+ // CHECK: -> tensor<?xf32>
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%a: tensor<2xf32>, %b: tensor<3xf32>):
+ tosa.yield %a : tensor<2xf32>
+ }, {
+ ^bb0(%a: tensor<2xf32>, %b: tensor<3xf32>):
+ tosa.yield %b : tensor<3xf32>
+ }) : (tensor<i1>, tensor<2xf32>, tensor<3xf32>) -> tensor<?xf32>
return
}
@@ -1149,12 +1153,14 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 :
// CHECK-LABEL: @if_test_unranked
func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<*xf32>)
- %0 = tosa.cond_if %arg2 -> (tensor<*xf32>) {
- tosa.yield %arg0 : tensor<f32>
- } else {
- tosa.yield %arg1 : tensor<3xf32>
- }
+ // CHECK: -> tensor<*xf32>
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%a: tensor<f32>, %b: tensor<3xf32>):
+ tosa.yield %a : tensor<f32>
+ }, {
+ ^bb0(%a: tensor<f32>, %b: tensor<3xf32>):
+ tosa.yield %b : tensor<3xf32>
+ }) : (tensor<i1>, tensor<f32>, tensor<3xf32>) -> tensor<*xf32>
return
}
@@ -1163,14 +1169,16 @@ func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 :
// CHECK-LABEL: @if_test_propagate
func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<f32>)
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
- %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: -> tensor<f32>
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%a: tensor<f32>, %b: tensor<f32>):
+ %1 = tosa.add %a, %b : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
- } else {
- %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ }, {
+ ^bb0(%a: tensor<f32>, %b: tensor<f32>):
+ %1 = tosa.sub %a, %b : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
- }
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return
}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 990e0d954f54e..e99608dfbeff4 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -502,14 +502,17 @@ func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor<f32>, %ar
func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}}
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1, %2 : tensor<f32>, tensor<f32>
- } else {
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
- }
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
return %0 : tensor<f32>
}
@@ -517,13 +520,15 @@ func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg
func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}}
- %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+ %0, %2 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
- } else {
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
- }
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
return %0 : tensor<f32>
}
@@ -531,14 +536,16 @@ func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %a
func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}}
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
- } else {
- %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
- %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %2 = tosa.add %1, %arg3 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1, %2 : tensor<f32>, tensor<f32>
- }
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
@@ -546,14 +553,16 @@ func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg
func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}}
- %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
- %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
- %2 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %0, %2 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %2 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1, %2 : tensor<f32>, tensor<f32>
- } else {
- %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
- }
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
return %0 : tensor<f32>
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the PR!
Since the tensor_list_shape for input_list, output_list, then_graph, and else_graph is required to be equal according to the spec, this information must be explicitly provided during operation construction. The current custom simplified form does not meet this requirement. For example, the input_list and output_list can be empty in the simplified form. A new compatible simplified form will be introduced in the future if necessary.