Skip to content

Commit 6d71a32

Browse files
authored
[TIR] Make loop unrolling in LoopPartition optional (apache#6823)
* [TIR] Make loop unrolling in LoopPartition optional For certain analysis/tensorization, it can be useful to keep the loop structure when partitioning loops. The current behaviour removes For loops of length 1. This change introduces the option to preserve these loops with the 'unroll' flag.
1 parent 3bfe6d3 commit 6d71a32

File tree

2 files changed

+43
-6
lines changed

2 files changed

+43
-6
lines changed

src/tir/transforms/loop_partition.cc

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,13 @@ namespace tir {
4040

4141
struct LoopPartitionConfigNode : public tvm::AttrsNode<LoopPartitionConfigNode> {
4242
bool partition_const_loop;
43+
bool no_unroll_loop_with_extent_one;
4344

4445
TVM_DECLARE_ATTRS(LoopPartitionConfigNode, "tir.transform.LoopPartitionConfig") {
4546
TVM_ATTR_FIELD(partition_const_loop).describe("Split constant loop").set_default(false);
47+
TVM_ATTR_FIELD(no_unroll_loop_with_extent_one)
48+
.describe("Don't unroll loops with extent 1")
49+
.set_default(false);
4650
}
4751
};
4852

@@ -334,8 +338,9 @@ class ThreadPartitionInserter : public StmtMutator {
334338
// likely conditions
335339
class LoopPartitioner : public StmtMutator {
336340
public:
337-
explicit LoopPartitioner(bool partition_const_loop)
338-
: selector(CandidateSelector(partition_const_loop)) {}
341+
explicit LoopPartitioner(bool partition_const_loop, bool no_unroll_loop_with_extent_one)
342+
: selector(CandidateSelector(partition_const_loop)),
343+
no_unroll_loop_with_extent_one_(no_unroll_loop_with_extent_one) {}
339344

340345
Stmt VisitAndMutate(Stmt stmt) {
341346
selector(stmt);
@@ -402,6 +407,7 @@ class LoopPartitioner : public StmtMutator {
402407
std::unordered_map<const VarNode*, IntSet> relax_map_;
403408
arith::Analyzer analyzer_;
404409
CandidateSelector selector;
410+
bool no_unroll_loop_with_extent_one_;
405411
};
406412

407413
// Returns an interval (in the first component) in which all the conditions
@@ -596,7 +602,8 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
596602
inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt body) {
597603
const ForNode* for_node = static_cast<const ForNode*>(node);
598604
ICHECK(for_node);
599-
if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) {
605+
if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1)) &&
606+
!no_unroll_loop_with_extent_one_) {
600607
// If the loop extent is 1, do not create the loop anymore
601608
return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
602609
} else {
@@ -617,8 +624,9 @@ class RemoveLikelyTags : public StmtExprMutator {
617624
}
618625
};
619626

620-
Stmt LoopPartition(Stmt stmt, bool partition_const_loop) {
621-
stmt = LoopPartitioner(partition_const_loop).VisitAndMutate(std::move(stmt));
627+
Stmt LoopPartition(Stmt stmt, bool partition_const_loop, bool no_unroll_loop_with_extent_one) {
628+
stmt = LoopPartitioner(partition_const_loop, no_unroll_loop_with_extent_one)
629+
.VisitAndMutate(std::move(stmt));
622630
stmt = RemoveLikelyTags()(std::move(stmt));
623631
return stmt;
624632
}
@@ -632,7 +640,8 @@ Pass LoopPartition() {
632640
if (!cfg.defined()) {
633641
cfg = AttrsWithDefaultValues<LoopPartitionConfig>();
634642
}
635-
n->body = LoopPartition(std::move(n->body), cfg.value()->partition_const_loop);
643+
n->body = LoopPartition(std::move(n->body), cfg.value()->partition_const_loop,
644+
cfg.value()->no_unroll_loop_with_extent_one);
636645
return f;
637646
};
638647
return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {});

tests/python/unittest/test_tir_transform_loop_partition.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,34 @@ def test_const_loop():
6666
assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
6767

6868

69+
def test_no_unroll_loop():
70+
n = 21
71+
A = te.placeholder((n,), name="A")
72+
B = te.placeholder((n,), name="B")
73+
74+
T = te.compute((n,), lambda i: A[i] + B[i])
75+
s = te.create_schedule(T.op)
76+
xo, xi = s[T].split(T.op.axis[0], factor=4)
77+
78+
bounds = tvm.te.schedule.InferBound(s)
79+
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
80+
81+
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
82+
with tvm.transform.PassContext(
83+
config={
84+
"tir.LoopPartition": {
85+
"partition_const_loop": True,
86+
"no_unroll_loop_with_extent_one": True,
87+
}
88+
}
89+
):
90+
mod = tvm.tir.transform.LoopPartition()(mod)
91+
mod = tvm.tir.transform.Simplify()(mod)
92+
stmt = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
93+
94+
assert sum(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.For))) == 4
95+
96+
6997
def test_multi_loop():
7098
ib = tvm.tir.ir_builder.create()
7199
m = te.size_var("m")

0 commit comments

Comments
 (0)