-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][tosa] Fix mul folder conformance to the spec #137601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Change the folder for mul with a shift such that the rounding happens correctly according to the spec pesudo-code. Fixes: https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040 Partial cherry-pick from: llvm#128059 Co-authored-by: Tai Ly <[email protected]> Change-Id: I3a8cf816cbf71c9ab839eb4f52768904cea29935
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Luke Hutton (lhutton1) ChangesChange the folder for mul with a shift such that the rounding happens correctly according to the spec Fixes: https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040 Full diff: https://github.com/llvm/llvm-project/pull/137601.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 47368532df169..e73e2c4e33522 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -918,6 +918,27 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
}
namespace {
+// calculate lhs * rhs >> shift according to TOSA Spec
+// return nullopt if result is not in range of int32_t when shift > 0
+std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
+ unsigned bitwidth) {
+ APInt result = lhs.sext(64) * rhs.sext(64);
+
+ if (shift > 0) {
+ auto round = APInt(64, 1) << (shift - 1);
+ result += round;
+ result.ashrInPlace(shift);
+ // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
+ if (!(result.getSExtValue() >= INT32_MIN &&
+ result.getSExtValue() <= INT32_MAX)) {
+ // REQUIRE failed
+ return std::nullopt;
+ }
+ }
+
+ return result.trunc(bitwidth);
+}
+
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType ty, int32_t shift) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
@@ -930,12 +951,10 @@ DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
}
auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
- l = l.sext(bitwidth * 2);
- r = r.sext(bitwidth * 2);
- auto result = l * r;
- result.lshrInPlace(shift);
- result = result.trunc(bitwidth);
- return DenseElementsAttr::get(ty, result);
+ const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
+ if (!result)
+ return {};
+ return DenseElementsAttr::get(ty, result.value());
}
if (llvm::isa<FloatType>(ty.getElementType())) {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 59fd490330691..c98335cdafe65 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1226,4 +1226,43 @@ func.func @slice_dynamic_size_static_output_canonicalize(%arg0: tensor<2x60x59x?
%1 = tosa.const_shape {values = dense<[-1, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
%2 = tosa.slice %arg0, %0, %1 : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32>
return %2 : tensor<2x60x58x?xf32>
- }
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_shift
+// CHECK-DAG: "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+func.func @fold_mul_shift() -> tensor<i32> {
+ %0 = "tosa.const"() <{values = dense<-23661> : tensor<i32>}> : () -> tensor<i32>
+ %1 = "tosa.const"() <{values = dense<-33022> : tensor<i32>}> : () -> tensor<i32>
+ %2 = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+ return %3 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_no_shift
+// CHECK-DAG: "tosa.const"() <{values = dense<781333542> : tensor<i32>}> : () -> tensor<i32>
+func.func @fold_mul_no_shift() -> tensor<i32> {
+ %0 = "tosa.const"() <{values = dense<-23661> : tensor<i32>}> : () -> tensor<i32>
+ %1 = "tosa.const"() <{values = dense<-33022> : tensor<i32>}> : () -> tensor<i32>
+ %2 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+ return %3 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @no_fold_mul_result_exceeds_i32
+// CHECK-DAG: %[[LHS:.*]] = "tosa.const"() <{values = dense<23661> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[RHS:.*]] = "tosa.const"() <{values = dense<330222> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: tosa.mul %[[LHS]], %[[RHS]], %[[SHIFT]] : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+func.func @no_fold_mul_result_exceeds_i32() -> tensor<i32> {
+ %0 = "tosa.const"() <{values = dense<23661> : tensor<i32>}> : () -> tensor<i32>
+ %1 = "tosa.const"() <{values = dense<330222> : tensor<i32>}> : () -> tensor<i32>
+ %2 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+ return %3 : tensor<i32>
+}
|
result.lshrInPlace(shift); | ||
result = result.trunc(bitwidth); | ||
return DenseElementsAttr::get(ty, result); | ||
const std::optional<APInt> result = mulInt(l, r, shift, bitwidth); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, reviewed
Change the folder for mul with a shift such that the rounding happens correctly according to the spec
pesudo-code.
Fixes: https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040
Partial cherry-pick from: #128059