-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir] [presburger] Add IntegerRelation::rangeProduct #148092
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This is intended to match `isl::map`'s `flat_range_product`.
@llvm/pr-subscribers-mlir Author: Jeremy Kun (j2kun) ChangesThis is intended to match I'd like to add some more tests, so hoping for a brief early review to make sure I'm going in the right direction. Full diff: https://github.com/llvm/llvm-project/pull/148092.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index b68262f09f485..ae213181f2c3b 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -707,6 +707,21 @@ class IntegerRelation {
/// this for uniformity with `applyDomain`.
void applyRange(const IntegerRelation &rel);
+ /// Let the relation `this` be R1, and the relation `rel` be R2. Requires
+ /// R1 and R2 to have the same domain.
+ ///
+ /// This operation computes the relation whose domain is the same as R1 and
+ /// whose range is the product of the ranges of R1 and R2, and whose
+ /// constraints are the conjunction of the constraints of R1 and R2 applied
+ /// to the relevant subspaces of the range.
+ ///
+ /// Example:
+ ///
+ /// R1: (i, j) -> k : f(i, j, k) = 0
+ /// R2: (i, j) -> l : g(i, j, l) = 0
+ /// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0
+ IntegerRelation rangeProduct(const IntegerRelation &rel);
+
/// Given a relation `other: (A -> B)`, this operation merges the symbol and
/// local variables and then takes the composition of `other` on `this: (B ->
/// C)`. The resulting relation represents tuples of the form: `A -> C`.
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 631e085574fd0..87cec81e86b59 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -2481,6 +2481,41 @@ void IntegerRelation::applyDomain(const IntegerRelation &rel) {
void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); }
+IntegerRelation IntegerRelation::rangeProduct(const IntegerRelation &rel) {
+ /// R1: (i, j) -> k : f(i, j, k) = 0
+ /// R2: (i, j) -> l : g(i, j, l) = 0
+ /// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0
+ assert(getNumDomainVars() == rel.getNumDomainVars() &&
+ "Range product is only defined for relations with equal domains");
+
+ // explicit copy of the context relation
+ IntegerRelation result = *this;
+ unsigned srcOffset = getVarKindOffset(VarKind::Range);
+ unsigned newNumRangeVars = rel.getNumRangeVars();
+
+ result.appendVar(VarKind::Range, newNumRangeVars);
+
+ for (unsigned i = 0; i < rel.getNumEqualities(); ++i) {
+ // Add a new equality that uses the new range variables.
+ // The old equality is a list of coefficients of the variables
+ // from `rel`, and so the range variables need to be shifted
+ // right by the number of range variables added to `result`.
+ SmallVector<DynamicAPInt> copy =
+ SmallVector<DynamicAPInt>(rel.getEquality(i));
+ copy.insert(copy.begin() + srcOffset, newNumRangeVars, DynamicAPInt(0));
+ result.addEquality(copy);
+ }
+
+ for (unsigned i = 0; i < rel.getNumInequalities(); ++i) {
+ SmallVector<DynamicAPInt> copy =
+ SmallVector<DynamicAPInt>(rel.getInequality(i));
+ copy.insert(copy.begin() + srcOffset, newNumRangeVars, DynamicAPInt(0));
+ result.addInequality(copy);
+ }
+
+ return result;
+}
+
void IntegerRelation::printSpace(raw_ostream &os) const {
space.print(os);
os << getNumConstraints() << " constraints\n";
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 7df500bc9568a..dd8b9e3f03330 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -608,3 +608,17 @@ TEST(IntegerRelationTest, convertVarKindToLocal) {
EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
}
+
+TEST(IntegerRelationTest, rangeProduct) {
+ IntegerRelation r1 = parseRelationFromSet(
+ "(i, j, k) : (2*i + 3*k == 0, i >= 0, j >= 0, k >= 0)", 2);
+ IntegerRelation r2 = parseRelationFromSet(
+ "(i, j, l) : (4*i + 6*j + 9*l == 0, i >= 0, j >= 0, l >= 0)", 2);
+
+ IntegerRelation rangeProd = r1.rangeProduct(r2);
+ IntegerRelation expected = parseRelationFromSet(
+ "(i, j, k, l) : (2*i + 3*k == 0, 4*i + 6*j + 9*l == 0, i >= 0, j >= 0, k >= 0, l >= 0)", 2);
+
+ EXPECT_TRUE(expected.isEqual(rangeProd));
+}
+
|
@llvm/pr-subscribers-mlir-presburger Author: Jeremy Kun (j2kun) ChangesThis is intended to match I'd like to add some more tests, so hoping for a brief early review to make sure I'm going in the right direction. Full diff: https://github.com/llvm/llvm-project/pull/148092.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index b68262f09f485..ae213181f2c3b 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -707,6 +707,21 @@ class IntegerRelation {
/// this for uniformity with `applyDomain`.
void applyRange(const IntegerRelation &rel);
+ /// Let the relation `this` be R1, and the relation `rel` be R2. Requires
+ /// R1 and R2 to have the same domain.
+ ///
+ /// This operation computes the relation whose domain is the same as R1 and
+ /// whose range is the product of the ranges of R1 and R2, and whose
+ /// constraints are the conjunction of the constraints of R1 and R2 applied
+ /// to the relevant subspaces of the range.
+ ///
+ /// Example:
+ ///
+ /// R1: (i, j) -> k : f(i, j, k) = 0
+ /// R2: (i, j) -> l : g(i, j, l) = 0
+ /// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0
+ IntegerRelation rangeProduct(const IntegerRelation &rel);
+
/// Given a relation `other: (A -> B)`, this operation merges the symbol and
/// local variables and then takes the composition of `other` on `this: (B ->
/// C)`. The resulting relation represents tuples of the form: `A -> C`.
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 631e085574fd0..87cec81e86b59 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -2481,6 +2481,41 @@ void IntegerRelation::applyDomain(const IntegerRelation &rel) {
void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); }
+IntegerRelation IntegerRelation::rangeProduct(const IntegerRelation &rel) {
+ /// R1: (i, j) -> k : f(i, j, k) = 0
+ /// R2: (i, j) -> l : g(i, j, l) = 0
+ /// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0
+ assert(getNumDomainVars() == rel.getNumDomainVars() &&
+ "Range product is only defined for relations with equal domains");
+
+ // explicit copy of the context relation
+ IntegerRelation result = *this;
+ unsigned srcOffset = getVarKindOffset(VarKind::Range);
+ unsigned newNumRangeVars = rel.getNumRangeVars();
+
+ result.appendVar(VarKind::Range, newNumRangeVars);
+
+ for (unsigned i = 0; i < rel.getNumEqualities(); ++i) {
+ // Add a new equality that uses the new range variables.
+ // The old equality is a list of coefficients of the variables
+ // from `rel`, and so the range variables need to be shifted
+ // right by the number of range variables added to `result`.
+ SmallVector<DynamicAPInt> copy =
+ SmallVector<DynamicAPInt>(rel.getEquality(i));
+ copy.insert(copy.begin() + srcOffset, newNumRangeVars, DynamicAPInt(0));
+ result.addEquality(copy);
+ }
+
+ for (unsigned i = 0; i < rel.getNumInequalities(); ++i) {
+ SmallVector<DynamicAPInt> copy =
+ SmallVector<DynamicAPInt>(rel.getInequality(i));
+ copy.insert(copy.begin() + srcOffset, newNumRangeVars, DynamicAPInt(0));
+ result.addInequality(copy);
+ }
+
+ return result;
+}
+
void IntegerRelation::printSpace(raw_ostream &os) const {
space.print(os);
os << getNumConstraints() << " constraints\n";
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 7df500bc9568a..dd8b9e3f03330 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -608,3 +608,17 @@ TEST(IntegerRelationTest, convertVarKindToLocal) {
EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
}
+
+TEST(IntegerRelationTest, rangeProduct) {
+ IntegerRelation r1 = parseRelationFromSet(
+ "(i, j, k) : (2*i + 3*k == 0, i >= 0, j >= 0, k >= 0)", 2);
+ IntegerRelation r2 = parseRelationFromSet(
+ "(i, j, l) : (4*i + 6*j + 9*l == 0, i >= 0, j >= 0, l >= 0)", 2);
+
+ IntegerRelation rangeProd = r1.rangeProduct(r2);
+ IntegerRelation expected = parseRelationFromSet(
+ "(i, j, k, l) : (2*i + 3*k == 0, 4*i + 6*j + 9*l == 0, i >= 0, j >= 0, k >= 0, l >= 0)", 2);
+
+ EXPECT_TRUE(expected.isEqual(rangeProd));
+}
+
|
|
You can test this locally with the following command:git-clang-format --diff HEAD~1 HEAD --extensions cpp,h -- mlir/include/mlir/Analysis/Presburger/IntegerRelation.h mlir/lib/Analysis/Presburger/IntegerRelation.cpp mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp View the diff from clang-format here.diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index dd8b9e3f0..db2460011 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -616,9 +616,10 @@ TEST(IntegerRelationTest, rangeProduct) {
"(i, j, l) : (4*i + 6*j + 9*l == 0, i >= 0, j >= 0, l >= 0)", 2);
IntegerRelation rangeProd = r1.rangeProduct(r2);
- IntegerRelation expected = parseRelationFromSet(
- "(i, j, k, l) : (2*i + 3*k == 0, 4*i + 6*j + 9*l == 0, i >= 0, j >= 0, k >= 0, l >= 0)", 2);
+ IntegerRelation expected =
+ parseRelationFromSet("(i, j, k, l) : (2*i + 3*k == 0, 4*i + 6*j + 9*l == "
+ "0, i >= 0, j >= 0, k >= 0, l >= 0)",
+ 2);
EXPECT_TRUE(expected.isEqual(rangeProd));
}
-
|
I didn't read any implementation details yet |
This is intended to match
isl::map
'sflat_range_product
.I'd like to add some more tests, so hoping for a brief early review to make sure I'm going in the right direction.