Skip to content

Commit 97cbf5a

Browse files
ukoxyztensorflower-gardener
authored andcommitted
[XLA:SPMD] Fix pad masking again
PiperOrigin-RevId: 367527318 Change-Id: Ie3599df55d259abff1e1260db849e34a9399c895
1 parent 6f3ce45 commit 97cbf5a

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2681,7 +2681,8 @@ Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) {
26812681
dim->set_padding_high(pd.edge_padding_high());
26822682
dim->set_base_dilation(pd.interior_padding() + 1);
26832683
needs_masking |= hlo->sharding().tile_assignment().dim(i) > 1 &&
2684-
(pd.edge_padding_low() < 0 || pd.edge_padding_high() < 0);
2684+
(pd.edge_padding_low() > 0 || pd.edge_padding_high() > 0 ||
2685+
pd.interior_padding() > 0);
26852686
}
26862687

26872688
auto replicated_rhs = GetPartitionedHlo(hlo->operand(1))

tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2122,7 +2122,7 @@ ENTRY entry {
21222122
op::Concatenate(param0, op::CollectivePermute(op::Slice(param0))));
21232123
auto pad = AllOf(op::Shape("f32[14,131]"),
21242124
op::Pad(after_halo_exchange, op::Constant()));
2125-
EXPECT_THAT(root, op::DynamicSlice(pad, op::Constant(), _));
2125+
EXPECT_THAT(root, op::Select(_, op::DynamicSlice(pad, op::Constant(), _), _));
21262126
}
21272127

21282128
TEST_F(SpmdPartitioningTest, PadAlongPartitionedDimensionWithInteriorPadding) {

0 commit comments

Comments
 (0)