@@ -48,17 +48,17 @@ struct TPUShardingIdentificationPass
4848 void runOnOperation () override ;
4949};
5050
51- // Finds XlaSharding op connected to a value. If value is a resource type then
52- // XlaSharding op will be connected to a ReadVariable op. XlaSharding op may be
53- // direct user of inputs but it may also be followed by an Identity op and, in
54- // the case where bfloat16 type is used, Cast op may be added right after the
55- // input.
51+ // Finds XlaSharding op connected to an argument value. If value is a resource
52+ // type then XlaSharding op will be connected to a ReadVariable op. XlaSharding
53+ // op may be direct user of inputs but it may also be followed by an Identity op
54+ // and, in the case where bfloat16 type is used, Cast op may be added right
55+ // after the input.
5656//
5757// TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
5858// Case, While) ops and Caller return values.
5959// TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
6060// inputs.
61- llvm::Optional<llvm::StringRef> GetXlaSharding (const Value& value) {
61+ llvm::Optional<llvm::StringRef> GetXlaShardingFromArg (const Value& value) {
6262 llvm::SmallPtrSet<Value, 4 > visited_values;
6363 llvm::SmallVector<Value, 4 > values_to_visit{value};
6464 while (!values_to_visit.empty ()) {
@@ -67,16 +67,16 @@ llvm::Optional<llvm::StringRef> GetXlaSharding(const Value& value) {
6767 if (!visited_values.insert (value_to_visit).second ) continue ;
6868
6969 for (auto & use : value_to_visit.getUses ()) {
70- if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(use.getOwner ()))
70+ Operation* owner = use.getOwner ();
71+ if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(owner))
7172 return sharding._XlaSharding ();
7273
73- if (llvm::isa<TF::IdentityOp, TF::CastOp, TF::ReadVariableOp>(
74- use.getOwner ())) {
74+ if (llvm::isa<TF::IdentityOp, TF::CastOp, TF::ReadVariableOp>(owner)) {
7575 next_values_to_visit.push_back (use.getOwner ()->getResult (0 ));
7676 continue ;
7777 }
7878
79- if (auto call_op = llvm::dyn_cast<CallOpInterface>(use. getOwner () )) {
79+ if (auto call_op = llvm::dyn_cast<CallOpInterface>(owner )) {
8080 FuncOp func = llvm::dyn_cast<FuncOp>(call_op.resolveCallable ());
8181 if (!func) continue ;
8282 next_values_to_visit.push_back (
@@ -91,18 +91,6 @@ llvm::Optional<llvm::StringRef> GetXlaSharding(const Value& value) {
9191 return llvm::None;
9292}
9393
94- // Returns the provided sharding configuration if operand of return value of
95- // tf_device.ClusterFunc op is directly from XlaSharding op,
96- llvm::Optional<StringRef> ParseReturnValueSharding (FuncOp func,
97- const int output_index,
98- const OpOperand& operand) {
99- if (auto sharding_op = llvm::dyn_cast_or_null<TF::XlaShardingOp>(
100- operand.get ().getDefiningOp ()))
101- return sharding_op._XlaSharding ();
102-
103- return llvm::Optional<StringRef>();
104- }
105-
10694// Walks the graph from the arguments of the `cluster_func_op` and extracts
10795// sharding configurations for all inputs by parsing XlaSharding op connected to
10896// the arguments. If argument to the `cluster_func_op` directly feeds into
@@ -128,17 +116,17 @@ void IdentifyXlaShardingForComputationInputs(
128116 // Sharding configurations are added to the tf_device.ClusterFunc as an
129117 // attribute and the function as an argument attribute.
130118 for (auto & arg : cluster_function_block.getArguments ()) {
131- auto arg_sharding = GetXlaSharding (arg);
132- const int arg_index = arg.getArgNumber ();
119+ auto arg_sharding = GetXlaShardingFromArg (arg);
120+ const int index = arg.getArgNumber ();
133121
134122 if (arg_sharding) {
135- sharding_for_args[arg_index ] = arg_sharding.getValue ();
123+ sharding_for_args[index ] = arg_sharding.getValue ();
136124 cluster_function.setArgAttr (
137- arg_index , kShardingAttr ,
125+ index , kShardingAttr ,
138126 builder->getStringAttr (arg_sharding.getValue ()));
139127 } else {
140128 cluster_function.setArgAttr (
141- arg_index , kShardingAttr ,
129+ index , kShardingAttr ,
142130 builder->getStringAttr (logical_core_0_sharding));
143131 }
144132 }
@@ -147,6 +135,44 @@ void IdentifyXlaShardingForComputationInputs(
147135 builder->getStrArrayAttr (sharding_for_args));
148136}
149137
138+ // Finds XlaSharding op connected to a result value. XlaSharding op may be
139+ // direct user of inputs but it may also be followed by an Identity op and, in
140+ // the case where bfloat16 type is used, Cast op may be added right after the
141+ // input.
142+ //
143+ // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
144+ // Case, While) ops and Caller argument values.
145+ // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
146+ // inputs.
147+ llvm::Optional<StringRef> GetXlaShardingFromRetval (const Value& value) {
148+ llvm::SmallPtrSet<Value, 4 > visited_values;
149+ Value value_to_visit = value;
150+ while (value_to_visit) {
151+ if (!visited_values.insert (value_to_visit).second ) return llvm::None;
152+
153+ Operation* def = value_to_visit.getDefiningOp ();
154+ if (auto sharding = llvm::dyn_cast_or_null<TF::XlaShardingOp>(def))
155+ return sharding._XlaSharding ();
156+
157+ if (llvm::isa_and_nonnull<TF::IdentityOp, TF::CastOp>(def)) {
158+ value_to_visit = def->getOperand (0 );
159+ continue ;
160+ }
161+
162+ if (auto call_op = llvm::dyn_cast_or_null<CallOpInterface>(def)) {
163+ FuncOp func = llvm::dyn_cast<FuncOp>(call_op.resolveCallable ());
164+ if (!func) continue ;
165+ value_to_visit = func.front ().getTerminator ()->getOperand (
166+ value_to_visit.cast <OpResult>().getResultNumber ());
167+ continue ;
168+ }
169+
170+ break ;
171+ }
172+
173+ return llvm::None;
174+ }
175+
150176// Parses XlaSharding op directly connected from the outputs of the
151177// `cluster_func` and extract sharding configurations for outputs.
152178void IdentifyXlaShardingForComputationOutputs (
@@ -164,8 +190,8 @@ void IdentifyXlaShardingForComputationOutputs(
164190 // tf_device.ClusterFunc as an attribute and the function as a result
165191 // attribute.
166192 for (auto & ret : terminator->getOpOperands ()) {
193+ auto ret_sharding = GetXlaShardingFromRetval (ret.get ());
167194 const int index = ret.getOperandNumber ();
168- auto ret_sharding = ParseReturnValueSharding (func, index, ret);
169195
170196 if (ret_sharding) {
171197 sharding_for_rets[index] = ret_sharding.getValue ();
@@ -176,6 +202,7 @@ void IdentifyXlaShardingForComputationOutputs(
176202 builder->getStringAttr (logical_core_0_sharding));
177203 }
178204 }
205+
179206 cluster_func.setAttr (tensorflow::kOutputShardingAttr ,
180207 builder->getStrArrayAttr (sharding_for_rets));
181208}
0 commit comments