2727
2828import torch
2929from fbgemm_gpu .permute_pooled_embedding_modules import PermutePooledEmbeddings
30+ from fbgemm_gpu .split_table_batched_embeddings_ops_training import (
31+ DenseTableBatchedEmbeddingBagsCodegen ,
32+ )
3033from tensordict import TensorDict
3134from torch import distributed as dist , nn , Tensor
3235from torch .autograd .profiler import record_function
5053)
5154from torchrec .distributed .sharding .cw_sharding import CwPooledEmbeddingSharding
5255from torchrec .distributed .sharding .dp_sharding import DpPooledEmbeddingSharding
56+ from torchrec .distributed .sharding .dynamic_sharding import (
57+ shards_all_to_all ,
58+ update_state_dict_post_resharding ,
59+ )
5360from torchrec .distributed .sharding .grid_sharding import GridPooledEmbeddingSharding
5461from torchrec .distributed .sharding .rw_sharding import RwPooledEmbeddingSharding
5562from torchrec .distributed .sharding .tw_sharding import TwPooledEmbeddingSharding
@@ -635,14 +642,17 @@ def __init__(
635642 self ._env = env
636643 # output parameters as DTensor in state dict
637644 self ._output_dtensor : bool = env .output_dtensor
638-
639- sharding_type_to_sharding_infos = create_sharding_infos_by_sharding (
640- module ,
641- table_name_to_parameter_sharding ,
642- "embedding_bags." ,
643- fused_params ,
645+ self .sharding_type_to_sharding_infos : Dict [str , List [EmbeddingShardingInfo ]] = (
646+ create_sharding_infos_by_sharding (
647+ module ,
648+ table_name_to_parameter_sharding ,
649+ "embedding_bags." ,
650+ fused_params ,
651+ )
652+ )
653+ self ._sharding_types : List [str ] = list (
654+ self .sharding_type_to_sharding_infos .keys ()
644655 )
645- self ._sharding_types : List [str ] = list (sharding_type_to_sharding_infos .keys ())
646656 self ._embedding_shardings : List [
647657 EmbeddingSharding [
648658 EmbeddingShardingContext ,
@@ -658,7 +668,7 @@ def __init__(
658668 permute_embeddings = True ,
659669 qcomm_codecs_registry = self .qcomm_codecs_registry ,
660670 )
661- for embedding_configs in sharding_type_to_sharding_infos .values ()
671+ for embedding_configs in self . sharding_type_to_sharding_infos .values ()
662672 ]
663673
664674 self ._is_weighted : bool = module .is_weighted ()
@@ -833,7 +843,7 @@ def _pre_load_state_dict_hook(
833843 lookup = lookup .module
834844 lookup .purge ()
835845
836- def _initialize_torch_state (self ) -> None : # noqa
846+ def _initialize_torch_state (self , skip_registering : bool = False ) -> None : # noqa
837847 """
838848 This provides consistency between this class and the EmbeddingBagCollection's
839849 nn.Module API calls (state_dict, named_modules, etc)
@@ -1063,11 +1073,12 @@ def post_state_dict_hook(
10631073 destination_key = f"{ prefix } embedding_bags.{ table_name } .weight"
10641074 destination [destination_key ] = sharded_kvtensor
10651075
1066- self .register_state_dict_pre_hook (self ._pre_state_dict_hook )
1067- self ._register_state_dict_hook (post_state_dict_hook )
1068- self ._register_load_state_dict_pre_hook (
1069- self ._pre_load_state_dict_hook , with_module = True
1070- )
1076+ if not skip_registering :
1077+ self .register_state_dict_pre_hook (self ._pre_state_dict_hook )
1078+ self ._register_state_dict_hook (post_state_dict_hook )
1079+ self ._register_load_state_dict_pre_hook (
1080+ self ._pre_load_state_dict_hook , with_module = True
1081+ )
10711082 self .reset_parameters ()
10721083
10731084 def reset_parameters (self ) -> None :
@@ -1164,6 +1175,7 @@ def _create_output_dist(self) -> None:
11641175 self ._uncombined_embedding_dims .extend (sharding .uncombined_embedding_dims ())
11651176 embedding_shard_metadata .extend (sharding .embedding_shard_metadata ())
11661177 self ._dim_per_key = torch .tensor (self ._embedding_dims , device = self ._device )
1178+
11671179 embedding_shard_offsets : List [int ] = [
11681180 meta .shard_offsets [1 ] if meta is not None else 0
11691181 for meta in embedding_shard_metadata
@@ -1179,6 +1191,38 @@ def _create_output_dist(self) -> None:
11791191 embedding_shard_offsets [i ],
11801192 ),
11811193 )
1194+
1195+ self ._permute_op : PermutePooledEmbeddings = PermutePooledEmbeddings (
1196+ self ._uncombined_embedding_dims , permute_indices , self ._device
1197+ )
1198+
1199+ def _update_output_dist (self ) -> None :
1200+ embedding_shard_metadata : List [Optional [ShardMetadata ]] = []
1201+ # TODO: Optimize to only go through embedding shardings with new ranks
1202+ self ._output_dists : List [nn .Module ] = []
1203+ self ._embedding_names : List [str ] = []
1204+ for sharding in self ._embedding_shardings :
1205+ # TODO: if sharding type of table completely changes, need to regenerate everything
1206+ self ._embedding_names .extend (sharding .embedding_names ())
1207+ self ._output_dists .append (sharding .create_output_dist (device = self ._device ))
1208+ embedding_shard_metadata .extend (sharding .embedding_shard_metadata ())
1209+
1210+ embedding_shard_offsets : List [int ] = [
1211+ meta .shard_offsets [1 ] if meta is not None else 0
1212+ for meta in embedding_shard_metadata
1213+ ]
1214+ embedding_name_order : Dict [str , int ] = {}
1215+ for i , name in enumerate (self ._uncombined_embedding_names ):
1216+ embedding_name_order .setdefault (name , i )
1217+
1218+ permute_indices = sorted (
1219+ range (len (self ._uncombined_embedding_names )),
1220+ key = lambda i : (
1221+ embedding_name_order [self ._uncombined_embedding_names [i ]],
1222+ embedding_shard_offsets [i ],
1223+ ),
1224+ )
1225+
11821226 self ._permute_op : PermutePooledEmbeddings = PermutePooledEmbeddings (
11831227 self ._uncombined_embedding_dims , permute_indices , self ._device
11841228 )
@@ -1396,13 +1440,119 @@ def compute_and_output_dist(
13961440
13971441 return awaitable
13981442
1443+ def update_shards (
1444+ self ,
1445+ changed_sharding_params : Dict [str , ParameterSharding ], # NOTE: only delta
1446+ env : ShardingEnv ,
1447+ device : Optional [torch .device ],
1448+ ) -> None :
1449+ """
1450+ Update shards for this module based on the changed_sharding_params. This will:
1451+ 1. Move current lookup tensors to CPU
1452+ 2. Purge lookups
1453+ 3. Call shards_all_2_all containing collective to redistribute tensors
1454+ 4. Update state_dict and other attributes to reflect new placements and shards
1455+ 5. Create new lookups, and load in updated state_dict
1456+
1457+ Args:
1458+ changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping
1459+ table names to their new parameter sharding configs. This should only
1460+ contain shards/table names that need to be moved.
1461+ env (ShardingEnv): The sharding environment for the module.
1462+ device (Optional[torch.device]): The device to place the updated module on.
1463+ """
1464+
1465+ if env .output_dtensor :
1466+ raise RuntimeError ("We do not yet support DTensor for resharding yet" )
1467+ return
1468+
1469+ current_state = self .state_dict ()
1470+ # TODO: Save Optimizers
1471+
1472+ saved_weights = {}
1473+ # TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
1474+ for i , lookup in enumerate (self ._lookups ):
1475+ for attribute , tbe_module in lookup .named_modules ():
1476+ if type (tbe_module ) is DenseTableBatchedEmbeddingBagsCodegen :
1477+ saved_weights [str (i ) + "." + attribute ] = tbe_module .weights .cpu ()
1478+ # Note: lookup.purge should delete tbe_module and weights
1479+ # del tbe_module.weights
1480+ # del tbe_module
1481+ # pyre-ignore
1482+ lookup .purge ()
1483+
1484+ # Deleting all lookups
1485+ self ._lookups .clear ()
1486+
1487+ local_shard_names_by_src_rank , local_output_tensor = shards_all_to_all (
1488+ module = self ,
1489+ state_dict = current_state ,
1490+ device = device , # pyre-ignore
1491+ changed_sharding_params = changed_sharding_params ,
1492+ env = env ,
1493+ extend_shard_name = self .extend_shard_name ,
1494+ )
1495+
1496+ current_state = update_state_dict_post_resharding (
1497+ state_dict = current_state ,
1498+ shard_names_by_src_rank = local_shard_names_by_src_rank ,
1499+ output_tensor = local_output_tensor ,
1500+ new_sharding_params = changed_sharding_params ,
1501+ curr_rank = dist .get_rank (),
1502+ extend_shard_name = self .extend_shard_name ,
1503+ )
1504+
1505+ for name , param in changed_sharding_params .items ():
1506+ self .module_sharding_plan [name ] = param
1507+ # TODO: Support detecting old sharding type when sharding type is changing
1508+ for sharding_info in self .sharding_type_to_sharding_infos [
1509+ param .sharding_type
1510+ ]:
1511+ if sharding_info .embedding_config .name == name :
1512+ sharding_info .param_sharding = param
1513+
1514+ self ._sharding_types : List [str ] = list (
1515+ self .sharding_type_to_sharding_infos .keys ()
1516+ )
1517+ # TODO: Optimize to update only the changed embedding shardings
1518+ self ._embedding_shardings : List [
1519+ EmbeddingSharding [
1520+ EmbeddingShardingContext ,
1521+ KeyedJaggedTensor ,
1522+ torch .Tensor ,
1523+ torch .Tensor ,
1524+ ]
1525+ ] = [
1526+ create_embedding_bag_sharding (
1527+ embedding_configs ,
1528+ env ,
1529+ device ,
1530+ permute_embeddings = True ,
1531+ qcomm_codecs_registry = self .qcomm_codecs_registry ,
1532+ )
1533+ for embedding_configs in self .sharding_type_to_sharding_infos .values ()
1534+ ]
1535+
1536+ self ._create_lookups ()
1537+ self ._update_output_dist ()
1538+
1539+ if env .process_group and dist .get_backend (env .process_group ) != "fake" :
1540+ self ._initialize_torch_state (skip_registering = True )
1541+
1542+ self .load_state_dict (current_state )
1543+ return
1544+
13991545 @property
14001546 def fused_optimizer (self ) -> KeyedOptimizer :
14011547 return self ._optim
14021548
14031549 def create_context (self ) -> EmbeddingBagCollectionContext :
14041550 return EmbeddingBagCollectionContext ()
14051551
1552+ @staticmethod
1553+ def extend_shard_name (shard_name : str ) -> str :
1554+ return f"embedding_bags.{ shard_name } .weight"
1555+
14061556
14071557class EmbeddingBagCollectionSharder (BaseEmbeddingSharder [EmbeddingBagCollection ]):
14081558 """
@@ -1435,6 +1585,33 @@ def shardable_parameters(
14351585 for name , param in module .embedding_bags .named_parameters ()
14361586 }
14371587
1588+ def reshard (
1589+ self ,
1590+ sharded_module : ShardedEmbeddingBagCollection ,
1591+ changed_shard_to_params : Dict [str , ParameterSharding ],
1592+ env : ShardingEnv ,
1593+ device : Optional [torch .device ] = None ,
1594+ ) -> ShardedEmbeddingBagCollection :
1595+ """
1596+ Updates the sharded module in place based on the changed_shard_to_params
1597+ which contains the new ParameterSharding with different shard placements.
1598+
1599+ Args:
1600+ sharded_module (ShardedEmbeddingBagCollection): The module to update
1601+ changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping
1602+ table names to their new parameter sharding configs. This should only
1603+ contain shards/table names that need to be moved
1604+ env (ShardingEnv): The sharding environment
1605+ device (Optional[torch.device]): The device to place the updated module on
1606+
1607+ Returns:
1608+ ShardedEmbeddingBagCollection: The updated sharded module
1609+ """
1610+
1611+ if len (changed_shard_to_params ) > 0 :
1612+ sharded_module .update_shards (changed_shard_to_params , env , device )
1613+ return sharded_module
1614+
14381615 @property
14391616 def module_type (self ) -> Type [EmbeddingBagCollection ]:
14401617 return EmbeddingBagCollection
0 commit comments