diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index 1dfe2a57df587..19fbcf64b2360 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -1113,6 +1113,10 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> { Due to the constraints of affine maps, all the basis elements must be strictly positive. A dynamic basis element being 0 or negative causes undefined behavior. + + As with other affine operations, lowerings of delinearize_index may assume + that the underlying computations do not overflow the index type in a signed sense + - that is, the product of all basis elements is positive as an `index` as well. }]; let arguments = (ins Index:$linear_index, @@ -1195,9 +1199,13 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index", If the `disjoint` property is present, this is an optimization hint that, for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index, except that `%idx_0` may be negative to make the index as a whole negative. + In addition, `disjoint` is an assertion that all bases elements are non-negative. Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`. + As with other affine ops, undefined behavior occurs if the linearization + computation overflows in the signed sense. + Example: ```mlir diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp index 7e335ea929c4f..35205a6ca2eee 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp @@ -35,10 +35,13 @@ using namespace mlir::affine; /// /// If excess dynamic values are provided, the values at the beginning /// will be ignored. This allows for dropping the outer bound without -/// needing to manipulate the dynamic value array. +/// needing to manipulate the dynamic value array. `knownPositive` +/// indicases that the values being used to compute the strides are known +/// to be non-negative. static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter, ValueRange dynamicBasis, - ArrayRef<int64_t> staticBasis) { + ArrayRef<int64_t> staticBasis, + bool knownNonNegative) { if (staticBasis.empty()) return {}; @@ -47,11 +50,18 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter, size_t dynamicIndex = dynamicBasis.size(); Value dynamicPart = nullptr; int64_t staticPart = 1; + // The products of the strides can't have overflow by definition of + // affine.*_index. + arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw; + if (knownNonNegative) + ovflags = ovflags | arith::IntegerOverflowFlags::nuw; for (int64_t elem : llvm::reverse(staticBasis)) { if (ShapedType::isDynamic(elem)) { + // Note: basis elements and their products are, definitionally, + // non-negative, so `nuw` is justified. if (dynamicPart) dynamicPart = rewriter.create<arith::MulIOp>( - loc, dynamicPart, dynamicBasis[dynamicIndex - 1]); + loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags); else dynamicPart = dynamicBasis[dynamicIndex - 1]; --dynamicIndex; @@ -65,7 +75,8 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter, Value stride = rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart); if (dynamicPart) - stride = rewriter.create<arith::MulIOp>(loc, dynamicPart, stride); + stride = + rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags); result.push_back(stride); } } @@ -96,7 +107,8 @@ struct LowerDelinearizeIndexOps SmallVector<Value> results; results.reserve(numResults); SmallVector<Value> strides = - computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis); + computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis, + /*knownNonNegative=*/true); Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); @@ -108,7 +120,11 @@ struct LowerDelinearizeIndexOps Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride); Value remainderNegative = rewriter.create<arith::CmpIOp>( loc, arith::CmpIPredicate::slt, remainder, zero); - Value corrected = rewriter.create<arith::AddIOp>(loc, remainder, stride); + // If the correction is relevant, this term is <= stride, which is known + // to be positive in `index`. Otherwise, while 2 * stride might overflow, + // this branch won't be taken, so the risk of `poison` is fine. + Value corrected = rewriter.create<arith::AddIOp>( + loc, remainder, stride, arith::IntegerOverflowFlags::nsw); Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative, corrected, remainder); return mod; @@ -155,7 +171,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> { staticBasis = staticBasis.drop_front(); SmallVector<Value> strides = - computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis); + computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis, + /*knownNonNegative=*/op.getDisjoint()); SmallVector<std::pair<Value, int64_t>> scaledValues; scaledValues.reserve(numIndexes); @@ -164,8 +181,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> { // our hands on an `OpOperand&` for the loop invariant counting function. for (auto [stride, idxOp] : llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) { - Value scaledIdx = - rewriter.create<arith::MulIOp>(loc, idxOp.get(), stride); + Value scaledIdx = rewriter.create<arith::MulIOp>( + loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw); int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp); scaledValues.emplace_back(scaledIdx, numHoistableLoops); } @@ -182,7 +199,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> { for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) { std::ignore = numHoistableLoops; - result = rewriter.create<arith::AddIOp>(loc, result, scaledValue); + result = rewriter.create<arith::AddIOp>(loc, result, scaledValue, + arith::IntegerOverflowFlags::nsw); } rewriter.replaceOp(op, result); return success(); diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir index 9bfaafb8c2468..202050489b7e4 100644 --- a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir +++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir @@ -8,12 +8,12 @@ // CHECK: %[[N:.+]] = arith.floordivsi %[[IDX]], %[[C50176]] // CHECK-DAG: %[[P_REM:.+]] = arith.remsi %[[IDX]], %[[C50176]] // CHECK-DAG: %[[P_NEG:.+]] = arith.cmpi slt, %[[P_REM]], %[[C0]] -// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[C50176]] +// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[C50176]] overflow<nsw> // CHECK-DAG: %[[P_MOD:.+]] = arith.select %[[P_NEG]], %[[P_SHIFTED]], %[[P_REM]] // CHECK: %[[P:.+]] = arith.divsi %[[P_MOD]], %[[C224]] // CHECK-DAG: %[[Q_REM:.+]] = arith.remsi %[[IDX]], %[[C224]] // CHECK-DAG: %[[Q_NEG:.+]] = arith.cmpi slt, %[[Q_REM]], %[[C0]] -// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[C224]] +// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[C224]] overflow<nsw> // CHECK: %[[Q:.+]] = arith.select %[[Q_NEG]], %[[Q_SHIFTED]], %[[Q_REM]] // CHECK: return %[[N]], %[[P]], %[[Q]] func.func @delinearize_static_basis(%linear_index: index) -> (index, index, index) { @@ -30,16 +30,16 @@ func.func @delinearize_static_basis(%linear_index: index) -> (index, index, inde // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[DIM1:.+]] = memref.dim %[[MEMREF]], %[[C1]] : // CHECK: %[[DIM2:.+]] = memref.dim %[[MEMREF]], %[[C2]] : -// CHECK: %[[STRIDE1:.+]] = arith.muli %[[DIM2]], %[[DIM1]] +// CHECK: %[[STRIDE1:.+]] = arith.muli %[[DIM2]], %[[DIM1]] overflow<nsw, nuw> // CHECK: %[[N:.+]] = arith.floordivsi %[[IDX]], %[[STRIDE1]] // CHECK-DAG: %[[P_REM:.+]] = arith.remsi %[[IDX]], %[[STRIDE1]] // CHECK-DAG: %[[P_NEG:.+]] = arith.cmpi slt, %[[P_REM]], %[[C0]] -// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[STRIDE1]] +// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[STRIDE1]] overflow<nsw> // CHECK-DAG: %[[P_MOD:.+]] = arith.select %[[P_NEG]], %[[P_SHIFTED]], %[[P_REM]] // CHECK: %[[P:.+]] = arith.divsi %[[P_MOD]], %[[DIM2]] // CHECK-DAG: %[[Q_REM:.+]] = arith.remsi %[[IDX]], %[[DIM2]] // CHECK-DAG: %[[Q_NEG:.+]] = arith.cmpi slt, %[[Q_REM]], %[[C0]] -// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[DIM2]] +// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[DIM2]] overflow<nsw> // CHECK: %[[Q:.+]] = arith.select %[[Q_NEG]], %[[Q_SHIFTED]], %[[Q_REM]] // CHECK: return %[[N]], %[[P]], %[[Q]] func.func @delinearize_dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (index, index, index) { @@ -58,10 +58,10 @@ func.func @delinearize_dynamic_basis(%linear_index: index, %src: memref<?x?x?xf3 // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index) // CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index // CHECK-DAG: %[[C15:.+]] = arith.constant 15 : index -// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[C15]] -// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[C5]] -// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]] -// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]] +// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[C15]] overflow<nsw> +// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[C5]] overflow<nsw> +// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]] overflow<nsw> +// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]] overflow<nsw> // CHECK: return %[[val_1]] func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index { %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index @@ -72,11 +72,11 @@ func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index { // CHECK-LABEL: @linearize_dynamic // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index) -// CHECK: %[[stride_0:.+]] = arith.muli %[[arg4]], %[[arg3]] -// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[stride_0]] -// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[arg4]] -// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]] -// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]] +// CHECK: %[[stride_0:.+]] = arith.muli %[[arg4]], %[[arg3]] overflow<nsw> +// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[stride_0]] overflow<nsw> +// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[arg4]] overflow<nsw> +// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]] overflow<nsw> +// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]] overflow<nsw> // CHECK: return %[[val_1]] func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> index { // Note: no outer bounds @@ -86,17 +86,33 @@ func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: in // ----- +// CHECK-LABEL: @linearize_dynamic_disjoint +// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index) +// CHECK: %[[stride_0:.+]] = arith.muli %[[arg4]], %[[arg3]] overflow<nsw, nuw> +// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[stride_0]] overflow<nsw> +// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[arg4]] overflow<nsw> +// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]] overflow<nsw> +// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]] overflow<nsw> +// CHECK: return %[[val_1]] +func.func @linearize_dynamic_disjoint(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> index { + // Note: no outer bounds + %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index + func.return %0 : index +} + +// ----- + // CHECK-LABEL: @linearize_sort_adds // CHECK-SAME: (%[[arg0:.+]]: memref<?xi32>, %[[arg1:.+]]: index, %[[arg2:.+]]: index) // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index // CHECK: scf.for %[[arg3:.+]] = %{{.*}} to %[[arg2]] step %{{.*}} { // CHECK: scf.for %[[arg4:.+]] = %{{.*}} to %[[C4]] step %{{.*}} { -// CHECK: %[[stride_0:.+]] = arith.muli %[[arg2]], %[[C4]] -// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg1]], %[[stride_0]] -// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg4]], %[[arg2]] +// CHECK: %[[stride_0:.+]] = arith.muli %[[arg2]], %[[C4]] overflow<nsw, nuw> +// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg1]], %[[stride_0]] overflow<nsw> +// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg4]], %[[arg2]] overflow<nsw> // Note: even though %arg3 has a lower stride, we add it first -// CHECK: %[[val_0_2:.+]] = arith.addi %[[scaled_0]], %[[arg3]] -// CHECK: %[[val_1:.+]] = arith.addi %[[val_0_2]], %[[scaled_1]] +// CHECK: %[[val_0_2:.+]] = arith.addi %[[scaled_0]], %[[arg3]] overflow<nsw> +// CHECK: %[[val_1:.+]] = arith.addi %[[val_0_2]], %[[scaled_1]] overflow<nsw> // CHECK: memref.store %{{.*}}, %[[arg0]][%[[val_1]]] func.func @linearize_sort_adds(%arg0: memref<?xi32>, %arg1: index, %arg2: index) { %c0 = arith.constant 0 : index