Skip to content

Bucket offsets and sizes in torchrec shard metadata for bucket wise sharding (#2884) #2885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 301 additions & 11 deletions torchrec/distributed/batched_embedding_kernel.py

Large diffs are not rendered by default.

26 changes: 21 additions & 5 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ def create_sharding_infos_by_sharding(
embedding_names=embedding_names,
weight_init_max=config.weight_init_max,
weight_init_min=config.weight_init_min,
total_num_buckets=config.total_num_buckets,
zero_collision=config.zero_collision,
),
param_sharding=parameter_sharding,
param=param,
Expand Down Expand Up @@ -353,6 +355,8 @@ def create_sharding_infos_by_sharding_device_group(
embedding_names=embedding_names,
weight_init_max=config.weight_init_max,
weight_init_min=config.weight_init_min,
total_num_buckets=config.total_num_buckets,
zero_collision=config.zero_collision,
),
param_sharding=parameter_sharding,
param=param,
Expand Down Expand Up @@ -767,11 +771,13 @@ def _initialize_torch_state(self) -> None: # noqa
)

self._name_to_table_size = {}
table_zero_collision = {}
for table in self._embedding_configs:
self._name_to_table_size[table.name] = (
table.num_embeddings,
table.embedding_dim,
)
table_zero_collision[table.name] = table.zero_collision

for sharding_type, lookup in zip(
self._sharding_type_to_sharding.keys(), self._lookups
Expand Down Expand Up @@ -871,8 +877,9 @@ def _initialize_torch_state(self) -> None: # noqa
# created ShardedTensors once in init, use in post_state_dict_hook
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
# access is allowed on them.
# for collision free TBE, the shard sizes should be recalculated during ShardedTensor initilization
self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards(
ShardedTensor._init_from_local_shards_and_reset_offsets(
local_shards,
self._name_to_table_size[table_name],
process_group=(
Expand Down Expand Up @@ -925,20 +932,29 @@ def post_state_dict_hook(
return

sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors)
sharded_id_buckets_state_dict = None
for lookup, sharding_type in zip(
module._lookups, module._sharding_type_to_sharding.keys()
):
if sharding_type != ShardingType.DATA_PARALLEL.value:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
for key, v in lookup.get_named_split_embedding_weights_snapshot():
assert key in sharded_kvtensors_copy
sharded_kvtensors_copy[key].local_shards()[0].tensor = v
for (
key,
v,
) in lookup.get_named_split_embedding_weights_snapshot(): # pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
if key in sharded_kvtensors_copy:
sharded_kvtensors_copy[key].local_shards()[0].tensor = v
else:
d_k = f"{prefix}embeddings.{key}"
destination[d_k] = v
logger.info(f"add sharded tensor key {d_k} to state dict")
for (
table_name,
sharded_kvtensor,
) in sharded_kvtensors_copy.items():
destination_key = f"{prefix}embeddings.{table_name}.weight"
destination[destination_key] = sharded_kvtensor
if sharded_id_buckets_state_dict:
destination.update(sharded_id_buckets_state_dict)

self.register_state_dict_pre_hook(self._pre_state_dict_hook)
self._register_state_dict_hook(post_state_dict_hook)
Expand Down
13 changes: 9 additions & 4 deletions torchrec/distributed/embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# pyre-strict

import abc
import copy
import logging
from collections import defaultdict, OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -103,9 +104,10 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
qbias = param[2]
param = param[0]

assert embedding_table.local_rows == param.size( # pyre-ignore[16]
0
), f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}" # pyre-ignore[16]
if not embedding_table.zero_collision:
assert embedding_table.local_rows == param.size( # pyre-ignore[16]
0
), f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}" # pyre-ignore[16]

if qscale is not None:
assert embedding_table.local_cols == param.size(1) # pyre-ignore[16]
Expand All @@ -128,14 +130,17 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
param.requires_grad # pyre-ignore[16]
)
key_to_global_metadata[key] = embedding_table.global_metadata
local_metadata = copy.deepcopy(embedding_table.local_metadata)
local_metadata.shard_sizes = list(param.size())

key_to_local_shards[key].append(
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[Module, Tensor]`.
# pyre-fixme[6]: For 2nd argument expected `ShardMetadata` but got
# `Optional[ShardMetadata]`.
Shard(param, embedding_table.local_metadata)
Shard(param, local_metadata)
)

else:
destination[key] = param
if qscale is not None:
Expand Down
15 changes: 15 additions & 0 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,16 @@ def _prefetch_and_cached(
)


def _is_kv_tbe(
table: ShardedEmbeddingTable,
) -> bool:
"""
Return true if this embedding enabled bucketized sharding for kv style TBE to support ZCH v.Next.
https://docs.google.com/document/d/13atWlDEkrkRulgC_gaoLv8ZsogQefdvsTdwlyam7ed0/edit?tab=t.0#heading=h.lxb1lainm4tc
"""
return table.zero_collision


def _all_tables_are_quant_kernel(
tables: List[ShardedEmbeddingTable],
) -> bool:
Expand Down Expand Up @@ -558,7 +568,9 @@ def _group_tables_per_rank(
table.data_type,
),
_prefetch_and_cached(table),
_is_kv_tbe(table),
)
print(f"line debug: grouping_key: {grouping_key}")
# micromanage the order of we traverse the groups to ensure backwards compatibility
if grouping_key not in groups:
grouping_keys.append(grouping_key)
Expand All @@ -573,6 +585,7 @@ def _group_tables_per_rank(
compute_kernel_type,
_,
_,
is_kv_tbe,
) = grouping_key
grouped_tables = groups[grouping_key]
# remove non-native fused params
Expand All @@ -581,6 +594,8 @@ def _group_tables_per_rank(
for k, v in fused_params_tuple
if k not in ["_batch_key", USE_ONE_TBE_PER_TABLE]
}
if is_kv_tbe:
per_tbe_fused_params["enable_zero_collision_tbe"] = True
cache_load_factor = _get_weighted_avg_cache_load_factor(grouped_tables)
if cache_load_factor is not None:
per_tbe_fused_params[CACHE_LOAD_FACTOR_STR] = cache_load_factor
Expand Down
15 changes: 15 additions & 0 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,21 @@ def feature_hash_sizes(self) -> List[int]:
feature_hash_sizes.extend(table.num_features() * [table.num_embeddings])
return feature_hash_sizes

def feature_total_num_buckets(self) -> Optional[List[int]]:
feature_total_num_buckets = []
for table in self.embedding_tables:
if table.total_num_buckets:
feature_total_num_buckets.extend(
table.num_features() * [table.total_num_buckets]
)
return feature_total_num_buckets if len(feature_total_num_buckets) > 0 else None

def _is_zero_collision(self) -> bool:
for table in self.embedding_tables:
if table.zero_collision:
return True
return False

def num_features(self) -> int:
num_features = 0
for table in self.embedding_tables:
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,9 @@ def __init__(
)
if device is not None:
self._emb_module.initialize_weights()
self._enable_zero_collision_tbe: bool = any(
table.zero_collision for table in config.embedding_tables
)

@property
def emb_module(
Expand Down
12 changes: 12 additions & 0 deletions torchrec/distributed/sharding/rw_sequence_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def create_input_dist(
) -> BaseSparseFeaturesDist[KeyedJaggedTensor]:
num_features = self._get_num_features()
feature_hash_sizes = self._get_feature_hash_sizes()
is_zero_collision = any(
emb_config._is_zero_collision()
for emb_config in self._grouped_embedding_configs
)
feature_total_num_buckets = self._get_feature_total_num_buckets()
return RwSparseFeaturesDist(
# pyre-fixme[6]: For 1st param expected `ProcessGroup` but got
# `Optional[ProcessGroup]`.
Expand All @@ -136,6 +141,8 @@ def create_input_dist(
is_sequence=True,
has_feature_processor=self._has_feature_processor,
need_pos=False,
feature_total_num_buckets=feature_total_num_buckets,
keep_original_indices=is_zero_collision,
)

def create_lookup(
Expand Down Expand Up @@ -265,6 +272,10 @@ def create_input_dist(
(emb_sharding, is_even_sharding) = get_embedding_shard_metadata(
self._grouped_embedding_configs_per_rank
)
is_zero_collision = any(
emb_config._is_zero_collision()
for emb_config in self._grouped_embedding_configs
)

return InferRwSparseFeaturesDist(
world_size=self._world_size,
Expand All @@ -275,6 +286,7 @@ def create_input_dist(
has_feature_processor=self._has_feature_processor,
need_pos=False,
embedding_shard_metadata=emb_sharding if not is_even_sharding else None,
keep_original_indices=is_zero_collision,
)

def create_lookup(
Expand Down
26 changes: 26 additions & 0 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
ShardingType,
ShardMetadata,
)
from torchrec.distributed.utils import none_throws
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.streamable import Multistreamable

Expand Down Expand Up @@ -137,6 +138,9 @@ def __init__(
[]
)
self._grouped_embedding_configs_per_rank = group_tables(sharded_tables_per_rank)
logger.info(
f"self._grouped_embedding_configs_per_rank: {self._grouped_embedding_configs_per_rank}"
)
self._grouped_embedding_configs: List[GroupedEmbeddingConfig] = (
self._grouped_embedding_configs_per_rank[self._rank]
)
Expand All @@ -146,6 +150,9 @@ def __init__(
if group_config.has_feature_processor:
self._has_feature_processor = True

for group_config in self._grouped_embedding_configs:
group_config.feature_names

def _shard(
self,
sharding_infos: List[EmbeddingShardingInfo],
Expand Down Expand Up @@ -217,6 +224,8 @@ def _shard(
weight_init_min=info.embedding_config.weight_init_min,
fused_params=info.fused_params,
num_embeddings_post_pruning=info.embedding_config.num_embeddings_post_pruning,
total_num_buckets=info.embedding_config.total_num_buckets,
zero_collision=info.embedding_config.zero_collision,
)
)
return tables_per_rank
Expand Down Expand Up @@ -272,6 +281,23 @@ def _get_feature_hash_sizes(self) -> List[int]:
feature_hash_sizes.extend(group_config.feature_hash_sizes())
return feature_hash_sizes

def _get_feature_total_num_buckets(self) -> Optional[List[int]]:
feature_total_num_buckets: List[int] = []
for group_config in self._grouped_embedding_configs:
if group_config.feature_total_num_buckets() is not None:
feature_total_num_buckets.extend(
none_throws(group_config.feature_total_num_buckets())
)
return (
feature_total_num_buckets if len(feature_total_num_buckets) > 0 else None
) # If no feature_total_num_buckets is provided, we return None to keep backward compatibility.

def _is_zero_collision(self) -> bool:
for group_config in self._grouped_embedding_configs:
if group_config._is_zero_collision():
return True
return False


class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]):
"""
Expand Down
35 changes: 33 additions & 2 deletions torchrec/distributed/sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def _get_parameter_sharding(
sharder: ModuleSharder[nn.Module],
placements: Optional[List[str]] = None,
compute_kernel: Optional[str] = None,
bucket_offset_sizes: Optional[List[Tuple[int, int]]] = None,
) -> ParameterSharding:
return ParameterSharding(
sharding_spec=(
Expand All @@ -371,6 +372,8 @@ def _get_parameter_sharding(
ShardMetadata(
shard_sizes=size,
shard_offsets=offset,
bucket_id_offset=bucket_id_offset,
num_buckets=num_buckets,
placement=(
placement(
device_type,
Expand All @@ -381,9 +384,17 @@ def _get_parameter_sharding(
else device_placement
),
)
for (size, offset, rank), device_placement in zip(
for (size, offset, rank), device_placement, (
num_buckets,
bucket_id_offset,
) in zip(
size_offset_ranks,
placements if placements else [None] * len(size_offset_ranks),
(
bucket_offset_sizes
if bucket_offset_sizes
else [(None, None)] * len(size_offset_ranks)
),
)
]
)
Expand Down Expand Up @@ -512,7 +523,8 @@ def _parameter_sharding_generator(


def row_wise(
sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None
sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None,
num_buckets_per_rank: Optional[List[int]] = None, # propagate num buckets per rank
) -> ParameterShardingGenerator:
"""
Returns a generator of ParameterShardingPlan for `ShardingType::ROW_WISE` for construct_module_sharding_plan.
Expand Down Expand Up @@ -545,6 +557,7 @@ def _parameter_sharding_generator(
device_type: str,
sharder: ModuleSharder[nn.Module],
) -> ParameterSharding:
bucket_offset_sizes = None
if sizes_placement is None:
size_and_offsets = _get_parameter_size_offsets(
param,
Expand All @@ -558,17 +571,34 @@ def _parameter_sharding_generator(
size_offset_ranks.append((size, offset, rank))
else:
size_offset_ranks = []
bucket_offset_sizes = None if num_buckets_per_rank is None else []
sizes = sizes_placement[0]
if num_buckets_per_rank is not None:
assert len(sizes) == len(
num_buckets_per_rank
), f"sizes and num_buckets_per_rank must have the same length during row_wise sharding, got {len(sizes)} and {len(num_buckets_per_rank)} respectively"
(rows, cols) = param.shape
cur_offset = 0
prev_offset = 0
prev_bucket_offset = 0
cur_bucket_offset = 0
for rank, size in enumerate(sizes):
per_rank_row = size
per_rank_bucket_size = None
if num_buckets_per_rank is not None:
per_rank_bucket_size = num_buckets_per_rank[rank]
cur_bucket_offset += per_rank_bucket_size
cur_offset += per_rank_row
cur_offset = min(cur_offset, rows)
per_rank_row = cur_offset - prev_offset
size_offset_ranks.append(([per_rank_row, cols], [prev_offset, 0], rank))
prev_offset = cur_offset
if num_buckets_per_rank is not None:
# bucket has only one col for now
none_throws(bucket_offset_sizes).append(
(per_rank_bucket_size, prev_bucket_offset)
)
prev_bucket_offset = cur_bucket_offset

if cur_offset < rows:
raise ValueError(
Expand Down Expand Up @@ -601,6 +631,7 @@ def _parameter_sharding_generator(
compute_kernel=(
EmbeddingComputeKernel.QUANT.value if sizes_placement else None
),
bucket_offset_sizes=bucket_offset_sizes,
)

return _parameter_sharding_generator
Expand Down
Loading
Loading