Skip to content

Commit 636ca6b

Browse files
andylytensorflower-gardener
authored andcommitted
Update TPU sharding identification pass to support checking for sharding attributes from tf.TPUPartitionedInput/tf.TPUPartitionedOutput ops.
When XLA SPMD is enabled, these ops are generated, holding pre partitioned inputs/outputs. The computation inputs and outputs should take on these shardings. Otherwise sharding should be set to replicate sharding. PiperOrigin-RevId: 348698987 Change-Id: If075108d6753d09018509862572be87247fdce95
1 parent 75d14ad commit 636ca6b

File tree

2 files changed

+217
-58
lines changed

2 files changed

+217
-58
lines changed

tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func @empty_func() {
2121
// gets default maximal(0) sharding configuration.
2222
// CHECK-LABEL: func @check_default_sharding_for_block_arg_inputs_outputs
2323
func @check_default_sharding_for_block_arg_inputs_outputs(%arg0: tensor<*xi32>) {
24-
"tf_device.cluster_func"(%arg0) {func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> ()
24+
"tf_device.cluster_func"(%arg0) {func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> tensor<*xi32>
2525
// CHECK: input_sharding_configuration
2626
// CHECK-SAME: ["\08\01\1A\01\01\22\01\00"]
2727
// CHECK: output_sharding_configuration
@@ -42,7 +42,7 @@ func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> {
4242
// default maximal(0) sharding configuration.
4343
// CHECK-LABEL: func @check_default_sharding_for_inputs_outputs
4444
func @check_default_sharding_for_inputs_outputs(%arg0: tensor<*xi32>) {
45-
"tf_device.cluster_func"(%arg0) {func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> ()
45+
"tf_device.cluster_func"(%arg0) {func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> tensor<*xi32>
4646
// CHECK: input_sharding_configuration
4747
// CHECK-SAME: ["\08\01\1A\01\01\22\01\00"]
4848
// CHECK: output_sharding_configuration
@@ -63,7 +63,7 @@ func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> {
6363
// Tests with a input arg connected to XlaSharding op.
6464
// CHECK-LABEL: func @check_sharding_for_input_correctly_identified
6565
func @check_sharding_for_input_correctly_identified(%arg0: tensor<*xi32>) {
66-
"tf_device.cluster_func"(%arg0) {func = @inputs_with_sharding_func, step_marker_location = ""} : (tensor<*xi32>) -> ()
66+
"tf_device.cluster_func"(%arg0) {func = @inputs_with_sharding_func, step_marker_location = ""} : (tensor<*xi32>) -> tensor<*xi32>
6767
// CHECK: input_sharding_configuration
6868
// CHECK-SAME: ["\01\02\03"]
6969
// CHECK: output_sharding_configuration
@@ -90,7 +90,7 @@ func @check_sharding_for_multiple_inputs_outputs(%arg0: tensor<*xi32>, %arg1: te
9090
// CHECK-SAME: ["\01\02\03", "\04\05\06"]
9191
// CHECK: output_sharding_configuration
9292
// CHECK-SAME: ["\0A\0B\0C", "\0D\0E\0F"]
93-
return
93+
return
9494
}
9595

9696
// CHECK-LABEL: func @func_with_sharding
@@ -252,3 +252,71 @@ func @func_body(%arg0: tensor<*xi32>)-> tensor<*xi32> {
252252
%1 = "tf.Identity"(%0) : (tensor<*xi32>) -> (tensor<*xi32>)
253253
return %1 : tensor<*xi32>
254254
}
255+
256+
// -----
257+
258+
// Tests partitioned data inputs/outputs are set correctly (via XLA SPMD) is
259+
// enabled. Non replicated inputs/outputs should have shardings set to be
260+
// replicate sharding ("").
261+
262+
// CHECK-LABEL: func @partitioned_input_output
263+
func @partitioned_input_output(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
264+
%0 = "tf.TPUPartitionedInput"(%arg0) {_XlaSharding = "\01\02\03", partition_dim = -1 : i64} : (tensor<*xi32>) -> tensor<*xi32>
265+
// CHECK: tf_device.cluster_func
266+
// CHECK-SAME: input_sharding_configuration = ["\01\02\03", ""]
267+
// CHECK-SAME: output_sharding_configuration = ["", "\04\05\06"]
268+
%1:2 = "tf_device.cluster_func"(%0, %arg1) {func = @cluster_func, use_spmd_for_xla_partitioning = true} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>)
269+
%2 = "tf.TPUPartitionedOutput"(%1#1) {_XlaSharding = "\04\05\06", partition_dim = -1 : i64} : (tensor<*xi32>) -> tensor<*xi32>
270+
return %1#0, %2 : tensor<*xi32>, tensor<*xi32>
271+
}
272+
273+
// CHECK-LABEL: func @cluster_func
274+
// CHECK-SAME: ({{.+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}, {{.+}}: tensor<*xi32> {mhlo.sharding = ""})
275+
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = ""}, tensor<*xi32> {mhlo.sharding = "\04\05\06"})
276+
func @cluster_func(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
277+
return %arg0, %arg1 : tensor<*xi32>, tensor<*xi32>
278+
}
279+
280+
// -----
281+
282+
// Tests partitioned variables (via XLA SPMD) propagates shardings correctly.
283+
284+
// CHECK-LABEL: func @partitioned_variable
285+
func @partitioned_variable(%arg0: tensor<!tf.resource<tensor<*xf32>>>) {
286+
%0 = "tf.TPUPartitionedInput"(%arg0) {_XlaSharding = "\01\02\03", partition_dim = -1 : i64} : (tensor<!tf.resource<tensor<*xf32>>>) -> tensor<!tf.resource<tensor<*xf32>>>
287+
%1 = "tf.ReadVariableOp"(%0) : (tensor<!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
288+
// CHECK: tf_device.cluster_func
289+
// CHECK-SAME: input_sharding_configuration = ["\01\02\03"]
290+
// CHECK-SAME: output_sharding_configuration = []
291+
"tf_device.cluster_func"(%1) {func = @cluster_func, use_spmd_for_xla_partitioning = true} : (tensor<*xf32>) -> ()
292+
return
293+
}
294+
295+
// CHECK-LABEL: func @cluster_func
296+
// CHECK-SAME: ({{.+}}: tensor<*xf32> {mhlo.sharding = "\01\02\03"})
297+
func @cluster_func(%arg0: tensor<*xf32>) {
298+
return
299+
}
300+
301+
// -----
302+
303+
// Tests partitioned inputs/outputs with no sharding (via XLA SPMD) defaults to
304+
// replicate sharding ("").
305+
306+
// CHECK-LABEL: func @partitioned_input_output
307+
func @partitioned_input_output(%arg0: tensor<*xi32>) -> tensor<*xi32> {
308+
%0 = "tf.TPUPartitionedInput"(%arg0) {partition_dim = -1 : i64} : (tensor<*xi32>) -> tensor<*xi32>
309+
// CHECK: tf_device.cluster_func
310+
// CHECK-SAME: input_sharding_configuration = [""]
311+
// CHECK-SAME: output_sharding_configuration = [""]
312+
%1 = "tf_device.cluster_func"(%0) {func = @cluster_func, use_spmd_for_xla_partitioning = true} : (tensor<*xi32>) -> tensor<*xi32>
313+
%2 = "tf.TPUPartitionedOutput"(%1) {partition_dim = -1 : i64} : (tensor<*xi32>) -> tensor<*xi32>
314+
return %2 : tensor<*xi32>
315+
}
316+
317+
// CHECK-LABEL: func @cluster_func
318+
// CHECK-SAME: ({{.+}}: tensor<*xi32> {mhlo.sharding = ""})
319+
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = ""})
320+
func @cluster_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
321+
return %arg0 : tensor<*xi32>
322+
}

tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc

Lines changed: 145 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ limitations under the License.
1818

1919
#include "llvm/ADT/ArrayRef.h"
2020
#include "llvm/ADT/Optional.h"
21+
#include "llvm/ADT/STLExtras.h"
2122
#include "llvm/ADT/SmallVector.h"
2223
#include "llvm/ADT/StringRef.h"
2324
#include "llvm/Support/Casting.h"
2425
#include "mlir/IR/Attributes.h" // from @llvm-project
2526
#include "mlir/IR/Block.h" // from @llvm-project
2627
#include "mlir/IR/Builders.h" // from @llvm-project
28+
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
2729
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
2830
#include "mlir/IR/Operation.h" // from @llvm-project
2931
#include "mlir/IR/Value.h" // from @llvm-project
@@ -40,24 +42,43 @@ namespace TFTPU {
4042
namespace {
4143

4244
constexpr char kShardingAttr[] = "mhlo.sharding";
45+
constexpr char kReplicateSharding[] = "";
4346

4447
struct TPUShardingIdentificationPass
4548
: public PassWrapper<TPUShardingIdentificationPass,
4649
OperationPass<ModuleOp>> {
4750
void runOnOperation() override;
4851
};
4952

50-
// Finds XlaSharding op connected to an argument value. If value is a resource
51-
// type then XlaSharding op will be connected to a ReadVariable op. XlaSharding
52-
// op may be direct user of inputs but it may also be followed by an Identity op
53-
// and, in the case where bfloat16 type is used, Cast op may be added right
54-
// after the input.
53+
// Returns XLA sharding from TPUPartitionedInput op connected to a
54+
// `tf_device.cluster_func` operand value. If value is a resource type then
55+
// TPUPartitionedInput op will be connected to a ReadVariable op that feeds into
56+
// a `tf_device.cluster_func`.
57+
llvm::Optional<llvm::StringRef> GetXlaShardingFromOperand(Value value) {
58+
Value value_to_visit = value;
59+
if (auto read_var = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
60+
value_to_visit.getDefiningOp()))
61+
value_to_visit = read_var.resource();
62+
63+
if (auto partitioned_input =
64+
llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
65+
value_to_visit.getDefiningOp()))
66+
return partitioned_input._XlaSharding();
67+
68+
return llvm::None;
69+
}
70+
71+
// Returns XLA sharding from a XlaSharding op connected to an argument value. If
72+
// value is a resource type then XlaSharding op will be connected to a
73+
// ReadVariable op. XlaSharding op may be direct user of inputs but it may also
74+
// be followed by an Identity op and, in the case where bfloat16 type is used,
75+
// Cast op may be added right after the input.
5576
//
5677
// TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
5778
// Case, While) ops and Caller return values.
5879
// TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
5980
// inputs.
60-
llvm::Optional<llvm::StringRef> GetXlaShardingFromArg(const Value& value) {
81+
llvm::Optional<llvm::StringRef> GetXlaShardingFromArg(Value value) {
6182
llvm::SmallPtrSet<Value, 4> visited_values;
6283
llvm::SmallVector<Value, 4> values_to_visit{value};
6384
while (!values_to_visit.empty()) {
@@ -90,22 +111,29 @@ llvm::Optional<llvm::StringRef> GetXlaShardingFromArg(const Value& value) {
90111
return llvm::None;
91112
}
92113

93-
// Walks the graph from the arguments of the `cluster_func_op` and extracts
94-
// sharding configurations for all inputs by parsing XlaSharding op connected to
95-
// the arguments. If argument to the `cluster_func_op` directly feeds into
96-
// another function call op, then recursively walk the function definition to
97-
// find the connected XlaSharding op.
114+
// Extracts sharding configurations for all inputs by parsing XlaSharding/
115+
// TPUPartitionedInput op connected to the operands/arguments. If argument to
116+
// the `cluster_func` directly feeds into another function call op, then
117+
// recursively walk the function definition to find the connected XlaSharding
118+
// op.
98119
void IdentifyXlaShardingForComputationInputs(
99-
StringRef logical_core_0_sharding, tf_device::ClusterFuncOp cluster_func_op,
100-
FuncOp cluster_function, Builder* builder) {
120+
StringRef logical_core_0_sharding, bool use_spmd,
121+
tf_device::ClusterFuncOp cluster_func, FuncOp func, Builder* builder) {
101122
// Look up function definition from module.
102-
Block& cluster_function_block = cluster_function.front();
123+
Block& function_block = func.front();
103124

104-
llvm::SmallVector<llvm::StringRef, 8> sharding_for_args(
105-
cluster_function_block.getNumArguments(), logical_core_0_sharding);
125+
llvm::SmallVector<llvm::StringRef, 8> sharding_for_args;
126+
sharding_for_args.reserve(function_block.getNumArguments());
106127

128+
// Iterate through operands of `cluster_func`.
129+
// The computation operand can either be:
130+
// 1) a TPUPartitionedInput Op if the input has a non-resource type;
131+
// 2) a ReadVariableOp else.
132+
//
133+
// Replicate sharding is used if `use_spmd` is set.
134+
//
107135
// Iterate through input arguments to the entry block of
108-
// tf_device.ClusterFunc. For input ops, look for following XlaSharding ops.
136+
// tf_device.ClusterFunc. For input ops, look for XlaSharding ops.
109137
// XlaSharding ops can:
110138
// 1) Directly follow the input argument if input argument has non-resource
111139
// types.
@@ -114,36 +142,70 @@ void IdentifyXlaShardingForComputationInputs(
114142
//
115143
// Sharding configurations are added to the tf_device.ClusterFunc as an
116144
// attribute and the function as an argument attribute.
117-
for (auto& arg : cluster_function_block.getArguments()) {
118-
auto arg_sharding = GetXlaShardingFromArg(arg);
145+
for (auto operand_and_arg :
146+
llvm::zip(cluster_func.operands(), function_block.getArguments())) {
147+
Value operand = std::get<0>(operand_and_arg);
148+
BlockArgument arg = std::get<1>(operand_and_arg);
119149
const int index = arg.getArgNumber();
120150

151+
if (auto operand_sharding = GetXlaShardingFromOperand(operand)) {
152+
sharding_for_args.push_back(operand_sharding.getValue());
153+
func.setArgAttr(index, kShardingAttr,
154+
builder->getStringAttr(operand_sharding.getValue()));
155+
continue;
156+
}
157+
158+
if (use_spmd) {
159+
// If XLA SPMD is enabled, host variables or non-variable per-replica
160+
// inputs should take on replicate sharding, unless another sharding is
161+
// set via a TPUPartitionedInput op.
162+
sharding_for_args.push_back(kReplicateSharding);
163+
func.setArgAttr(index, kShardingAttr,
164+
builder->getStringAttr(kReplicateSharding));
165+
continue;
166+
}
167+
168+
auto arg_sharding = GetXlaShardingFromArg(arg);
121169
if (arg_sharding) {
122-
sharding_for_args[index] = arg_sharding.getValue();
123-
cluster_function.setArgAttr(
124-
index, kShardingAttr,
125-
builder->getStringAttr(arg_sharding.getValue()));
126-
} else {
127-
cluster_function.setArgAttr(
128-
index, kShardingAttr,
129-
builder->getStringAttr(logical_core_0_sharding));
170+
sharding_for_args.push_back(arg_sharding.getValue());
171+
func.setArgAttr(index, kShardingAttr,
172+
builder->getStringAttr(arg_sharding.getValue()));
173+
continue;
130174
}
175+
176+
// Default to maximal sharding core 0 if no sharding is present.
177+
sharding_for_args.push_back(logical_core_0_sharding);
178+
func.setArgAttr(index, kShardingAttr,
179+
builder->getStringAttr(logical_core_0_sharding));
131180
}
132181

133-
cluster_func_op->setAttr(tensorflow::kInputShardingAttr,
134-
builder->getStrArrayAttr(sharding_for_args));
182+
cluster_func->setAttr(tensorflow::kInputShardingAttr,
183+
builder->getStrArrayAttr(sharding_for_args));
135184
}
136185

137-
// Finds XlaSharding op connected to a result value. XlaSharding op may be
138-
// direct user of inputs but it may also be followed by an Identity op and, in
139-
// the case where bfloat16 type is used, Cast op may be added right after the
140-
// input.
186+
// Returns XLA sharding from TPUPartitionedOutput op connected to a
187+
// `tf_device.cluster_func` result value.
188+
llvm::Optional<llvm::StringRef> GetXlaShardingFromResult(Value value) {
189+
if (!value.hasOneUse()) return llvm::None;
190+
191+
Operation* user = *value.getUsers().begin();
192+
if (auto partitioned_output =
193+
llvm::dyn_cast<TF::TPUPartitionedOutputOp>(user))
194+
return partitioned_output._XlaSharding();
195+
196+
return llvm::None;
197+
}
198+
199+
// Returns XLA sharding from XlaSharding op connected to a result value.
200+
// XlaSharding op may be direct user of inputs but it may also be followed by an
201+
// Identity op and, in the case where bfloat16 type is used, Cast op may be
202+
// added right after the input.
141203
//
142204
// TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
143205
// Case, While) ops and Caller argument values.
144206
// TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
145207
// inputs.
146-
llvm::Optional<StringRef> GetXlaShardingFromRetval(const Value& value) {
208+
llvm::Optional<StringRef> GetXlaShardingFromRetval(Value value) {
147209
llvm::SmallPtrSet<Value, 4> visited_values;
148210
Value value_to_visit = value;
149211
while (value_to_visit) {
@@ -172,34 +234,58 @@ llvm::Optional<StringRef> GetXlaShardingFromRetval(const Value& value) {
172234
return llvm::None;
173235
}
174236

175-
// Parses XlaSharding op directly connected from the outputs of the
176-
// `cluster_func` and extract sharding configurations for outputs.
237+
// Extracts sharding configurations for all outputs by parsing XlaSharding/
238+
// TPUPartitionedOutput op connected to the retvals/results.
177239
void IdentifyXlaShardingForComputationOutputs(
178-
StringRef logical_core_0_sharding, FuncOp func,
179-
tf_device::ClusterFuncOp cluster_func, Builder* builder) {
180-
// By default return values from logical core 0 is used if no sharding
181-
// configuration is defined.
240+
StringRef logical_core_0_sharding, bool use_spmd,
241+
tf_device::ClusterFuncOp cluster_func, FuncOp func, Builder* builder) {
182242
Block& function_block = func.front();
183243
Operation* terminator = function_block.getTerminator();
184-
llvm::SmallVector<llvm::StringRef, 8> sharding_for_rets(
185-
terminator->getNumOperands(), logical_core_0_sharding);
244+
llvm::SmallVector<llvm::StringRef, 8> sharding_for_rets;
245+
sharding_for_rets.reserve(terminator->getNumOperands());
186246

247+
// Iterate through results of `cluster_func`. For output ops, look for
248+
// TPUPartitionedOutput ops.
249+
//
250+
// Replicate sharding is used if `use_spmd` is set.
251+
//
187252
// Iterate through operands of the terminator. If the preceding op is
188253
// XlaShardingOp, then the provided sharding configuration is added to the
189254
// tf_device.ClusterFunc as an attribute and the function as a result
190255
// attribute.
191-
for (auto& ret : terminator->getOpOperands()) {
192-
auto ret_sharding = GetXlaShardingFromRetval(ret.get());
193-
const int index = ret.getOperandNumber();
256+
for (auto result_and_retval :
257+
llvm::zip(cluster_func.results(), terminator->getOpOperands())) {
258+
Value result = std::get<0>(result_and_retval);
259+
OpOperand& retval = std::get<1>(result_and_retval);
260+
const int index = retval.getOperandNumber();
261+
262+
if (auto result_sharding = GetXlaShardingFromResult(result)) {
263+
sharding_for_rets.push_back(result_sharding.getValue());
264+
func.setResultAttr(index, kShardingAttr,
265+
builder->getStringAttr(result_sharding.getValue()));
266+
continue;
267+
}
194268

195-
if (ret_sharding) {
196-
sharding_for_rets[index] = ret_sharding.getValue();
269+
if (use_spmd) {
270+
// If XLA SPMD is enabled, outputs all should have replicate sharding,
271+
// unless another sharding is set via a TPUPartitionedOutput op.
272+
sharding_for_rets.push_back(kReplicateSharding);
197273
func.setResultAttr(index, kShardingAttr,
198-
builder->getStringAttr(ret_sharding.getValue()));
199-
} else {
274+
builder->getStringAttr(kReplicateSharding));
275+
continue;
276+
}
277+
278+
if (auto retval_sharding = GetXlaShardingFromRetval(retval.get())) {
279+
sharding_for_rets.push_back(retval_sharding.getValue());
200280
func.setResultAttr(index, kShardingAttr,
201-
builder->getStringAttr(logical_core_0_sharding));
281+
builder->getStringAttr(retval_sharding.getValue()));
282+
continue;
202283
}
284+
285+
// Default to maximal sharding core 0 if no sharding is present.
286+
sharding_for_rets.push_back(logical_core_0_sharding);
287+
func.setResultAttr(index, kShardingAttr,
288+
builder->getStringAttr(logical_core_0_sharding));
203289
}
204290

205291
cluster_func->setAttr(tensorflow::kOutputShardingAttr,
@@ -219,11 +305,16 @@ void IdentifyXlaShardingForTPUComputation(
219305
const std::string logical_core_0_sharding =
220306
xla::sharding_builder::AssignDevice(0).SerializeAsString();
221307

222-
IdentifyXlaShardingForComputationInputs(logical_core_0_sharding, cluster_func,
223-
func, builder);
308+
bool use_spmd = false;
309+
if (auto use_spmd_attr =
310+
cluster_func.getAttrOfType<BoolAttr>("use_spmd_for_xla_partitioning"))
311+
use_spmd = use_spmd_attr.getValue();
312+
313+
IdentifyXlaShardingForComputationInputs(logical_core_0_sharding, use_spmd,
314+
cluster_func, func, builder);
224315

225-
IdentifyXlaShardingForComputationOutputs(logical_core_0_sharding, func,
226-
cluster_func, builder);
316+
IdentifyXlaShardingForComputationOutputs(logical_core_0_sharding, use_spmd,
317+
cluster_func, func, builder);
227318
}
228319

229320
void TPUShardingIdentificationPass::runOnOperation() {

0 commit comments

Comments
 (0)