Skip to content

support bucket offsets and sizes in shard metadata #2884

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions torchrec/distributed/mapped_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#!/usr/bin/env python3

from typing import Any, Dict, List, Optional, Type

import torch

from torchrec.distributed.embedding import (
EmbeddingCollectionSharder,
ShardedEmbeddingCollection,
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.types import (
ParameterSharding,
QuantizedCommCodecs,
ShardingEnv,
ShardingType,
)
from torchrec.modules.mapped_embedding_module import MappedEmbeddingCollection


class ShardedMappedEmbeddingCollection(ShardedEmbeddingCollection):
def __init__(
self,
module: MappedEmbeddingCollection,
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
env: ShardingEnv,
device: Optional[torch.device],
fused_params: Optional[Dict[str, Any]] = None,
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
use_index_dedup: bool = False,
module_fqn: Optional[str] = None,
) -> None:
super().__init__(
module=module,
table_name_to_parameter_sharding=table_name_to_parameter_sharding,
env=env,
device=device,
fused_params=fused_params,
qcomm_codecs_registry=qcomm_codecs_registry,
use_index_dedup=use_index_dedup,
module_fqn=module_fqn,
)


class MappedEmbeddingCollectionSharder(EmbeddingCollectionSharder):
def __init__(
self,
fused_params: Optional[Dict[str, Any]] = None,
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
use_index_dedup: bool = False,
) -> None:
super().__init__(
fused_params=fused_params,
qcomm_codecs_registry=qcomm_codecs_registry,
use_index_dedup=use_index_dedup,
)

def sharding_types(self, compute_device_type: str) -> List[str]:
return [ShardingType.ROW_WISE.value]

def compute_kernels(
self,
*args: Any,
**kwargs: Any,
) -> List[str]:
return [EmbeddingComputeKernel.KEY_VALUE.value]

@property
def module_type(self) -> Type[MappedEmbeddingCollection]:
return MappedEmbeddingCollection

# pyre-ignore: Inconsistent override [14]
def shard(
self,
module: MappedEmbeddingCollection,
params: Dict[str, ParameterSharding],
env: ShardingEnv,
device: Optional[torch.device] = None,
module_fqn: Optional[str] = None,
) -> ShardedMappedEmbeddingCollection:
return ShardedMappedEmbeddingCollection(
module=module,
table_name_to_parameter_sharding=params,
env=env,
device=device,
fused_params=self.fused_params,
qcomm_codecs_registry=self.qcomm_codecs_registry,
use_index_dedup=self._use_index_dedup,
module_fqn=module_fqn,
)
3 changes: 3 additions & 0 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,9 @@ class QuantEmbeddingCollectionSharder(
This implementation uses non-fused EmbeddingCollection
"""

def __init__(self, fused_params: Optional[Dict[str, Any]] = None) -> None:
super().__init__(fused_params)

def shard(
self,
module: QuantEmbeddingCollection,
Expand Down
66 changes: 66 additions & 0 deletions torchrec/distributed/quant_mapped_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#!/usr/bin/env python3

from typing import Any, Dict, Optional, Type

import torch

from torchrec.distributed.quant_embedding import (
QuantEmbeddingCollectionSharder,
ShardedQuantEmbeddingCollection,
)
from torchrec.distributed.types import ParameterSharding, ShardingEnv
from torchrec.quant.embedding_modules import QuantMappedEmbeddingCollection


class ShardedQuantMappedEmbeddingCollection(ShardedQuantEmbeddingCollection):
def __init__(
self,
module: QuantMappedEmbeddingCollection,
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
env: ShardingEnv,
fused_params: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
) -> None:
super().__init__(
module,
table_name_to_parameter_sharding,
env,
fused_params,
device,
)


class QuantMappedEmbeddingCollectionSharder(QuantEmbeddingCollectionSharder):
def __init__(
self,
fused_params: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(fused_params)

@property
def module_type(self) -> Type[QuantMappedEmbeddingCollection]:
return QuantMappedEmbeddingCollection

# pyre-ignore [14]
def shard(
self,
module: QuantMappedEmbeddingCollection,
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
env: ShardingEnv,
fused_params: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
) -> ShardedQuantMappedEmbeddingCollection:
return ShardedQuantMappedEmbeddingCollection(
module,
table_name_to_parameter_sharding,
env,
fused_params=fused_params,
device=device,
)
52 changes: 47 additions & 5 deletions torchrec/distributed/sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
FeatureProcessedEmbeddingBagCollectionSharder,
)
from torchrec.distributed.fused_embeddingbag import FusedEmbeddingBagCollectionSharder
from torchrec.distributed.mapped_embedding import MappedEmbeddingCollectionSharder
from torchrec.distributed.mc_embedding import ManagedCollisionEmbeddingCollectionSharder
from torchrec.distributed.mc_embeddingbag import (
ManagedCollisionEmbeddingBagCollectionSharder,
Expand All @@ -34,6 +35,9 @@
QuantManagedCollisionEmbeddingCollectionSharder,
)
from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder
from torchrec.distributed.quant_mapped_embedding import (
QuantMappedEmbeddingCollectionSharder,
)
from torchrec.distributed.types import (
EmbeddingModuleShardingPlan,
EnumerableShardingSpec,
Expand Down Expand Up @@ -62,6 +66,8 @@ def get_default_sharders() -> List[ModuleSharder[nn.Module]]:
InferManagedCollisionCollectionSharder(),
),
),
cast(ModuleSharder[nn.Module], MappedEmbeddingCollectionSharder()),
cast(ModuleSharder[nn.Module], QuantMappedEmbeddingCollectionSharder()),
]


Expand Down Expand Up @@ -361,6 +367,7 @@ def _get_parameter_sharding(
sharder: ModuleSharder[nn.Module],
placements: Optional[List[str]] = None,
compute_kernel: Optional[str] = None,
bucket_offset_sizes: Optional[List[Tuple[int, int]]] = None,
) -> ParameterSharding:
return ParameterSharding(
sharding_spec=(
Expand All @@ -371,6 +378,8 @@ def _get_parameter_sharding(
ShardMetadata(
shard_sizes=size,
shard_offsets=offset,
bucket_id_offset=bucket_id_offset,
num_buckets=num_buckets,
placement=(
placement(
device_type,
Expand All @@ -381,9 +390,17 @@ def _get_parameter_sharding(
else device_placement
),
)
for (size, offset, rank), device_placement in zip(
for (size, offset, rank), device_placement, (
num_buckets,
bucket_id_offset,
) in zip(
size_offset_ranks,
placements if placements else [None] * len(size_offset_ranks),
(
bucket_offset_sizes
if bucket_offset_sizes
else [(None, None)] * len(size_offset_ranks)
),
)
]
)
Expand Down Expand Up @@ -512,7 +529,8 @@ def _parameter_sharding_generator(


def row_wise(
sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None
sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None,
num_buckets_per_rank: Optional[List[int]] = None, # propagate num buckets per rank
) -> ParameterShardingGenerator:
"""
Returns a generator of ParameterShardingPlan for `ShardingType::ROW_WISE` for construct_module_sharding_plan.
Expand Down Expand Up @@ -545,6 +563,7 @@ def _parameter_sharding_generator(
device_type: str,
sharder: ModuleSharder[nn.Module],
) -> ParameterSharding:
bucket_offset_sizes = None
if sizes_placement is None:
size_and_offsets = _get_parameter_size_offsets(
param,
Expand All @@ -558,17 +577,34 @@ def _parameter_sharding_generator(
size_offset_ranks.append((size, offset, rank))
else:
size_offset_ranks = []
bucket_offset_sizes = None if num_buckets_per_rank is None else []
sizes = sizes_placement[0]
if num_buckets_per_rank is not None:
assert len(sizes) == len(
num_buckets_per_rank
), f"sizes and num_buckets_per_rank must have the same length during row_wise sharding, got {len(sizes)} and {len(num_buckets_per_rank)} respectively"
(rows, cols) = param.shape
cur_offset = 0
prev_offset = 0
prev_bucket_offset = 0
cur_bucket_offset = 0
for rank, size in enumerate(sizes):
per_rank_row = size
per_rank_bucket_size = None
if num_buckets_per_rank is not None:
per_rank_bucket_size = num_buckets_per_rank[rank]
cur_bucket_offset += per_rank_bucket_size
cur_offset += per_rank_row
cur_offset = min(cur_offset, rows)
per_rank_row = cur_offset - prev_offset
size_offset_ranks.append(([per_rank_row, cols], [prev_offset, 0], rank))
prev_offset = cur_offset
if num_buckets_per_rank is not None:
# bucket has only one col for now
none_throws(bucket_offset_sizes).append(
(per_rank_bucket_size, prev_bucket_offset)
)
prev_bucket_offset = cur_bucket_offset

if cur_offset < rows:
raise ValueError(
Expand All @@ -590,6 +626,13 @@ def _parameter_sharding_generator(
if device_type == "cuda":
index += 1

compute_kernel = None
if sizes_placement is not None:
if num_buckets_per_rank is not None:
compute_kernel = EmbeddingComputeKernel.KEY_VALUE.value
else:
compute_kernel = EmbeddingComputeKernel.QUANT.value

return _get_parameter_sharding(
param,
ShardingType.ROW_WISE.value,
Expand All @@ -598,9 +641,8 @@ def _parameter_sharding_generator(
device_type,
sharder,
placements=placements if sizes_placement else None,
compute_kernel=(
EmbeddingComputeKernel.QUANT.value if sizes_placement else None
),
compute_kernel=compute_kernel,
bucket_offset_sizes=bucket_offset_sizes,
)

return _parameter_sharding_generator
Expand Down
Loading
Loading