@@ -279,6 +279,16 @@ class ShapeInference {
279279 results_[value_port] = value;
280280 }
281281
282+ // Infers shape of tf.While/tf.WhileRegion. If `shape_invariant` attribute is
283+ // set, operand types are set as result types if associated body result types
284+ // match the operand type (does not change per loop iteration). If operand and
285+ // body result types are not the same, only handle types are propagated to
286+ // result types. This is necessary to not incorrectly change result shapes
287+ // when the While op will have a different result shape. Otherwise operand
288+ // shapes are propagated to result shapes.
289+ template <typename WhileOpTy>
290+ bool InferShapeForWhile (WhileOpTy op, TypeRange body_result_types);
291+
282292 // Performs shape inference on the provided op and return true if the type of
283293 // at least one result has been changed.
284294 // A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
@@ -298,16 +308,17 @@ class ShapeInference {
298308 // 1) They are never reused, ie. having a single use in module.
299309 // 2) Their input types match those of their parent ops (excluding inputs
300310 // like predicate).
301- LogicalResult PropagateShapeToFunctions (
302- ModuleOp module , Operation::operand_type_range input_types,
303- ArrayRef<FuncOp> functions, int64_t max_iteration);
311+ LogicalResult PropagateShapeToFunctions (ModuleOp module ,
312+ TypeRange input_types,
313+ ArrayRef<FuncOp> functions,
314+ int64_t max_iteration);
304315
305316 // Propagates shapes to regions given the shapes of the inputs of the regions.
306317 // All regions provided in `regions` are assumed to have inputs of type
307318 // `input_types`.
308- LogicalResult PropagateShapeToRegions (
309- Operation::operand_type_range input_types, ArrayRef<Region*> regions,
310- int64_t max_iteration);
319+ LogicalResult PropagateShapeToRegions (TypeRange input_types,
320+ ArrayRef<Region*> regions,
321+ int64_t max_iteration);
311322
312323 // Shape propagation for call/control flow ops.
313324 LogicalResult PropagateShapeIntoAttachedFunctions (Operation* op,
@@ -757,15 +768,58 @@ bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) {
757768 return false ;
758769}
759770
771+ // Finds element type to be used for result from operand, with special handling
772+ // for handle types.
773+ Type GetElementTypeFromOperand (TensorType operand_type,
774+ TensorType result_type) {
775+ auto operand_handle_type =
776+ operand_type.getElementType ().dyn_cast <TensorFlowTypeWithSubtype>();
777+ if (!operand_handle_type) return result_type.getElementType ();
778+ auto result_handle_type =
779+ result_type.getElementType ().cast <TensorFlowTypeWithSubtype>();
780+ if (operand_handle_type.GetSubtypes ().empty () ||
781+ !result_handle_type.GetSubtypes ().empty ())
782+ return result_type.getElementType ();
783+ return operand_handle_type;
784+ }
785+
786+ template <typename WhileOpTy>
787+ bool ShapeInference::InferShapeForWhile (WhileOpTy op,
788+ TypeRange body_result_types) {
789+ if (!op.shape_invariant ())
790+ return RefineTypeForPassThroughOperands (op, op.input (), op.output ());
791+
792+ bool changed = false ;
793+ for (auto entry :
794+ zip (op.input ().getTypes (), op.output (), body_result_types)) {
795+ auto operand_type = std::get<0 >(entry).template cast <TensorType>();
796+ Value result = std::get<1 >(entry);
797+ auto body_result_type = std::get<2 >(entry).template cast <TensorType>();
798+ if (operand_type == body_result_type) {
799+ changed |= RefineResultType (op, result, operand_type);
800+ continue ;
801+ }
802+ auto result_type = result.getType ().cast <TensorType>();
803+ Type element_type = GetElementTypeFromOperand (operand_type, result_type);
804+ Type potential_refined_type;
805+ if (result_type.hasRank ())
806+ potential_refined_type =
807+ RankedTensorType::get (result_type.getShape (), element_type);
808+ else
809+ potential_refined_type = UnrankedTensorType::get (element_type);
810+ changed |= RefineResultType (op, result, potential_refined_type);
811+ }
812+ return changed;
813+ }
814+
760815bool ShapeInference::InferShapeForSingleOperation (Operation* op) {
761816 LLVM_DEBUG (op->print (llvm::dbgs () << " InferShapeForSingleOperation for " );
762817 llvm::dbgs () << " \n " );
763818 assert (tf_dialect_ == op->getDialect ());
764819 // The shape function of these ops sometimes does not propagate subtypes
765820 // (handle shapes) for resource and variant types. We use a simple passthrough
766821 // to make sure they are preserved in the output.
767- if (isa<TF::IdentityOp, TF::IdentityNOp, TF::ZerosLikeOp, TF::WhileOp,
768- TF::WhileRegionOp>(op)) {
822+ if (isa<TF::IdentityOp, TF::IdentityNOp, TF::ZerosLikeOp>(op)) {
769823 return RefineTypeForPassThroughOperands (op, op->getOperands (),
770824 op->getResults ());
771825 }
@@ -799,6 +853,15 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
799853 if (auto if_region = dyn_cast<IfRegionOp>(op))
800854 return InferShapeForIfRegion (if_region);
801855
856+ if (auto while_op = dyn_cast<WhileOp>(op))
857+ return InferShapeForWhile (while_op,
858+ while_op.body_function ().getType ().getResults ());
859+
860+ if (auto while_region = dyn_cast<WhileRegionOp>(op))
861+ return InferShapeForWhile (
862+ while_region,
863+ while_region.body ().front ().getTerminator ()->getOperandTypes ());
864+
802865 // Return operand as a constant attribute.
803866 auto operand_as_constant_fn = [&](Value operand) {
804867 ValuePort vp (operand);
@@ -851,8 +914,8 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
851914}
852915
853916LogicalResult ShapeInference::PropagateShapeToFunctions (
854- ModuleOp module , Operation::operand_type_range input_types,
855- ArrayRef<FuncOp> functions, int64_t max_iteration) {
917+ ModuleOp module , TypeRange input_types, ArrayRef<FuncOp> functions ,
918+ int64_t max_iteration) {
856919 bool all_succeeded = true ;
857920 // If shape propagation fails for one function, return failure, but do not
858921 // early exit and attempt to propagate shapes for all provided functions to
@@ -885,9 +948,9 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
885948 return success (all_succeeded);
886949}
887950
888- LogicalResult ShapeInference::PropagateShapeToRegions (
889- Operation::operand_type_range input_types, ArrayRef<Region*> regions,
890- int64_t max_iteration) {
951+ LogicalResult ShapeInference::PropagateShapeToRegions (TypeRange input_types,
952+ ArrayRef<Region*> regions,
953+ int64_t max_iteration) {
891954 DCOMMENT (" \t Propagating shapes to regions" );
892955 bool all_succeeded = true ;
893956 // If shape propagation fails for one region, return failure, but do not
@@ -965,23 +1028,68 @@ void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op,
9651028 }
9661029}
9671030
1031+ // Finds compatible types to propagate into functions/regions of a shape variant
1032+ // tf.While/tf.WhileRegion. If operand and result types are the same, that type
1033+ // is returned. Otherwise functions/regions arguments are returned but with the
1034+ // handle type from the operand type.
1035+ // TODO(b/174145518): Support more granular shape refining of different shaped
1036+ // operands and results (e.g. if rank does not change or only some dimensions
1037+ // change).
1038+ llvm::SmallVector<Type, 4 > GetWhileCompatibleTypes (
1039+ TypeRange operand_types, TypeRange result_types,
1040+ TypeRange region_argument_types) {
1041+ llvm::SmallVector<Type, 4 > types;
1042+ types.reserve (operand_types.size ());
1043+ for (auto entry :
1044+ llvm::zip (operand_types, result_types, region_argument_types)) {
1045+ Type operand_type = std::get<0 >(entry);
1046+ Type result_type = std::get<1 >(entry);
1047+ if (operand_type == result_type) {
1048+ types.push_back (operand_type);
1049+ } else {
1050+ auto region_argument_type = std::get<2 >(entry).cast <TensorType>();
1051+ Type element_type = GetElementTypeFromOperand (
1052+ operand_type.cast <TensorType>(), region_argument_type);
1053+ Type potential_refined_type;
1054+ if (region_argument_type.hasRank ())
1055+ potential_refined_type = RankedTensorType::get (
1056+ region_argument_type.getShape (), element_type);
1057+ else
1058+ potential_refined_type = UnrankedTensorType::get (element_type);
1059+ types.push_back (potential_refined_type);
1060+ }
1061+ }
1062+ return types;
1063+ }
1064+
9681065LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions (
9691066 Operation* op, int64_t max_iteration) {
9701067 ModuleOp module = op->getParentOfType <ModuleOp>();
9711068 if (auto if_op = dyn_cast<TF::IfOp>(op)) {
9721069 DCOMMENT (" Propagating shapes into If" );
9731070 return PropagateShapeToFunctions (
974- module , drop_begin ( if_op.getOperandTypes (), 1 ),
1071+ module , if_op.input (). getTypes ( ),
9751072 {if_op.then_function (), if_op.else_function ()}, max_iteration);
9761073 } else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
9771074 SmallVector<FuncOp, 4 > branches;
9781075 case_op.get_branch_functions (branches);
979- return PropagateShapeToFunctions (module ,
980- drop_begin (case_op.getOperandTypes (), 1 ),
1076+ return PropagateShapeToFunctions (module , case_op.input ().getTypes (),
9811077 branches, max_iteration);
9821078 } else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
1079+ // If `shape_invariant` is set, operand shapes cannot be simply propagated
1080+ // to result shapes as the op may have different intermediate shapes (such
1081+ // While ops can have different result shapes from operand shapes).
1082+ // Compatible shapes must be determined before propagating them.
1083+ if (while_op.shape_invariant ()) {
1084+ auto compatible_types = GetWhileCompatibleTypes (
1085+ while_op.input ().getTypes (), while_op.output ().getTypes (),
1086+ while_op.body_function ().getType ().getInputs ());
1087+ return PropagateShapeToFunctions (
1088+ module , compatible_types,
1089+ {while_op.cond_function (), while_op.body_function ()}, max_iteration);
1090+ }
9831091 return PropagateShapeToFunctions (
984- module , while_op.getOperandTypes (),
1092+ module , while_op.input (). getTypes (),
9851093 {while_op.cond_function (), while_op.body_function ()}, max_iteration);
9861094 } else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
9871095 if (auto func = dyn_cast<FuncOp>(call_op.resolveCallable ())) {
@@ -1004,7 +1112,19 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
10041112LogicalResult ShapeInference::PropagateShapeIntoAttachedRegions (
10051113 Operation* op, int64_t max_iteration) {
10061114 if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) {
1007- return PropagateShapeToRegions (while_op.getOperandTypes (),
1115+ // If `shape_invariant` is set, operand shapes cannot be simply propagated
1116+ // to result shapes as the op may have different intermediate shapes (such
1117+ // While ops can have different result shapes from operand shapes).
1118+ // Compatible shapes must be determined before propagating them.
1119+ if (while_op.shape_invariant ()) {
1120+ auto compatible_types = GetWhileCompatibleTypes (
1121+ while_op.input ().getTypes (), while_op.output ().getTypes (),
1122+ while_op.body ().getArgumentTypes ());
1123+ return PropagateShapeToRegions (compatible_types,
1124+ {&while_op.cond (), &while_op.body ()},
1125+ max_iteration);
1126+ }
1127+ return PropagateShapeToRegions (while_op.input ().getTypes (),
10081128 {&while_op.cond (), &while_op.body ()},
10091129 max_iteration);
10101130 }
0 commit comments