Skip to content

Commit 17ea61c

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
minior refactoring and adding comments to the TrainPipelineSparseDist
Summary: # context * add in-code comments to the train_pipeline.TrainPipelineSparseDist to explain the actions in the sparse dist pipeline. * minior refactoring: change `_pipelined_forward_type` as the class variable instead of instance variable since it's always constant in the class init. Differential Revision: D63793825
1 parent f2791cd commit 17ea61c

File tree

2 files changed

+83
-8
lines changed

2 files changed

+83
-8
lines changed

Diff for: torchrec/distributed/train_pipeline/train_pipelines.py

+71-8
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
345345
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
346346
"""
347347

348+
# The PipelinedForward class that is used in _rewrite_model
349+
_pipelined_forward_type = PipelinedForward
350+
348351
def __init__(
349352
self,
350353
model: torch.nn.Module,
@@ -413,7 +416,6 @@ def __init__(
413416
self._model_fwd: Callable[[Optional[In]], Tuple[torch.Tensor, Out]] = (
414417
custom_model_fwd if custom_model_fwd else model
415418
)
416-
self._pipelined_forward_type = PipelinedForward
417419

418420
# DEPRECATED FIELDS
419421
self._batch_i: Optional[In] = None
@@ -423,7 +425,11 @@ def __init__(
423425

424426
def detach(self) -> torch.nn.Module:
425427
"""
426-
Detaches the model from sparse data dist (SDD) pipeline.
428+
Detaches the model from sparse data dist (SDD) pipeline. A user might want to get
429+
the original model back after training. The original model.forward was previously
430+
modified by the train pipeline. for more please see:
431+
https://github.com/pytorch/torchrec/pull/2076
432+
427433
To use the pipeline after detaching the model, pipeline.attach(model)
428434
needs to be called.
429435
Inflight batches are kept so pipeline.progress(data_iter) can be resumed normally.
@@ -445,6 +451,11 @@ def detach(self) -> torch.nn.Module:
445451
def attach(
446452
self, model: Optional[torch.nn.Module] = None, sparse_dist: bool = True
447453
) -> None:
454+
"""
455+
should be used with detach function. these functions should only be used from user code,
456+
when user want to switch the train pipeline. for more please see:
457+
https://github.com/pytorch/torchrec/pull/2076
458+
"""
448459
if model:
449460
self._model = model
450461

@@ -463,6 +474,12 @@ def attach(
463474
self._pipelined_postprocs = []
464475

465476
def _set_module_context(self, context: TrainPipelineContext) -> None:
477+
"""
478+
pipelined modules are the TorchRec's sparse modules like shardedEBC, shardedEC, etc.
479+
the forward function is swapped with a PipelinedForward in the _rewrite_model call.
480+
The PipelinedForward needs a context to correctly perform the forward behavior.
481+
please check PipelinedForward for details.
482+
"""
466483
for module in self._pipelined_modules:
467484
module.forward.set_context(context)
468485

@@ -471,6 +488,10 @@ def _set_module_context(self, context: TrainPipelineContext) -> None:
471488
postproc_module.set_context(context)
472489

473490
def enqueue_batch(self, dataloader_iter: Iterator[In]) -> bool:
491+
"""
492+
load a data batch from dataloader, and copy it from cpu to gpu
493+
also create the context for this batch.
494+
"""
474495
batch, context = self.copy_batch_to_gpu(dataloader_iter)
475496
if batch is None:
476497
return False
@@ -481,30 +502,50 @@ def enqueue_batch(self, dataloader_iter: Iterator[In]) -> bool:
481502
return True
482503

483504
def dequeue_batch(self) -> None:
505+
"""
506+
remove a processed batch from the batch queue, also set the module context if applicable
507+
"""
484508
self.batches.popleft()
485509
self.contexts.popleft()
486-
# update PipelineForwards context to match next forward pass
510+
511+
# update PipelinedForward context to match next forward pass
487512
if len(self.batches) >= 1:
488513
self._set_module_context(self.contexts[0])
489514

490515
def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
491-
# pipeline is already filled
516+
"""
517+
This function is called in self.progress (one of the main APIs for running train pipeline)
518+
Here we assume the max pipelined len(batches) == 2 (capacity), which will be the most common
519+
scenario during the full training job, when this function is effectively doing nothing.
520+
There would only be two other scenarios:
521+
len(batches) == 0:
522+
initialize the pipeline, fill in two batches, start input_dist for the first batch.
523+
len(batches) == 1:
524+
dataloader_iter stops, the last batch, do nothing
525+
"""
526+
527+
# pipeline is already filled with max capacity (2)
492528
if len(self.batches) >= 2:
493529
return
494-
# executes last batch in pipeline
530+
531+
# executes last batch in pipeline, when there is only one batch in the pipeline
532+
# TODO: this _execute_all_batches doesn't really work here D43546239. it will
533+
# just throw an exception at copy_to_gpu when the dataloader is exhausted
495534
if self.batches and self._execute_all_batches:
496535
return
497536

498-
# batch i
537+
# batch i, data (batch) and context
499538
if not self.enqueue_batch(dataloader_iter):
500539
return
501540

541+
# modify the (sharded) sparse module forward, and invoke the first part of input_dist
502542
self._init_pipelined_modules(
503543
# pyre-ignore [6]
504544
self.batches[0],
505545
self.contexts[0],
506546
PipelinedForward,
507547
)
548+
# doing the second part of input_dist, the first part is invoked in _init_pipelined_modules
508549
self.wait_sparse_data_dist(self.contexts[0])
509550

510551
# batch i+1
@@ -520,10 +561,22 @@ def _backward(self, losses: torch.Tensor) -> None:
520561
torch.sum(losses, dim=0).backward()
521562

522563
def progress(self, dataloader_iter: Iterator[In]) -> Out:
564+
"""
565+
For TrainPipelineSparseDist, we assume the max pipelined batches == 3 (capacity):
566+
batches[0]: current batch, for emb_lookup, output_dist, and fwd/bwd/opt (expecting input_dist)
567+
batches[1]: next batch, for input_dist (expecting copied to device)
568+
batches[2]: i+2 batch, for copy_batch_to_gpu (expecting non-exhausted dataloader iter)
569+
"""
570+
571+
# attach the model just in case the user forgets to call it, especially when the user
572+
# pauses the pipeline.progress and detach the model for other purpose.
523573
if not self._model_attached:
524574
self.attach(self._model)
525575

576+
# fill the pipeline is only needed for the beginning when the pipeline (batches) is empty
526577
self.fill_pipeline(dataloader_iter)
578+
579+
# here is the expected stop after exhausting all batches
527580
if not self.batches:
528581
raise StopIteration
529582

@@ -534,19 +587,23 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
534587
with record_function("## zero_grad ##"):
535588
self._optimizer.zero_grad()
536589

590+
# wait for batches[0] being available on device, this should always be completed since
591+
# the input_dist of batches[0] has be invoked in previous iter. TODO: fact check
537592
self._wait_for_batch()
538593

539594
if len(self.batches) >= 2:
595+
# invoke splits all_to_all comms (first part of input_dist)
540596
self.start_sparse_data_dist(self.batches[1], self.contexts[1])
541597

542-
# batch i+2
598+
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
543599
self.enqueue_batch(dataloader_iter)
544600

545601
# forward
546602
with record_function("## forward ##"):
547603
losses, output = self._model_fwd(self.batches[0])
548604

549605
if len(self.batches) >= 2:
606+
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
550607
self.wait_sparse_data_dist(self.contexts[1])
551608

552609
if self._model.training:
@@ -768,6 +825,9 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
768825
training. If False, will update dense optimizer as soon as gradients available (naive "Semi-Sync)
769826
"""
770827

828+
# The PipelinedForward class that is used in _rewrite_model
829+
_pipelined_forward_type = EmbeddingPipelinedForward # pyre-ignore
830+
771831
def __init__(
772832
self,
773833
model: torch.nn.Module,
@@ -793,7 +853,6 @@ def __init__(
793853
pipeline_postproc=pipeline_postproc,
794854
custom_model_fwd=custom_model_fwd,
795855
)
796-
self._pipelined_forward_type = EmbeddingPipelinedForward
797856
self._start_batch = start_batch
798857
self._stash_gradients = stash_gradients
799858
logger.debug(f"Starting semi-sync run at batch: {self._start_batch}")
@@ -865,6 +924,8 @@ def _mlp_optimizer_step(self, current_batch: int) -> None:
865924
self._optimizer.step()
866925

867926
def progress(self, dataloader_iter: Iterator[In]) -> Out:
927+
# attach the model just in case the user forgets to call it, especially when the user
928+
# pauses the pipeline.progress and detach the model for other purpose.
868929
if not self._model_attached:
869930
self.attach(self._model)
870931

@@ -1653,6 +1714,8 @@ def get_compiled_autograd_ctx(
16531714
)
16541715

16551716
def progress(self, dataloader_iter: Iterator[In]) -> Out:
1717+
# attach the model just in case the user forgets to call it, especially when the user
1718+
# pauses the pipeline.progress and detach the model for other purpose.
16561719
if not self._model_attached:
16571720
self.attach(self._model)
16581721

Diff for: torchrec/distributed/train_pipeline/utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,10 @@ def get_context(self) -> TForwardContext:
535535

536536

537537
class PipelinedForward(BaseForward[TrainPipelineContext]):
538+
"""
539+
This pipeline is used in TrainPipelineSparseDist
540+
"""
541+
538542
# pyre-ignore [2, 24]
539543
def __call__(self, *input, **kwargs) -> Awaitable:
540544
assert (
@@ -568,6 +572,10 @@ def __call__(self, *input, **kwargs) -> Awaitable:
568572

569573

570574
class EmbeddingPipelinedForward(BaseForward[EmbeddingTrainPipelineContext]):
575+
"""
576+
This pipeline is used in TrainPipelineSparseDist
577+
"""
578+
571579
def __call__(
572580
self,
573581
# pyre-ignore
@@ -642,6 +650,10 @@ def __call__(
642650

643651

644652
class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]):
653+
"""
654+
This pipeline is used in PrefetchTrainPipelineSparseDist
655+
"""
656+
645657
def __init__(
646658
self,
647659
name: str,

0 commit comments

Comments
 (0)