@@ -18,12 +18,14 @@ limitations under the License.
1818
1919#include " llvm/ADT/ArrayRef.h"
2020#include " llvm/ADT/Optional.h"
21+ #include " llvm/ADT/STLExtras.h"
2122#include " llvm/ADT/SmallVector.h"
2223#include " llvm/ADT/StringRef.h"
2324#include " llvm/Support/Casting.h"
2425#include " mlir/IR/Attributes.h" // from @llvm-project
2526#include " mlir/IR/Block.h" // from @llvm-project
2627#include " mlir/IR/Builders.h" // from @llvm-project
28+ #include " mlir/IR/BuiltinAttributes.h" // from @llvm-project
2729#include " mlir/IR/BuiltinOps.h" // from @llvm-project
2830#include " mlir/IR/Operation.h" // from @llvm-project
2931#include " mlir/IR/Value.h" // from @llvm-project
@@ -40,24 +42,43 @@ namespace TFTPU {
4042namespace {
4143
4244constexpr char kShardingAttr [] = " mhlo.sharding" ;
45+ constexpr char kReplicateSharding [] = " " ;
4346
4447struct TPUShardingIdentificationPass
4548 : public PassWrapper<TPUShardingIdentificationPass,
4649 OperationPass<ModuleOp>> {
4750 void runOnOperation () override ;
4851};
4952
50- // Finds XlaSharding op connected to an argument value. If value is a resource
51- // type then XlaSharding op will be connected to a ReadVariable op. XlaSharding
52- // op may be direct user of inputs but it may also be followed by an Identity op
53- // and, in the case where bfloat16 type is used, Cast op may be added right
54- // after the input.
53+ // Returns XLA sharding from TPUPartitionedInput op connected to a
54+ // `tf_device.cluster_func` operand value. If value is a resource type then
55+ // TPUPartitionedInput op will be connected to a ReadVariable op that feeds into
56+ // a `tf_device.cluster_func`.
57+ llvm::Optional<llvm::StringRef> GetXlaShardingFromOperand (Value value) {
58+ Value value_to_visit = value;
59+ if (auto read_var = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
60+ value_to_visit.getDefiningOp ()))
61+ value_to_visit = read_var.resource ();
62+
63+ if (auto partitioned_input =
64+ llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
65+ value_to_visit.getDefiningOp ()))
66+ return partitioned_input._XlaSharding ();
67+
68+ return llvm::None;
69+ }
70+
71+ // Returns XLA sharding from a XlaSharding op connected to an argument value. If
72+ // value is a resource type then XlaSharding op will be connected to a
73+ // ReadVariable op. XlaSharding op may be direct user of inputs but it may also
74+ // be followed by an Identity op and, in the case where bfloat16 type is used,
75+ // Cast op may be added right after the input.
5576//
5677// TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
5778// Case, While) ops and Caller return values.
5879// TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
5980// inputs.
60- llvm::Optional<llvm::StringRef> GetXlaShardingFromArg (const Value& value) {
81+ llvm::Optional<llvm::StringRef> GetXlaShardingFromArg (Value value) {
6182 llvm::SmallPtrSet<Value, 4 > visited_values;
6283 llvm::SmallVector<Value, 4 > values_to_visit{value};
6384 while (!values_to_visit.empty ()) {
@@ -90,22 +111,29 @@ llvm::Optional<llvm::StringRef> GetXlaShardingFromArg(const Value& value) {
90111 return llvm::None;
91112}
92113
93- // Walks the graph from the arguments of the `cluster_func_op` and extracts
94- // sharding configurations for all inputs by parsing XlaSharding op connected to
95- // the arguments. If argument to the `cluster_func_op ` directly feeds into
96- // another function call op, then recursively walk the function definition to
97- // find the connected XlaSharding op.
114+ // Extracts sharding configurations for all inputs by parsing XlaSharding/
115+ // TPUPartitionedInput op connected to the operands/arguments. If argument to
116+ // the `cluster_func ` directly feeds into another function call op, then
117+ // recursively walk the function definition to find the connected XlaSharding
118+ // op.
98119void IdentifyXlaShardingForComputationInputs (
99- StringRef logical_core_0_sharding, tf_device::ClusterFuncOp cluster_func_op ,
100- FuncOp cluster_function , Builder* builder) {
120+ StringRef logical_core_0_sharding, bool use_spmd ,
121+ tf_device::ClusterFuncOp cluster_func, FuncOp func , Builder* builder) {
101122 // Look up function definition from module.
102- Block& cluster_function_block = cluster_function .front ();
123+ Block& function_block = func .front ();
103124
104- llvm::SmallVector<llvm::StringRef, 8 > sharding_for_args (
105- cluster_function_block. getNumArguments (), logical_core_0_sharding );
125+ llvm::SmallVector<llvm::StringRef, 8 > sharding_for_args;
126+ sharding_for_args. reserve (function_block. getNumArguments ());
106127
128+ // Iterate through operands of `cluster_func`.
129+ // The computation operand can either be:
130+ // 1) a TPUPartitionedInput Op if the input has a non-resource type;
131+ // 2) a ReadVariableOp else.
132+ //
133+ // Replicate sharding is used if `use_spmd` is set.
134+ //
107135 // Iterate through input arguments to the entry block of
108- // tf_device.ClusterFunc. For input ops, look for following XlaSharding ops.
136+ // tf_device.ClusterFunc. For input ops, look for XlaSharding ops.
109137 // XlaSharding ops can:
110138 // 1) Directly follow the input argument if input argument has non-resource
111139 // types.
@@ -114,36 +142,70 @@ void IdentifyXlaShardingForComputationInputs(
114142 //
115143 // Sharding configurations are added to the tf_device.ClusterFunc as an
116144 // attribute and the function as an argument attribute.
117- for (auto & arg : cluster_function_block.getArguments ()) {
118- auto arg_sharding = GetXlaShardingFromArg (arg);
145+ for (auto operand_and_arg :
146+ llvm::zip (cluster_func.operands (), function_block.getArguments ())) {
147+ Value operand = std::get<0 >(operand_and_arg);
148+ BlockArgument arg = std::get<1 >(operand_and_arg);
119149 const int index = arg.getArgNumber ();
120150
151+ if (auto operand_sharding = GetXlaShardingFromOperand (operand)) {
152+ sharding_for_args.push_back (operand_sharding.getValue ());
153+ func.setArgAttr (index, kShardingAttr ,
154+ builder->getStringAttr (operand_sharding.getValue ()));
155+ continue ;
156+ }
157+
158+ if (use_spmd) {
159+ // If XLA SPMD is enabled, host variables or non-variable per-replica
160+ // inputs should take on replicate sharding, unless another sharding is
161+ // set via a TPUPartitionedInput op.
162+ sharding_for_args.push_back (kReplicateSharding );
163+ func.setArgAttr (index, kShardingAttr ,
164+ builder->getStringAttr (kReplicateSharding ));
165+ continue ;
166+ }
167+
168+ auto arg_sharding = GetXlaShardingFromArg (arg);
121169 if (arg_sharding) {
122- sharding_for_args[index] = arg_sharding.getValue ();
123- cluster_function.setArgAttr (
124- index, kShardingAttr ,
125- builder->getStringAttr (arg_sharding.getValue ()));
126- } else {
127- cluster_function.setArgAttr (
128- index, kShardingAttr ,
129- builder->getStringAttr (logical_core_0_sharding));
170+ sharding_for_args.push_back (arg_sharding.getValue ());
171+ func.setArgAttr (index, kShardingAttr ,
172+ builder->getStringAttr (arg_sharding.getValue ()));
173+ continue ;
130174 }
175+
176+ // Default to maximal sharding core 0 if no sharding is present.
177+ sharding_for_args.push_back (logical_core_0_sharding);
178+ func.setArgAttr (index, kShardingAttr ,
179+ builder->getStringAttr (logical_core_0_sharding));
131180 }
132181
133- cluster_func_op ->setAttr (tensorflow::kInputShardingAttr ,
134- builder->getStrArrayAttr (sharding_for_args));
182+ cluster_func ->setAttr (tensorflow::kInputShardingAttr ,
183+ builder->getStrArrayAttr (sharding_for_args));
135184}
136185
137- // Finds XlaSharding op connected to a result value. XlaSharding op may be
138- // direct user of inputs but it may also be followed by an Identity op and, in
139- // the case where bfloat16 type is used, Cast op may be added right after the
140- // input.
186+ // Returns XLA sharding from TPUPartitionedOutput op connected to a
187+ // `tf_device.cluster_func` result value.
188+ llvm::Optional<llvm::StringRef> GetXlaShardingFromResult (Value value) {
189+ if (!value.hasOneUse ()) return llvm::None;
190+
191+ Operation* user = *value.getUsers ().begin ();
192+ if (auto partitioned_output =
193+ llvm::dyn_cast<TF::TPUPartitionedOutputOp>(user))
194+ return partitioned_output._XlaSharding ();
195+
196+ return llvm::None;
197+ }
198+
199+ // Returns XLA sharding from XlaSharding op connected to a result value.
200+ // XlaSharding op may be direct user of inputs but it may also be followed by an
201+ // Identity op and, in the case where bfloat16 type is used, Cast op may be
202+ // added right after the input.
141203//
142204// TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
143205// Case, While) ops and Caller argument values.
144206// TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
145207// inputs.
146- llvm::Optional<StringRef> GetXlaShardingFromRetval (const Value& value) {
208+ llvm::Optional<StringRef> GetXlaShardingFromRetval (Value value) {
147209 llvm::SmallPtrSet<Value, 4 > visited_values;
148210 Value value_to_visit = value;
149211 while (value_to_visit) {
@@ -172,34 +234,58 @@ llvm::Optional<StringRef> GetXlaShardingFromRetval(const Value& value) {
172234 return llvm::None;
173235}
174236
175- // Parses XlaSharding op directly connected from the outputs of the
176- // `cluster_func` and extract sharding configurations for outputs .
237+ // Extracts sharding configurations for all outputs by parsing XlaSharding/
238+ // TPUPartitionedOutput op connected to the retvals/results .
177239void IdentifyXlaShardingForComputationOutputs (
178- StringRef logical_core_0_sharding, FuncOp func,
179- tf_device::ClusterFuncOp cluster_func, Builder* builder) {
180- // By default return values from logical core 0 is used if no sharding
181- // configuration is defined.
240+ StringRef logical_core_0_sharding, bool use_spmd,
241+ tf_device::ClusterFuncOp cluster_func, FuncOp func, Builder* builder) {
182242 Block& function_block = func.front ();
183243 Operation* terminator = function_block.getTerminator ();
184- llvm::SmallVector<llvm::StringRef, 8 > sharding_for_rets (
185- terminator->getNumOperands (), logical_core_0_sharding );
244+ llvm::SmallVector<llvm::StringRef, 8 > sharding_for_rets;
245+ sharding_for_rets. reserve ( terminator->getNumOperands ());
186246
247+ // Iterate through results of `cluster_func`. For output ops, look for
248+ // TPUPartitionedOutput ops.
249+ //
250+ // Replicate sharding is used if `use_spmd` is set.
251+ //
187252 // Iterate through operands of the terminator. If the preceding op is
188253 // XlaShardingOp, then the provided sharding configuration is added to the
189254 // tf_device.ClusterFunc as an attribute and the function as a result
190255 // attribute.
191- for (auto & ret : terminator->getOpOperands ()) {
192- auto ret_sharding = GetXlaShardingFromRetval (ret.get ());
193- const int index = ret.getOperandNumber ();
256+ for (auto result_and_retval :
257+ llvm::zip (cluster_func.results (), terminator->getOpOperands ())) {
258+ Value result = std::get<0 >(result_and_retval);
259+ OpOperand& retval = std::get<1 >(result_and_retval);
260+ const int index = retval.getOperandNumber ();
261+
262+ if (auto result_sharding = GetXlaShardingFromResult (result)) {
263+ sharding_for_rets.push_back (result_sharding.getValue ());
264+ func.setResultAttr (index, kShardingAttr ,
265+ builder->getStringAttr (result_sharding.getValue ()));
266+ continue ;
267+ }
194268
195- if (ret_sharding) {
196- sharding_for_rets[index] = ret_sharding.getValue ();
269+ if (use_spmd) {
270+ // If XLA SPMD is enabled, outputs all should have replicate sharding,
271+ // unless another sharding is set via a TPUPartitionedOutput op.
272+ sharding_for_rets.push_back (kReplicateSharding );
197273 func.setResultAttr (index, kShardingAttr ,
198- builder->getStringAttr (ret_sharding.getValue ()));
199- } else {
274+ builder->getStringAttr (kReplicateSharding ));
275+ continue ;
276+ }
277+
278+ if (auto retval_sharding = GetXlaShardingFromRetval (retval.get ())) {
279+ sharding_for_rets.push_back (retval_sharding.getValue ());
200280 func.setResultAttr (index, kShardingAttr ,
201- builder->getStringAttr (logical_core_0_sharding));
281+ builder->getStringAttr (retval_sharding.getValue ()));
282+ continue ;
202283 }
284+
285+ // Default to maximal sharding core 0 if no sharding is present.
286+ sharding_for_rets.push_back (logical_core_0_sharding);
287+ func.setResultAttr (index, kShardingAttr ,
288+ builder->getStringAttr (logical_core_0_sharding));
203289 }
204290
205291 cluster_func->setAttr (tensorflow::kOutputShardingAttr ,
@@ -219,11 +305,16 @@ void IdentifyXlaShardingForTPUComputation(
219305 const std::string logical_core_0_sharding =
220306 xla::sharding_builder::AssignDevice (0 ).SerializeAsString ();
221307
222- IdentifyXlaShardingForComputationInputs (logical_core_0_sharding, cluster_func,
223- func, builder);
308+ bool use_spmd = false ;
309+ if (auto use_spmd_attr =
310+ cluster_func.getAttrOfType <BoolAttr>(" use_spmd_for_xla_partitioning" ))
311+ use_spmd = use_spmd_attr.getValue ();
312+
313+ IdentifyXlaShardingForComputationInputs (logical_core_0_sharding, use_spmd,
314+ cluster_func, func, builder);
224315
225- IdentifyXlaShardingForComputationOutputs (logical_core_0_sharding, func ,
226- cluster_func, builder);
316+ IdentifyXlaShardingForComputationOutputs (logical_core_0_sharding, use_spmd ,
317+ cluster_func, func, builder);
227318}
228319
229320void TPUShardingIdentificationPass::runOnOperation () {
0 commit comments