diff --git a/torchrec/distributed/mapped_embedding.py b/torchrec/distributed/mapped_embedding.py
new file mode 100644
index 000000000..1121b8684
--- /dev/null
+++ b/torchrec/distributed/mapped_embedding.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#!/usr/bin/env python3
+
+from typing import Any, Dict, List, Optional, Type
+
+import torch
+
+from torchrec.distributed.embedding import (
+    EmbeddingCollectionSharder,
+    ShardedEmbeddingCollection,
+)
+from torchrec.distributed.embedding_types import EmbeddingComputeKernel
+from torchrec.distributed.types import (
+    ParameterSharding,
+    QuantizedCommCodecs,
+    ShardingEnv,
+    ShardingType,
+)
+from torchrec.modules.mapped_embedding_module import MappedEmbeddingCollection
+
+
+class ShardedMappedEmbeddingCollection(ShardedEmbeddingCollection):
+    def __init__(
+        self,
+        module: MappedEmbeddingCollection,
+        table_name_to_parameter_sharding: Dict[str, ParameterSharding],
+        env: ShardingEnv,
+        device: Optional[torch.device],
+        fused_params: Optional[Dict[str, Any]] = None,
+        qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
+        use_index_dedup: bool = False,
+        module_fqn: Optional[str] = None,
+    ) -> None:
+        super().__init__(
+            module=module,
+            table_name_to_parameter_sharding=table_name_to_parameter_sharding,
+            env=env,
+            device=device,
+            fused_params=fused_params,
+            qcomm_codecs_registry=qcomm_codecs_registry,
+            use_index_dedup=use_index_dedup,
+            module_fqn=module_fqn,
+        )
+
+
+class MappedEmbeddingCollectionSharder(EmbeddingCollectionSharder):
+    def __init__(
+        self,
+        fused_params: Optional[Dict[str, Any]] = None,
+        qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
+        use_index_dedup: bool = False,
+    ) -> None:
+        super().__init__(
+            fused_params=fused_params,
+            qcomm_codecs_registry=qcomm_codecs_registry,
+            use_index_dedup=use_index_dedup,
+        )
+
+    def sharding_types(self, compute_device_type: str) -> List[str]:
+        return [ShardingType.ROW_WISE.value]
+
+    def compute_kernels(
+        self,
+        *args: Any,
+        **kwargs: Any,
+    ) -> List[str]:
+        return [EmbeddingComputeKernel.KEY_VALUE.value]
+
+    @property
+    def module_type(self) -> Type[MappedEmbeddingCollection]:
+        return MappedEmbeddingCollection
+
+    # pyre-ignore: Inconsistent override [14]
+    def shard(
+        self,
+        module: MappedEmbeddingCollection,
+        params: Dict[str, ParameterSharding],
+        env: ShardingEnv,
+        device: Optional[torch.device] = None,
+        module_fqn: Optional[str] = None,
+    ) -> ShardedMappedEmbeddingCollection:
+        return ShardedMappedEmbeddingCollection(
+            module=module,
+            table_name_to_parameter_sharding=params,
+            env=env,
+            device=device,
+            fused_params=self.fused_params,
+            qcomm_codecs_registry=self.qcomm_codecs_registry,
+            use_index_dedup=self._use_index_dedup,
+            module_fqn=module_fqn,
+        )
diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py
index 792fdeb0a..a4253a71c 100644
--- a/torchrec/distributed/quant_embedding.py
+++ b/torchrec/distributed/quant_embedding.py
@@ -994,6 +994,9 @@ class QuantEmbeddingCollectionSharder(
     This implementation uses non-fused EmbeddingCollection
     """
 
+    def __init__(self, fused_params: Optional[Dict[str, Any]] = None) -> None:
+        super().__init__(fused_params)
+
     def shard(
         self,
         module: QuantEmbeddingCollection,
diff --git a/torchrec/distributed/quant_mapped_embedding.py b/torchrec/distributed/quant_mapped_embedding.py
new file mode 100644
index 000000000..228fe37b0
--- /dev/null
+++ b/torchrec/distributed/quant_mapped_embedding.py
@@ -0,0 +1,66 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#!/usr/bin/env python3
+
+from typing import Any, Dict, Optional, Type
+
+import torch
+
+from torchrec.distributed.quant_embedding import (
+    QuantEmbeddingCollectionSharder,
+    ShardedQuantEmbeddingCollection,
+)
+from torchrec.distributed.types import ParameterSharding, ShardingEnv
+from torchrec.quant.embedding_modules import QuantMappedEmbeddingCollection
+
+
+class ShardedQuantMappedEmbeddingCollection(ShardedQuantEmbeddingCollection):
+    def __init__(
+        self,
+        module: QuantMappedEmbeddingCollection,
+        table_name_to_parameter_sharding: Dict[str, ParameterSharding],
+        env: ShardingEnv,
+        fused_params: Optional[Dict[str, Any]] = None,
+        device: Optional[torch.device] = None,
+    ) -> None:
+        super().__init__(
+            module,
+            table_name_to_parameter_sharding,
+            env,
+            fused_params,
+            device,
+        )
+
+
+class QuantMappedEmbeddingCollectionSharder(QuantEmbeddingCollectionSharder):
+    def __init__(
+        self,
+        fused_params: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        super().__init__(fused_params)
+
+    @property
+    def module_type(self) -> Type[QuantMappedEmbeddingCollection]:
+        return QuantMappedEmbeddingCollection
+
+    # pyre-ignore [14]
+    def shard(
+        self,
+        module: QuantMappedEmbeddingCollection,
+        table_name_to_parameter_sharding: Dict[str, ParameterSharding],
+        env: ShardingEnv,
+        fused_params: Optional[Dict[str, Any]] = None,
+        device: Optional[torch.device] = None,
+    ) -> ShardedQuantMappedEmbeddingCollection:
+        return ShardedQuantMappedEmbeddingCollection(
+            module,
+            table_name_to_parameter_sharding,
+            env,
+            fused_params=fused_params,
+            device=device,
+        )
diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py
index 27b011300..5b94f7379 100644
--- a/torchrec/distributed/sharding_plan.py
+++ b/torchrec/distributed/sharding_plan.py
@@ -23,6 +23,7 @@
     FeatureProcessedEmbeddingBagCollectionSharder,
 )
 from torchrec.distributed.fused_embeddingbag import FusedEmbeddingBagCollectionSharder
+from torchrec.distributed.mapped_embedding import MappedEmbeddingCollectionSharder
 from torchrec.distributed.mc_embedding import ManagedCollisionEmbeddingCollectionSharder
 from torchrec.distributed.mc_embeddingbag import (
     ManagedCollisionEmbeddingBagCollectionSharder,
@@ -34,6 +35,9 @@
     QuantManagedCollisionEmbeddingCollectionSharder,
 )
 from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder
+from torchrec.distributed.quant_mapped_embedding import (
+    QuantMappedEmbeddingCollectionSharder,
+)
 from torchrec.distributed.types import (
     EmbeddingModuleShardingPlan,
     EnumerableShardingSpec,
@@ -62,6 +66,8 @@ def get_default_sharders() -> List[ModuleSharder[nn.Module]]:
                 InferManagedCollisionCollectionSharder(),
             ),
         ),
+        cast(ModuleSharder[nn.Module], MappedEmbeddingCollectionSharder()),
+        cast(ModuleSharder[nn.Module], QuantMappedEmbeddingCollectionSharder()),
     ]
 
 
@@ -361,6 +367,7 @@ def _get_parameter_sharding(
     sharder: ModuleSharder[nn.Module],
     placements: Optional[List[str]] = None,
     compute_kernel: Optional[str] = None,
+    bucket_offset_sizes: Optional[List[Tuple[int, int]]] = None,
 ) -> ParameterSharding:
     return ParameterSharding(
         sharding_spec=(
@@ -371,6 +378,8 @@ def _get_parameter_sharding(
                     ShardMetadata(
                         shard_sizes=size,
                         shard_offsets=offset,
+                        bucket_id_offset=bucket_id_offset,
+                        num_buckets=num_buckets,
                         placement=(
                             placement(
                                 device_type,
@@ -381,9 +390,17 @@ def _get_parameter_sharding(
                             else device_placement
                         ),
                     )
-                    for (size, offset, rank), device_placement in zip(
+                    for (size, offset, rank), device_placement, (
+                        num_buckets,
+                        bucket_id_offset,
+                    ) in zip(
                         size_offset_ranks,
                         placements if placements else [None] * len(size_offset_ranks),
+                        (
+                            bucket_offset_sizes
+                            if bucket_offset_sizes
+                            else [(None, None)] * len(size_offset_ranks)
+                        ),
                     )
                 ]
             )
@@ -512,7 +529,8 @@ def _parameter_sharding_generator(
 
 
 def row_wise(
-    sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None
+    sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None,
+    num_buckets_per_rank: Optional[List[int]] = None,  # propagate num buckets per rank
 ) -> ParameterShardingGenerator:
     """
     Returns a generator of ParameterShardingPlan for `ShardingType::ROW_WISE` for construct_module_sharding_plan.
@@ -545,6 +563,7 @@ def _parameter_sharding_generator(
         device_type: str,
         sharder: ModuleSharder[nn.Module],
     ) -> ParameterSharding:
+        bucket_offset_sizes = None
         if sizes_placement is None:
             size_and_offsets = _get_parameter_size_offsets(
                 param,
@@ -558,17 +577,34 @@ def _parameter_sharding_generator(
                 size_offset_ranks.append((size, offset, rank))
         else:
             size_offset_ranks = []
+            bucket_offset_sizes = None if num_buckets_per_rank is None else []
             sizes = sizes_placement[0]
+            if num_buckets_per_rank is not None:
+                assert len(sizes) == len(
+                    num_buckets_per_rank
+                ), f"sizes and num_buckets_per_rank must have the same length during row_wise sharding, got {len(sizes)} and {len(num_buckets_per_rank)} respectively"
             (rows, cols) = param.shape
             cur_offset = 0
             prev_offset = 0
+            prev_bucket_offset = 0
+            cur_bucket_offset = 0
             for rank, size in enumerate(sizes):
                 per_rank_row = size
+                per_rank_bucket_size = None
+                if num_buckets_per_rank is not None:
+                    per_rank_bucket_size = num_buckets_per_rank[rank]
+                    cur_bucket_offset += per_rank_bucket_size
                 cur_offset += per_rank_row
                 cur_offset = min(cur_offset, rows)
                 per_rank_row = cur_offset - prev_offset
                 size_offset_ranks.append(([per_rank_row, cols], [prev_offset, 0], rank))
                 prev_offset = cur_offset
+                if num_buckets_per_rank is not None:
+                    # bucket has only one col for now
+                    none_throws(bucket_offset_sizes).append(
+                        (per_rank_bucket_size, prev_bucket_offset)
+                    )
+                    prev_bucket_offset = cur_bucket_offset
 
             if cur_offset < rows:
                 raise ValueError(
@@ -590,6 +626,13 @@ def _parameter_sharding_generator(
                 if device_type == "cuda":
                     index += 1
 
+        compute_kernel = None
+        if sizes_placement is not None:
+            if num_buckets_per_rank is not None:
+                compute_kernel = EmbeddingComputeKernel.KEY_VALUE.value
+            else:
+                compute_kernel = EmbeddingComputeKernel.QUANT.value
+
         return _get_parameter_sharding(
             param,
             ShardingType.ROW_WISE.value,
@@ -598,9 +641,8 @@ def _parameter_sharding_generator(
             device_type,
             sharder,
             placements=placements if sizes_placement else None,
-            compute_kernel=(
-                EmbeddingComputeKernel.QUANT.value if sizes_placement else None
-            ),
+            compute_kernel=compute_kernel,
+            bucket_offset_sizes=bucket_offset_sizes,
         )
 
     return _parameter_sharding_generator
diff --git a/torchrec/distributed/tests/test_quant_mapped_embedding.py b/torchrec/distributed/tests/test_quant_mapped_embedding.py
new file mode 100644
index 000000000..d7e2fc7e6
--- /dev/null
+++ b/torchrec/distributed/tests/test_quant_mapped_embedding.py
@@ -0,0 +1,287 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# pyre-strict
+import copy
+import multiprocessing
+import unittest
+from typing import Any, Dict, List, Optional, Tuple, Type
+
+import torch
+from hypothesis import settings
+from libfb.py.pyre import none_throws
+from torch import nn
+
+from torchrec import EmbeddingConfig, inference as trec_infer, KeyedJaggedTensor
+from torchrec.distributed.global_settings import set_propogate_device
+from torchrec.distributed.mapped_embedding import MappedEmbeddingCollectionSharder
+from torchrec.distributed.quant_mapped_embedding import (
+    QuantMappedEmbeddingCollectionSharder,
+)
+from torchrec.distributed.shard import _shard_modules
+from torchrec.distributed.sharding_plan import construct_module_sharding_plan, row_wise
+
+from torchrec.distributed.test_utils.multi_process import (
+    MultiProcessContext,
+    MultiProcessTestBase,
+)
+from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingPlan
+from torchrec.inference.modules import (
+    DEFAULT_FUSED_PARAMS,
+    trim_torch_package_prefix_from_typename,
+)
+from torchrec.modules.embedding_modules import EmbeddingCollection
+from torchrec.modules.mapped_embedding_module import MappedEmbeddingCollection
+from torchrec.quant.embedding_modules import (
+    EmbeddingCollection as QuantEmbeddingCollection,
+    QuantMappedEmbeddingCollection,
+)
+
+QUANTIZATION_MAPPING: Dict[str, Type[torch.nn.Module]] = {
+    trim_torch_package_prefix_from_typename(
+        torch.typename(EmbeddingCollection)
+    ): QuantEmbeddingCollection,
+    trim_torch_package_prefix_from_typename(
+        torch.typename(MappedEmbeddingCollection)
+    ): QuantMappedEmbeddingCollection,
+}
+
+
+class SparseArch(nn.Module):
+    def __init__(
+        self,
+        tables: List[EmbeddingConfig],
+        device: torch.device,
+        buckets: int,
+        input_hash_size: int = 4000,
+        is_inference: bool = False,
+    ) -> None:
+        super().__init__()
+
+        self._ec: MappedEmbeddingCollection = MappedEmbeddingCollection(
+            tables=tables,
+            device=device,
+        )
+
+    def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]:
+
+        ec_out = self._ec(kjt)
+        pred: torch.Tensor = torch.cat(
+            [ec_out[key].values() for key in ["feature_0", "feature_1"]],
+            dim=0,
+        )
+        loss = pred.mean()
+        return loss, pred
+
+
+class TestQuantMappedEmbedding(MultiProcessTestBase):
+    # pyre-ignore[56]: Pyre was not able to infer the type of argument
+    @unittest.skipIf(
+        torch.cuda.device_count() <= 1,
+        "Not enough GPUs, this test requires at least two GPUs",
+    )
+    @settings(deadline=None)
+    def test_quant_sharding_mapped_ec(self) -> None:
+
+        WORLD_SIZE = 2
+
+        embedding_config = [
+            EmbeddingConfig(
+                name="table_0",
+                feature_names=["feature_0"],
+                embedding_dim=8,
+                num_embeddings=16,
+            ),
+            EmbeddingConfig(
+                name="table_1",
+                feature_names=["feature_1"],
+                embedding_dim=8,
+                num_embeddings=32,
+            ),
+        ]
+
+        train_input_per_rank = [
+            KeyedJaggedTensor.from_lengths_sync(
+                keys=["feature_0", "feature_1"],
+                values=torch.cat([torch.randint(16, (8,)), torch.randint(32, (16,))]),
+                lengths=torch.LongTensor([1] * 8 + [2] * 8),
+                weights=None,
+            ),
+            KeyedJaggedTensor.from_lengths_sync(
+                keys=["feature_0", "feature_1"],
+                values=torch.cat([torch.randint(16, (8,)), torch.randint(32, (16,))]),
+                lengths=torch.LongTensor([1] * 8 + [2] * 8),
+                weights=None,
+            ),
+        ]
+
+        infer_input = KeyedJaggedTensor.from_lengths_sync(
+            keys=["feature_0", "feature_1"],
+            values=torch.randint(
+                32,
+                (8,),
+            ),
+            lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1]),
+            weights=None,
+        )
+
+        train_info = multiprocessing.Manager().dict()
+
+        # Train Model with ZCH on GPU
+        import fbvscode
+
+        fbvscode.attach_debugger()
+        self._run_multi_process_test(
+            callable=_train_model,
+            world_size=WORLD_SIZE,
+            tables=embedding_config,
+            num_buckets=2,
+            kjt_input_per_rank=train_input_per_rank,
+            sharder=MappedEmbeddingCollectionSharder(),
+            return_dict=train_info,
+            backend="nccl",
+            infer_input=infer_input,
+        )
+        print(f"train_info: {train_info}")
+        # Load Train Model State Dict into Inference Model
+        inference_model = SparseArch(
+            tables=embedding_config,
+            device=torch.device("cpu"),
+            input_hash_size=0,
+            buckets=4,
+        )
+
+        merged_state_dict = {
+            "_ec.embeddings.table_0.weight": torch.cat(
+                [value for key, value in train_info.items() if "table_0" in key],
+                dim=0,
+            ),
+            "_ec.embeddings.table_1.weight": torch.cat(
+                [value for key, value in train_info.items() if "table_1" in key],
+                dim=0,
+            ),
+        }
+
+        inference_model.load_state_dict(merged_state_dict)
+
+        # Get Train Model Output
+        # train_output = inference_model(infer_input)
+
+        # Quantize Inference Model
+        quant_inference_model = trec_infer.modules.quantize_inference_model(
+            inference_model, QUANTIZATION_MAPPING, None, torch.quint4x2
+        )
+
+        # Get Quantized Inference Model Output
+        _, quant_output = quant_inference_model(infer_input)
+
+        # Verify Quantized Inference Model Output is close to Train Model Output
+        # TODO: [Kaus] Check why this fails
+        # self.assertTrue(
+        #     torch.allclose(
+        #         train_info["train_output_0"],
+        #         quant_output,
+        #         atol=1e-02,
+        #     )
+        # )
+
+        # Shard Quantized Inference Model
+        sharder = QuantMappedEmbeddingCollectionSharder(
+            fused_params=DEFAULT_FUSED_PARAMS
+        )
+        module_sharding_plan = construct_module_sharding_plan(
+            quant_inference_model._ec,  # pyre-ignore
+            per_param_sharding={"table_0": row_wise(), "table_1": row_wise()},
+            local_size=2,
+            world_size=WORLD_SIZE,
+            device_type="cpu",
+            sharder=sharder,  # pyre-ignore
+        )
+        set_propogate_device(True)
+
+        sharded_quant_inference_model = _shard_modules(
+            module=copy.deepcopy(quant_inference_model),
+            plan=ShardingPlan({"_mc_ec": module_sharding_plan}),
+            env=ShardingEnv.from_local(
+                WORLD_SIZE,
+                0,
+            ),
+            sharders=[sharder],  # pyre-ignore
+            device=torch.device("cpu"),
+        )
+
+        _, sharded_quant_output = sharded_quant_inference_model(infer_input)
+        self.assertTrue(
+            torch.allclose(
+                sharded_quant_output,
+                quant_output,
+                atol=0,
+            )
+        )
+
+
+def _train_model(
+    tables: List[EmbeddingConfig],
+    num_buckets: int,
+    rank: int,
+    world_size: int,
+    kjt_input_per_rank: List[KeyedJaggedTensor],
+    sharder: ModuleSharder[nn.Module],
+    backend: str,
+    return_dict: Dict[str, Any],
+    infer_input: KeyedJaggedTensor,
+    local_size: Optional[int] = None,
+) -> None:
+    with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
+        import fbvscode
+
+        fbvscode.attach_debugger()
+        kjt_input = kjt_input_per_rank[rank].to(ctx.device)
+
+        train_model = SparseArch(
+            tables=tables,
+            device=torch.device("cuda"),
+            input_hash_size=0,
+            buckets=num_buckets,
+        )
+        train_sharding_plan = construct_module_sharding_plan(
+            train_model._ec,
+            per_param_sharding={"table_0": row_wise(), "table_1": row_wise()},
+            local_size=local_size,
+            world_size=world_size,
+            device_type="cuda",
+            sharder=sharder,
+        )
+        print(f"train_sharding_plan: {train_sharding_plan}")
+        sharded_train_model = _shard_modules(
+            module=copy.deepcopy(train_model),
+            plan=ShardingPlan({"_ec": train_sharding_plan}),
+            env=ShardingEnv.from_process_group(none_throws(ctx.pg)),
+            sharders=[sharder],
+            device=ctx.device,
+        )
+        optim = torch.optim.SGD(sharded_train_model.parameters(), lr=0.1)
+        # train
+        optim.zero_grad()
+        sharded_train_model.train(True)
+        loss, output = sharded_train_model(kjt_input.to(ctx.device))
+        loss.backward()
+        optim.step()
+
+        # infer
+        with torch.no_grad():
+            sharded_train_model.train(False)
+            _, infer_output = sharded_train_model(infer_input.to(ctx.device))
+
+        return_dict[f"train_output_{rank}"] = infer_output.cpu()
+
+        for (
+            key,
+            value,
+            # pyre-ignore
+        ) in sharded_train_model._ec.embeddings.state_dict().items():
+            return_dict[f"ec_{key}_{rank}"] = value.cpu()
diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py
index 5dc18885a..34a84d761 100644
--- a/torchrec/distributed/tests/test_sharding_plan.py
+++ b/torchrec/distributed/tests/test_sharding_plan.py
@@ -816,6 +816,159 @@ def test_row_wise_set_heterogenous_device(self, data_type: DataType) -> None:
                 0,
             )
 
+    # pyre-fixme[56]
+    @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16]))
+    @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None)
+    def test_row_wise_bucket_level_sharding(self, data_type: DataType) -> None:
+
+        embedding_config = [
+            EmbeddingBagConfig(
+                name=f"table_{idx}",
+                feature_names=[f"feature_{idx}"],
+                embedding_dim=64,
+                num_embeddings=4096,
+                data_type=data_type,
+            )
+            for idx in range(2)
+        ]
+        module_sharding_plan = construct_module_sharding_plan(
+            EmbeddingCollection(tables=embedding_config),
+            per_param_sharding={
+                "table_0": row_wise(
+                    sizes_placement=(
+                        [2048, 1024, 1024],
+                        ["cpu", "cuda", "cuda"],
+                    ),
+                    num_buckets_per_rank=[20, 30, 40],
+                ),
+                "table_1": row_wise(
+                    sizes_placement=([2048, 1024, 1024], ["cpu", "cpu", "cpu"])
+                ),
+            },
+            local_size=1,
+            world_size=2,
+            device_type="cuda",
+        )
+
+        # Make sure per_param_sharding setting override the default device_type
+        device_table_0_shard_0 = (
+            # pyre-ignore[16]
+            module_sharding_plan["table_0"]
+            .sharding_spec.shards[0]
+            .placement
+        )
+        self.assertEqual(
+            device_table_0_shard_0.device().type,
+            "cpu",
+        )
+        # cpu always has rank 0
+        self.assertEqual(
+            device_table_0_shard_0.rank(),
+            0,
+        )
+        for i in range(1, 3):
+            device_table_0_shard_i = (
+                module_sharding_plan["table_0"].sharding_spec.shards[i].placement
+            )
+            self.assertEqual(
+                device_table_0_shard_i.device().type,
+                "cuda",
+            )
+            # first rank is assigned to cpu so index = rank - 1
+            self.assertEqual(
+                device_table_0_shard_i.device().index,
+                i - 1,
+            )
+            self.assertEqual(
+                device_table_0_shard_i.rank(),
+                i,
+            )
+        for i in range(3):
+            device_table_1_shard_i = (
+                module_sharding_plan["table_1"].sharding_spec.shards[i].placement
+            )
+            self.assertEqual(
+                device_table_1_shard_i.device().type,
+                "cpu",
+            )
+            # cpu always has rank 0
+            self.assertEqual(
+                device_table_1_shard_i.rank(),
+                0,
+            )
+
+        expected = {
+            "table_0": ParameterSharding(
+                sharding_type="row_wise",
+                compute_kernel="key_value",
+                ranks=[
+                    0,
+                    1,
+                    2,
+                ],
+                sharding_spec=EnumerableShardingSpec(
+                    shards=[
+                        ShardMetadata(
+                            shard_offsets=[0, 0],
+                            shard_sizes=[2048, 64],
+                            placement="rank:0/cpu",
+                            bucket_id_offset=0,
+                            num_buckets=20,
+                        ),
+                        ShardMetadata(
+                            shard_offsets=[2048, 0],
+                            shard_sizes=[1024, 64],
+                            placement="rank:1/cuda:0",
+                            bucket_id_offset=20,
+                            num_buckets=30,
+                        ),
+                        ShardMetadata(
+                            shard_offsets=[3072, 0],
+                            shard_sizes=[1024, 64],
+                            placement="rank:2/cuda:1",
+                            bucket_id_offset=50,
+                            num_buckets=40,
+                        ),
+                    ]
+                ),
+            ),
+            "table_1": ParameterSharding(
+                sharding_type="row_wise",
+                compute_kernel="quant",
+                ranks=[
+                    0,
+                    1,
+                    2,
+                ],
+                sharding_spec=EnumerableShardingSpec(
+                    shards=[
+                        ShardMetadata(
+                            shard_offsets=[0, 0],
+                            shard_sizes=[2048, 64],
+                            placement="rank:0/cpu",
+                            bucket_id_offset=None,
+                            num_buckets=None,
+                        ),
+                        ShardMetadata(
+                            shard_offsets=[2048, 0],
+                            shard_sizes=[1024, 64],
+                            placement="rank:0/cpu",
+                            bucket_id_offset=None,
+                            num_buckets=None,
+                        ),
+                        ShardMetadata(
+                            shard_offsets=[3072, 0],
+                            shard_sizes=[1024, 64],
+                            placement="rank:0/cpu",
+                            bucket_id_offset=None,
+                            num_buckets=None,
+                        ),
+                    ]
+                ),
+            ),
+        }
+        self.assertDictEqual(expected, module_sharding_plan)
+
     # pyre-fixme[56]
     @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16]))
     @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None)
@@ -929,18 +1082,89 @@ def test_str(self) -> None:
         )
         expected = """module: ebc
 
- param   | sharding type | compute kernel | ranks
+ param   | sharding type | compute kernel | ranks  
 -------- | ------------- | -------------- | ------
 user_id  | table_wise    | dense          | [0]
 movie_id | row_wise      | dense          | [0, 1]
 
- param   | shard offsets | shard sizes |   placement
+ param   | shard offsets | shard sizes |   placement  
 -------- | ------------- | ----------- | -------------
 user_id  | [0, 0]        | [4096, 32]  | rank:0/cuda:0
 movie_id | [0, 0]        | [2048, 32]  | rank:0/cuda:0
 movie_id | [2048, 0]     | [2048, 32]  | rank:0/cuda:1
 """
         self.maxDiff = None
+        print("STR PLAN")
+        print(str(plan))
+        print("=======")
+        for i in range(len(expected.splitlines())):
+            self.assertEqual(
+                expected.splitlines()[i].strip(), str(plan).splitlines()[i].strip()
+            )
+
+    def test_str_bucket_wise_sharding(self) -> None:
+        plan = ShardingPlan(
+            {
+                "ebc": EmbeddingModuleShardingPlan(
+                    {
+                        "user_id": ParameterSharding(
+                            sharding_type="table_wise",
+                            compute_kernel="dense",
+                            ranks=[0],
+                            sharding_spec=EnumerableShardingSpec(
+                                [
+                                    ShardMetadata(
+                                        shard_offsets=[0, 0],
+                                        shard_sizes=[4096, 32],
+                                        placement="rank:0/cuda:0",
+                                    ),
+                                ]
+                            ),
+                        ),
+                        "movie_id": ParameterSharding(
+                            sharding_type="row_wise",
+                            compute_kernel="dense",
+                            ranks=[0, 1],
+                            sharding_spec=EnumerableShardingSpec(
+                                [
+                                    ShardMetadata(
+                                        shard_offsets=[0, 0],
+                                        shard_sizes=[2048, 32],
+                                        placement="rank:0/cuda:0",
+                                        bucket_id_offset=0,
+                                        num_buckets=20,
+                                    ),
+                                    ShardMetadata(
+                                        shard_offsets=[2048, 0],
+                                        shard_sizes=[2048, 32],
+                                        placement="rank:0/cuda:1",
+                                        bucket_id_offset=20,
+                                        num_buckets=30,
+                                    ),
+                                ]
+                            ),
+                        ),
+                    }
+                )
+            }
+        )
+        expected = """module: ebc
+        
+ param   | sharding type | compute kernel | ranks  
+-------- | ------------- | -------------- | ------
+user_id  | table_wise    | dense          | [0]
+movie_id | row_wise      | dense          | [0, 1]
+
+ param   | shard offsets | shard sizes |   placement   | bucket id offset | num buckets
+-------- | ------------- | ----------- | ------------- | ---------------- | -----------
+user_id  | [0, 0]        | [4096, 32]  | rank:0/cuda:0 | None             | None       
+movie_id | [0, 0]        | [2048, 32]  | rank:0/cuda:0 | 0                | 20       
+movie_id | [2048, 0]     | [2048, 32]  | rank:0/cuda:1 | 20               | 30       
+"""
+        self.maxDiff = None
+        print("STR PLAN BUCKET WISE")
+        print(str(plan))
+        print("=======")
         for i in range(len(expected.splitlines())):
             self.assertEqual(
                 expected.splitlines()[i].strip(), str(plan).splitlines()[i].strip()
diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py
index 1d45bff69..2be9a9c86 100644
--- a/torchrec/distributed/types.py
+++ b/torchrec/distributed/types.py
@@ -737,6 +737,7 @@ def __str__(self) -> str:
         out = ""
         param_table = []
         shard_table = []
+        contains_bucket_wise_shards = False
         for param_name, param_sharding in self.items():
             param_table.append(
                 [
@@ -749,20 +750,54 @@ def __str__(self) -> str:
             if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec):
                 shards = param_sharding.sharding_spec.shards
                 if shards is not None:
+                    param_sharding_contains_bucket_info = any(
+                        shard.bucket_id_offset is not None for shard in shards
+                    )
+                    if param_sharding_contains_bucket_info:
+                        contains_bucket_wise_shards = True
                     for shard in shards:
-                        shard_table.append(
+                        cols = (
                             [
                                 param_name,
                                 shard.shard_offsets,
                                 shard.shard_sizes,
                                 shard.placement,
                             ]
+                            if param_sharding_contains_bucket_info is None
+                            else [
+                                param_name,
+                                shard.shard_offsets,
+                                shard.shard_sizes,
+                                shard.placement,
+                                shard.bucket_id_offset,
+                                shard.num_buckets,
+                            ]
                         )
+                        shard_table.append(cols)
+        if contains_bucket_wise_shards:
+            for i in range(len(shard_table)):
+                if len(shard_table[i]) == 4:
+                    # add None for the tables that don't have bucket info
+                    shard_table[i].append(None)
+                    shard_table[i].append(None)
         out += "\n\n" + _tabulate(
             param_table, ["param", "sharding type", "compute kernel", "ranks"]
         )
+        column_str = (
+            ["param", "shard offsets", "shard sizes", "placement"]
+            if not contains_bucket_wise_shards
+            else [
+                "param",
+                "shard offsets",
+                "shard sizes",
+                "placement",
+                "bucket id offset",
+                "num buckets",
+            ]
+        )
         out += "\n\n" + _tabulate(
-            shard_table, ["param", "shard offsets", "shard sizes", "placement"]
+            shard_table,
+            column_str,
         )
         return out
 
diff --git a/torchrec/modules/mapped_embedding_module.py b/torchrec/modules/mapped_embedding_module.py
new file mode 100644
index 000000000..a9173592e
--- /dev/null
+++ b/torchrec/modules/mapped_embedding_module.py
@@ -0,0 +1,29 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#!/usr/bin/env python3
+
+from typing import List, Optional
+
+import torch
+
+from torchrec.modules.embedding_configs import EmbeddingConfig
+
+from torchrec.modules.embedding_modules import EmbeddingCollection
+
+
+class MappedEmbeddingCollection(EmbeddingCollection):
+    """ """
+
+    def __init__(
+        self,
+        tables: List[EmbeddingConfig],
+        device: Optional[torch.device] = None,
+        need_indices: bool = False,
+    ) -> None:
+        super().__init__(tables=tables, need_indices=need_indices, device=device)
+        torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}")
diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py
index 81b9b8bfc..108f4ae1d 100644
--- a/torchrec/quant/embedding_modules.py
+++ b/torchrec/quant/embedding_modules.py
@@ -55,6 +55,7 @@
 from torchrec.modules.fp_embedding_modules import (
     FeatureProcessedEmbeddingBagCollection as OriginalFeatureProcessedEmbeddingBagCollection,
 )
+from torchrec.modules.mapped_embedding_module import MappedEmbeddingCollection
 from torchrec.modules.mc_embedding_modules import (
     ManagedCollisionEmbeddingCollection as OriginalManagedCollisionEmbeddingCollection,
 )
@@ -1007,6 +1008,82 @@ def device(self) -> torch.device:
         return self._device
 
 
+class QuantMappedEmbeddingCollection(EmbeddingCollection):
+    """ """
+
+    def __init__(
+        self,
+        tables: List[EmbeddingConfig],
+        device: torch.device,
+        need_indices: bool = False,
+        output_dtype: torch.dtype = torch.float,
+        table_name_to_quantized_weights: Optional[
+            Dict[str, Tuple[Tensor, Tensor]]
+        ] = None,
+        register_tbes: bool = False,
+        quant_state_dict_split_scale_bias: bool = False,
+        row_alignment: int = DEFAULT_ROW_ALIGNMENT,
+        cache_features_order: bool = False,
+    ) -> None:
+        super().__init__(
+            tables,
+            device,
+            need_indices,
+            output_dtype,
+            table_name_to_quantized_weights,
+            register_tbes,
+            quant_state_dict_split_scale_bias,
+            row_alignment,
+            cache_features_order,
+        )
+
+    @classmethod
+    # pyre-ignore
+    def from_float(
+        cls,
+        module: MappedEmbeddingCollection,
+        use_precomputed_fake_quant: bool = False,
+    ) -> "QuantMappedEmbeddingCollection":
+        assert hasattr(
+            module, "qconfig"
+        ), "MappedEmbeddingCollection input float module must have qconfig defined"
+        embedding_configs = copy.deepcopy(module.embedding_configs())
+        _update_embedding_configs(
+            cast(List[BaseEmbeddingConfig], embedding_configs),
+            # pyre-fixme[6]: For 2nd argument expected `Union[QuantConfig, QConfig]`
+            #  but got `Union[Module, Tensor]`.
+            module.qconfig,
+        )
+        table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]] = {}
+        device = quantize_state_dict(
+            module,
+            table_name_to_quantized_weights,
+            {table.name: table.data_type for table in embedding_configs},
+        )
+        return cls(
+            embedding_configs,
+            device=device,
+            need_indices=module.need_indices(),
+            # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
+            #  `activation`.
+            output_dtype=module.qconfig.activation().dtype,
+            table_name_to_quantized_weights=table_name_to_quantized_weights,
+            register_tbes=getattr(module, MODULE_ATTR_REGISTER_TBES_BOOL, False),
+            quant_state_dict_split_scale_bias=getattr(
+                module, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False
+            ),
+            row_alignment=getattr(
+                module, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT
+            ),
+            cache_features_order=getattr(
+                module, MODULE_ATTR_CACHE_FEATURES_ORDER, False
+            ),
+        )
+
+    def _get_name(self) -> str:
+        return "QuantMappedEmbeddingCollection"
+
+
 class QuantManagedCollisionEmbeddingCollection(EmbeddingCollection):
     """
     QuantManagedCollisionEmbeddingCollection represents a quantized EC module and a set of managed collision modules.