Skip to content

[mlir][tosa] Fix mul folder conformance to the spec #137601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2025

Conversation

lhutton1
Copy link
Contributor

@lhutton1 lhutton1 commented Apr 28, 2025

Change the folder for mul with a shift such that the rounding happens correctly according to the spec
pesudo-code.

Fixes: https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040
Partial cherry-pick from: #128059

Change the folder for mul with a shift such that the
rounding happens correctly according to the spec
pesudo-code.

Fixes: https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040
Partial cherry-pick from: llvm#128059

Co-authored-by: Tai Ly <[email protected]>
Change-Id: I3a8cf816cbf71c9ab839eb4f52768904cea29935
@llvmbot
Copy link
Member

llvmbot commented Apr 28, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

Change the folder for mul with a shift such that the rounding happens correctly according to the spec
pesudo-code.

Fixes: https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040
Partial cherry-pick from: #128059


Full diff: https://github.com/llvm/llvm-project/pull/137601.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+25-6)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+40-1)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 47368532df169..e73e2c4e33522 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -918,6 +918,27 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
 }
 
 namespace {
+// calculate lhs * rhs >> shift according to TOSA Spec
+// return nullopt if result is not in range of int32_t when shift > 0
+std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
+                            unsigned bitwidth) {
+  APInt result = lhs.sext(64) * rhs.sext(64);
+
+  if (shift > 0) {
+    auto round = APInt(64, 1) << (shift - 1);
+    result += round;
+    result.ashrInPlace(shift);
+    // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
+    if (!(result.getSExtValue() >= INT32_MIN &&
+          result.getSExtValue() <= INT32_MAX)) {
+      // REQUIRE failed
+      return std::nullopt;
+    }
+  }
+
+  return result.trunc(bitwidth);
+}
+
 DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
                                   RankedTensorType ty, int32_t shift) {
   if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
@@ -930,12 +951,10 @@ DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
       }
 
       auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
-      l = l.sext(bitwidth * 2);
-      r = r.sext(bitwidth * 2);
-      auto result = l * r;
-      result.lshrInPlace(shift);
-      result = result.trunc(bitwidth);
-      return DenseElementsAttr::get(ty, result);
+      const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
+      if (!result)
+        return {};
+      return DenseElementsAttr::get(ty, result.value());
     }
 
     if (llvm::isa<FloatType>(ty.getElementType())) {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 59fd490330691..c98335cdafe65 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1226,4 +1226,43 @@ func.func @slice_dynamic_size_static_output_canonicalize(%arg0: tensor<2x60x59x?
     %1 = tosa.const_shape  {values = dense<[-1, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
     %2 = tosa.slice %arg0, %0, %1 : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32>
     return %2 : tensor<2x60x58x?xf32>
-  }
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_shift
+// CHECK-DAG: "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+func.func @fold_mul_shift() -> tensor<i32> {
+    %0 = "tosa.const"() <{values = dense<-23661> : tensor<i32>}> : () -> tensor<i32>
+    %1 = "tosa.const"() <{values = dense<-33022> : tensor<i32>}> : () -> tensor<i32>
+    %2 = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8>
+    %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+    return %3 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_mul_no_shift
+// CHECK-DAG: "tosa.const"() <{values = dense<781333542> : tensor<i32>}> : () -> tensor<i32>
+func.func @fold_mul_no_shift() -> tensor<i32> {
+    %0 = "tosa.const"() <{values = dense<-23661> : tensor<i32>}> : () -> tensor<i32>
+    %1 = "tosa.const"() <{values = dense<-33022> : tensor<i32>}> : () -> tensor<i32>
+    %2 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+    %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+    return %3 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @no_fold_mul_result_exceeds_i32
+// CHECK-DAG: %[[LHS:.*]] = "tosa.const"() <{values = dense<23661> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[RHS:.*]] = "tosa.const"() <{values = dense<330222> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: tosa.mul %[[LHS]], %[[RHS]], %[[SHIFT]] : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+func.func @no_fold_mul_result_exceeds_i32() -> tensor<i32> {
+    %0 = "tosa.const"() <{values = dense<23661> : tensor<i32>}> : () -> tensor<i32>
+    %1 = "tosa.const"() <{values = dense<330222> : tensor<i32>}> : () -> tensor<i32>
+    %2 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+    %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+    return %3 : tensor<i32>
+}

result.lshrInPlace(shift);
result = result.trunc(bitwidth);
return DenseElementsAttr::get(ty, result);
const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, reviewed

@lhutton1 lhutton1 merged commit 18f8928 into llvm:main May 8, 2025
14 checks passed
@lhutton1 lhutton1 deleted the fix-mul-folder branch May 8, 2025 09:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants