@@ -40,9 +40,13 @@ namespace tir {
4040
4141struct LoopPartitionConfigNode : public tvm ::AttrsNode<LoopPartitionConfigNode> {
4242 bool partition_const_loop;
43+ bool no_unroll_loop_with_extent_one;
4344
4445 TVM_DECLARE_ATTRS (LoopPartitionConfigNode, " tir.transform.LoopPartitionConfig" ) {
4546 TVM_ATTR_FIELD (partition_const_loop).describe (" Split constant loop" ).set_default (false );
47+ TVM_ATTR_FIELD (no_unroll_loop_with_extent_one)
48+ .describe (" Don't unroll loops with extent 1" )
49+ .set_default (false );
4650 }
4751};
4852
@@ -334,8 +338,9 @@ class ThreadPartitionInserter : public StmtMutator {
334338// likely conditions
335339class LoopPartitioner : public StmtMutator {
336340 public:
337- explicit LoopPartitioner (bool partition_const_loop)
338- : selector(CandidateSelector(partition_const_loop)) {}
341+ explicit LoopPartitioner (bool partition_const_loop, bool no_unroll_loop_with_extent_one)
342+ : selector(CandidateSelector(partition_const_loop)),
343+ no_unroll_loop_with_extent_one_(no_unroll_loop_with_extent_one) {}
339344
340345 Stmt VisitAndMutate (Stmt stmt) {
341346 selector (stmt);
@@ -402,6 +407,7 @@ class LoopPartitioner : public StmtMutator {
402407 std::unordered_map<const VarNode*, IntSet> relax_map_;
403408 arith::Analyzer analyzer_;
404409 CandidateSelector selector;
410+ bool no_unroll_loop_with_extent_one_;
405411};
406412
407413// Returns an interval (in the first component) in which all the conditions
@@ -596,7 +602,8 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
596602inline Stmt LoopPartitioner::MakeFor (const Object* node, PrimExpr extent, Stmt body) {
597603 const ForNode* for_node = static_cast <const ForNode*>(node);
598604 ICHECK (for_node);
599- if (analyzer_.CanProve (extent == make_const (DataType::Int (32 ), 1 ))) {
605+ if (analyzer_.CanProve (extent == make_const (DataType::Int (32 ), 1 )) &&
606+ !no_unroll_loop_with_extent_one_) {
600607 // If the loop extent is 1, do not create the loop anymore
601608 return Substitute (body, {{Var{for_node->loop_var }, make_const (DataType::Int (32 ), 0 )}});
602609 } else {
@@ -617,8 +624,9 @@ class RemoveLikelyTags : public StmtExprMutator {
617624 }
618625};
619626
620- Stmt LoopPartition (Stmt stmt, bool partition_const_loop) {
621- stmt = LoopPartitioner (partition_const_loop).VisitAndMutate (std::move (stmt));
627+ Stmt LoopPartition (Stmt stmt, bool partition_const_loop, bool no_unroll_loop_with_extent_one) {
628+ stmt = LoopPartitioner (partition_const_loop, no_unroll_loop_with_extent_one)
629+ .VisitAndMutate (std::move (stmt));
622630 stmt = RemoveLikelyTags ()(std::move (stmt));
623631 return stmt;
624632}
@@ -632,7 +640,8 @@ Pass LoopPartition() {
632640 if (!cfg.defined ()) {
633641 cfg = AttrsWithDefaultValues<LoopPartitionConfig>();
634642 }
635- n->body = LoopPartition (std::move (n->body ), cfg.value ()->partition_const_loop );
643+ n->body = LoopPartition (std::move (n->body ), cfg.value ()->partition_const_loop ,
644+ cfg.value ()->no_unroll_loop_with_extent_one );
636645 return f;
637646 };
638647 return CreatePrimFuncPass (pass_func, 0 , " tir.LoopPartition" , {});
0 commit comments