File tree Expand file tree Collapse file tree 6 files changed +25
-1
lines changed
hlo/experimental/auto_sharding Expand file tree Collapse file tree 6 files changed +25
-1
lines changed Original file line number Diff line number Diff line change @@ -17,6 +17,7 @@ package_group(
17
17
packages = xla_internal (["..." ]) + [
18
18
"//xla/hlo/experimental/auto_sharding/..." ,
19
19
"//xla/service/gpu/..." ,
20
+ "//xla/service/spmd/..." ,
20
21
],
21
22
)
22
23
@@ -554,6 +555,7 @@ cc_library(
554
555
name = "auto_sharding_stablehlo_pass" ,
555
556
srcs = ["auto_sharding_stablehlo_pass.cc" ],
556
557
hdrs = ["auto_sharding_stablehlo_pass.h" ],
558
+ compatible_with = get_compatible_with_libtpu_portable (),
557
559
deps = [
558
560
":auto_sharding" ,
559
561
":auto_sharding_option" ,
Original file line number Diff line number Diff line change @@ -193,5 +193,11 @@ void RegisterAutoSharding() {
193
193
/* dialectsDependenciesCallback=*/ &RegisterDialectDependencies);
194
194
}
195
195
196
+ void RegisterAutoShardingIfRegistryEmpty () {
197
+ if (!sdy::AutoPartitionerRegistry::isRegistered ()) {
198
+ RegisterAutoSharding ();
199
+ }
200
+ }
201
+
196
202
} // namespace spmd
197
203
} // namespace xla
Original file line number Diff line number Diff line change @@ -24,6 +24,10 @@ namespace spmd {
24
24
void RegisterDialectDependencies (mlir::DialectRegistry& registry);
25
25
void AddAutoShardingToPipeline (mlir::OpPassManager& pm);
26
26
void RegisterAutoSharding ();
27
+ // Register Alpa auto partitioner in case no other auto partitioner is already
28
+ // registered.
29
+ // TODO(b/431368844): Remove when there is a way for users to register Alpa.
30
+ void RegisterAutoShardingIfRegistryEmpty ();
27
31
} // namespace spmd
28
32
} // namespace xla
29
33
Original file line number Diff line number Diff line change @@ -1717,6 +1717,7 @@ cc_library(
1717
1717
] + xla_internal (["service:export_hlo" ]) + if_google ([
1718
1718
"//xla/hlo/experimental/auto_sharding" ,
1719
1719
"//xla/hlo/experimental/auto_sharding:auto_sharding_option" ,
1720
+ "//xla/hlo/experimental/auto_sharding:auto_sharding_stablehlo_pass" ,
1720
1721
]),
1721
1722
)
1722
1723
Original file line number Diff line number Diff line change @@ -303,6 +303,7 @@ limitations under the License.
303
303
#ifdef PLATFORM_GOOGLE
304
304
#include " xla/hlo/experimental/auto_sharding/auto_sharding.h"
305
305
#include " xla/hlo/experimental/auto_sharding/auto_sharding_option.h"
306
+ #include " xla/hlo/experimental/auto_sharding/auto_sharding_stablehlo_pass.h"
306
307
#endif // PLATFORM_GOOGLE
307
308
308
309
namespace xla {
@@ -644,7 +645,13 @@ absl::Status RunSPMDPasses(
644
645
spmd_pipeline,
645
646
#ifdef PLATFORM_GOOGLE
646
647
[&](HloPassPipeline& pipeline) {
647
- if (auto_sharding) {
648
+ if (!auto_sharding) {
649
+ return ;
650
+ }
651
+ if (hlo_module->config ().use_shardy_partitioner ()) {
652
+ // Register Alpa auto partitioner if registry is empty.
653
+ spmd::RegisterAutoShardingIfRegistryEmpty ();
654
+ } else {
648
655
spmd_pipeline.AddPass <AutoSharding>(
649
656
DefaultAutoShardingOptionFromModuleConfig (hlo_module->config ()),
650
657
alias_info);
Original file line number Diff line number Diff line change @@ -90,6 +90,10 @@ void AddSPMDPasses(
90
90
const HloModuleConfig& config = hlo_module->config ();
91
91
92
92
if (config.use_shardy_partitioner ()) {
93
+ // This will make sure an auto partitioner is registered.
94
+ if (auto_sharding_func.has_value ()) {
95
+ (*auto_sharding_func)(spmd_pipeline);
96
+ }
93
97
spmd_pipeline.AddPass <sdy::ShardyXLA>();
94
98
} else {
95
99
spmd_pipeline.AddPass <HloConstantSplitter>();
You can’t perform that action at this time.
0 commit comments