Skip to content

Commit 3c572a3

Browse files
andylytensorflower-gardener
authored andcommitted
Update tf.While/tf.WhileRegion shape inference to support different operand and result shapes.
tf.While/tf.WhileRegion supports changing shapes in each iteration of its loop body. That case currently can be checked via the `shape_invariants` attribute, and special handling is necessary to refining shapes and propagating handle types. PiperOrigin-RevId: 344255608 Change-Id: Ibcca0c48b7fd6e305a2b38b9261dfcbcb03815ba
1 parent 69e92a3 commit 3c572a3

File tree

2 files changed

+233
-18
lines changed

2 files changed

+233
-18
lines changed

tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,4 +622,99 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
622622
}
623623
return %0 : tensor<*xi32>
624624
}
625+
626+
// Test shape invariant While only propagates operand handle types into
627+
// results and functions/regions.
628+
// CHECK-LABEL: func @while_shape_invariant_propagate
629+
// CHECK-SAME: ({{%.+}}: tensor<4xf32>, {{%.+}}: tensor<!tf.resource<tensor<4xf32>>>, {{%.+}}: tensor<!tf.resource<tensor<8xf32>>>, {{%.+}}: tensor<1xi32>)
630+
// CHECK-SAME: -> (tensor<*xf32>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<8xf32>>>, tensor<?xi32>, tensor<*xf32>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<8xf32>>>, tensor<?xi32>)
631+
func @while_shape_invariant_propagate(%arg0: tensor<4xf32>, %arg1: tensor<!tf.resource<tensor<4xf32>>>, %arg2: tensor<!tf.resource<tensor<8xf32>>>, %arg3: tensor<1xi32>) -> (tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource>, tensor<?xi32>, tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource>, tensor<?xi32>) {
632+
// CHECK: "tf.While"
633+
// CHECK-SAME: (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<8xf32>>>, tensor<1xi32>)
634+
// CHECK-SAME: -> (tensor<*xf32>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<8xf32>>>, tensor<?xi32>)
635+
%0:4 = "tf.While"(%arg0, %arg1, %arg2, %arg3) {cond = @while_shape_invariant_func_propagate, body = @while_shape_invariant_body_func_propagate, is_stateless = false, shape_invariant} : (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<8xf32>>>, tensor<1xi32>) -> (tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource>, tensor<?xi32>)
636+
637+
// CHECK: "tf.WhileRegion"
638+
%1:4 = "tf.WhileRegion"(%arg0, %arg1, %arg2, %arg3) ( {
639+
// CHECK-NEXT: ^{{.+}}({{%.+}}: tensor<*xf32>, {{%.+}}: tensor<*x!tf.resource<tensor<4xf32>>>, {{%.+}}: tensor<!tf.resource<tensor<8xf32>>>, {{%.+}}: tensor<?xi32>):
640+
^cond(%carg0: tensor<*xf32>, %carg1: tensor<*x!tf.resource>, %carg2: tensor<!tf.resource>, %carg3: tensor<?xi32>):
641+
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
642+
"tf.Yield"(%2) : (tensor<i1>) -> ()
643+
}, {
644+
// CHECK: ^{{.+}}({{%.+}}: tensor<*xf32>, {{%.+}}: tensor<*x!tf.resource<tensor<4xf32>>>, {{%.+}}: tensor<!tf.resource<tensor<8xf32>>>, {{%.+}}: tensor<?xi32>):
645+
^body(%barg0: tensor<*xf32>, %barg1: tensor<*x!tf.resource>, %barg2: tensor<!tf.resource>, %barg3: tensor<?xi32>):
646+
%2 = "tf.SomeOp"(%barg3) : (tensor<?xi32>) -> tensor<?xi32>
647+
// CHECK: "tf.Yield"
648+
// CHECK-SAME: (tensor<*xf32>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<8xf32>>>, tensor<?xi32>) -> ()
649+
"tf.Yield"(%barg0, %barg1, %barg2, %2) : (tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource>, tensor<?xi32>) -> ()
650+
// CHECK-NEXT: shape_invariant
651+
// CHECK-SAME: (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<8xf32>>>, tensor<1xi32>)
652+
// CHECK-SAME: -> (tensor<*xf32>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<8xf32>>>, tensor<?xi32>)
653+
}) {is_stateless = false, shape_invariant} : (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<8xf32>>>, tensor<1xi32>) -> (tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource>, tensor<?xi32>)
654+
655+
return %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, %1#2, %1#3 : tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource>, tensor<?xi32>, tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource>, tensor<?xi32>
656+
}
657+
658+
// CHECK-LABEL: func @while_shape_invariant_func_propagate
659+
// CHECK-SAME: ({{%.+}}: tensor<*xf32>, {{%.+}}: tensor<*x!tf.resource<tensor<4xf32>>>, {{%.+}}: tensor<!tf.resource<tensor<8xf32>>>, {{%.+}}: tensor<?xi32>)
660+
// CHECK-SAME: -> tensor<i1>
661+
func @while_shape_invariant_func_propagate(%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>, %arg2: tensor<!tf.resource>, %arg3: tensor<?xi32>) -> tensor<i1> {
662+
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
663+
return %0 : tensor<i1>
664+
}
665+
666+
// CHECK-LABEL: func @while_shape_invariant_body_func_propagate
667+
// CHECK-SAME: ({{%.+}}: tensor<*xf32>, {{%.+}}: tensor<*x!tf.resource<tensor<4xf32>>>, {{%.+}}: tensor<!tf.resource<tensor<8xf32>>>, {{%.+}}: tensor<?xi32>)
668+
// CHECK-SAME: -> (tensor<*xf32>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<8xf32>>>, tensor<?xi32>)
669+
func @while_shape_invariant_body_func_propagate(%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>, %arg2: tensor<!tf.resource>, %arg3: tensor<?xi32>) -> (tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource>, tensor<?xi32>) {
670+
%0 = "tf.SomeOp"(%arg3) : (tensor<?xi32>) -> tensor<?xi32>
671+
return %arg0, %arg1, %arg2, %0 : tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource>, tensor<?xi32>
672+
}
673+
674+
// Test shape invariant While with result type refinement.
675+
// CHECK-LABEL: func @while_shape_invariant_refine
676+
// CHECK-SAME: ({{%.+}}: tensor<2xi32>, {{%.+}}: tensor<8xf32>, {{%.+}}: tensor<?xi1>)
677+
// CHECK-SAME: -> (tensor<2xi32>, tensor<8xf32>, tensor<?xi1>, tensor<2xi32>, tensor<8xf32>, tensor<?xi1>)
678+
func @while_shape_invariant_refine(%arg0: tensor<2xi32>, %arg1: tensor<8xf32>, %arg2: tensor<?xi1>) -> (tensor<?xi32>, tensor<*xf32>, tensor<*xi1>, tensor<?xi32>, tensor<*xf32>, tensor<*xi1>) {
679+
// CHECK: "tf.While"
680+
// CHECK-SAME: (tensor<2xi32>, tensor<8xf32>, tensor<?xi1>)
681+
// CHECK-SAME: -> (tensor<2xi32>, tensor<8xf32>, tensor<?xi1>)
682+
%0:3 = "tf.While"(%arg0, %arg1, %arg2) {cond = @while_shape_invariant_func_refine, body = @while_shape_invariant_body_func_refine, is_stateless = false, shape_invariant} : (tensor<2xi32>, tensor<8xf32>, tensor<?xi1>) -> (tensor<?xi32>, tensor<*xf32>, tensor<*xi1>)
683+
684+
// CHECK: "tf.WhileRegion"
685+
%1:3 = "tf.WhileRegion"(%arg0, %arg1, %arg2) ( {
686+
// CHECK-NEXT: ^{{.+}}({{%.+}}: tensor<2xi32>, {{%.+}}: tensor<8xf32>, {{%.+}}: tensor<?xi1>):
687+
^cond(%carg0: tensor<2xi32>, %carg1: tensor<8xf32>, %carg2: tensor<?xi1>):
688+
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
689+
"tf.Yield"(%2) : (tensor<i1>) -> ()
690+
}, {
691+
// CHECK: ^{{.+}}({{%.+}}: tensor<2xi32>, {{%.+}}: tensor<8xf32>, {{%.+}}: tensor<?xi1>):
692+
^body(%barg0: tensor<2xi32>, %barg1: tensor<8xf32>, %barg2: tensor<?xi1>):
693+
%2:3 = "tf.IdentityN"(%barg0, %barg1, %barg2) : (tensor<2xi32>, tensor<8xf32>, tensor<?xi1>) -> (tensor<?xi32>, tensor<*xf32>, tensor<*xi1>)
694+
// CHECK: "tf.Yield"
695+
// CHECK-SAME: (tensor<2xi32>, tensor<8xf32>, tensor<?xi1>) -> ()
696+
"tf.Yield"(%2#0, %2#1, %2#2) : (tensor<?xi32>, tensor<*xf32>, tensor<*xi1>) -> ()
697+
// CHECK-NEXT: shape_invariant
698+
// CHECK-SAME: (tensor<2xi32>, tensor<8xf32>, tensor<?xi1>)
699+
// CHECK-SAME: -> (tensor<2xi32>, tensor<8xf32>, tensor<?xi1>)
700+
}) {is_stateless = false, shape_invariant} : (tensor<2xi32>, tensor<8xf32>, tensor<?xi1>) -> (tensor<?xi32>, tensor<*xf32>, tensor<*xi1>)
701+
702+
return %0#0, %0#1, %0#2, %1#0, %1#1, %1#2 : tensor<?xi32>, tensor<*xf32>, tensor<*xi1>, tensor<?xi32>, tensor<*xf32>, tensor<*xi1>
703+
}
704+
705+
// CHECK-LABEL: func @while_shape_invariant_func_refine
706+
// CHECK-SAME: ({{%.+}}: tensor<2xi32>, {{%.+}}: tensor<8xf32>, {{%.+}}: tensor<?xi1>)
707+
// CHECK-SAME: -> tensor<i1>
708+
func @while_shape_invariant_func_refine(%arg0: tensor<2xi32>, %arg1: tensor<8xf32>, %arg2: tensor<?xi1>) -> tensor<i1> {
709+
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
710+
return %0 : tensor<i1>
711+
}
712+
713+
// CHECK-LABEL: func @while_shape_invariant_body_func_refine
714+
// CHECK-SAME: ({{%.+}}: tensor<2xi32>, {{%.+}}: tensor<8xf32>, {{%.+}}: tensor<?xi1>)
715+
// CHECK-SAME: -> (tensor<2xi32>, tensor<8xf32>, tensor<?xi1>)
716+
func @while_shape_invariant_body_func_refine(%arg0: tensor<2xi32>, %arg1: tensor<8xf32>, %arg2: tensor<?xi1>) -> (tensor<?xi32>, tensor<*xf32>, tensor<*xi1>) {
717+
%0:3 = "tf.IdentityN"(%arg0, %arg1, %arg2) : (tensor<2xi32>, tensor<8xf32>, tensor<?xi1>) -> (tensor<?xi32>, tensor<*xf32>, tensor<*xi1>)
718+
return %0#0, %0#1, %0#2 : tensor<?xi32>, tensor<*xf32>, tensor<*xi1>
719+
}
625720
}

tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc

Lines changed: 138 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,16 @@ class ShapeInference {
279279
results_[value_port] = value;
280280
}
281281

282+
// Infers shape of tf.While/tf.WhileRegion. If `shape_invariant` attribute is
283+
// set, operand types are set as result types if associated body result types
284+
// match the operand type (does not change per loop iteration). If operand and
285+
// body result types are not the same, only handle types are propagated to
286+
// result types. This is necessary to not incorrectly change result shapes
287+
// when the While op will have a different result shape. Otherwise operand
288+
// shapes are propagated to result shapes.
289+
template <typename WhileOpTy>
290+
bool InferShapeForWhile(WhileOpTy op, TypeRange body_result_types);
291+
282292
// Performs shape inference on the provided op and return true if the type of
283293
// at least one result has been changed.
284294
// A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
@@ -298,16 +308,17 @@ class ShapeInference {
298308
// 1) They are never reused, ie. having a single use in module.
299309
// 2) Their input types match those of their parent ops (excluding inputs
300310
// like predicate).
301-
LogicalResult PropagateShapeToFunctions(
302-
ModuleOp module, Operation::operand_type_range input_types,
303-
ArrayRef<FuncOp> functions, int64_t max_iteration);
311+
LogicalResult PropagateShapeToFunctions(ModuleOp module,
312+
TypeRange input_types,
313+
ArrayRef<FuncOp> functions,
314+
int64_t max_iteration);
304315

305316
// Propagates shapes to regions given the shapes of the inputs of the regions.
306317
// All regions provided in `regions` are assumed to have inputs of type
307318
// `input_types`.
308-
LogicalResult PropagateShapeToRegions(
309-
Operation::operand_type_range input_types, ArrayRef<Region*> regions,
310-
int64_t max_iteration);
319+
LogicalResult PropagateShapeToRegions(TypeRange input_types,
320+
ArrayRef<Region*> regions,
321+
int64_t max_iteration);
311322

312323
// Shape propagation for call/control flow ops.
313324
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
@@ -757,15 +768,58 @@ bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) {
757768
return false;
758769
}
759770

771+
// Finds element type to be used for result from operand, with special handling
772+
// for handle types.
773+
Type GetElementTypeFromOperand(TensorType operand_type,
774+
TensorType result_type) {
775+
auto operand_handle_type =
776+
operand_type.getElementType().dyn_cast<TensorFlowTypeWithSubtype>();
777+
if (!operand_handle_type) return result_type.getElementType();
778+
auto result_handle_type =
779+
result_type.getElementType().cast<TensorFlowTypeWithSubtype>();
780+
if (operand_handle_type.GetSubtypes().empty() ||
781+
!result_handle_type.GetSubtypes().empty())
782+
return result_type.getElementType();
783+
return operand_handle_type;
784+
}
785+
786+
template <typename WhileOpTy>
787+
bool ShapeInference::InferShapeForWhile(WhileOpTy op,
788+
TypeRange body_result_types) {
789+
if (!op.shape_invariant())
790+
return RefineTypeForPassThroughOperands(op, op.input(), op.output());
791+
792+
bool changed = false;
793+
for (auto entry :
794+
zip(op.input().getTypes(), op.output(), body_result_types)) {
795+
auto operand_type = std::get<0>(entry).template cast<TensorType>();
796+
Value result = std::get<1>(entry);
797+
auto body_result_type = std::get<2>(entry).template cast<TensorType>();
798+
if (operand_type == body_result_type) {
799+
changed |= RefineResultType(op, result, operand_type);
800+
continue;
801+
}
802+
auto result_type = result.getType().cast<TensorType>();
803+
Type element_type = GetElementTypeFromOperand(operand_type, result_type);
804+
Type potential_refined_type;
805+
if (result_type.hasRank())
806+
potential_refined_type =
807+
RankedTensorType::get(result_type.getShape(), element_type);
808+
else
809+
potential_refined_type = UnrankedTensorType::get(element_type);
810+
changed |= RefineResultType(op, result, potential_refined_type);
811+
}
812+
return changed;
813+
}
814+
760815
bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
761816
LLVM_DEBUG(op->print(llvm::dbgs() << "InferShapeForSingleOperation for ");
762817
llvm::dbgs() << "\n");
763818
assert(tf_dialect_ == op->getDialect());
764819
// The shape function of these ops sometimes does not propagate subtypes
765820
// (handle shapes) for resource and variant types. We use a simple passthrough
766821
// to make sure they are preserved in the output.
767-
if (isa<TF::IdentityOp, TF::IdentityNOp, TF::ZerosLikeOp, TF::WhileOp,
768-
TF::WhileRegionOp>(op)) {
822+
if (isa<TF::IdentityOp, TF::IdentityNOp, TF::ZerosLikeOp>(op)) {
769823
return RefineTypeForPassThroughOperands(op, op->getOperands(),
770824
op->getResults());
771825
}
@@ -799,6 +853,15 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
799853
if (auto if_region = dyn_cast<IfRegionOp>(op))
800854
return InferShapeForIfRegion(if_region);
801855

856+
if (auto while_op = dyn_cast<WhileOp>(op))
857+
return InferShapeForWhile(while_op,
858+
while_op.body_function().getType().getResults());
859+
860+
if (auto while_region = dyn_cast<WhileRegionOp>(op))
861+
return InferShapeForWhile(
862+
while_region,
863+
while_region.body().front().getTerminator()->getOperandTypes());
864+
802865
// Return operand as a constant attribute.
803866
auto operand_as_constant_fn = [&](Value operand) {
804867
ValuePort vp(operand);
@@ -851,8 +914,8 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
851914
}
852915

853916
LogicalResult ShapeInference::PropagateShapeToFunctions(
854-
ModuleOp module, Operation::operand_type_range input_types,
855-
ArrayRef<FuncOp> functions, int64_t max_iteration) {
917+
ModuleOp module, TypeRange input_types, ArrayRef<FuncOp> functions,
918+
int64_t max_iteration) {
856919
bool all_succeeded = true;
857920
// If shape propagation fails for one function, return failure, but do not
858921
// early exit and attempt to propagate shapes for all provided functions to
@@ -885,9 +948,9 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
885948
return success(all_succeeded);
886949
}
887950

888-
LogicalResult ShapeInference::PropagateShapeToRegions(
889-
Operation::operand_type_range input_types, ArrayRef<Region*> regions,
890-
int64_t max_iteration) {
951+
LogicalResult ShapeInference::PropagateShapeToRegions(TypeRange input_types,
952+
ArrayRef<Region*> regions,
953+
int64_t max_iteration) {
891954
DCOMMENT("\tPropagating shapes to regions");
892955
bool all_succeeded = true;
893956
// If shape propagation fails for one region, return failure, but do not
@@ -965,23 +1028,68 @@ void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op,
9651028
}
9661029
}
9671030

1031+
// Finds compatible types to propagate into functions/regions of a shape variant
1032+
// tf.While/tf.WhileRegion. If operand and result types are the same, that type
1033+
// is returned. Otherwise functions/regions arguments are returned but with the
1034+
// handle type from the operand type.
1035+
// TODO(b/174145518): Support more granular shape refining of different shaped
1036+
// operands and results (e.g. if rank does not change or only some dimensions
1037+
// change).
1038+
llvm::SmallVector<Type, 4> GetWhileCompatibleTypes(
1039+
TypeRange operand_types, TypeRange result_types,
1040+
TypeRange region_argument_types) {
1041+
llvm::SmallVector<Type, 4> types;
1042+
types.reserve(operand_types.size());
1043+
for (auto entry :
1044+
llvm::zip(operand_types, result_types, region_argument_types)) {
1045+
Type operand_type = std::get<0>(entry);
1046+
Type result_type = std::get<1>(entry);
1047+
if (operand_type == result_type) {
1048+
types.push_back(operand_type);
1049+
} else {
1050+
auto region_argument_type = std::get<2>(entry).cast<TensorType>();
1051+
Type element_type = GetElementTypeFromOperand(
1052+
operand_type.cast<TensorType>(), region_argument_type);
1053+
Type potential_refined_type;
1054+
if (region_argument_type.hasRank())
1055+
potential_refined_type = RankedTensorType::get(
1056+
region_argument_type.getShape(), element_type);
1057+
else
1058+
potential_refined_type = UnrankedTensorType::get(element_type);
1059+
types.push_back(potential_refined_type);
1060+
}
1061+
}
1062+
return types;
1063+
}
1064+
9681065
LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
9691066
Operation* op, int64_t max_iteration) {
9701067
ModuleOp module = op->getParentOfType<ModuleOp>();
9711068
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
9721069
DCOMMENT("Propagating shapes into If");
9731070
return PropagateShapeToFunctions(
974-
module, drop_begin(if_op.getOperandTypes(), 1),
1071+
module, if_op.input().getTypes(),
9751072
{if_op.then_function(), if_op.else_function()}, max_iteration);
9761073
} else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
9771074
SmallVector<FuncOp, 4> branches;
9781075
case_op.get_branch_functions(branches);
979-
return PropagateShapeToFunctions(module,
980-
drop_begin(case_op.getOperandTypes(), 1),
1076+
return PropagateShapeToFunctions(module, case_op.input().getTypes(),
9811077
branches, max_iteration);
9821078
} else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
1079+
// If `shape_invariant` is set, operand shapes cannot be simply propagated
1080+
// to result shapes as the op may have different intermediate shapes (such
1081+
// While ops can have different result shapes from operand shapes).
1082+
// Compatible shapes must be determined before propagating them.
1083+
if (while_op.shape_invariant()) {
1084+
auto compatible_types = GetWhileCompatibleTypes(
1085+
while_op.input().getTypes(), while_op.output().getTypes(),
1086+
while_op.body_function().getType().getInputs());
1087+
return PropagateShapeToFunctions(
1088+
module, compatible_types,
1089+
{while_op.cond_function(), while_op.body_function()}, max_iteration);
1090+
}
9831091
return PropagateShapeToFunctions(
984-
module, while_op.getOperandTypes(),
1092+
module, while_op.input().getTypes(),
9851093
{while_op.cond_function(), while_op.body_function()}, max_iteration);
9861094
} else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
9871095
if (auto func = dyn_cast<FuncOp>(call_op.resolveCallable())) {
@@ -1004,7 +1112,19 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
10041112
LogicalResult ShapeInference::PropagateShapeIntoAttachedRegions(
10051113
Operation* op, int64_t max_iteration) {
10061114
if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) {
1007-
return PropagateShapeToRegions(while_op.getOperandTypes(),
1115+
// If `shape_invariant` is set, operand shapes cannot be simply propagated
1116+
// to result shapes as the op may have different intermediate shapes (such
1117+
// While ops can have different result shapes from operand shapes).
1118+
// Compatible shapes must be determined before propagating them.
1119+
if (while_op.shape_invariant()) {
1120+
auto compatible_types = GetWhileCompatibleTypes(
1121+
while_op.input().getTypes(), while_op.output().getTypes(),
1122+
while_op.body().getArgumentTypes());
1123+
return PropagateShapeToRegions(compatible_types,
1124+
{&while_op.cond(), &while_op.body()},
1125+
max_iteration);
1126+
}
1127+
return PropagateShapeToRegions(while_op.input().getTypes(),
10081128
{&while_op.cond(), &while_op.body()},
10091129
max_iteration);
10101130
}

0 commit comments

Comments
 (0)