Skip to content

Commit 02e86af

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
minior refactoring and adding comments to the TrainPipelineSparseDist (#2898)
Summary: # context * add in-code comments to the train_pipeline.TrainPipelineSparseDist to explain the actions in the sparse dist pipeline. * minior refactoring: change `_pipelined_forward_type` as the class variable instead of instance variable since it's always constant in the class init. Differential Revision: D63793825
1 parent f2791cd commit 02e86af

File tree

2 files changed

+115
-24
lines changed

2 files changed

+115
-24
lines changed

Diff for: torchrec/distributed/train_pipeline/train_pipelines.py

+81-12
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
345345
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
346346
"""
347347

348+
# The PipelinedForward class that is used in _rewrite_model
349+
_pipelined_forward_type = PipelinedForward
350+
348351
def __init__(
349352
self,
350353
model: torch.nn.Module,
@@ -413,7 +416,6 @@ def __init__(
413416
self._model_fwd: Callable[[Optional[In]], Tuple[torch.Tensor, Out]] = (
414417
custom_model_fwd if custom_model_fwd else model
415418
)
416-
self._pipelined_forward_type = PipelinedForward
417419

418420
# DEPRECATED FIELDS
419421
self._batch_i: Optional[In] = None
@@ -423,7 +425,11 @@ def __init__(
423425

424426
def detach(self) -> torch.nn.Module:
425427
"""
426-
Detaches the model from sparse data dist (SDD) pipeline.
428+
Detaches the model from sparse data dist (SDD) pipeline. A user might want to get
429+
the original model back after training. The original model.forward was previously
430+
modified by the train pipeline. for more please see:
431+
https://github.com/pytorch/torchrec/pull/2076
432+
427433
To use the pipeline after detaching the model, pipeline.attach(model)
428434
needs to be called.
429435
Inflight batches are kept so pipeline.progress(data_iter) can be resumed normally.
@@ -445,6 +451,11 @@ def detach(self) -> torch.nn.Module:
445451
def attach(
446452
self, model: Optional[torch.nn.Module] = None, sparse_dist: bool = True
447453
) -> None:
454+
"""
455+
should be used with detach function. these functions should only be used from user code,
456+
when user want to switch the train pipeline. for more please see:
457+
https://github.com/pytorch/torchrec/pull/2076
458+
"""
448459
if model:
449460
self._model = model
450461

@@ -463,6 +474,12 @@ def attach(
463474
self._pipelined_postprocs = []
464475

465476
def _set_module_context(self, context: TrainPipelineContext) -> None:
477+
"""
478+
pipelined modules are the TorchRec's sparse modules like shardedEBC, shardedEC, etc.
479+
the forward function is swapped with a PipelinedForward in the _rewrite_model call.
480+
The PipelinedForward needs a context to correctly perform the forward behavior.
481+
please check PipelinedForward for details.
482+
"""
466483
for module in self._pipelined_modules:
467484
module.forward.set_context(context)
468485

@@ -471,6 +488,10 @@ def _set_module_context(self, context: TrainPipelineContext) -> None:
471488
postproc_module.set_context(context)
472489

473490
def enqueue_batch(self, dataloader_iter: Iterator[In]) -> bool:
491+
"""
492+
load a data batch from dataloader, and copy it from cpu to gpu
493+
also create the context for this batch.
494+
"""
474495
batch, context = self.copy_batch_to_gpu(dataloader_iter)
475496
if batch is None:
476497
return False
@@ -481,30 +502,50 @@ def enqueue_batch(self, dataloader_iter: Iterator[In]) -> bool:
481502
return True
482503

483504
def dequeue_batch(self) -> None:
505+
"""
506+
remove a processed batch from the batch queue, also set the module context if applicable
507+
"""
484508
self.batches.popleft()
485509
self.contexts.popleft()
486-
# update PipelineForwards context to match next forward pass
510+
511+
# update PipelinedForward context to match next forward pass
487512
if len(self.batches) >= 1:
488513
self._set_module_context(self.contexts[0])
489514

490515
def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
491-
# pipeline is already filled
516+
"""
517+
This function is called in self.progress (one of the main APIs for running train pipeline)
518+
Here we assume the max pipelined len(batches) == 2 (capacity), which will be the most common
519+
scenario during the full training job, when this function is effectively doing nothing.
520+
There would only be two other scenarios:
521+
len(batches) == 0:
522+
initialize the pipeline, fill in two batches, start input_dist for the first batch.
523+
len(batches) == 1:
524+
dataloader_iter stops, the last batch, do nothing
525+
"""
526+
527+
# pipeline is already filled with max capacity (2)
492528
if len(self.batches) >= 2:
493529
return
494-
# executes last batch in pipeline
530+
531+
# executes last batch in pipeline, when there is only one batch in the pipeline
532+
# TODO: this _execute_all_batches doesn't really work here D43546239. it will
533+
# just throw an exception at copy_to_gpu when the dataloader is exhausted
495534
if self.batches and self._execute_all_batches:
496535
return
497536

498-
# batch i
537+
# batch i, data (batch) and context
499538
if not self.enqueue_batch(dataloader_iter):
500539
return
501540

541+
# modify the (sharded) sparse module forward, and invoke the first part of input_dist
502542
self._init_pipelined_modules(
503543
# pyre-ignore [6]
504544
self.batches[0],
505545
self.contexts[0],
506-
PipelinedForward,
546+
self._pipelined_forward_type,
507547
)
548+
# doing the second part of input_dist, the first part is invoked in _init_pipelined_modules
508549
self.wait_sparse_data_dist(self.contexts[0])
509550

510551
# batch i+1
@@ -520,10 +561,22 @@ def _backward(self, losses: torch.Tensor) -> None:
520561
torch.sum(losses, dim=0).backward()
521562

522563
def progress(self, dataloader_iter: Iterator[In]) -> Out:
564+
"""
565+
For TrainPipelineSparseDist, we assume the max pipelined batches == 3 (capacity):
566+
batches[0]: current batch, for emb_lookup, output_dist, and fwd/bwd/opt (expecting input_dist)
567+
batches[1]: next batch, for input_dist (expecting copied to device)
568+
batches[2]: i+2 batch, for copy_batch_to_gpu (expecting non-exhausted dataloader iter)
569+
"""
570+
571+
# attach the model just in case the user forgets to call it, especially when the user
572+
# pauses the pipeline.progress and detach the model for other purpose.
523573
if not self._model_attached:
524574
self.attach(self._model)
525575

576+
# fill the pipeline is only needed for the beginning when the pipeline (batches) is empty
526577
self.fill_pipeline(dataloader_iter)
578+
579+
# here is the expected stop after exhausting all batches
527580
if not self.batches:
528581
raise StopIteration
529582

@@ -534,19 +587,23 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
534587
with record_function("## zero_grad ##"):
535588
self._optimizer.zero_grad()
536589

590+
# wait for batches[0] being available on device, this should always be completed since
591+
# the input_dist of batches[0] has be invoked in previous iter. TODO: fact check
537592
self._wait_for_batch()
538593

539594
if len(self.batches) >= 2:
595+
# invoke splits all_to_all comms (first part of input_dist)
540596
self.start_sparse_data_dist(self.batches[1], self.contexts[1])
541597

542-
# batch i+2
598+
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
543599
self.enqueue_batch(dataloader_iter)
544600

545601
# forward
546602
with record_function("## forward ##"):
547603
losses, output = self._model_fwd(self.batches[0])
548604

549605
if len(self.batches) >= 2:
606+
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
550607
self.wait_sparse_data_dist(self.contexts[1])
551608

552609
if self._model.training:
@@ -768,6 +825,9 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
768825
training. If False, will update dense optimizer as soon as gradients available (naive "Semi-Sync)
769826
"""
770827

828+
# The PipelinedForward class that is used in _rewrite_model
829+
_pipelined_forward_type = EmbeddingPipelinedForward # pyre-ignore
830+
771831
def __init__(
772832
self,
773833
model: torch.nn.Module,
@@ -793,7 +853,6 @@ def __init__(
793853
pipeline_postproc=pipeline_postproc,
794854
custom_model_fwd=custom_model_fwd,
795855
)
796-
self._pipelined_forward_type = EmbeddingPipelinedForward
797856
self._start_batch = start_batch
798857
self._stash_gradients = stash_gradients
799858
logger.debug(f"Starting semi-sync run at batch: {self._start_batch}")
@@ -835,7 +894,7 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
835894
self.batches[0],
836895
self.contexts[0],
837896
# pyre-ignore [6]
838-
EmbeddingPipelinedForward,
897+
self._pipelined_forward_type,
839898
)
840899
self.wait_sparse_data_dist(self.contexts[0])
841900
self._validate_optimizer()
@@ -865,6 +924,8 @@ def _mlp_optimizer_step(self, current_batch: int) -> None:
865924
self._optimizer.step()
866925

867926
def progress(self, dataloader_iter: Iterator[In]) -> Out:
927+
# attach the model just in case the user forgets to call it, especially when the user
928+
# pauses the pipeline.progress and detach the model for other purpose.
868929
if not self._model_attached:
869930
self.attach(self._model)
870931

@@ -1074,6 +1135,9 @@ class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
10741135
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
10751136
"""
10761137

1138+
# The PipelinedForward class that is used in _rewrite_model
1139+
_pipelined_forward_type = PrefetchPipelinedForward # pyre-ignore
1140+
10771141
def __init__(
10781142
self,
10791143
model: torch.nn.Module,
@@ -1126,7 +1190,7 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
11261190
self._batch_i,
11271191
self._context,
11281192
# pyre-ignore
1129-
PrefetchPipelinedForward,
1193+
self._pipelined_forward_type,
11301194
)
11311195
self._start_sparse_data_dist(self._batch_i)
11321196
self._wait_sparse_data_dist()
@@ -1228,6 +1292,9 @@ class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
12281292
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
12291293
"""
12301294

1295+
# The PipelinedForward class that is used in _rewrite_model
1296+
_pipelined_forward_type = PipelinedForward
1297+
12311298
def __init__(
12321299
self,
12331300
model: torch.nn.Module,
@@ -1265,7 +1332,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
12651332
# pyre-ignore
12661333
self.batches[0],
12671334
self.contexts[0],
1268-
PipelinedForward,
1335+
self._pipelined_forward_type,
12691336
)
12701337
self.start_sparse_data_dist(self.batches[0], self.contexts[0])
12711338
self.wait_sparse_data_dist(self.contexts[0])
@@ -1653,6 +1720,8 @@ def get_compiled_autograd_ctx(
16531720
)
16541721

16551722
def progress(self, dataloader_iter: Iterator[In]) -> Out:
1723+
# attach the model just in case the user forgets to call it, especially when the user
1724+
# pauses the pipeline.progress and detach the model for other purpose.
16561725
if not self._model_attached:
16571726
self.attach(self._model)
16581727

Diff for: torchrec/distributed/train_pipeline/utils.py

+34-12
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,10 @@ def get_context(self) -> TForwardContext:
535535

536536

537537
class PipelinedForward(BaseForward[TrainPipelineContext]):
538+
"""
539+
This pipeline is used in TrainPipelineSparseDist
540+
"""
541+
538542
# pyre-ignore [2, 24]
539543
def __call__(self, *input, **kwargs) -> Awaitable:
540544
assert (
@@ -568,6 +572,10 @@ def __call__(self, *input, **kwargs) -> Awaitable:
568572

569573

570574
class EmbeddingPipelinedForward(BaseForward[EmbeddingTrainPipelineContext]):
575+
"""
576+
This pipeline is used in TrainPipelineSemiSync
577+
"""
578+
571579
def __call__(
572580
self,
573581
# pyre-ignore
@@ -593,8 +601,11 @@ def __call__(
593601
)
594602
ctx.record_stream(cur_stream)
595603
awaitable = self._context.embedding_a2a_requests.pop(self._name)
604+
# in case of MC modules
605+
is_mc_module: bool = isinstance(awaitable, Iterable)
596606
remapped_kjts: Optional[KeyedJaggedTensor] = None
597-
if isinstance(awaitable, Iterable):
607+
608+
if is_mc_module:
598609
embeddings = awaitable[0].wait()
599610
remapped_kjts = awaitable[1].wait()
600611
else:
@@ -604,6 +615,7 @@ def __call__(
604615
) # trigger awaitable manually for type checking
605616
tensors = []
606617
detached_tensors = []
618+
# in case of EC, embeddings are Dict[str, JaggedTensor]
607619
if isinstance(embeddings, Dict):
608620
for jt in embeddings.values():
609621
assert isinstance(jt, JaggedTensor)
@@ -617,6 +629,7 @@ def __call__(
617629
self._context.embedding_features.append(list(embeddings.keys()))
618630
self._context.detached_embedding_tensors.append(detached_tensors)
619631
else:
632+
# in case of EBC, embeddings are KeyedTensor
620633
assert isinstance(embeddings, KeyedTensor)
621634
embeddings.record_stream(cur_stream)
622635
tensor = embeddings.values()
@@ -626,22 +639,28 @@ def __call__(
626639
tensors.append(tensor)
627640
detached_tensors.append(detached_tensor)
628641
self._context.embedding_tensors.append(tensors)
629-
# KeyedTensor is returned by EmbeddingBagCollections and its variants
630-
# KeyedTensor holds dense data from multiple features and .values()
631-
# returns a single concatenated dense tensor. To ensure that
632-
# context.embedding_tensors[i] has the same length as
633-
# context.embedding_features[i], we pass in a list with a single item:
634-
# a list containing all the embedding feature names.
642+
"""
643+
KeyedTensor is returned by EmbeddingBagCollections and its variants
644+
KeyedTensor holds dense data from multiple features and .values()
645+
returns a single concatenated dense tensor. To ensure that
646+
context.embedding_tensors[i] has the same length as
647+
context.embedding_features[i], we pass in a list with a single item:
648+
a list containing all the embedding feature names.
649+
"""
635650
self._context.embedding_features.append([list(embeddings.keys())])
636651
self._context.detached_embedding_tensors.append(detached_tensors)
637652

638-
if isinstance(awaitable, Iterable):
653+
if is_mc_module:
639654
return (LazyNoWait(embeddings), LazyNoWait(remapped_kjts))
640655
else:
641656
return LazyNoWait(embeddings)
642657

643658

644659
class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]):
660+
"""
661+
This pipeline is used in PrefetchTrainPipelineSparseDist
662+
"""
663+
645664
def __init__(
646665
self,
647666
name: str,
@@ -811,10 +830,13 @@ def _start_data_dist(
811830

812831
for module in pipelined_modules:
813832
forward = module.forward
814-
assert (
815-
isinstance(forward, PipelinedForward)
816-
or isinstance(forward, PrefetchPipelinedForward)
817-
or isinstance(forward, EmbeddingPipelinedForward)
833+
assert isinstance(
834+
forward,
835+
(
836+
PipelinedForward,
837+
PrefetchPipelinedForward,
838+
EmbeddingPipelinedForward,
839+
),
818840
)
819841

820842
# Retrieve argument for the input_dist of EBC

0 commit comments

Comments
 (0)