Skip to content

[mlir][affine] Set overflow flags when lowering [de]linearize_index #139612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 13, 2025

Conversation

krzysz00
Copy link
Contributor

By analogy to some changess to the affine.apply lowering which put nsws on various multiplications, add appropritae overflow flags to the multiplications and additions that're emitted when lowering affine.delinearize_index and affine.linearize_index to arith ops.

By analogy to some changess to the affine.apply lowering which put
`nsw`s on various multiplications, add appropritae overflow flags to
the multiplications and additions that're emitted when lowering
affine.delinearize_index and affine.linearize_index to arith ops.
@llvmbot
Copy link
Member

llvmbot commented May 12, 2025

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

By analogy to some changess to the affine.apply lowering which put nsws on various multiplications, add appropritae overflow flags to the multiplications and additions that're emitted when lowering affine.delinearize_index and affine.linearize_index to arith ops.


Full diff: https://github.com/llvm/llvm-project/pull/139612.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+8)
  • (modified) mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp (+28-10)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dfe2a57df587..19fbcf64b2360 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1113,6 +1113,10 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
     Due to the constraints of affine maps, all the basis elements must
     be strictly positive. A dynamic basis element being 0 or negative causes
     undefined behavior.
+
+    As with other affine operations, lowerings of delinearize_index may assume
+    that the underlying computations do not overflow the index type in a signed sense
+    - that is, the product of all basis elements is positive as an `index` as well.
   }];
 
   let arguments = (ins Index:$linear_index,
@@ -1195,9 +1199,13 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     If the `disjoint` property is present, this is an optimization hint that,
     for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
     except that `%idx_0` may be negative to make the index as a whole negative.
+    In addition, `disjoint` is an assertion that all bases elements are non-negative.
 
     Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`.
 
+    As with other affine ops, undefined behavior occurs if the linearization
+    computation overflows in the signed sense.
+
     Example:
 
     ```mlir
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 7e335ea929c4f..35205a6ca2eee 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -35,10 +35,13 @@ using namespace mlir::affine;
 ///
 /// If excess dynamic values are provided, the values at the beginning
 /// will be ignored. This allows for dropping the outer bound without
-/// needing to manipulate the dynamic value array.
+/// needing to manipulate the dynamic value array. `knownPositive`
+/// indicases that the values being used to compute the strides are known
+/// to be non-negative.
 static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
                                          ValueRange dynamicBasis,
-                                         ArrayRef<int64_t> staticBasis) {
+                                         ArrayRef<int64_t> staticBasis,
+                                         bool knownNonNegative) {
   if (staticBasis.empty())
     return {};
 
@@ -47,11 +50,18 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
   size_t dynamicIndex = dynamicBasis.size();
   Value dynamicPart = nullptr;
   int64_t staticPart = 1;
+  // The products of the strides can't have overflow by definition of
+  // affine.*_index.
+  arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
+  if (knownNonNegative)
+    ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
   for (int64_t elem : llvm::reverse(staticBasis)) {
     if (ShapedType::isDynamic(elem)) {
+      // Note: basis elements and their products are, definitionally,
+      // non-negative, so `nuw` is justified.
       if (dynamicPart)
         dynamicPart = rewriter.create<arith::MulIOp>(
-            loc, dynamicPart, dynamicBasis[dynamicIndex - 1]);
+            loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags);
       else
         dynamicPart = dynamicBasis[dynamicIndex - 1];
       --dynamicIndex;
@@ -65,7 +75,8 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
       Value stride =
           rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
       if (dynamicPart)
-        stride = rewriter.create<arith::MulIOp>(loc, dynamicPart, stride);
+        stride =
+            rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags);
       result.push_back(stride);
     }
   }
@@ -96,7 +107,8 @@ struct LowerDelinearizeIndexOps
     SmallVector<Value> results;
     results.reserve(numResults);
     SmallVector<Value> strides =
-        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
+        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+                       /*knownNonNegative=*/true);
 
     Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
 
@@ -108,7 +120,11 @@ struct LowerDelinearizeIndexOps
       Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
       Value remainderNegative = rewriter.create<arith::CmpIOp>(
           loc, arith::CmpIPredicate::slt, remainder, zero);
-      Value corrected = rewriter.create<arith::AddIOp>(loc, remainder, stride);
+      // If the correction is relevant, this term is <= stride, which is known
+      // to be positive in `index`. Otherwise, while 2 * stride might overflow,
+      // this branch won't be taken, so the risk of `poison` is fine.
+      Value corrected = rewriter.create<arith::AddIOp>(
+          loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
       Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
                                                    corrected, remainder);
       return mod;
@@ -155,7 +171,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
       staticBasis = staticBasis.drop_front();
 
     SmallVector<Value> strides =
-        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
+        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+                       /*knownNonNegative=*/op.getDisjoint());
     SmallVector<std::pair<Value, int64_t>> scaledValues;
     scaledValues.reserve(numIndexes);
 
@@ -164,8 +181,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
     // our hands on an `OpOperand&` for the loop invariant counting function.
     for (auto [stride, idxOp] :
          llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
-      Value scaledIdx =
-          rewriter.create<arith::MulIOp>(loc, idxOp.get(), stride);
+      Value scaledIdx = rewriter.create<arith::MulIOp>(
+          loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
       int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
       scaledValues.emplace_back(scaledIdx, numHoistableLoops);
     }
@@ -182,7 +199,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
     for (auto [scaledValue, numHoistableLoops] :
          llvm::drop_begin(scaledValues)) {
       std::ignore = numHoistableLoops;
-      result = rewriter.create<arith::AddIOp>(loc, result, scaledValue);
+      result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
+                                              arith::IntegerOverflowFlags::nsw);
     }
     rewriter.replaceOp(op, result);
     return success();

@llvmbot
Copy link
Member

llvmbot commented May 12, 2025

@llvm/pr-subscribers-mlir-affine

Author: Krzysztof Drewniak (krzysz00)

Changes

By analogy to some changess to the affine.apply lowering which put nsws on various multiplications, add appropritae overflow flags to the multiplications and additions that're emitted when lowering affine.delinearize_index and affine.linearize_index to arith ops.


Full diff: https://github.com/llvm/llvm-project/pull/139612.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+8)
  • (modified) mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp (+28-10)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dfe2a57df587..19fbcf64b2360 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1113,6 +1113,10 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
     Due to the constraints of affine maps, all the basis elements must
     be strictly positive. A dynamic basis element being 0 or negative causes
     undefined behavior.
+
+    As with other affine operations, lowerings of delinearize_index may assume
+    that the underlying computations do not overflow the index type in a signed sense
+    - that is, the product of all basis elements is positive as an `index` as well.
   }];
 
   let arguments = (ins Index:$linear_index,
@@ -1195,9 +1199,13 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     If the `disjoint` property is present, this is an optimization hint that,
     for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
     except that `%idx_0` may be negative to make the index as a whole negative.
+    In addition, `disjoint` is an assertion that all bases elements are non-negative.
 
     Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`.
 
+    As with other affine ops, undefined behavior occurs if the linearization
+    computation overflows in the signed sense.
+
     Example:
 
     ```mlir
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 7e335ea929c4f..35205a6ca2eee 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -35,10 +35,13 @@ using namespace mlir::affine;
 ///
 /// If excess dynamic values are provided, the values at the beginning
 /// will be ignored. This allows for dropping the outer bound without
-/// needing to manipulate the dynamic value array.
+/// needing to manipulate the dynamic value array. `knownPositive`
+/// indicases that the values being used to compute the strides are known
+/// to be non-negative.
 static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
                                          ValueRange dynamicBasis,
-                                         ArrayRef<int64_t> staticBasis) {
+                                         ArrayRef<int64_t> staticBasis,
+                                         bool knownNonNegative) {
   if (staticBasis.empty())
     return {};
 
@@ -47,11 +50,18 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
   size_t dynamicIndex = dynamicBasis.size();
   Value dynamicPart = nullptr;
   int64_t staticPart = 1;
+  // The products of the strides can't have overflow by definition of
+  // affine.*_index.
+  arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
+  if (knownNonNegative)
+    ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
   for (int64_t elem : llvm::reverse(staticBasis)) {
     if (ShapedType::isDynamic(elem)) {
+      // Note: basis elements and their products are, definitionally,
+      // non-negative, so `nuw` is justified.
       if (dynamicPart)
         dynamicPart = rewriter.create<arith::MulIOp>(
-            loc, dynamicPart, dynamicBasis[dynamicIndex - 1]);
+            loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags);
       else
         dynamicPart = dynamicBasis[dynamicIndex - 1];
       --dynamicIndex;
@@ -65,7 +75,8 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
       Value stride =
           rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
       if (dynamicPart)
-        stride = rewriter.create<arith::MulIOp>(loc, dynamicPart, stride);
+        stride =
+            rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags);
       result.push_back(stride);
     }
   }
@@ -96,7 +107,8 @@ struct LowerDelinearizeIndexOps
     SmallVector<Value> results;
     results.reserve(numResults);
     SmallVector<Value> strides =
-        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
+        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+                       /*knownNonNegative=*/true);
 
     Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
 
@@ -108,7 +120,11 @@ struct LowerDelinearizeIndexOps
       Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
       Value remainderNegative = rewriter.create<arith::CmpIOp>(
           loc, arith::CmpIPredicate::slt, remainder, zero);
-      Value corrected = rewriter.create<arith::AddIOp>(loc, remainder, stride);
+      // If the correction is relevant, this term is <= stride, which is known
+      // to be positive in `index`. Otherwise, while 2 * stride might overflow,
+      // this branch won't be taken, so the risk of `poison` is fine.
+      Value corrected = rewriter.create<arith::AddIOp>(
+          loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
       Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
                                                    corrected, remainder);
       return mod;
@@ -155,7 +171,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
       staticBasis = staticBasis.drop_front();
 
     SmallVector<Value> strides =
-        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
+        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+                       /*knownNonNegative=*/op.getDisjoint());
     SmallVector<std::pair<Value, int64_t>> scaledValues;
     scaledValues.reserve(numIndexes);
 
@@ -164,8 +181,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
     // our hands on an `OpOperand&` for the loop invariant counting function.
     for (auto [stride, idxOp] :
          llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
-      Value scaledIdx =
-          rewriter.create<arith::MulIOp>(loc, idxOp.get(), stride);
+      Value scaledIdx = rewriter.create<arith::MulIOp>(
+          loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
       int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
       scaledValues.emplace_back(scaledIdx, numHoistableLoops);
     }
@@ -182,7 +199,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
     for (auto [scaledValue, numHoistableLoops] :
          llvm::drop_begin(scaledValues)) {
       std::ignore = numHoistableLoops;
-      result = rewriter.create<arith::AddIOp>(loc, result, scaledValue);
+      result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
+                                              arith::IntegerOverflowFlags::nsw);
     }
     rewriter.replaceOp(op, result);
     return success();

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some tests?

@krzysz00 krzysz00 merged commit 698fcb1 into llvm:main May 13, 2025
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants