@@ -2534,21 +2534,20 @@ OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
25342534
25352535static LogicalResult VerifyWhileTypes (Operation *op, TypeRange cond_input,
25362536 TypeRange body_input,
2537- TypeRange body_result) {
2538- // Collect all the type lists for the op so that different pairs of type lists
2539- // can be compared for the compatibility.
2540- constexpr int kNumTypeLists = 5 ;
2541- const std::array<TypeRangeWithDesc, kNumTypeLists > type_lists = {{
2542- {op-> getOperandTypes (), " input " },
2537+ TypeRange body_result,
2538+ bool shape_invariant) {
2539+ const TypeRangeWithDesc input_type = {op-> getOperandTypes (), " input " };
2540+ const TypeRangeWithDesc result_type = {op-> getResultTypes (), " result " } ;
2541+ constexpr int kNumRegionTypeLists = 3 ;
2542+ const std::array<TypeRangeWithDesc, kNumRegionTypeLists > region_types = {{
25432543 {body_result, " body result" },
2544- {op->getResultTypes (), " result" },
25452544 {cond_input, " condition input" },
25462545 {body_input, " body input" },
25472546 }};
25482547
25492548 // A pair of type lists should be cast compatible with each other if one is
25502549 // converted to the another for a function call or assignment or there is a
2551- // common source of inputs for both. Therefore, the While op requires the
2550+ // common source of inputs for both. Therefore, the While op requires the
25522551 // following pairs of type lists to be cast compatible for the tensor_cast
25532552 // operation:
25542553 //
@@ -2557,7 +2556,8 @@ static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
25572556 // * Operands and body inputs to call the body function for the first
25582557 // iteration if the cond functions returns True or equivalent result.
25592558 // * Operands and results to assign cond function arguments to op results if
2560- // the cond function returns False or equivalent result.
2559+ // the cond function returns False or equivalent result. If the op is shape
2560+ // invariant, this does not hold as shapes can differ.
25612561 // * All three pairs using cond inputs, body inputs and results as operand is
25622562 // a common source for all three.
25632563 // * Body result and cond inputs to call the cond function for the subsequent
@@ -2566,17 +2566,28 @@ static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
25662566 //
25672567 // Note that the operands and body results need not be compatible as they are
25682568 // never converted from one to the another nor there is a common source
2569- // tensors. Compatibility requirement is not transitive.
2570-
2571- for (int i = 0 ; i < kNumTypeLists ; ++i) {
2572- // Skip the first pair as the While op operands and body function results
2573- // does not need to be compatible with each other.
2574- for (int j = std::max (2 , i + 1 ); j < kNumTypeLists ; ++j) {
2575- auto &a = type_lists[i];
2576- auto &b = type_lists[j];
2577- if (failed (VerifyTypeRangesAreCompatible (op, a, b))) return failure ();
2578- }
2579- }
2569+ // tensors. Compatibility requirement is not transitive.
2570+
2571+ if (!shape_invariant &&
2572+ failed (VerifyTypeRangesAreCompatible (op, input_type, result_type)))
2573+ return failure ();
2574+
2575+ // Skip the first pair as the While op operands and body function results does
2576+ // not need to be compatible with each other.
2577+ for (int i = 1 ; i < kNumRegionTypeLists ; ++i)
2578+ if (failed (VerifyTypeRangesAreCompatible (op, input_type, region_types[i])))
2579+ return failure ();
2580+
2581+ for (int i = 0 ; i < kNumRegionTypeLists ; ++i)
2582+ if (failed (VerifyTypeRangesAreCompatible (op, result_type, region_types[i])))
2583+ return failure ();
2584+
2585+ for (int i = 0 ; i < kNumRegionTypeLists ; ++i)
2586+ for (int j = i + 1 ; j < kNumRegionTypeLists ; ++j)
2587+ if (failed (VerifyTypeRangesAreCompatible (op, region_types[i],
2588+ region_types[j])))
2589+ return failure ();
2590+
25802591 return success ();
25812592}
25822593
@@ -2601,7 +2612,8 @@ static LogicalResult Verify(WhileOp op) {
26012612
26022613 if (failed (VerifyWhileTypes (op, /* cond_input=*/ cond_fn_type.getInputs (),
26032614 /* body_input=*/ body_fn_type.getInputs (),
2604- /* body_result=*/ body_fn_type.getResults ())))
2615+ /* body_result=*/ body_fn_type.getResults (),
2616+ op.shape_invariant ())))
26052617 return failure ();
26062618 return success ();
26072619}
@@ -2626,7 +2638,8 @@ static LogicalResult Verify(WhileRegionOp op) {
26262638 Operation *body_yield = op.body ().front ().getTerminator ();
26272639 if (failed (VerifyWhileTypes (op, /* cond_input=*/ op.cond ().getArgumentTypes (),
26282640 /* body_input=*/ op.body ().getArgumentTypes (),
2629- /* body_result=*/ body_yield->getOperandTypes ())))
2641+ /* body_result=*/ body_yield->getOperandTypes (),
2642+ op.shape_invariant ())))
26302643 return failure ();
26312644 return success ();
26322645}
0 commit comments