Skip to content

Commit a1300a0

Browse files
tomnatan30tensorflower-gardener
authored andcommitted
#sdy If auto partitioning is enabled and there is no registered auto partitioner, register Alpa as the default.
PiperOrigin-RevId: 782366951
1 parent e160cc1 commit a1300a0

File tree

6 files changed

+25
-1
lines changed

6 files changed

+25
-1
lines changed

third_party/xla/xla/hlo/experimental/auto_sharding/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package_group(
1717
packages = xla_internal(["..."]) + [
1818
"//xla/hlo/experimental/auto_sharding/...",
1919
"//xla/service/gpu/...",
20+
"//xla/service/spmd/...",
2021
],
2122
)
2223

@@ -554,6 +555,7 @@ cc_library(
554555
name = "auto_sharding_stablehlo_pass",
555556
srcs = ["auto_sharding_stablehlo_pass.cc"],
556557
hdrs = ["auto_sharding_stablehlo_pass.h"],
558+
compatible_with = get_compatible_with_libtpu_portable(),
557559
deps = [
558560
":auto_sharding",
559561
":auto_sharding_option",

third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_stablehlo_pass.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,5 +193,11 @@ void RegisterAutoSharding() {
193193
/*dialectsDependenciesCallback=*/&RegisterDialectDependencies);
194194
}
195195

196+
void RegisterAutoShardingIfRegistryEmpty() {
197+
if (!sdy::AutoPartitionerRegistry::isRegistered()) {
198+
RegisterAutoSharding();
199+
}
200+
}
201+
196202
} // namespace spmd
197203
} // namespace xla

third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_stablehlo_pass.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ namespace spmd {
2424
void RegisterDialectDependencies(mlir::DialectRegistry& registry);
2525
void AddAutoShardingToPipeline(mlir::OpPassManager& pm);
2626
void RegisterAutoSharding();
27+
// Register Alpa auto partitioner in case no other auto partitioner is already
28+
// registered.
29+
// TODO(b/431368844): Remove when there is a way for users to register Alpa.
30+
void RegisterAutoShardingIfRegistryEmpty();
2731
} // namespace spmd
2832
} // namespace xla
2933

third_party/xla/xla/service/gpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,7 @@ cc_library(
17171717
] + xla_internal(["service:export_hlo"]) + if_google([
17181718
"//xla/hlo/experimental/auto_sharding",
17191719
"//xla/hlo/experimental/auto_sharding:auto_sharding_option",
1720+
"//xla/hlo/experimental/auto_sharding:auto_sharding_stablehlo_pass",
17201721
]),
17211722
)
17221723

third_party/xla/xla/service/gpu/gpu_compiler.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ limitations under the License.
303303
#ifdef PLATFORM_GOOGLE
304304
#include "xla/hlo/experimental/auto_sharding/auto_sharding.h"
305305
#include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h"
306+
#include "xla/hlo/experimental/auto_sharding/auto_sharding_stablehlo_pass.h"
306307
#endif // PLATFORM_GOOGLE
307308

308309
namespace xla {
@@ -644,7 +645,13 @@ absl::Status RunSPMDPasses(
644645
spmd_pipeline,
645646
#ifdef PLATFORM_GOOGLE
646647
[&](HloPassPipeline& pipeline) {
647-
if (auto_sharding) {
648+
if (!auto_sharding) {
649+
return;
650+
}
651+
if (hlo_module->config().use_shardy_partitioner()) {
652+
// Register Alpa auto partitioner if registry is empty.
653+
spmd::RegisterAutoShardingIfRegistryEmpty();
654+
} else {
648655
spmd_pipeline.AddPass<AutoSharding>(
649656
DefaultAutoShardingOptionFromModuleConfig(hlo_module->config()),
650657
alias_info);

third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ void AddSPMDPasses(
9090
const HloModuleConfig& config = hlo_module->config();
9191

9292
if (config.use_shardy_partitioner()) {
93+
// This will make sure an auto partitioner is registered.
94+
if (auto_sharding_func.has_value()) {
95+
(*auto_sharding_func)(spmd_pipeline);
96+
}
9397
spmd_pipeline.AddPass<sdy::ShardyXLA>();
9498
} else {
9599
spmd_pipeline.AddPass<HloConstantSplitter>();

0 commit comments

Comments
 (0)