-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][affine] Set overflow flags when lowering [de]linearize_index #139612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][affine] Set overflow flags when lowering [de]linearize_index #139612
Conversation
By analogy to some changess to the affine.apply lowering which put `nsw`s on various multiplications, add appropritae overflow flags to the multiplications and additions that're emitted when lowering affine.delinearize_index and affine.linearize_index to arith ops.
@llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesBy analogy to some changess to the affine.apply lowering which put Full diff: https://github.com/llvm/llvm-project/pull/139612.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dfe2a57df587..19fbcf64b2360 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1113,6 +1113,10 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
Due to the constraints of affine maps, all the basis elements must
be strictly positive. A dynamic basis element being 0 or negative causes
undefined behavior.
+
+ As with other affine operations, lowerings of delinearize_index may assume
+ that the underlying computations do not overflow the index type in a signed sense
+ - that is, the product of all basis elements is positive as an `index` as well.
}];
let arguments = (ins Index:$linear_index,
@@ -1195,9 +1199,13 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
If the `disjoint` property is present, this is an optimization hint that,
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
except that `%idx_0` may be negative to make the index as a whole negative.
+ In addition, `disjoint` is an assertion that all bases elements are non-negative.
Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`.
+ As with other affine ops, undefined behavior occurs if the linearization
+ computation overflows in the signed sense.
+
Example:
```mlir
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 7e335ea929c4f..35205a6ca2eee 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -35,10 +35,13 @@ using namespace mlir::affine;
///
/// If excess dynamic values are provided, the values at the beginning
/// will be ignored. This allows for dropping the outer bound without
-/// needing to manipulate the dynamic value array.
+/// needing to manipulate the dynamic value array. `knownPositive`
+/// indicases that the values being used to compute the strides are known
+/// to be non-negative.
static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
ValueRange dynamicBasis,
- ArrayRef<int64_t> staticBasis) {
+ ArrayRef<int64_t> staticBasis,
+ bool knownNonNegative) {
if (staticBasis.empty())
return {};
@@ -47,11 +50,18 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
size_t dynamicIndex = dynamicBasis.size();
Value dynamicPart = nullptr;
int64_t staticPart = 1;
+ // The products of the strides can't have overflow by definition of
+ // affine.*_index.
+ arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
+ if (knownNonNegative)
+ ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
for (int64_t elem : llvm::reverse(staticBasis)) {
if (ShapedType::isDynamic(elem)) {
+ // Note: basis elements and their products are, definitionally,
+ // non-negative, so `nuw` is justified.
if (dynamicPart)
dynamicPart = rewriter.create<arith::MulIOp>(
- loc, dynamicPart, dynamicBasis[dynamicIndex - 1]);
+ loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags);
else
dynamicPart = dynamicBasis[dynamicIndex - 1];
--dynamicIndex;
@@ -65,7 +75,8 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
Value stride =
rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
if (dynamicPart)
- stride = rewriter.create<arith::MulIOp>(loc, dynamicPart, stride);
+ stride =
+ rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags);
result.push_back(stride);
}
}
@@ -96,7 +107,8 @@ struct LowerDelinearizeIndexOps
SmallVector<Value> results;
results.reserve(numResults);
SmallVector<Value> strides =
- computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
+ computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+ /*knownNonNegative=*/true);
Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
@@ -108,7 +120,11 @@ struct LowerDelinearizeIndexOps
Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
Value remainderNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, remainder, zero);
- Value corrected = rewriter.create<arith::AddIOp>(loc, remainder, stride);
+ // If the correction is relevant, this term is <= stride, which is known
+ // to be positive in `index`. Otherwise, while 2 * stride might overflow,
+ // this branch won't be taken, so the risk of `poison` is fine.
+ Value corrected = rewriter.create<arith::AddIOp>(
+ loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
corrected, remainder);
return mod;
@@ -155,7 +171,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
staticBasis = staticBasis.drop_front();
SmallVector<Value> strides =
- computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
+ computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+ /*knownNonNegative=*/op.getDisjoint());
SmallVector<std::pair<Value, int64_t>> scaledValues;
scaledValues.reserve(numIndexes);
@@ -164,8 +181,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
// our hands on an `OpOperand&` for the loop invariant counting function.
for (auto [stride, idxOp] :
llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
- Value scaledIdx =
- rewriter.create<arith::MulIOp>(loc, idxOp.get(), stride);
+ Value scaledIdx = rewriter.create<arith::MulIOp>(
+ loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
scaledValues.emplace_back(scaledIdx, numHoistableLoops);
}
@@ -182,7 +199,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
for (auto [scaledValue, numHoistableLoops] :
llvm::drop_begin(scaledValues)) {
std::ignore = numHoistableLoops;
- result = rewriter.create<arith::AddIOp>(loc, result, scaledValue);
+ result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
+ arith::IntegerOverflowFlags::nsw);
}
rewriter.replaceOp(op, result);
return success();
|
@llvm/pr-subscribers-mlir-affine Author: Krzysztof Drewniak (krzysz00) ChangesBy analogy to some changess to the affine.apply lowering which put Full diff: https://github.com/llvm/llvm-project/pull/139612.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dfe2a57df587..19fbcf64b2360 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1113,6 +1113,10 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
Due to the constraints of affine maps, all the basis elements must
be strictly positive. A dynamic basis element being 0 or negative causes
undefined behavior.
+
+ As with other affine operations, lowerings of delinearize_index may assume
+ that the underlying computations do not overflow the index type in a signed sense
+ - that is, the product of all basis elements is positive as an `index` as well.
}];
let arguments = (ins Index:$linear_index,
@@ -1195,9 +1199,13 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
If the `disjoint` property is present, this is an optimization hint that,
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
except that `%idx_0` may be negative to make the index as a whole negative.
+ In addition, `disjoint` is an assertion that all bases elements are non-negative.
Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`.
+ As with other affine ops, undefined behavior occurs if the linearization
+ computation overflows in the signed sense.
+
Example:
```mlir
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 7e335ea929c4f..35205a6ca2eee 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -35,10 +35,13 @@ using namespace mlir::affine;
///
/// If excess dynamic values are provided, the values at the beginning
/// will be ignored. This allows for dropping the outer bound without
-/// needing to manipulate the dynamic value array.
+/// needing to manipulate the dynamic value array. `knownPositive`
+/// indicases that the values being used to compute the strides are known
+/// to be non-negative.
static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
ValueRange dynamicBasis,
- ArrayRef<int64_t> staticBasis) {
+ ArrayRef<int64_t> staticBasis,
+ bool knownNonNegative) {
if (staticBasis.empty())
return {};
@@ -47,11 +50,18 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
size_t dynamicIndex = dynamicBasis.size();
Value dynamicPart = nullptr;
int64_t staticPart = 1;
+ // The products of the strides can't have overflow by definition of
+ // affine.*_index.
+ arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
+ if (knownNonNegative)
+ ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
for (int64_t elem : llvm::reverse(staticBasis)) {
if (ShapedType::isDynamic(elem)) {
+ // Note: basis elements and their products are, definitionally,
+ // non-negative, so `nuw` is justified.
if (dynamicPart)
dynamicPart = rewriter.create<arith::MulIOp>(
- loc, dynamicPart, dynamicBasis[dynamicIndex - 1]);
+ loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags);
else
dynamicPart = dynamicBasis[dynamicIndex - 1];
--dynamicIndex;
@@ -65,7 +75,8 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
Value stride =
rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
if (dynamicPart)
- stride = rewriter.create<arith::MulIOp>(loc, dynamicPart, stride);
+ stride =
+ rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags);
result.push_back(stride);
}
}
@@ -96,7 +107,8 @@ struct LowerDelinearizeIndexOps
SmallVector<Value> results;
results.reserve(numResults);
SmallVector<Value> strides =
- computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
+ computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+ /*knownNonNegative=*/true);
Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
@@ -108,7 +120,11 @@ struct LowerDelinearizeIndexOps
Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
Value remainderNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, remainder, zero);
- Value corrected = rewriter.create<arith::AddIOp>(loc, remainder, stride);
+ // If the correction is relevant, this term is <= stride, which is known
+ // to be positive in `index`. Otherwise, while 2 * stride might overflow,
+ // this branch won't be taken, so the risk of `poison` is fine.
+ Value corrected = rewriter.create<arith::AddIOp>(
+ loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
corrected, remainder);
return mod;
@@ -155,7 +171,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
staticBasis = staticBasis.drop_front();
SmallVector<Value> strides =
- computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
+ computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+ /*knownNonNegative=*/op.getDisjoint());
SmallVector<std::pair<Value, int64_t>> scaledValues;
scaledValues.reserve(numIndexes);
@@ -164,8 +181,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
// our hands on an `OpOperand&` for the loop invariant counting function.
for (auto [stride, idxOp] :
llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
- Value scaledIdx =
- rewriter.create<arith::MulIOp>(loc, idxOp.get(), stride);
+ Value scaledIdx = rewriter.create<arith::MulIOp>(
+ loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
scaledValues.emplace_back(scaledIdx, numHoistableLoops);
}
@@ -182,7 +199,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
for (auto [scaledValue, numHoistableLoops] :
llvm::drop_begin(scaledValues)) {
std::ignore = numHoistableLoops;
- result = rewriter.create<arith::AddIOp>(loc, result, scaledValue);
+ result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
+ arith::IntegerOverflowFlags::nsw);
}
rewriter.replaceOp(op, result);
return success();
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some tests?
By analogy to some changess to the affine.apply lowering which put
nsw
s on various multiplications, add appropritae overflow flags to the multiplications and additions that're emitted when lowering affine.delinearize_index and affine.linearize_index to arith ops.