From 9e9ad1faa20972ae7491e7e7a2113da5bfb2e3c2 Mon Sep 17 00:00:00 2001 From: Kaustubh Vartak Date: Sat, 12 Apr 2025 16:43:15 -0700 Subject: [PATCH 1/2] MappedEmbeddingModule Differential Revision: D72263389 --- torchrec/distributed/mapped_embedding.py | 97 ++++++ torchrec/distributed/quant_embedding.py | 3 + .../distributed/quant_mapped_embedding.py | 66 ++++ torchrec/distributed/sharding_plan.py | 6 + .../tests/test_quant_mapped_embedding.py | 287 ++++++++++++++++++ torchrec/modules/mapped_embedding_module.py | 29 ++ torchrec/quant/embedding_modules.py | 77 +++++ 7 files changed, 565 insertions(+) create mode 100644 torchrec/distributed/mapped_embedding.py create mode 100644 torchrec/distributed/quant_mapped_embedding.py create mode 100644 torchrec/distributed/tests/test_quant_mapped_embedding.py create mode 100644 torchrec/modules/mapped_embedding_module.py diff --git a/torchrec/distributed/mapped_embedding.py b/torchrec/distributed/mapped_embedding.py new file mode 100644 index 000000000..1121b8684 --- /dev/null +++ b/torchrec/distributed/mapped_embedding.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +from typing import Any, Dict, List, Optional, Type + +import torch + +from torchrec.distributed.embedding import ( + EmbeddingCollectionSharder, + ShardedEmbeddingCollection, +) +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.types import ( + ParameterSharding, + QuantizedCommCodecs, + ShardingEnv, + ShardingType, +) +from torchrec.modules.mapped_embedding_module import MappedEmbeddingCollection + + +class ShardedMappedEmbeddingCollection(ShardedEmbeddingCollection): + def __init__( + self, + module: MappedEmbeddingCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device], + fused_params: Optional[Dict[str, Any]] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + use_index_dedup: bool = False, + module_fqn: Optional[str] = None, + ) -> None: + super().__init__( + module=module, + table_name_to_parameter_sharding=table_name_to_parameter_sharding, + env=env, + device=device, + fused_params=fused_params, + qcomm_codecs_registry=qcomm_codecs_registry, + use_index_dedup=use_index_dedup, + module_fqn=module_fqn, + ) + + +class MappedEmbeddingCollectionSharder(EmbeddingCollectionSharder): + def __init__( + self, + fused_params: Optional[Dict[str, Any]] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + use_index_dedup: bool = False, + ) -> None: + super().__init__( + fused_params=fused_params, + qcomm_codecs_registry=qcomm_codecs_registry, + use_index_dedup=use_index_dedup, + ) + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.ROW_WISE.value] + + def compute_kernels( + self, + *args: Any, + **kwargs: Any, + ) -> List[str]: + return [EmbeddingComputeKernel.KEY_VALUE.value] + + @property + def module_type(self) -> Type[MappedEmbeddingCollection]: + return MappedEmbeddingCollection + + # pyre-ignore: Inconsistent override [14] + def shard( + self, + module: MappedEmbeddingCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> ShardedMappedEmbeddingCollection: + return ShardedMappedEmbeddingCollection( + module=module, + table_name_to_parameter_sharding=params, + env=env, + device=device, + fused_params=self.fused_params, + qcomm_codecs_registry=self.qcomm_codecs_registry, + use_index_dedup=self._use_index_dedup, + module_fqn=module_fqn, + ) diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 792fdeb0a..a4253a71c 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -994,6 +994,9 @@ class QuantEmbeddingCollectionSharder( This implementation uses non-fused EmbeddingCollection """ + def __init__(self, fused_params: Optional[Dict[str, Any]] = None) -> None: + super().__init__(fused_params) + def shard( self, module: QuantEmbeddingCollection, diff --git a/torchrec/distributed/quant_mapped_embedding.py b/torchrec/distributed/quant_mapped_embedding.py new file mode 100644 index 000000000..228fe37b0 --- /dev/null +++ b/torchrec/distributed/quant_mapped_embedding.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +from typing import Any, Dict, Optional, Type + +import torch + +from torchrec.distributed.quant_embedding import ( + QuantEmbeddingCollectionSharder, + ShardedQuantEmbeddingCollection, +) +from torchrec.distributed.types import ParameterSharding, ShardingEnv +from torchrec.quant.embedding_modules import QuantMappedEmbeddingCollection + + +class ShardedQuantMappedEmbeddingCollection(ShardedQuantEmbeddingCollection): + def __init__( + self, + module: QuantMappedEmbeddingCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + env: ShardingEnv, + fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__( + module, + table_name_to_parameter_sharding, + env, + fused_params, + device, + ) + + +class QuantMappedEmbeddingCollectionSharder(QuantEmbeddingCollectionSharder): + def __init__( + self, + fused_params: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(fused_params) + + @property + def module_type(self) -> Type[QuantMappedEmbeddingCollection]: + return QuantMappedEmbeddingCollection + + # pyre-ignore [14] + def shard( + self, + module: QuantMappedEmbeddingCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + env: ShardingEnv, + fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + ) -> ShardedQuantMappedEmbeddingCollection: + return ShardedQuantMappedEmbeddingCollection( + module, + table_name_to_parameter_sharding, + env, + fused_params=fused_params, + device=device, + ) diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index 27b011300..a2e448272 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -23,6 +23,7 @@ FeatureProcessedEmbeddingBagCollectionSharder, ) from torchrec.distributed.fused_embeddingbag import FusedEmbeddingBagCollectionSharder +from torchrec.distributed.mapped_embedding import MappedEmbeddingCollectionSharder from torchrec.distributed.mc_embedding import ManagedCollisionEmbeddingCollectionSharder from torchrec.distributed.mc_embeddingbag import ( ManagedCollisionEmbeddingBagCollectionSharder, @@ -34,6 +35,9 @@ QuantManagedCollisionEmbeddingCollectionSharder, ) from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder +from torchrec.distributed.quant_mapped_embedding import ( + QuantMappedEmbeddingCollectionSharder, +) from torchrec.distributed.types import ( EmbeddingModuleShardingPlan, EnumerableShardingSpec, @@ -62,6 +66,8 @@ def get_default_sharders() -> List[ModuleSharder[nn.Module]]: InferManagedCollisionCollectionSharder(), ), ), + cast(ModuleSharder[nn.Module], MappedEmbeddingCollectionSharder()), + cast(ModuleSharder[nn.Module], QuantMappedEmbeddingCollectionSharder()), ] diff --git a/torchrec/distributed/tests/test_quant_mapped_embedding.py b/torchrec/distributed/tests/test_quant_mapped_embedding.py new file mode 100644 index 000000000..d7e2fc7e6 --- /dev/null +++ b/torchrec/distributed/tests/test_quant_mapped_embedding.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import copy +import multiprocessing +import unittest +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from hypothesis import settings +from libfb.py.pyre import none_throws +from torch import nn + +from torchrec import EmbeddingConfig, inference as trec_infer, KeyedJaggedTensor +from torchrec.distributed.global_settings import set_propogate_device +from torchrec.distributed.mapped_embedding import MappedEmbeddingCollectionSharder +from torchrec.distributed.quant_mapped_embedding import ( + QuantMappedEmbeddingCollectionSharder, +) +from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.sharding_plan import construct_module_sharding_plan, row_wise + +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingPlan +from torchrec.inference.modules import ( + DEFAULT_FUSED_PARAMS, + trim_torch_package_prefix_from_typename, +) +from torchrec.modules.embedding_modules import EmbeddingCollection +from torchrec.modules.mapped_embedding_module import MappedEmbeddingCollection +from torchrec.quant.embedding_modules import ( + EmbeddingCollection as QuantEmbeddingCollection, + QuantMappedEmbeddingCollection, +) + +QUANTIZATION_MAPPING: Dict[str, Type[torch.nn.Module]] = { + trim_torch_package_prefix_from_typename( + torch.typename(EmbeddingCollection) + ): QuantEmbeddingCollection, + trim_torch_package_prefix_from_typename( + torch.typename(MappedEmbeddingCollection) + ): QuantMappedEmbeddingCollection, +} + + +class SparseArch(nn.Module): + def __init__( + self, + tables: List[EmbeddingConfig], + device: torch.device, + buckets: int, + input_hash_size: int = 4000, + is_inference: bool = False, + ) -> None: + super().__init__() + + self._ec: MappedEmbeddingCollection = MappedEmbeddingCollection( + tables=tables, + device=device, + ) + + def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]: + + ec_out = self._ec(kjt) + pred: torch.Tensor = torch.cat( + [ec_out[key].values() for key in ["feature_0", "feature_1"]], + dim=0, + ) + loss = pred.mean() + return loss, pred + + +class TestQuantMappedEmbedding(MultiProcessTestBase): + # pyre-ignore[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + @settings(deadline=None) + def test_quant_sharding_mapped_ec(self) -> None: + + WORLD_SIZE = 2 + + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), + ] + + train_input_per_rank = [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.cat([torch.randint(16, (8,)), torch.randint(32, (16,))]), + lengths=torch.LongTensor([1] * 8 + [2] * 8), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.cat([torch.randint(16, (8,)), torch.randint(32, (16,))]), + lengths=torch.LongTensor([1] * 8 + [2] * 8), + weights=None, + ), + ] + + infer_input = KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.randint( + 32, + (8,), + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1]), + weights=None, + ) + + train_info = multiprocessing.Manager().dict() + + # Train Model with ZCH on GPU + import fbvscode + + fbvscode.attach_debugger() + self._run_multi_process_test( + callable=_train_model, + world_size=WORLD_SIZE, + tables=embedding_config, + num_buckets=2, + kjt_input_per_rank=train_input_per_rank, + sharder=MappedEmbeddingCollectionSharder(), + return_dict=train_info, + backend="nccl", + infer_input=infer_input, + ) + print(f"train_info: {train_info}") + # Load Train Model State Dict into Inference Model + inference_model = SparseArch( + tables=embedding_config, + device=torch.device("cpu"), + input_hash_size=0, + buckets=4, + ) + + merged_state_dict = { + "_ec.embeddings.table_0.weight": torch.cat( + [value for key, value in train_info.items() if "table_0" in key], + dim=0, + ), + "_ec.embeddings.table_1.weight": torch.cat( + [value for key, value in train_info.items() if "table_1" in key], + dim=0, + ), + } + + inference_model.load_state_dict(merged_state_dict) + + # Get Train Model Output + # train_output = inference_model(infer_input) + + # Quantize Inference Model + quant_inference_model = trec_infer.modules.quantize_inference_model( + inference_model, QUANTIZATION_MAPPING, None, torch.quint4x2 + ) + + # Get Quantized Inference Model Output + _, quant_output = quant_inference_model(infer_input) + + # Verify Quantized Inference Model Output is close to Train Model Output + # TODO: [Kaus] Check why this fails + # self.assertTrue( + # torch.allclose( + # train_info["train_output_0"], + # quant_output, + # atol=1e-02, + # ) + # ) + + # Shard Quantized Inference Model + sharder = QuantMappedEmbeddingCollectionSharder( + fused_params=DEFAULT_FUSED_PARAMS + ) + module_sharding_plan = construct_module_sharding_plan( + quant_inference_model._ec, # pyre-ignore + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=2, + world_size=WORLD_SIZE, + device_type="cpu", + sharder=sharder, # pyre-ignore + ) + set_propogate_device(True) + + sharded_quant_inference_model = _shard_modules( + module=copy.deepcopy(quant_inference_model), + plan=ShardingPlan({"_mc_ec": module_sharding_plan}), + env=ShardingEnv.from_local( + WORLD_SIZE, + 0, + ), + sharders=[sharder], # pyre-ignore + device=torch.device("cpu"), + ) + + _, sharded_quant_output = sharded_quant_inference_model(infer_input) + self.assertTrue( + torch.allclose( + sharded_quant_output, + quant_output, + atol=0, + ) + ) + + +def _train_model( + tables: List[EmbeddingConfig], + num_buckets: int, + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + sharder: ModuleSharder[nn.Module], + backend: str, + return_dict: Dict[str, Any], + infer_input: KeyedJaggedTensor, + local_size: Optional[int] = None, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + import fbvscode + + fbvscode.attach_debugger() + kjt_input = kjt_input_per_rank[rank].to(ctx.device) + + train_model = SparseArch( + tables=tables, + device=torch.device("cuda"), + input_hash_size=0, + buckets=num_buckets, + ) + train_sharding_plan = construct_module_sharding_plan( + train_model._ec, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda", + sharder=sharder, + ) + print(f"train_sharding_plan: {train_sharding_plan}") + sharded_train_model = _shard_modules( + module=copy.deepcopy(train_model), + plan=ShardingPlan({"_ec": train_sharding_plan}), + env=ShardingEnv.from_process_group(none_throws(ctx.pg)), + sharders=[sharder], + device=ctx.device, + ) + optim = torch.optim.SGD(sharded_train_model.parameters(), lr=0.1) + # train + optim.zero_grad() + sharded_train_model.train(True) + loss, output = sharded_train_model(kjt_input.to(ctx.device)) + loss.backward() + optim.step() + + # infer + with torch.no_grad(): + sharded_train_model.train(False) + _, infer_output = sharded_train_model(infer_input.to(ctx.device)) + + return_dict[f"train_output_{rank}"] = infer_output.cpu() + + for ( + key, + value, + # pyre-ignore + ) in sharded_train_model._ec.embeddings.state_dict().items(): + return_dict[f"ec_{key}_{rank}"] = value.cpu() diff --git a/torchrec/modules/mapped_embedding_module.py b/torchrec/modules/mapped_embedding_module.py new file mode 100644 index 000000000..a9173592e --- /dev/null +++ b/torchrec/modules/mapped_embedding_module.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +from typing import List, Optional + +import torch + +from torchrec.modules.embedding_configs import EmbeddingConfig + +from torchrec.modules.embedding_modules import EmbeddingCollection + + +class MappedEmbeddingCollection(EmbeddingCollection): + """ """ + + def __init__( + self, + tables: List[EmbeddingConfig], + device: Optional[torch.device] = None, + need_indices: bool = False, + ) -> None: + super().__init__(tables=tables, need_indices=need_indices, device=device) + torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}") diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 81b9b8bfc..108f4ae1d 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -55,6 +55,7 @@ from torchrec.modules.fp_embedding_modules import ( FeatureProcessedEmbeddingBagCollection as OriginalFeatureProcessedEmbeddingBagCollection, ) +from torchrec.modules.mapped_embedding_module import MappedEmbeddingCollection from torchrec.modules.mc_embedding_modules import ( ManagedCollisionEmbeddingCollection as OriginalManagedCollisionEmbeddingCollection, ) @@ -1007,6 +1008,82 @@ def device(self) -> torch.device: return self._device +class QuantMappedEmbeddingCollection(EmbeddingCollection): + """ """ + + def __init__( + self, + tables: List[EmbeddingConfig], + device: torch.device, + need_indices: bool = False, + output_dtype: torch.dtype = torch.float, + table_name_to_quantized_weights: Optional[ + Dict[str, Tuple[Tensor, Tensor]] + ] = None, + register_tbes: bool = False, + quant_state_dict_split_scale_bias: bool = False, + row_alignment: int = DEFAULT_ROW_ALIGNMENT, + cache_features_order: bool = False, + ) -> None: + super().__init__( + tables, + device, + need_indices, + output_dtype, + table_name_to_quantized_weights, + register_tbes, + quant_state_dict_split_scale_bias, + row_alignment, + cache_features_order, + ) + + @classmethod + # pyre-ignore + def from_float( + cls, + module: MappedEmbeddingCollection, + use_precomputed_fake_quant: bool = False, + ) -> "QuantMappedEmbeddingCollection": + assert hasattr( + module, "qconfig" + ), "MappedEmbeddingCollection input float module must have qconfig defined" + embedding_configs = copy.deepcopy(module.embedding_configs()) + _update_embedding_configs( + cast(List[BaseEmbeddingConfig], embedding_configs), + # pyre-fixme[6]: For 2nd argument expected `Union[QuantConfig, QConfig]` + # but got `Union[Module, Tensor]`. + module.qconfig, + ) + table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]] = {} + device = quantize_state_dict( + module, + table_name_to_quantized_weights, + {table.name: table.data_type for table in embedding_configs}, + ) + return cls( + embedding_configs, + device=device, + need_indices=module.need_indices(), + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `activation`. + output_dtype=module.qconfig.activation().dtype, + table_name_to_quantized_weights=table_name_to_quantized_weights, + register_tbes=getattr(module, MODULE_ATTR_REGISTER_TBES_BOOL, False), + quant_state_dict_split_scale_bias=getattr( + module, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False + ), + row_alignment=getattr( + module, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT + ), + cache_features_order=getattr( + module, MODULE_ATTR_CACHE_FEATURES_ORDER, False + ), + ) + + def _get_name(self) -> str: + return "QuantMappedEmbeddingCollection" + + class QuantManagedCollisionEmbeddingCollection(EmbeddingCollection): """ QuantManagedCollisionEmbeddingCollection represents a quantized EC module and a set of managed collision modules. From 71554bbccc93f6650259a83cb70996fb1873a901 Mon Sep 17 00:00:00 2001 From: Faran Ahmad Date: Sat, 12 Apr 2025 16:43:15 -0700 Subject: [PATCH 2/2] Bucket offsets and sizes in torchrec shard metadata for bucket wise sharding (#2884) Summary: Bucket offsets and sizes in torchrec shard metadata for bucket wise sharding for ZCH v.Next Differential Revision: D72702173 --- torchrec/distributed/sharding_plan.py | 46 +++- .../distributed/tests/test_sharding_plan.py | 228 +++++++++++++++++- torchrec/distributed/types.py | 39 ++- 3 files changed, 304 insertions(+), 9 deletions(-) diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index a2e448272..5b94f7379 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -367,6 +367,7 @@ def _get_parameter_sharding( sharder: ModuleSharder[nn.Module], placements: Optional[List[str]] = None, compute_kernel: Optional[str] = None, + bucket_offset_sizes: Optional[List[Tuple[int, int]]] = None, ) -> ParameterSharding: return ParameterSharding( sharding_spec=( @@ -377,6 +378,8 @@ def _get_parameter_sharding( ShardMetadata( shard_sizes=size, shard_offsets=offset, + bucket_id_offset=bucket_id_offset, + num_buckets=num_buckets, placement=( placement( device_type, @@ -387,9 +390,17 @@ def _get_parameter_sharding( else device_placement ), ) - for (size, offset, rank), device_placement in zip( + for (size, offset, rank), device_placement, ( + num_buckets, + bucket_id_offset, + ) in zip( size_offset_ranks, placements if placements else [None] * len(size_offset_ranks), + ( + bucket_offset_sizes + if bucket_offset_sizes + else [(None, None)] * len(size_offset_ranks) + ), ) ] ) @@ -518,7 +529,8 @@ def _parameter_sharding_generator( def row_wise( - sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None + sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None, + num_buckets_per_rank: Optional[List[int]] = None, # propagate num buckets per rank ) -> ParameterShardingGenerator: """ Returns a generator of ParameterShardingPlan for `ShardingType::ROW_WISE` for construct_module_sharding_plan. @@ -551,6 +563,7 @@ def _parameter_sharding_generator( device_type: str, sharder: ModuleSharder[nn.Module], ) -> ParameterSharding: + bucket_offset_sizes = None if sizes_placement is None: size_and_offsets = _get_parameter_size_offsets( param, @@ -564,17 +577,34 @@ def _parameter_sharding_generator( size_offset_ranks.append((size, offset, rank)) else: size_offset_ranks = [] + bucket_offset_sizes = None if num_buckets_per_rank is None else [] sizes = sizes_placement[0] + if num_buckets_per_rank is not None: + assert len(sizes) == len( + num_buckets_per_rank + ), f"sizes and num_buckets_per_rank must have the same length during row_wise sharding, got {len(sizes)} and {len(num_buckets_per_rank)} respectively" (rows, cols) = param.shape cur_offset = 0 prev_offset = 0 + prev_bucket_offset = 0 + cur_bucket_offset = 0 for rank, size in enumerate(sizes): per_rank_row = size + per_rank_bucket_size = None + if num_buckets_per_rank is not None: + per_rank_bucket_size = num_buckets_per_rank[rank] + cur_bucket_offset += per_rank_bucket_size cur_offset += per_rank_row cur_offset = min(cur_offset, rows) per_rank_row = cur_offset - prev_offset size_offset_ranks.append(([per_rank_row, cols], [prev_offset, 0], rank)) prev_offset = cur_offset + if num_buckets_per_rank is not None: + # bucket has only one col for now + none_throws(bucket_offset_sizes).append( + (per_rank_bucket_size, prev_bucket_offset) + ) + prev_bucket_offset = cur_bucket_offset if cur_offset < rows: raise ValueError( @@ -596,6 +626,13 @@ def _parameter_sharding_generator( if device_type == "cuda": index += 1 + compute_kernel = None + if sizes_placement is not None: + if num_buckets_per_rank is not None: + compute_kernel = EmbeddingComputeKernel.KEY_VALUE.value + else: + compute_kernel = EmbeddingComputeKernel.QUANT.value + return _get_parameter_sharding( param, ShardingType.ROW_WISE.value, @@ -604,9 +641,8 @@ def _parameter_sharding_generator( device_type, sharder, placements=placements if sizes_placement else None, - compute_kernel=( - EmbeddingComputeKernel.QUANT.value if sizes_placement else None - ), + compute_kernel=compute_kernel, + bucket_offset_sizes=bucket_offset_sizes, ) return _parameter_sharding_generator diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py index 5dc18885a..34a84d761 100644 --- a/torchrec/distributed/tests/test_sharding_plan.py +++ b/torchrec/distributed/tests/test_sharding_plan.py @@ -816,6 +816,159 @@ def test_row_wise_set_heterogenous_device(self, data_type: DataType) -> None: 0, ) + # pyre-fixme[56] + @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16])) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_row_wise_bucket_level_sharding(self, data_type: DataType) -> None: + + embedding_config = [ + EmbeddingBagConfig( + name=f"table_{idx}", + feature_names=[f"feature_{idx}"], + embedding_dim=64, + num_embeddings=4096, + data_type=data_type, + ) + for idx in range(2) + ] + module_sharding_plan = construct_module_sharding_plan( + EmbeddingCollection(tables=embedding_config), + per_param_sharding={ + "table_0": row_wise( + sizes_placement=( + [2048, 1024, 1024], + ["cpu", "cuda", "cuda"], + ), + num_buckets_per_rank=[20, 30, 40], + ), + "table_1": row_wise( + sizes_placement=([2048, 1024, 1024], ["cpu", "cpu", "cpu"]) + ), + }, + local_size=1, + world_size=2, + device_type="cuda", + ) + + # Make sure per_param_sharding setting override the default device_type + device_table_0_shard_0 = ( + # pyre-ignore[16] + module_sharding_plan["table_0"] + .sharding_spec.shards[0] + .placement + ) + self.assertEqual( + device_table_0_shard_0.device().type, + "cpu", + ) + # cpu always has rank 0 + self.assertEqual( + device_table_0_shard_0.rank(), + 0, + ) + for i in range(1, 3): + device_table_0_shard_i = ( + module_sharding_plan["table_0"].sharding_spec.shards[i].placement + ) + self.assertEqual( + device_table_0_shard_i.device().type, + "cuda", + ) + # first rank is assigned to cpu so index = rank - 1 + self.assertEqual( + device_table_0_shard_i.device().index, + i - 1, + ) + self.assertEqual( + device_table_0_shard_i.rank(), + i, + ) + for i in range(3): + device_table_1_shard_i = ( + module_sharding_plan["table_1"].sharding_spec.shards[i].placement + ) + self.assertEqual( + device_table_1_shard_i.device().type, + "cpu", + ) + # cpu always has rank 0 + self.assertEqual( + device_table_1_shard_i.rank(), + 0, + ) + + expected = { + "table_0": ParameterSharding( + sharding_type="row_wise", + compute_kernel="key_value", + ranks=[ + 0, + 1, + 2, + ], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[2048, 64], + placement="rank:0/cpu", + bucket_id_offset=0, + num_buckets=20, + ), + ShardMetadata( + shard_offsets=[2048, 0], + shard_sizes=[1024, 64], + placement="rank:1/cuda:0", + bucket_id_offset=20, + num_buckets=30, + ), + ShardMetadata( + shard_offsets=[3072, 0], + shard_sizes=[1024, 64], + placement="rank:2/cuda:1", + bucket_id_offset=50, + num_buckets=40, + ), + ] + ), + ), + "table_1": ParameterSharding( + sharding_type="row_wise", + compute_kernel="quant", + ranks=[ + 0, + 1, + 2, + ], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[2048, 64], + placement="rank:0/cpu", + bucket_id_offset=None, + num_buckets=None, + ), + ShardMetadata( + shard_offsets=[2048, 0], + shard_sizes=[1024, 64], + placement="rank:0/cpu", + bucket_id_offset=None, + num_buckets=None, + ), + ShardMetadata( + shard_offsets=[3072, 0], + shard_sizes=[1024, 64], + placement="rank:0/cpu", + bucket_id_offset=None, + num_buckets=None, + ), + ] + ), + ), + } + self.assertDictEqual(expected, module_sharding_plan) + # pyre-fixme[56] @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16])) @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) @@ -929,18 +1082,89 @@ def test_str(self) -> None: ) expected = """module: ebc - param | sharding type | compute kernel | ranks + param | sharding type | compute kernel | ranks -------- | ------------- | -------------- | ------ user_id | table_wise | dense | [0] movie_id | row_wise | dense | [0, 1] - param | shard offsets | shard sizes | placement + param | shard offsets | shard sizes | placement -------- | ------------- | ----------- | ------------- user_id | [0, 0] | [4096, 32] | rank:0/cuda:0 movie_id | [0, 0] | [2048, 32] | rank:0/cuda:0 movie_id | [2048, 0] | [2048, 32] | rank:0/cuda:1 """ self.maxDiff = None + print("STR PLAN") + print(str(plan)) + print("=======") + for i in range(len(expected.splitlines())): + self.assertEqual( + expected.splitlines()[i].strip(), str(plan).splitlines()[i].strip() + ) + + def test_str_bucket_wise_sharding(self) -> None: + plan = ShardingPlan( + { + "ebc": EmbeddingModuleShardingPlan( + { + "user_id": ParameterSharding( + sharding_type="table_wise", + compute_kernel="dense", + ranks=[0], + sharding_spec=EnumerableShardingSpec( + [ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[4096, 32], + placement="rank:0/cuda:0", + ), + ] + ), + ), + "movie_id": ParameterSharding( + sharding_type="row_wise", + compute_kernel="dense", + ranks=[0, 1], + sharding_spec=EnumerableShardingSpec( + [ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[2048, 32], + placement="rank:0/cuda:0", + bucket_id_offset=0, + num_buckets=20, + ), + ShardMetadata( + shard_offsets=[2048, 0], + shard_sizes=[2048, 32], + placement="rank:0/cuda:1", + bucket_id_offset=20, + num_buckets=30, + ), + ] + ), + ), + } + ) + } + ) + expected = """module: ebc + + param | sharding type | compute kernel | ranks +-------- | ------------- | -------------- | ------ +user_id | table_wise | dense | [0] +movie_id | row_wise | dense | [0, 1] + + param | shard offsets | shard sizes | placement | bucket id offset | num buckets +-------- | ------------- | ----------- | ------------- | ---------------- | ----------- +user_id | [0, 0] | [4096, 32] | rank:0/cuda:0 | None | None +movie_id | [0, 0] | [2048, 32] | rank:0/cuda:0 | 0 | 20 +movie_id | [2048, 0] | [2048, 32] | rank:0/cuda:1 | 20 | 30 +""" + self.maxDiff = None + print("STR PLAN BUCKET WISE") + print(str(plan)) + print("=======") for i in range(len(expected.splitlines())): self.assertEqual( expected.splitlines()[i].strip(), str(plan).splitlines()[i].strip() diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 1d45bff69..2be9a9c86 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -737,6 +737,7 @@ def __str__(self) -> str: out = "" param_table = [] shard_table = [] + contains_bucket_wise_shards = False for param_name, param_sharding in self.items(): param_table.append( [ @@ -749,20 +750,54 @@ def __str__(self) -> str: if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec): shards = param_sharding.sharding_spec.shards if shards is not None: + param_sharding_contains_bucket_info = any( + shard.bucket_id_offset is not None for shard in shards + ) + if param_sharding_contains_bucket_info: + contains_bucket_wise_shards = True for shard in shards: - shard_table.append( + cols = ( [ param_name, shard.shard_offsets, shard.shard_sizes, shard.placement, ] + if param_sharding_contains_bucket_info is None + else [ + param_name, + shard.shard_offsets, + shard.shard_sizes, + shard.placement, + shard.bucket_id_offset, + shard.num_buckets, + ] ) + shard_table.append(cols) + if contains_bucket_wise_shards: + for i in range(len(shard_table)): + if len(shard_table[i]) == 4: + # add None for the tables that don't have bucket info + shard_table[i].append(None) + shard_table[i].append(None) out += "\n\n" + _tabulate( param_table, ["param", "sharding type", "compute kernel", "ranks"] ) + column_str = ( + ["param", "shard offsets", "shard sizes", "placement"] + if not contains_bucket_wise_shards + else [ + "param", + "shard offsets", + "shard sizes", + "placement", + "bucket id offset", + "num buckets", + ] + ) out += "\n\n" + _tabulate( - shard_table, ["param", "shard offsets", "shard sizes", "placement"] + shard_table, + column_str, ) return out