Skip to content

[mlir] [presburger] Add IntegerRelation::rangeProduct #148092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

j2kun
Copy link
Contributor

@j2kun j2kun commented Jul 11, 2025

This is intended to match isl::map's flat_range_product.

I'd like to add some more tests, so hoping for a brief early review to make sure I'm going in the right direction.

This is intended to match `isl::map`'s `flat_range_product`.
@j2kun j2kun requested review from Groverkss and Superty as code owners July 11, 2025 00:27
@j2kun j2kun changed the title Add IntegerRelation::rangeProduct [mlir] [presburger] Add IntegerRelation::rangeProduct Jul 11, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 11, 2025

@llvm/pr-subscribers-mlir

Author: Jeremy Kun (j2kun)

Changes

This is intended to match isl::map's flat_range_product.

I'd like to add some more tests, so hoping for a brief early review to make sure I'm going in the right direction.


Full diff: https://github.com/llvm/llvm-project/pull/148092.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Analysis/Presburger/IntegerRelation.h (+15)
  • (modified) mlir/lib/Analysis/Presburger/IntegerRelation.cpp (+35)
  • (modified) mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp (+14)
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index b68262f09f485..ae213181f2c3b 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -707,6 +707,21 @@ class IntegerRelation {
   /// this for uniformity with `applyDomain`.
   void applyRange(const IntegerRelation &rel);
 
+  /// Let the relation `this` be R1, and the relation `rel` be R2. Requires
+  /// R1 and R2 to have the same domain.
+  ///
+  /// This operation computes the relation whose domain is the same as R1 and
+  /// whose range is the product of the ranges of R1 and R2, and whose
+  /// constraints are the conjunction of the constraints of R1 and R2 applied
+  /// to the relevant subspaces of the range.
+  ///
+  /// Example:
+  ///
+  /// R1: (i, j) -> k : f(i, j, k) = 0
+  /// R2: (i, j) -> l : g(i, j, l) = 0
+  /// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0
+  IntegerRelation rangeProduct(const IntegerRelation &rel);
+
   /// Given a relation `other: (A -> B)`, this operation merges the symbol and
   /// local variables and then takes the composition of `other` on `this: (B ->
   /// C)`. The resulting relation represents tuples of the form: `A -> C`.
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 631e085574fd0..87cec81e86b59 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -2481,6 +2481,41 @@ void IntegerRelation::applyDomain(const IntegerRelation &rel) {
 
 void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); }
 
+IntegerRelation IntegerRelation::rangeProduct(const IntegerRelation &rel) {
+  /// R1: (i, j) -> k : f(i, j, k) = 0
+  /// R2: (i, j) -> l : g(i, j, l) = 0
+  /// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0
+  assert(getNumDomainVars() == rel.getNumDomainVars() &&
+         "Range product is only defined for relations with equal domains");
+
+  // explicit copy of the context relation
+  IntegerRelation result = *this;
+  unsigned srcOffset = getVarKindOffset(VarKind::Range);
+  unsigned newNumRangeVars = rel.getNumRangeVars();
+
+  result.appendVar(VarKind::Range, newNumRangeVars);
+
+  for (unsigned i = 0; i < rel.getNumEqualities(); ++i) {
+    // Add a new equality that uses the new range variables.
+    // The old equality is a list of coefficients of the variables
+    // from `rel`, and so the range variables need to be shifted
+    // right by the number of range variables added to `result`.
+    SmallVector<DynamicAPInt> copy =
+        SmallVector<DynamicAPInt>(rel.getEquality(i));
+    copy.insert(copy.begin() + srcOffset, newNumRangeVars, DynamicAPInt(0));
+    result.addEquality(copy);
+  }
+
+  for (unsigned i = 0; i < rel.getNumInequalities(); ++i) {
+    SmallVector<DynamicAPInt> copy =
+        SmallVector<DynamicAPInt>(rel.getInequality(i));
+    copy.insert(copy.begin() + srcOffset, newNumRangeVars, DynamicAPInt(0));
+    result.addInequality(copy);
+  }
+
+  return result;
+}
+
 void IntegerRelation::printSpace(raw_ostream &os) const {
   space.print(os);
   os << getNumConstraints() << " constraints\n";
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 7df500bc9568a..dd8b9e3f03330 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -608,3 +608,17 @@ TEST(IntegerRelationTest, convertVarKindToLocal) {
   EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
   EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
 }
+
+TEST(IntegerRelationTest, rangeProduct) {
+  IntegerRelation r1 = parseRelationFromSet(
+      "(i, j, k) : (2*i + 3*k == 0, i >= 0, j >= 0, k >= 0)", 2);
+  IntegerRelation r2 = parseRelationFromSet(
+      "(i, j, l) : (4*i + 6*j + 9*l == 0, i >= 0, j >= 0, l >= 0)", 2);
+
+  IntegerRelation rangeProd = r1.rangeProduct(r2);
+  IntegerRelation expected = parseRelationFromSet(
+      "(i, j, k, l) : (2*i + 3*k == 0, 4*i + 6*j + 9*l == 0, i >= 0, j >= 0, k >= 0, l >= 0)", 2);
+
+  EXPECT_TRUE(expected.isEqual(rangeProd));
+}
+

@llvmbot
Copy link
Member

llvmbot commented Jul 11, 2025

@llvm/pr-subscribers-mlir-presburger

Author: Jeremy Kun (j2kun)

Changes

This is intended to match isl::map's flat_range_product.

I'd like to add some more tests, so hoping for a brief early review to make sure I'm going in the right direction.


Full diff: https://github.com/llvm/llvm-project/pull/148092.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Analysis/Presburger/IntegerRelation.h (+15)
  • (modified) mlir/lib/Analysis/Presburger/IntegerRelation.cpp (+35)
  • (modified) mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp (+14)
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index b68262f09f485..ae213181f2c3b 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -707,6 +707,21 @@ class IntegerRelation {
   /// this for uniformity with `applyDomain`.
   void applyRange(const IntegerRelation &rel);
 
+  /// Let the relation `this` be R1, and the relation `rel` be R2. Requires
+  /// R1 and R2 to have the same domain.
+  ///
+  /// This operation computes the relation whose domain is the same as R1 and
+  /// whose range is the product of the ranges of R1 and R2, and whose
+  /// constraints are the conjunction of the constraints of R1 and R2 applied
+  /// to the relevant subspaces of the range.
+  ///
+  /// Example:
+  ///
+  /// R1: (i, j) -> k : f(i, j, k) = 0
+  /// R2: (i, j) -> l : g(i, j, l) = 0
+  /// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0
+  IntegerRelation rangeProduct(const IntegerRelation &rel);
+
   /// Given a relation `other: (A -> B)`, this operation merges the symbol and
   /// local variables and then takes the composition of `other` on `this: (B ->
   /// C)`. The resulting relation represents tuples of the form: `A -> C`.
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 631e085574fd0..87cec81e86b59 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -2481,6 +2481,41 @@ void IntegerRelation::applyDomain(const IntegerRelation &rel) {
 
 void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); }
 
+IntegerRelation IntegerRelation::rangeProduct(const IntegerRelation &rel) {
+  /// R1: (i, j) -> k : f(i, j, k) = 0
+  /// R2: (i, j) -> l : g(i, j, l) = 0
+  /// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0
+  assert(getNumDomainVars() == rel.getNumDomainVars() &&
+         "Range product is only defined for relations with equal domains");
+
+  // explicit copy of the context relation
+  IntegerRelation result = *this;
+  unsigned srcOffset = getVarKindOffset(VarKind::Range);
+  unsigned newNumRangeVars = rel.getNumRangeVars();
+
+  result.appendVar(VarKind::Range, newNumRangeVars);
+
+  for (unsigned i = 0; i < rel.getNumEqualities(); ++i) {
+    // Add a new equality that uses the new range variables.
+    // The old equality is a list of coefficients of the variables
+    // from `rel`, and so the range variables need to be shifted
+    // right by the number of range variables added to `result`.
+    SmallVector<DynamicAPInt> copy =
+        SmallVector<DynamicAPInt>(rel.getEquality(i));
+    copy.insert(copy.begin() + srcOffset, newNumRangeVars, DynamicAPInt(0));
+    result.addEquality(copy);
+  }
+
+  for (unsigned i = 0; i < rel.getNumInequalities(); ++i) {
+    SmallVector<DynamicAPInt> copy =
+        SmallVector<DynamicAPInt>(rel.getInequality(i));
+    copy.insert(copy.begin() + srcOffset, newNumRangeVars, DynamicAPInt(0));
+    result.addInequality(copy);
+  }
+
+  return result;
+}
+
 void IntegerRelation::printSpace(raw_ostream &os) const {
   space.print(os);
   os << getNumConstraints() << " constraints\n";
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 7df500bc9568a..dd8b9e3f03330 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -608,3 +608,17 @@ TEST(IntegerRelationTest, convertVarKindToLocal) {
   EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
   EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
 }
+
+TEST(IntegerRelationTest, rangeProduct) {
+  IntegerRelation r1 = parseRelationFromSet(
+      "(i, j, k) : (2*i + 3*k == 0, i >= 0, j >= 0, k >= 0)", 2);
+  IntegerRelation r2 = parseRelationFromSet(
+      "(i, j, l) : (4*i + 6*j + 9*l == 0, i >= 0, j >= 0, l >= 0)", 2);
+
+  IntegerRelation rangeProd = r1.rangeProduct(r2);
+  IntegerRelation expected = parseRelationFromSet(
+      "(i, j, k, l) : (2*i + 3*k == 0, 4*i + 6*j + 9*l == 0, i >= 0, j >= 0, k >= 0, l >= 0)", 2);
+
+  EXPECT_TRUE(expected.isEqual(rangeProd));
+}
+

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Developer Policy and LLVM Discourse for more information.

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff HEAD~1 HEAD --extensions cpp,h -- mlir/include/mlir/Analysis/Presburger/IntegerRelation.h mlir/lib/Analysis/Presburger/IntegerRelation.cpp mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
View the diff from clang-format here.
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index dd8b9e3f0..db2460011 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -616,9 +616,10 @@ TEST(IntegerRelationTest, rangeProduct) {
       "(i, j, l) : (4*i + 6*j + 9*l == 0, i >= 0, j >= 0, l >= 0)", 2);
 
   IntegerRelation rangeProd = r1.rangeProduct(r2);
-  IntegerRelation expected = parseRelationFromSet(
-      "(i, j, k, l) : (2*i + 3*k == 0, 4*i + 6*j + 9*l == 0, i >= 0, j >= 0, k >= 0, l >= 0)", 2);
+  IntegerRelation expected =
+      parseRelationFromSet("(i, j, k, l) : (2*i + 3*k == 0, 4*i + 6*j + 9*l == "
+                           "0, i >= 0, j >= 0, k >= 0, l >= 0)",
+                           2);
 
   EXPECT_TRUE(expected.isEqual(rangeProd));
 }
-

@Xinlong-Wu
Copy link
Contributor

Xinlong-Wu commented Jul 11, 2025

I didn't read any implementation details yet
I think what you need to do should be add more test case and format the code before commit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants