Skip to content

Commit 6976e16

Browse files
andylytensorflower-gardener
authored andcommitted
Update TPUShardingIdentificationPass to walk ops for finding XlaSharding ops from computation results.
It is possible for XlaSharding ops to be in functions and/or feed into passthrough ops like Identity. PiperOrigin-RevId: 337518511 Change-Id: I1c9fd1de30d40c0e938eb0c1cb6a8c026cdd4d2f
1 parent f61013b commit 6976e16

File tree

2 files changed

+80
-28
lines changed

2 files changed

+80
-28
lines changed

tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,28 @@ func @pcall_func_body(%arg0: tensor<*xi1>) -> tensor<i32> {
227227
%2 = "tf.D"(%1) : (tensor<*xi1>) -> (tensor<i32>)
228228
return %2 : tensor<i32>
229229
}
230+
231+
// -----
232+
233+
// Tests that output sharding inside a functional op is parsed correctly.
234+
235+
// CHECK-LABEL: func @check_sharding_inside_functional_op
236+
func @check_sharding_inside_functional_op(%arg0: tensor<*xi32>) {
237+
"tf_device.cluster_func"(%arg0) {func = @cluster_func, step_marker_location = ""} : (tensor<*xi32>) -> tensor<*xi32>
238+
// CHECK: input_sharding_configuration
239+
// CHECK-SAME: ["\01\02\03"]
240+
// CHECK: output_sharding_configuration
241+
// CHECK-SAME: ["\01\02\03"]
242+
return
243+
}
244+
245+
func @cluster_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
246+
%0 = "tf.PartitionedCall"(%arg0) {f= @func_body, config="", config_proto="", executor_type=""} : (tensor<*xi32>) -> tensor<*xi32>
247+
return %0 : tensor<*xi32>
248+
}
249+
250+
func @func_body(%arg0: tensor<*xi32>)-> tensor<*xi32> {
251+
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
252+
%1 = "tf.Identity"(%0) : (tensor<*xi32>) -> (tensor<*xi32>)
253+
return %1 : tensor<*xi32>
254+
}

tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,17 @@ struct TPUShardingIdentificationPass
4848
void runOnOperation() override;
4949
};
5050

51-
// Finds XlaSharding op connected to a value. If value is a resource type then
52-
// XlaSharding op will be connected to a ReadVariable op. XlaSharding op may be
53-
// direct user of inputs but it may also be followed by an Identity op and, in
54-
// the case where bfloat16 type is used, Cast op may be added right after the
55-
// input.
51+
// Finds XlaSharding op connected to an argument value. If value is a resource
52+
// type then XlaSharding op will be connected to a ReadVariable op. XlaSharding
53+
// op may be direct user of inputs but it may also be followed by an Identity op
54+
// and, in the case where bfloat16 type is used, Cast op may be added right
55+
// after the input.
5656
//
5757
// TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
5858
// Case, While) ops and Caller return values.
5959
// TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
6060
// inputs.
61-
llvm::Optional<llvm::StringRef> GetXlaSharding(const Value& value) {
61+
llvm::Optional<llvm::StringRef> GetXlaShardingFromArg(const Value& value) {
6262
llvm::SmallPtrSet<Value, 4> visited_values;
6363
llvm::SmallVector<Value, 4> values_to_visit{value};
6464
while (!values_to_visit.empty()) {
@@ -67,16 +67,16 @@ llvm::Optional<llvm::StringRef> GetXlaSharding(const Value& value) {
6767
if (!visited_values.insert(value_to_visit).second) continue;
6868

6969
for (auto& use : value_to_visit.getUses()) {
70-
if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(use.getOwner()))
70+
Operation* owner = use.getOwner();
71+
if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(owner))
7172
return sharding._XlaSharding();
7273

73-
if (llvm::isa<TF::IdentityOp, TF::CastOp, TF::ReadVariableOp>(
74-
use.getOwner())) {
74+
if (llvm::isa<TF::IdentityOp, TF::CastOp, TF::ReadVariableOp>(owner)) {
7575
next_values_to_visit.push_back(use.getOwner()->getResult(0));
7676
continue;
7777
}
7878

79-
if (auto call_op = llvm::dyn_cast<CallOpInterface>(use.getOwner())) {
79+
if (auto call_op = llvm::dyn_cast<CallOpInterface>(owner)) {
8080
FuncOp func = llvm::dyn_cast<FuncOp>(call_op.resolveCallable());
8181
if (!func) continue;
8282
next_values_to_visit.push_back(
@@ -91,18 +91,6 @@ llvm::Optional<llvm::StringRef> GetXlaSharding(const Value& value) {
9191
return llvm::None;
9292
}
9393

94-
// Returns the provided sharding configuration if operand of return value of
95-
// tf_device.ClusterFunc op is directly from XlaSharding op,
96-
llvm::Optional<StringRef> ParseReturnValueSharding(FuncOp func,
97-
const int output_index,
98-
const OpOperand& operand) {
99-
if (auto sharding_op = llvm::dyn_cast_or_null<TF::XlaShardingOp>(
100-
operand.get().getDefiningOp()))
101-
return sharding_op._XlaSharding();
102-
103-
return llvm::Optional<StringRef>();
104-
}
105-
10694
// Walks the graph from the arguments of the `cluster_func_op` and extracts
10795
// sharding configurations for all inputs by parsing XlaSharding op connected to
10896
// the arguments. If argument to the `cluster_func_op` directly feeds into
@@ -128,17 +116,17 @@ void IdentifyXlaShardingForComputationInputs(
128116
// Sharding configurations are added to the tf_device.ClusterFunc as an
129117
// attribute and the function as an argument attribute.
130118
for (auto& arg : cluster_function_block.getArguments()) {
131-
auto arg_sharding = GetXlaSharding(arg);
132-
const int arg_index = arg.getArgNumber();
119+
auto arg_sharding = GetXlaShardingFromArg(arg);
120+
const int index = arg.getArgNumber();
133121

134122
if (arg_sharding) {
135-
sharding_for_args[arg_index] = arg_sharding.getValue();
123+
sharding_for_args[index] = arg_sharding.getValue();
136124
cluster_function.setArgAttr(
137-
arg_index, kShardingAttr,
125+
index, kShardingAttr,
138126
builder->getStringAttr(arg_sharding.getValue()));
139127
} else {
140128
cluster_function.setArgAttr(
141-
arg_index, kShardingAttr,
129+
index, kShardingAttr,
142130
builder->getStringAttr(logical_core_0_sharding));
143131
}
144132
}
@@ -147,6 +135,44 @@ void IdentifyXlaShardingForComputationInputs(
147135
builder->getStrArrayAttr(sharding_for_args));
148136
}
149137

138+
// Finds XlaSharding op connected to a result value. XlaSharding op may be
139+
// direct user of inputs but it may also be followed by an Identity op and, in
140+
// the case where bfloat16 type is used, Cast op may be added right after the
141+
// input.
142+
//
143+
// TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
144+
// Case, While) ops and Caller argument values.
145+
// TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
146+
// inputs.
147+
llvm::Optional<StringRef> GetXlaShardingFromRetval(const Value& value) {
148+
llvm::SmallPtrSet<Value, 4> visited_values;
149+
Value value_to_visit = value;
150+
while (value_to_visit) {
151+
if (!visited_values.insert(value_to_visit).second) return llvm::None;
152+
153+
Operation* def = value_to_visit.getDefiningOp();
154+
if (auto sharding = llvm::dyn_cast_or_null<TF::XlaShardingOp>(def))
155+
return sharding._XlaSharding();
156+
157+
if (llvm::isa_and_nonnull<TF::IdentityOp, TF::CastOp>(def)) {
158+
value_to_visit = def->getOperand(0);
159+
continue;
160+
}
161+
162+
if (auto call_op = llvm::dyn_cast_or_null<CallOpInterface>(def)) {
163+
FuncOp func = llvm::dyn_cast<FuncOp>(call_op.resolveCallable());
164+
if (!func) continue;
165+
value_to_visit = func.front().getTerminator()->getOperand(
166+
value_to_visit.cast<OpResult>().getResultNumber());
167+
continue;
168+
}
169+
170+
break;
171+
}
172+
173+
return llvm::None;
174+
}
175+
150176
// Parses XlaSharding op directly connected from the outputs of the
151177
// `cluster_func` and extract sharding configurations for outputs.
152178
void IdentifyXlaShardingForComputationOutputs(
@@ -164,8 +190,8 @@ void IdentifyXlaShardingForComputationOutputs(
164190
// tf_device.ClusterFunc as an attribute and the function as a result
165191
// attribute.
166192
for (auto& ret : terminator->getOpOperands()) {
193+
auto ret_sharding = GetXlaShardingFromRetval(ret.get());
167194
const int index = ret.getOperandNumber();
168-
auto ret_sharding = ParseReturnValueSharding(func, index, ret);
169195

170196
if (ret_sharding) {
171197
sharding_for_rets[index] = ret_sharding.getValue();
@@ -176,6 +202,7 @@ void IdentifyXlaShardingForComputationOutputs(
176202
builder->getStringAttr(logical_core_0_sharding));
177203
}
178204
}
205+
179206
cluster_func.setAttr(tensorflow::kOutputShardingAttr,
180207
builder->getStrArrayAttr(sharding_for_rets));
181208
}

0 commit comments

Comments
 (0)