@@ -192,6 +192,144 @@ TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) {
192192 EXPECT_EQ (computation->root_instruction ()->operand (0 ), b);
193193}
194194
195+ TEST_F (ReducePrecisionInsertionTest, SkipRedundantReducePrecision) {
196+ auto builder = HloComputation::Builder (TestName ());
197+ Shape shape = ShapeUtil::MakeShape (F32, {4 });
198+ HloInstruction* x =
199+ builder.AddInstruction (HloInstruction::CreateParameter (0 , shape, " x" ));
200+ HloInstruction* y = builder.AddInstruction (
201+ HloInstruction::CreateReducePrecision (shape, x, 5 , 10 ));
202+
203+ auto module = CreateNewModule ();
204+ auto computation = module ->AddEntryComputation (builder.Build ());
205+
206+ // Confirm expected graph before adding ops.
207+ EXPECT_THAT (x->users (), UnorderedElementsAre (y));
208+ EXPECT_EQ (computation->root_instruction (), y);
209+
210+ // Since the new reduce-precision operation would be redundant, this
211+ // should not change the graph.
212+ EXPECT_FALSE (
213+ InsertOps (module .get (), HloReducePrecisionOptions::BEFORE_OP_FUSION,
214+ [](const HloInstruction* instruction) {
215+ return instruction->opcode () == HloOpcode::kParameter ;
216+ }));
217+
218+ // Confirm that graph has not changed.
219+ EXPECT_THAT (x->users (), UnorderedElementsAre (y));
220+ EXPECT_EQ (computation->root_instruction (), y);
221+ }
222+
223+ TEST_F (ReducePrecisionInsertionTest, AddNonRedundantReducePrecision) {
224+ auto builder = HloComputation::Builder (TestName ());
225+ Shape shape = ShapeUtil::MakeShape (F32, {4 });
226+ HloInstruction* x =
227+ builder.AddInstruction (HloInstruction::CreateParameter (0 , shape, " x" ));
228+ HloInstruction* y = builder.AddInstruction (
229+ HloInstruction::CreateReducePrecision (shape, x, 8 , 23 ));
230+
231+ auto module = CreateNewModule ();
232+ auto computation = module ->AddEntryComputation (builder.Build ());
233+
234+ // Confirm expected graph before adding ops.
235+ EXPECT_THAT (x->users (), UnorderedElementsAre (y));
236+ EXPECT_EQ (computation->root_instruction (), y);
237+
238+ // Since the new reduce-precision operation is not the same as the existing
239+ // one, this should add a new one.
240+ EXPECT_TRUE (InsertOps (module .get (),
241+ HloReducePrecisionOptions::BEFORE_OP_FUSION,
242+ [](const HloInstruction* instruction) {
243+ return instruction->opcode () == HloOpcode::kParameter ;
244+ }));
245+
246+ // Confirm that graph is as expected.
247+ EXPECT_EQ (computation->root_instruction (), y);
248+ EXPECT_THAT (y->operand (0 ), op::ReducePrecision (x));
249+ }
250+
251+ TEST_F (ReducePrecisionInsertionTest, IgnoreOpsInsideFusionNode) {
252+ auto builder = HloComputation::Builder (TestName ());
253+ Shape shape = ShapeUtil::MakeShape (F32, {4 });
254+ HloInstruction* x =
255+ builder.AddInstruction (HloInstruction::CreateParameter (0 , shape, " x" ));
256+ HloInstruction* y = builder.AddInstruction (
257+ HloInstruction::CreateUnary (shape, HloOpcode::kCos , x));
258+ auto module = CreateNewModule ();
259+ auto computation = module ->AddEntryComputation (builder.Build ());
260+
261+ // Manually fuse the kCos operation into a fusion operation.
262+ HloInstruction* z = computation->AddInstruction (HloInstruction::CreateFusion (
263+ shape, HloInstruction::FusionKind::kLoop , y));
264+ EXPECT_IS_OK (computation->ReplaceUsesOfInstruction (y, z));
265+ EXPECT_IS_OK (computation->RemoveInstruction (y));
266+
267+ // Confirm expected graph before adding reduce-precision ops.
268+ EXPECT_THAT (x->users (), UnorderedElementsAre (z));
269+ EXPECT_EQ (computation->root_instruction (), z);
270+ HloInstruction* y_fused = z->fused_expression_root ();
271+ EXPECT_EQ (y_fused->opcode (), HloOpcode::kCos );
272+
273+ // The ReducePrecisionInsertion pass should not see inside the fusion
274+ // operation, so this should not change the graph.
275+ EXPECT_FALSE (InsertOps (module .get (),
276+ HloReducePrecisionOptions::AFTER_OP_FUSION,
277+ [](const HloInstruction* instruction) {
278+ return instruction->opcode () == HloOpcode::kCos ;
279+ }));
280+
281+ // Confirm that graph has not changed.
282+ EXPECT_THAT (x->users (), UnorderedElementsAre (z));
283+ EXPECT_EQ (computation->root_instruction (), z);
284+ EXPECT_EQ (z->fused_expression_root (), y_fused);
285+ }
286+
287+ TEST_F (ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) {
288+ auto builder = HloComputation::Builder (TestName ());
289+ Shape shape = ShapeUtil::MakeShape (F32, {4 });
290+ HloInstruction* x =
291+ builder.AddInstruction (HloInstruction::CreateParameter (0 , shape, " x" ));
292+ HloInstruction* y = builder.AddInstruction (
293+ HloInstruction::CreateUnary (shape, HloOpcode::kCos , x));
294+ auto module = CreateNewModule ();
295+ auto computation = module ->AddEntryComputation (builder.Build ());
296+
297+ // Manually fuse the kCos operation into a fusion operation.
298+ HloInstruction* z = computation->AddInstruction (HloInstruction::CreateFusion (
299+ shape, HloInstruction::FusionKind::kLoop , y));
300+ EXPECT_IS_OK (computation->ReplaceUsesOfInstruction (y, z));
301+ EXPECT_IS_OK (computation->RemoveInstruction (y));
302+
303+ // Confirm expected graph before adding reduce-precision ops.
304+ EXPECT_THAT (x->users (), UnorderedElementsAre (z));
305+ EXPECT_EQ (computation->root_instruction (), z);
306+ HloInstruction* y_fused = z->fused_expression_root ();
307+ EXPECT_EQ (y_fused->opcode (), HloOpcode::kCos );
308+
309+ // This should see that the fusion computation contains a kCos operation,
310+ // and insert a new reduce-precision node at its root.
311+ EXPECT_TRUE (InsertOps (module .get (),
312+ HloReducePrecisionOptions::FUSION_BY_CONTENT,
313+ [](const HloInstruction* instruction) {
314+ return instruction->opcode () == HloOpcode::kCos ;
315+ }));
316+
317+ // This should refuse to insert a second reduce-precision operation, as
318+ // it would be redundant with the first.
319+ EXPECT_FALSE (InsertOps (module .get (),
320+ HloReducePrecisionOptions::FUSION_BY_CONTENT,
321+ [](const HloInstruction* instruction) {
322+ return instruction->opcode () == HloOpcode::kCos ;
323+ }));
324+
325+ // Confirm that the top-level computation still only contains the fusion
326+ // instruction, but that the fused computation now has a reduce-precision
327+ // instruction inserted as its root.
328+ EXPECT_THAT (x->users (), UnorderedElementsAre (z));
329+ EXPECT_EQ (computation->root_instruction (), z);
330+ EXPECT_THAT (z->fused_expression_root (), op::ReducePrecision (y_fused));
331+ }
332+
195333TEST_F (ReducePrecisionInsertionTest, MakeFilterFunctionNoSubstrings) {
196334 auto builder = HloComputation::Builder (TestName ());
197335 Shape shape = ShapeUtil::MakeShape (F32, {4 });
0 commit comments