Skip to content

Commit 37d8f7e

Browse files
[XLA] Insert reduce-precision nodes inside tail of fusion nodes, and don't add duplicates.
Reduce-precision instructions can be fused, so when adding them to the output of fusion nodes, we should add them inside the node rather than outside. In addition, since the same node may be selected by multiple passes (especially when we have fusion-by-contents passes with different trigger conditions), we want to avoid inserting redundant reduce-precision instructions. PiperOrigin-RevId: 166526816
1 parent 0236a36 commit 37d8f7e

File tree

2 files changed

+180
-7
lines changed

2 files changed

+180
-7
lines changed

tensorflow/compiler/xla/service/reduce_precision_insertion.cc

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,49 @@ StatusOr<bool> ReducePrecisionInsertion::Run(HloModule* module) {
9393

9494
bool computation_changed = false;
9595
for (auto& instruction : instructions_to_suffix(computation.get())) {
96-
HloInstruction* reduced =
97-
computation->AddInstruction(HloInstruction::CreateReducePrecision(
98-
instruction->shape(), instruction, exponent_bits_,
99-
mantissa_bits_));
100-
TF_RETURN_IF_ERROR(
101-
computation->ReplaceUsesOfInstruction(instruction, reduced));
102-
VLOG(2) << "Inserted new op after instruction: "
96+
VLOG(2) << "Adding reduce-precision operation to output of instruction: "
10397
<< instruction->ToString();
98+
99+
// Check that we haven't already inserted an equivalant reduce-precision
100+
// operation after this instruction.
101+
if (instruction->user_count() == 1) {
102+
HloInstruction* user = instruction->users()[0];
103+
104+
if (user->opcode() == HloOpcode::kReducePrecision &&
105+
user->exponent_bits() == exponent_bits_ &&
106+
user->mantissa_bits() == mantissa_bits_) {
107+
VLOG(2) << "Skipped; instruction already followed by equivalent"
108+
" reduce-precision instruction:"
109+
<< user->ToString();
110+
continue;
111+
}
112+
}
113+
114+
if (instruction->opcode() == HloOpcode::kFusion) {
115+
// Insert the reduce-precision operation as the last operation inside
116+
// the fusion computation.
117+
instruction = instruction->fused_expression_root();
118+
119+
VLOG(2) << "Inserting new operation after existing fusion root: "
120+
<< instruction->ToString();
121+
122+
if (instruction->opcode() == HloOpcode::kReducePrecision &&
123+
instruction->exponent_bits() == exponent_bits_ &&
124+
instruction->mantissa_bits() == mantissa_bits_) {
125+
VLOG(2) << "Skipped; fused computation already ends in equivalent"
126+
" reduce-precision instruction:"
127+
<< instruction->ToString();
128+
continue;
129+
}
130+
}
131+
132+
HloInstruction* reduced = instruction->parent()->AddInstruction(
133+
HloInstruction::CreateReducePrecision(instruction->shape(),
134+
instruction, exponent_bits_,
135+
mantissa_bits_));
136+
137+
TF_RETURN_IF_ERROR(instruction->parent()->ReplaceUsesOfInstruction(
138+
instruction, reduced));
104139
computation_changed = true;
105140
}
106141

tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,144 @@ TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) {
192192
EXPECT_EQ(computation->root_instruction()->operand(0), b);
193193
}
194194

195+
TEST_F(ReducePrecisionInsertionTest, SkipRedundantReducePrecision) {
196+
auto builder = HloComputation::Builder(TestName());
197+
Shape shape = ShapeUtil::MakeShape(F32, {4});
198+
HloInstruction* x =
199+
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
200+
HloInstruction* y = builder.AddInstruction(
201+
HloInstruction::CreateReducePrecision(shape, x, 5, 10));
202+
203+
auto module = CreateNewModule();
204+
auto computation = module->AddEntryComputation(builder.Build());
205+
206+
// Confirm expected graph before adding ops.
207+
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
208+
EXPECT_EQ(computation->root_instruction(), y);
209+
210+
// Since the new reduce-precision operation would be redundant, this
211+
// should not change the graph.
212+
EXPECT_FALSE(
213+
InsertOps(module.get(), HloReducePrecisionOptions::BEFORE_OP_FUSION,
214+
[](const HloInstruction* instruction) {
215+
return instruction->opcode() == HloOpcode::kParameter;
216+
}));
217+
218+
// Confirm that graph has not changed.
219+
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
220+
EXPECT_EQ(computation->root_instruction(), y);
221+
}
222+
223+
TEST_F(ReducePrecisionInsertionTest, AddNonRedundantReducePrecision) {
224+
auto builder = HloComputation::Builder(TestName());
225+
Shape shape = ShapeUtil::MakeShape(F32, {4});
226+
HloInstruction* x =
227+
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
228+
HloInstruction* y = builder.AddInstruction(
229+
HloInstruction::CreateReducePrecision(shape, x, 8, 23));
230+
231+
auto module = CreateNewModule();
232+
auto computation = module->AddEntryComputation(builder.Build());
233+
234+
// Confirm expected graph before adding ops.
235+
EXPECT_THAT(x->users(), UnorderedElementsAre(y));
236+
EXPECT_EQ(computation->root_instruction(), y);
237+
238+
// Since the new reduce-precision operation is not the same as the existing
239+
// one, this should add a new one.
240+
EXPECT_TRUE(InsertOps(module.get(),
241+
HloReducePrecisionOptions::BEFORE_OP_FUSION,
242+
[](const HloInstruction* instruction) {
243+
return instruction->opcode() == HloOpcode::kParameter;
244+
}));
245+
246+
// Confirm that graph is as expected.
247+
EXPECT_EQ(computation->root_instruction(), y);
248+
EXPECT_THAT(y->operand(0), op::ReducePrecision(x));
249+
}
250+
251+
TEST_F(ReducePrecisionInsertionTest, IgnoreOpsInsideFusionNode) {
252+
auto builder = HloComputation::Builder(TestName());
253+
Shape shape = ShapeUtil::MakeShape(F32, {4});
254+
HloInstruction* x =
255+
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
256+
HloInstruction* y = builder.AddInstruction(
257+
HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
258+
auto module = CreateNewModule();
259+
auto computation = module->AddEntryComputation(builder.Build());
260+
261+
// Manually fuse the kCos operation into a fusion operation.
262+
HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion(
263+
shape, HloInstruction::FusionKind::kLoop, y));
264+
EXPECT_IS_OK(computation->ReplaceUsesOfInstruction(y, z));
265+
EXPECT_IS_OK(computation->RemoveInstruction(y));
266+
267+
// Confirm expected graph before adding reduce-precision ops.
268+
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
269+
EXPECT_EQ(computation->root_instruction(), z);
270+
HloInstruction* y_fused = z->fused_expression_root();
271+
EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos);
272+
273+
// The ReducePrecisionInsertion pass should not see inside the fusion
274+
// operation, so this should not change the graph.
275+
EXPECT_FALSE(InsertOps(module.get(),
276+
HloReducePrecisionOptions::AFTER_OP_FUSION,
277+
[](const HloInstruction* instruction) {
278+
return instruction->opcode() == HloOpcode::kCos;
279+
}));
280+
281+
// Confirm that graph has not changed.
282+
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
283+
EXPECT_EQ(computation->root_instruction(), z);
284+
EXPECT_EQ(z->fused_expression_root(), y_fused);
285+
}
286+
287+
TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) {
288+
auto builder = HloComputation::Builder(TestName());
289+
Shape shape = ShapeUtil::MakeShape(F32, {4});
290+
HloInstruction* x =
291+
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
292+
HloInstruction* y = builder.AddInstruction(
293+
HloInstruction::CreateUnary(shape, HloOpcode::kCos, x));
294+
auto module = CreateNewModule();
295+
auto computation = module->AddEntryComputation(builder.Build());
296+
297+
// Manually fuse the kCos operation into a fusion operation.
298+
HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion(
299+
shape, HloInstruction::FusionKind::kLoop, y));
300+
EXPECT_IS_OK(computation->ReplaceUsesOfInstruction(y, z));
301+
EXPECT_IS_OK(computation->RemoveInstruction(y));
302+
303+
// Confirm expected graph before adding reduce-precision ops.
304+
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
305+
EXPECT_EQ(computation->root_instruction(), z);
306+
HloInstruction* y_fused = z->fused_expression_root();
307+
EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos);
308+
309+
// This should see that the fusion computation contains a kCos operation,
310+
// and insert a new reduce-precision node at its root.
311+
EXPECT_TRUE(InsertOps(module.get(),
312+
HloReducePrecisionOptions::FUSION_BY_CONTENT,
313+
[](const HloInstruction* instruction) {
314+
return instruction->opcode() == HloOpcode::kCos;
315+
}));
316+
317+
// This should refuse to insert a second reduce-precision operation, as
318+
// it would be redundant with the first.
319+
EXPECT_FALSE(InsertOps(module.get(),
320+
HloReducePrecisionOptions::FUSION_BY_CONTENT,
321+
[](const HloInstruction* instruction) {
322+
return instruction->opcode() == HloOpcode::kCos;
323+
}));
324+
325+
// Confirm that the top-level computation still only contains the fusion
326+
// instruction, but that the fused computation now has a reduce-precision
327+
// instruction inserted as its root.
328+
EXPECT_THAT(x->users(), UnorderedElementsAre(z));
329+
EXPECT_EQ(computation->root_instruction(), z);
330+
EXPECT_THAT(z->fused_expression_root(), op::ReducePrecision(y_fused));
331+
}
332+
195333
TEST_F(ReducePrecisionInsertionTest, MakeFilterFunctionNoSubstrings) {
196334
auto builder = HloComputation::Builder(TestName());
197335
Shape shape = ShapeUtil::MakeShape(F32, {4});

0 commit comments

Comments
 (0)