Skip to content

Commit 8819393

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Enable random weights for unit test + Doc (#2889)
Summary: Pull Request resolved: #2889 Enable Random weights for unit test. When testing for DMP interface for Dynamic Sharding, I'm noticing discrepancies in predictions. Still debugging this case, but will be enabling random weights by default for the initial dynamic sharding interface and keep the debug values as an optional flag. Main changes: 1. Added comment to `copy_state_dict` in `test_sharding` to make it clear it is the global state_dict being copied to the local 2. Removing redundant `copy_state_dict` use in dynamic sharding unit test set up, since already using `load_state_dict` 3. Added `use_debug_state_dict` flag defaulted to `False` - if turned on this will force the test models to have dummy int values in embeddings weights. 4. With `use_debug_state_dict` turned off, the weights will be randomly generated up-on initialization of the EBCs. 1. Note: `torch.manual_seed(0)` - is needed to force the EBCs to be initialized with the same float values across ranks in the distributed env. 2. Alternate approach could be to initialize the global EBCs outside of distributed test process, but since this is just unit test, I can keep as is. Reviewed By: TroyGarden Differential Revision: D73077322 fbshipit-source-id: 093f23c10b73a90b61429c4109e484627270bd46
1 parent 8581ea1 commit 8819393

File tree

2 files changed

+37
-30
lines changed

2 files changed

+37
-30
lines changed

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,9 @@ def copy_state_dict(
254254
glob: Dict[str, torch.Tensor],
255255
exclude_predfix: Optional[str] = None,
256256
) -> None:
257+
"""
258+
Copies the contents of the global tensors in glob to the local tensors in loc.
259+
"""
257260
for name, tensor in loc.items():
258261
if exclude_predfix is not None and name.startswith(exclude_predfix):
259262
continue

torchrec/distributed/tests/test_dynamic_sharding.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,8 @@ def _test_ebc_resharding(
213213
trec_dist.comm_ops.set_gradient_division(False)
214214
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
215215
kjt_input_per_rank = [kjt.to(ctx.device) for kjt in kjt_input_per_rank]
216-
217-
initial_state_dict = {
218-
fqn: tensor.to(ctx.device) for fqn, tensor in initial_state_dict.items()
219-
}
216+
# Set seed to be 0 to ensure models have the same initialization across ranks
217+
torch.manual_seed(0)
220218
m1 = EmbeddingBagCollection(
221219
tables=tables,
222220
device=ctx.device,
@@ -226,19 +224,24 @@ def _test_ebc_resharding(
226224
tables=tables,
227225
device=ctx.device,
228226
)
229-
230-
# Load initial State - making sure models are identical
231-
m1.load_state_dict(initial_state_dict)
232-
copy_state_dict(
233-
loc=m1.state_dict(),
234-
glob=copy.deepcopy(initial_state_dict),
235-
)
236-
237-
m2.load_state_dict(initial_state_dict)
238-
copy_state_dict(
239-
loc=m2.state_dict(),
240-
glob=copy.deepcopy(initial_state_dict),
241-
)
227+
if initial_state_dict is not None:
228+
initial_state_dict = {
229+
fqn: tensor.to(ctx.device) for fqn, tensor in initial_state_dict.items()
230+
}
231+
232+
# Load initial State - making sure models are identical
233+
m1.load_state_dict(initial_state_dict)
234+
235+
m2.load_state_dict(initial_state_dict)
236+
237+
else:
238+
# Note this is the only correct behavior due to setting random seed to 0 above
239+
# Otherwise the weights generated in EBC initialization will be different on
240+
# Each rank, resulting in different behavior after resharding
241+
copy_state_dict(
242+
loc=m2.state_dict(),
243+
glob=m1.state_dict(),
244+
)
242245

243246
sharder = get_module_to_default_sharders()[type(m1)]
244247

@@ -278,8 +281,8 @@ def _test_ebc_resharding(
278281
feature_keys.extend(table.feature_names)
279282

280283
# For current test model and inputs, the prediction should be the exact same
281-
rtol = 0
282-
atol = 0
284+
# rtol = 0
285+
# atol = 0
283286

284287
for _ in range(world_size):
285288
# sharded model
@@ -301,9 +304,7 @@ def _test_ebc_resharding(
301304
# their model. output from sharded_pred is correctly on the correct device.
302305

303306
# Compare predictions of sharded vs unsharded models.
304-
torch.testing.assert_close(
305-
sharded_m1_pred.cpu(), resharded_m2_pred.cpu(), rtol=rtol, atol=atol
306-
)
307+
torch.testing.assert_close(sharded_m1_pred.cpu(), resharded_m2_pred.cpu())
307308

308309
sharded_m1_pred.sum().backward()
309310
resharded_m2_pred.sum().backward()
@@ -320,6 +321,7 @@ def _run_ebc_resharding_test(
320321
data_type: DataType,
321322
embedding_dim: int = 16,
322323
num_embeddings: int = 4,
324+
use_debug_state_dict: bool = False, # Turn on to use dummy values for initial state dict
323325
) -> None:
324326
embedding_bag_config = generate_embedding_bag_config(
325327
data_type, num_tables, embedding_dim, num_embeddings
@@ -359,14 +361,16 @@ def _run_ebc_resharding_test(
359361
for _ in range(world_size)
360362
]
361363

362-
# initial_state_dict filled with deterministic dummy values
363-
initial_state_dict = create_test_initial_state_dict(
364-
ShardedEmbeddingBagCollection, # pyre-ignore
365-
num_tables,
366-
data_type,
367-
embedding_dim,
368-
num_embeddings,
369-
)
364+
initial_state_dict = None
365+
if use_debug_state_dict:
366+
# initial_state_dict filled with deterministic dummy values
367+
initial_state_dict = create_test_initial_state_dict(
368+
ShardedEmbeddingBagCollection, # pyre-ignore
369+
num_tables,
370+
data_type,
371+
embedding_dim,
372+
num_embeddings,
373+
)
370374

371375
self._run_multi_process_test(
372376
callable=_test_ebc_resharding,

0 commit comments

Comments
 (0)