diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 1aff0ecf6..15fba983b 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -13,6 +13,7 @@ import itertools import logging import tempfile +from collections import defaultdict, OrderedDict from dataclasses import dataclass from typing import ( Any, @@ -65,7 +66,7 @@ ShardMetadata, TensorProperties, ) -from torchrec.distributed.utils import append_prefix +from torchrec.distributed.utils import append_prefix, none_throws from torchrec.modules.embedding_configs import ( data_type_to_sparse_type, pooling_type_to_pooling_mode, @@ -88,6 +89,10 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: ssd_tbe_params: Dict[str, Any] = {} + for table in config.embedding_tables: + if table.zero_collision: + ssd_tbe_params["enable_zero_collision_tbe"] = True + logger.info("Enabling zero collision TBE") # drop the non-ssd tbe fused params ssd_tbe_signature = inspect.signature( SSDTableBatchedEmbeddingBags.__init__ @@ -904,7 +909,7 @@ def __init__( embedding_location = compute_kernel_to_embedding_location(compute_kernel) self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags( - embedding_specs=list(zip(self._local_rows, self._local_cols)), + embedding_specs=list(zip(self._num_embeddings, self._local_cols)), feature_table_map=self._feature_table_map, ssd_cache_location=embedding_location, pooling_mode=PoolingMode.NONE, @@ -926,6 +931,18 @@ def __init__( ) self.init_parameters() + self._enable_zero_collision_tbe: bool = ssd_tbe_params[ + "enable_zero_collision_tbe" + ] + self._tracked_ids: Optional[KeyedJaggedTensor] = None + self._sharded_local_buckets: Optional[List[Tuple[int, int, int]]] = None + if self._enable_zero_collision_tbe: + self._sharded_local_buckets = self.get_sharded_local_buckets() + # temporary tensors auto generated for checkpointing + # once training is resumed and forward is called, these tensors will be reset to None + # since the value can be changed by backward pass, we don't want to duplicate memory + self._split_weights: Optional[List[Dict[str, ShardedTensor]]] = None + def init_parameters(self) -> None: """ An advantage of SSD TBE is that we don't need to init weights. Hence skipping. @@ -963,19 +980,96 @@ def state_dict( # in the case no_snapshot=False, a flush is required. we rely on the flush operation in # ShardedEmbeddingBagCollection._pre_state_dict_hook() - emb_tables = self.split_embedding_weights(no_snapshot=no_snapshot) + emb_tables = self.split_embedding_weights_with_id_buckets( + no_snapshot=no_snapshot + ) + weights = [emb_table[0] for emb_table in emb_tables] emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) for emb_table in emb_table_config_copy: emb_table.local_metadata.placement._device = torch.device("cpu") ret = get_state_dict( emb_table_config_copy, - emb_tables, + weights, self._pg, destination, prefix, ) return ret + def get_sharded_split_tensors( + self, + prefix: str, + emb_table_inx: int, + weight_tensor: torch.Tensor, + bucket_tensor: torch.Tensor, + id_tensor: torch.Tensor, + ) -> Dict[str, Any]: + if not self._enable_zero_collision_tbe: + return {} + + table_config = copy.deepcopy(self._config.embedding_tables[emb_table_inx]) + table_config.local_metadata.placement._device = torch.device("cpu") + ret: Dict[str, Any] = {} + + weight_key = append_prefix(prefix, f"{table_config.name}.weight") + weight_local_metadata = copy.deepcopy(table_config.local_metadata) + weight_local_metadata.shard_sizes = list(weight_tensor.size()) + weight_local_shards = [Shard(weight_tensor, weight_local_metadata)] + weight_global_size = ( + self._num_embeddings[emb_table_inx], + self._local_cols[emb_table_inx], + ) + if self._pg is not None: + ret[weight_key] = ShardedTensor._init_from_local_shards_and_reset_offsets( + weight_local_shards, + weight_global_size, + process_group=self._pg, + ) + + id_key = append_prefix(prefix, f"{table_config.name}.weight_id") + id_local_metadata = copy.deepcopy(table_config.local_metadata) + id_local_metadata.shard_offsets[1] = 0 + id_local_metadata.shard_sizes = list(id_tensor.size()) # one column tensor + id_local_shards = [Shard(id_tensor, id_local_metadata)] + id_global_size = (self._num_embeddings[emb_table_inx], 1) + if self._pg is not None: + ret[id_key] = ShardedTensor._init_from_local_shards_and_reset_offsets( + id_local_shards, + id_global_size, + process_group=self._pg, + ) + + bucket_key = append_prefix(prefix, f"{table_config.name}.bucket") + bucket_global_metadata = copy.deepcopy(table_config.global_metadata) + bucket_global_metadata.tensor_properties.dtype = torch.int64 + bucket_global_metadata.tensor_properties.requires_grad = False + bucket_global_metadata.size = torch.Size((table_config.total_num_buckets, 1)) + # prototype: assuming even sharding here + bucket_length = self._sharded_local_buckets[emb_table_inx][1] + for j, shard in enumerate(bucket_global_metadata.shards_metadata): + shard.shard_offsets[0] = j * bucket_length + shard.shard_offsets[1] = 0 + shard.shard_sizes = list(bucket_tensor.size()) + bucket_local_metadata = copy.deepcopy(table_config.local_metadata) + bucket_local_metadata.shard_offsets[0] = self._sharded_local_buckets[ + emb_table_inx + ][0] + bucket_local_metadata.shard_offsets[1] = 0 + bucket_local_metadata.shard_sizes[0] = bucket_length + bucket_local_metadata.shard_sizes[1] = 1 + local_shards = [Shard(bucket_tensor, bucket_local_metadata)] + if self._pg is not None: + ret[bucket_key] = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=local_shards, + sharded_tensor_metadata=bucket_global_metadata, + process_group=self._pg, + ) + + logger.info( + f"get_sharded_split_id_bucket_tensors generated two additiona tensors: {ret.keys()}" + ) + return ret + def named_parameters( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, nn.Parameter]]: @@ -1002,22 +1096,62 @@ def named_split_embedding_weights( ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights" for config, tensor in zip( self._config.embedding_tables, - self.split_embedding_weights(), + self.split_embedding_weights_with_id_buckets(), ): + weight_tensor = tensor[0] key = append_prefix(prefix, f"{config.name}.weight") - yield key, tensor + yield key, weight_tensor def get_named_split_embedding_weights_snapshot( self, prefix: str = "" - ) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]: + ) -> Iterator[Tuple[str, Union[PartiallyMaterializedTensor, ShardedTensor]]]: """ Return an iterator over embedding tables, yielding both the table name as well as the embedding table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid RocksDB snapshot to support windowed access. """ - for config, tensor in zip( - self._config.embedding_tables, - self.split_embedding_weights(no_snapshot=False), + if self._enable_zero_collision_tbe: + if self._split_weights is not None: + split_weights = self._split_weights + for splits in split_weights: + for key, tensor in splits.items(): + yield key, tensor + return + else: + self._split_weights = [] + split_weights = [] + + for i, (config, tensors) in enumerate( + zip( + self._config.embedding_tables, + self.split_embedding_weights_with_id_buckets(no_snapshot=False), + ) + ): + weight_tensor = tensors[0] + bucket_tensor = tensors[1] + id_tensor = tensors[2] + + if not self._enable_zero_collision_tbe: + key = append_prefix(prefix, f"{config.name}") + yield key, weight_tensor + else: + if id_tensor is None: + continue + + sharded_tensors = self.get_sharded_split_tensors( + prefix, i, weight_tensor, bucket_tensor, id_tensor + ) + split_weights.append(sharded_tensors) + for key, tensor in sharded_tensors.items(): + yield key, tensor + self._split_weights = split_weights + return + + for config, tensor in enumerate( + zip( + self._config.embedding_tables, + self.split_embedding_weights(no_snapshot=False), + ) ): key = append_prefix(prefix, f"{config.name}") yield key, tensor @@ -1036,12 +1170,168 @@ def purge(self) -> None: self.emb_module.lxu_cache_weights.zero_() self.emb_module.lxu_cache_state.fill_(-1) - # pyre-ignore [15] + # pyre-ignore[15] def split_embedding_weights( self, no_snapshot: bool = True ) -> List[PartiallyMaterializedTensor]: return self.emb_module.split_embedding_weights(no_snapshot) + # TODO: read result from torchrec sharding plan + def get_sharded_local_buckets(self) -> List[Tuple[int, int, int]]: + """ + utils to get bucket offset, bucket length, bucket size based on embedding sharding spec + """ + sharded_local_buckets: List[Tuple[int, int, int]] = [] + world_size = dist.get_world_size(self._pg) + local_rank = dist.get_rank(self._pg) + + for table in self._config.embedding_tables: + # temporary before uneven sharding utils is ready + assert ( + table.num_embeddings % world_size == 0 + ), "total_num_embeddings must be divisible by world_size" + total_num_buckets = none_throws(table.total_num_buckets) + bucket_offset = total_num_buckets // world_size * local_rank + bucket_length = total_num_buckets // world_size + bucket_size = table.num_embeddings // total_num_buckets + sharded_local_buckets.append((bucket_offset, bucket_length, bucket_size)) + logger.info( + f"bucket_offset: {bucket_offset}, bucket_length: {bucket_length}, bucket_size: {bucket_size} for table {table.name}" + ) + return sharded_local_buckets + + @torch.jit.export + def split_embedding_weights_with_id_buckets( + self, + no_snapshot: bool = True, + should_flush: bool = False, + ) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]]: + """ + copied from SSDTableBatchedEmbeddingBags.split_embedding_weights + for debugging purpose, otherwise, need to rebuild light package + if change in SSDTableBatchedEmbeddingBags directly + return bucket, id, and weight tensors + """ + if not self._enable_zero_collision_tbe: + return [ + (weight, None, None) + for weight in self.emb_module.split_embedding_weights(no_snapshot) + ] + + # TODO: move the logic to SSD TBE when debugging is done + # Force device synchronize for now + # torch.cuda.synchronize() + # # Create a snapshot + # if no_snapshot: + # snapshot_handle = None + # else: + # if should_flush: + # # Flush L1 and L2 caches + # self.emb_module.flush() + # snapshot_handle = self.emb_module.ssd_db.create_snapshot() + dtype = self.emb_module.weights_precision.as_dtype() + splits = [] + assert ( + len(self._config.embedding_tables) == 1 + ), "only support 1 table in prototype" + if self._tracked_ids is None: + bucket_length = self._sharded_local_buckets[0][1] + bucket_tensor = torch.zeros( + (bucket_length, 1), + dtype=torch.int64, + device=torch.device("cpu"), + ) + splits.append( + ( + torch.empty( + (1, self._local_cols[0]), + device=torch.device("cpu"), + dtype=dtype, + ), + bucket_tensor, + torch.empty((1, 1), device=torch.device("cpu"), dtype=torch.int64), + ) + ) + return splits + + # TODO: to support multiple tables + # 1. split ids per table + # 2. unique and sort ids per table + # 3. linearize ids per table + # 4. query weight with get_cuda + # 5. return sorted ids as weight_id tensor, and queried weight as weight ensor + # when we get ids from embedding backend directly, need to do: + # 1. sort ids + # 2. split ids per table based on table fusion + # 3. query weight with get_cuda using split ids per table + # 4. deduct table offset from split ids + id = self._tracked_ids.values().long().cpu() + sorted_id = torch.unique(id, sorted=True).view(-1, 1) + # test size mismatch: select half of ids + sorted_id, _ = torch.chunk(sorted_id, 2, dim=0) + + bucket_offset = self._sharded_local_buckets[0][0] + bucket_length = self._sharded_local_buckets[0][1] + bucket_size = self._sharded_local_buckets[0][2] + + def get_bucket_tensor( + ids, bucket_offset, bucket_length, bucket_size + ) -> torch.Tensor: + # Step 1: Compute bucket index for each id + ids = ids.flatten() + bucket_ids = ids // bucket_size + + # Step 2: Verify bucket range + min_bucket = bucket_offset + max_bucket = bucket_offset + bucket_length + assert torch.all( + (bucket_ids >= min_bucket) & (bucket_ids < max_bucket) + ), f"Some IDs fall outside the expected bucket range [{min_bucket}, {max_bucket})" + + # Step 3: Normalize bucket indices to 0-based range + norm_bucket_ids = bucket_ids - bucket_offset + + # Step 4: Count occurrences + counts = torch.bincount(norm_bucket_ids, minlength=bucket_length) + + # Step 5: Return as 2D tensor [bucket_length, 1] + return counts.view(-1, 1) + + for i, _ in enumerate(self._config.embedding_tables): + bucket_tensor = get_bucket_tensor( + sorted_id, bucket_offset, bucket_length, bucket_size + ) + + # get weight tensor from tracked global id + weight_tensor = torch.empty( + (sorted_id.size(0), self.emb_module.max_D), + dtype=dtype, + ) + # this row throws, comments out to unblock downstream works + self.emb_module.ssd_db.get_cuda( + sorted_id.to(torch.int64), + weight_tensor, + torch.as_tensor(sorted_id.size(0)), + ) + splits.append((weight_tensor, bucket_tensor, sorted_id)) + return splits + + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + if self._enable_zero_collision_tbe: + # track the last access ids for testing purpose + self._tracked_ids = KeyedJaggedTensor.from_lengths_sync( + keys=features.keys().copy(), + values=features.values().clone(), + lengths=features.lengths().clone(), + ) + # reset split weights during training + self._split_weights = None + + return self.emb_module( + indices=features.values().long(), + offsets=features.offsets().long(), + ) + class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule): def __init__( diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 4f79bd6b9..d989456c9 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -259,6 +259,8 @@ def create_sharding_infos_by_sharding( embedding_names=embedding_names, weight_init_max=config.weight_init_max, weight_init_min=config.weight_init_min, + total_num_buckets=config.total_num_buckets, + zero_collision=config.zero_collision, ), param_sharding=parameter_sharding, param=param, @@ -353,6 +355,8 @@ def create_sharding_infos_by_sharding_device_group( embedding_names=embedding_names, weight_init_max=config.weight_init_max, weight_init_min=config.weight_init_min, + total_num_buckets=config.total_num_buckets, + zero_collision=config.zero_collision, ), param_sharding=parameter_sharding, param=param, @@ -767,11 +771,13 @@ def _initialize_torch_state(self) -> None: # noqa ) self._name_to_table_size = {} + table_zero_collision = {} for table in self._embedding_configs: self._name_to_table_size[table.name] = ( table.num_embeddings, table.embedding_dim, ) + table_zero_collision[table.name] = table.zero_collision for sharding_type, lookup in zip( self._sharding_type_to_sharding.keys(), self._lookups @@ -871,8 +877,9 @@ def _initialize_torch_state(self) -> None: # noqa # created ShardedTensors once in init, use in post_state_dict_hook # note: at this point kvstore backed tensors don't own valid snapshots, so no read # access is allowed on them. + # for collision free TBE, the shard sizes should be recalculated during ShardedTensor initilization self._model_parallel_name_to_sharded_tensor[table_name] = ( - ShardedTensor._init_from_local_shards( + ShardedTensor._init_from_local_shards_and_reset_offsets( local_shards, self._name_to_table_size[table_name], process_group=( @@ -925,20 +932,29 @@ def post_state_dict_hook( return sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors) + sharded_id_buckets_state_dict = None for lookup, sharding_type in zip( module._lookups, module._sharding_type_to_sharding.keys() ): if sharding_type != ShardingType.DATA_PARALLEL.value: - # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. - for key, v in lookup.get_named_split_embedding_weights_snapshot(): - assert key in sharded_kvtensors_copy - sharded_kvtensors_copy[key].local_shards()[0].tensor = v + for ( + key, + v, + ) in lookup.get_named_split_embedding_weights_snapshot(): # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + if key in sharded_kvtensors_copy: + sharded_kvtensors_copy[key].local_shards()[0].tensor = v + else: + d_k = f"{prefix}embeddings.{key}" + destination[d_k] = v + logger.info(f"add sharded tensor key {d_k} to state dict") for ( table_name, sharded_kvtensor, ) in sharded_kvtensors_copy.items(): destination_key = f"{prefix}embeddings.{table_name}.weight" destination[destination_key] = sharded_kvtensor + if sharded_id_buckets_state_dict: + destination.update(sharded_id_buckets_state_dict) self.register_state_dict_pre_hook(self._pre_state_dict_hook) self._register_state_dict_hook(post_state_dict_hook) diff --git a/torchrec/distributed/embedding_kernel.py b/torchrec/distributed/embedding_kernel.py index f3bb60619..7ba953d1e 100644 --- a/torchrec/distributed/embedding_kernel.py +++ b/torchrec/distributed/embedding_kernel.py @@ -8,6 +8,7 @@ # pyre-strict import abc +import copy import logging from collections import defaultdict, OrderedDict from typing import Any, Dict, List, Optional, Tuple, Union @@ -103,9 +104,10 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str: qbias = param[2] param = param[0] - assert embedding_table.local_rows == param.size( # pyre-ignore[16] - 0 - ), f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}" # pyre-ignore[16] + if not embedding_table.zero_collision: + assert embedding_table.local_rows == param.size( # pyre-ignore[16] + 0 + ), f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}" # pyre-ignore[16] if qscale is not None: assert embedding_table.local_cols == param.size(1) # pyre-ignore[16] @@ -128,14 +130,17 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str: param.requires_grad # pyre-ignore[16] ) key_to_global_metadata[key] = embedding_table.global_metadata + local_metadata = copy.deepcopy(embedding_table.local_metadata) + local_metadata.shard_sizes = list(param.size()) key_to_local_shards[key].append( # pyre-fixme[6]: For 1st argument expected `Tensor` but got # `Union[Module, Tensor]`. # pyre-fixme[6]: For 2nd argument expected `ShardMetadata` but got # `Optional[ShardMetadata]`. - Shard(param, embedding_table.local_metadata) + Shard(param, local_metadata) ) + else: destination[key] = param if qscale is not None: diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 98fa2d15f..59ebf472d 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -483,6 +483,16 @@ def _prefetch_and_cached( ) +def _is_kv_tbe( + table: ShardedEmbeddingTable, +) -> bool: + """ + Return true if this embedding enabled bucketized sharding for kv style TBE to support ZCH v.Next. + https://docs.google.com/document/d/13atWlDEkrkRulgC_gaoLv8ZsogQefdvsTdwlyam7ed0/edit?tab=t.0#heading=h.lxb1lainm4tc + """ + return table.zero_collision + + def _all_tables_are_quant_kernel( tables: List[ShardedEmbeddingTable], ) -> bool: @@ -558,7 +568,9 @@ def _group_tables_per_rank( table.data_type, ), _prefetch_and_cached(table), + _is_kv_tbe(table), ) + print(f"line debug: grouping_key: {grouping_key}") # micromanage the order of we traverse the groups to ensure backwards compatibility if grouping_key not in groups: grouping_keys.append(grouping_key) @@ -573,6 +585,7 @@ def _group_tables_per_rank( compute_kernel_type, _, _, + is_kv_tbe, ) = grouping_key grouped_tables = groups[grouping_key] # remove non-native fused params @@ -581,6 +594,8 @@ def _group_tables_per_rank( for k, v in fused_params_tuple if k not in ["_batch_key", USE_ONE_TBE_PER_TABLE] } + if is_kv_tbe: + per_tbe_fused_params["enable_zero_collision_tbe"] = True cache_load_factor = _get_weighted_avg_cache_load_factor(grouped_tables) if cache_load_factor is not None: per_tbe_fused_params[CACHE_LOAD_FACTOR_STR] = cache_load_factor diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 1e155b8ad..903615ca9 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -231,6 +231,21 @@ def feature_hash_sizes(self) -> List[int]: feature_hash_sizes.extend(table.num_features() * [table.num_embeddings]) return feature_hash_sizes + def feature_total_num_buckets(self) -> Optional[List[int]]: + feature_total_num_buckets = [] + for table in self.embedding_tables: + if table.total_num_buckets: + feature_total_num_buckets.extend( + table.num_features() * [table.total_num_buckets] + ) + return feature_total_num_buckets if len(feature_total_num_buckets) > 0 else None + + def _is_zero_collision(self) -> bool: + for table in self.embedding_tables: + if table.zero_collision: + return True + return False + def num_features(self) -> int: num_features = 0 for table in self.embedding_tables: diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index cc324d52a..8ff887f98 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -495,6 +495,9 @@ def __init__( ) if device is not None: self._emb_module.initialize_weights() + self._enable_zero_collision_tbe: bool = any( + table.zero_collision for table in config.embedding_tables + ) @property def emb_module( diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index 4029d9aa6..8a550328a 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -126,6 +126,11 @@ def create_input_dist( ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: num_features = self._get_num_features() feature_hash_sizes = self._get_feature_hash_sizes() + is_zero_collision = any( + emb_config._is_zero_collision() + for emb_config in self._grouped_embedding_configs + ) + feature_total_num_buckets = self._get_feature_total_num_buckets() return RwSparseFeaturesDist( # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got # `Optional[ProcessGroup]`. @@ -136,6 +141,8 @@ def create_input_dist( is_sequence=True, has_feature_processor=self._has_feature_processor, need_pos=False, + feature_total_num_buckets=feature_total_num_buckets, + keep_original_indices=is_zero_collision, ) def create_lookup( @@ -265,6 +272,10 @@ def create_input_dist( (emb_sharding, is_even_sharding) = get_embedding_shard_metadata( self._grouped_embedding_configs_per_rank ) + is_zero_collision = any( + emb_config._is_zero_collision() + for emb_config in self._grouped_embedding_configs + ) return InferRwSparseFeaturesDist( world_size=self._world_size, @@ -275,6 +286,7 @@ def create_input_dist( has_feature_processor=self._has_feature_processor, need_pos=False, embedding_shard_metadata=emb_sharding if not is_even_sharding else None, + keep_original_indices=is_zero_collision, ) def create_lookup( diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index b62609da1..cc01b9ba1 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -56,6 +56,7 @@ ShardingType, ShardMetadata, ) +from torchrec.distributed.utils import none_throws from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable @@ -137,6 +138,9 @@ def __init__( [] ) self._grouped_embedding_configs_per_rank = group_tables(sharded_tables_per_rank) + logger.info( + f"self._grouped_embedding_configs_per_rank: {self._grouped_embedding_configs_per_rank}" + ) self._grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( self._grouped_embedding_configs_per_rank[self._rank] ) @@ -146,6 +150,9 @@ def __init__( if group_config.has_feature_processor: self._has_feature_processor = True + for group_config in self._grouped_embedding_configs: + group_config.feature_names + def _shard( self, sharding_infos: List[EmbeddingShardingInfo], @@ -217,6 +224,8 @@ def _shard( weight_init_min=info.embedding_config.weight_init_min, fused_params=info.fused_params, num_embeddings_post_pruning=info.embedding_config.num_embeddings_post_pruning, + total_num_buckets=info.embedding_config.total_num_buckets, + zero_collision=info.embedding_config.zero_collision, ) ) return tables_per_rank @@ -272,6 +281,23 @@ def _get_feature_hash_sizes(self) -> List[int]: feature_hash_sizes.extend(group_config.feature_hash_sizes()) return feature_hash_sizes + def _get_feature_total_num_buckets(self) -> Optional[List[int]]: + feature_total_num_buckets: List[int] = [] + for group_config in self._grouped_embedding_configs: + if group_config.feature_total_num_buckets() is not None: + feature_total_num_buckets.extend( + none_throws(group_config.feature_total_num_buckets()) + ) + return ( + feature_total_num_buckets if len(feature_total_num_buckets) > 0 else None + ) # If no feature_total_num_buckets is provided, we return None to keep backward compatibility. + + def _is_zero_collision(self) -> bool: + for group_config in self._grouped_embedding_configs: + if group_config._is_zero_collision(): + return True + return False + class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]): """ diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index 27b011300..e0374b534 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -361,6 +361,7 @@ def _get_parameter_sharding( sharder: ModuleSharder[nn.Module], placements: Optional[List[str]] = None, compute_kernel: Optional[str] = None, + bucket_offset_sizes: Optional[List[Tuple[int, int]]] = None, ) -> ParameterSharding: return ParameterSharding( sharding_spec=( @@ -371,6 +372,8 @@ def _get_parameter_sharding( ShardMetadata( shard_sizes=size, shard_offsets=offset, + bucket_id_offset=bucket_id_offset, + num_buckets=num_buckets, placement=( placement( device_type, @@ -381,9 +384,17 @@ def _get_parameter_sharding( else device_placement ), ) - for (size, offset, rank), device_placement in zip( + for (size, offset, rank), device_placement, ( + num_buckets, + bucket_id_offset, + ) in zip( size_offset_ranks, placements if placements else [None] * len(size_offset_ranks), + ( + bucket_offset_sizes + if bucket_offset_sizes + else [(None, None)] * len(size_offset_ranks) + ), ) ] ) @@ -512,7 +523,8 @@ def _parameter_sharding_generator( def row_wise( - sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None + sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None, + num_buckets_per_rank: Optional[List[int]] = None, # propagate num buckets per rank ) -> ParameterShardingGenerator: """ Returns a generator of ParameterShardingPlan for `ShardingType::ROW_WISE` for construct_module_sharding_plan. @@ -545,6 +557,7 @@ def _parameter_sharding_generator( device_type: str, sharder: ModuleSharder[nn.Module], ) -> ParameterSharding: + bucket_offset_sizes = None if sizes_placement is None: size_and_offsets = _get_parameter_size_offsets( param, @@ -558,17 +571,34 @@ def _parameter_sharding_generator( size_offset_ranks.append((size, offset, rank)) else: size_offset_ranks = [] + bucket_offset_sizes = None if num_buckets_per_rank is None else [] sizes = sizes_placement[0] + if num_buckets_per_rank is not None: + assert len(sizes) == len( + num_buckets_per_rank + ), f"sizes and num_buckets_per_rank must have the same length during row_wise sharding, got {len(sizes)} and {len(num_buckets_per_rank)} respectively" (rows, cols) = param.shape cur_offset = 0 prev_offset = 0 + prev_bucket_offset = 0 + cur_bucket_offset = 0 for rank, size in enumerate(sizes): per_rank_row = size + per_rank_bucket_size = None + if num_buckets_per_rank is not None: + per_rank_bucket_size = num_buckets_per_rank[rank] + cur_bucket_offset += per_rank_bucket_size cur_offset += per_rank_row cur_offset = min(cur_offset, rows) per_rank_row = cur_offset - prev_offset size_offset_ranks.append(([per_rank_row, cols], [prev_offset, 0], rank)) prev_offset = cur_offset + if num_buckets_per_rank is not None: + # bucket has only one col for now + none_throws(bucket_offset_sizes).append( + (per_rank_bucket_size, prev_bucket_offset) + ) + prev_bucket_offset = cur_bucket_offset if cur_offset < rows: raise ValueError( @@ -601,6 +631,7 @@ def _parameter_sharding_generator( compute_kernel=( EmbeddingComputeKernel.QUANT.value if sizes_placement else None ), + bucket_offset_sizes=bucket_offset_sizes, ) return _parameter_sharding_generator diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py index 5dc18885a..f6c1c568a 100644 --- a/torchrec/distributed/tests/test_sharding_plan.py +++ b/torchrec/distributed/tests/test_sharding_plan.py @@ -816,6 +816,159 @@ def test_row_wise_set_heterogenous_device(self, data_type: DataType) -> None: 0, ) + # pyre-fixme[56] + @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16])) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_row_wise_bucket_level_sharding(self, data_type: DataType) -> None: + + embedding_config = [ + EmbeddingBagConfig( + name=f"table_{idx}", + feature_names=[f"feature_{idx}"], + embedding_dim=64, + num_embeddings=4096, + data_type=data_type, + ) + for idx in range(2) + ] + module_sharding_plan = construct_module_sharding_plan( + EmbeddingCollection(tables=embedding_config), + per_param_sharding={ + "table_0": row_wise( + sizes_placement=( + [2048, 1024, 1024], + ["cpu", "cuda", "cuda"], + ), + num_buckets_per_rank=[20, 30, 40], + ), + "table_1": row_wise( + sizes_placement=([2048, 1024, 1024], ["cpu", "cpu", "cpu"]) + ), + }, + local_size=1, + world_size=2, + device_type="cuda", + ) + + # Make sure per_param_sharding setting override the default device_type + device_table_0_shard_0 = ( + # pyre-ignore[16] + module_sharding_plan["table_0"] + .sharding_spec.shards[0] + .placement + ) + self.assertEqual( + device_table_0_shard_0.device().type, + "cpu", + ) + # cpu always has rank 0 + self.assertEqual( + device_table_0_shard_0.rank(), + 0, + ) + for i in range(1, 3): + device_table_0_shard_i = ( + module_sharding_plan["table_0"].sharding_spec.shards[i].placement + ) + self.assertEqual( + device_table_0_shard_i.device().type, + "cuda", + ) + # first rank is assigned to cpu so index = rank - 1 + self.assertEqual( + device_table_0_shard_i.device().index, + i - 1, + ) + self.assertEqual( + device_table_0_shard_i.rank(), + i, + ) + for i in range(3): + device_table_1_shard_i = ( + module_sharding_plan["table_1"].sharding_spec.shards[i].placement + ) + self.assertEqual( + device_table_1_shard_i.device().type, + "cpu", + ) + # cpu always has rank 0 + self.assertEqual( + device_table_1_shard_i.rank(), + 0, + ) + + expected = { + "table_0": ParameterSharding( + sharding_type="row_wise", + compute_kernel="quant", + ranks=[ + 0, + 1, + 2, + ], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[2048, 64], + placement="rank:0/cpu", + bucket_id_offset=0, + num_buckets=20, + ), + ShardMetadata( + shard_offsets=[2048, 0], + shard_sizes=[1024, 64], + placement="rank:1/cuda:0", + bucket_id_offset=20, + num_buckets=30, + ), + ShardMetadata( + shard_offsets=[3072, 0], + shard_sizes=[1024, 64], + placement="rank:2/cuda:1", + bucket_id_offset=50, + num_buckets=40, + ), + ] + ), + ), + "table_1": ParameterSharding( + sharding_type="row_wise", + compute_kernel="quant", + ranks=[ + 0, + 1, + 2, + ], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[2048, 64], + placement="rank:0/cpu", + bucket_id_offset=None, + num_buckets=None, + ), + ShardMetadata( + shard_offsets=[2048, 0], + shard_sizes=[1024, 64], + placement="rank:0/cpu", + bucket_id_offset=None, + num_buckets=None, + ), + ShardMetadata( + shard_offsets=[3072, 0], + shard_sizes=[1024, 64], + placement="rank:0/cpu", + bucket_id_offset=None, + num_buckets=None, + ), + ] + ), + ), + } + self.assertDictEqual(expected, module_sharding_plan) + # pyre-fixme[56] @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16])) @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) @@ -929,18 +1082,85 @@ def test_str(self) -> None: ) expected = """module: ebc - param | sharding type | compute kernel | ranks + param | sharding type | compute kernel | ranks -------- | ------------- | -------------- | ------ user_id | table_wise | dense | [0] movie_id | row_wise | dense | [0, 1] - param | shard offsets | shard sizes | placement + param | shard offsets | shard sizes | placement -------- | ------------- | ----------- | ------------- user_id | [0, 0] | [4096, 32] | rank:0/cuda:0 movie_id | [0, 0] | [2048, 32] | rank:0/cuda:0 movie_id | [2048, 0] | [2048, 32] | rank:0/cuda:1 +""" + for i in range(len(expected.splitlines())): + self.assertEqual( + expected.splitlines()[i].strip(), str(plan).splitlines()[i].strip() + ) + + def test_str_bucket_wise_sharding(self) -> None: + plan = ShardingPlan( + { + "ebc": EmbeddingModuleShardingPlan( + { + "user_id": ParameterSharding( + sharding_type="table_wise", + compute_kernel="dense", + ranks=[0], + sharding_spec=EnumerableShardingSpec( + [ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[4096, 32], + placement="rank:0/cuda:0", + ), + ] + ), + ), + "movie_id": ParameterSharding( + sharding_type="row_wise", + compute_kernel="dense", + ranks=[0, 1], + sharding_spec=EnumerableShardingSpec( + [ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[2048, 32], + placement="rank:0/cuda:0", + bucket_id_offset=0, + num_buckets=20, + ), + ShardMetadata( + shard_offsets=[2048, 0], + shard_sizes=[2048, 32], + placement="rank:0/cuda:1", + bucket_id_offset=20, + num_buckets=30, + ), + ] + ), + ), + } + ) + } + ) + expected = """module: ebc + + param | sharding type | compute kernel | ranks +-------- | ------------- | -------------- | ------ +user_id | table_wise | dense | [0] +movie_id | row_wise | dense | [0, 1] + + param | shard offsets | shard sizes | placement | bucket id offset | num buckets +-------- | ------------- | ----------- | ------------- | ---------------- | ----------- +user_id | [0, 0] | [4096, 32] | rank:0/cuda:0 | None | None +movie_id | [0, 0] | [2048, 32] | rank:0/cuda:0 | 0 | 20 +movie_id | [2048, 0] | [2048, 32] | rank:0/cuda:1 | 20 | 30 """ self.maxDiff = None + print("STR PLAN BUCKET WISE") + print(str(plan)) + print("=======") for i in range(len(expected.splitlines())): self.assertEqual( expected.splitlines()[i].strip(), str(plan).splitlines()[i].strip() diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 1d45bff69..2be9a9c86 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -737,6 +737,7 @@ def __str__(self) -> str: out = "" param_table = [] shard_table = [] + contains_bucket_wise_shards = False for param_name, param_sharding in self.items(): param_table.append( [ @@ -749,20 +750,54 @@ def __str__(self) -> str: if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec): shards = param_sharding.sharding_spec.shards if shards is not None: + param_sharding_contains_bucket_info = any( + shard.bucket_id_offset is not None for shard in shards + ) + if param_sharding_contains_bucket_info: + contains_bucket_wise_shards = True for shard in shards: - shard_table.append( + cols = ( [ param_name, shard.shard_offsets, shard.shard_sizes, shard.placement, ] + if param_sharding_contains_bucket_info is None + else [ + param_name, + shard.shard_offsets, + shard.shard_sizes, + shard.placement, + shard.bucket_id_offset, + shard.num_buckets, + ] ) + shard_table.append(cols) + if contains_bucket_wise_shards: + for i in range(len(shard_table)): + if len(shard_table[i]) == 4: + # add None for the tables that don't have bucket info + shard_table[i].append(None) + shard_table[i].append(None) out += "\n\n" + _tabulate( param_table, ["param", "sharding type", "compute kernel", "ranks"] ) + column_str = ( + ["param", "shard offsets", "shard sizes", "placement"] + if not contains_bucket_wise_shards + else [ + "param", + "shard offsets", + "shard sizes", + "placement", + "bucket id offset", + "num buckets", + ] + ) out += "\n\n" + _tabulate( - shard_table, ["param", "shard offsets", "shard sizes", "placement"] + shard_table, + column_str, ) return out diff --git a/torchrec/modules/embedding_configs.py b/torchrec/modules/embedding_configs.py index b665257a8..82c739afd 100644 --- a/torchrec/modules/embedding_configs.py +++ b/torchrec/modules/embedding_configs.py @@ -195,6 +195,8 @@ class BaseEmbeddingConfig: # handle the special case input_dim: Optional[int] = None + total_num_buckets: Optional[int] = None + zero_collision: bool = False def get_weight_init_max(self) -> float: if self.weight_init_max is None: