@@ -345,6 +345,9 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
345
345
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
346
346
"""
347
347
348
+ # The PipelinedForward class that is used in _rewrite_model
349
+ _pipelined_forward_type = PipelinedForward
350
+
348
351
def __init__ (
349
352
self ,
350
353
model : torch .nn .Module ,
@@ -413,7 +416,6 @@ def __init__(
413
416
self ._model_fwd : Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]] = (
414
417
custom_model_fwd if custom_model_fwd else model
415
418
)
416
- self ._pipelined_forward_type = PipelinedForward
417
419
418
420
# DEPRECATED FIELDS
419
421
self ._batch_i : Optional [In ] = None
@@ -423,7 +425,11 @@ def __init__(
423
425
424
426
def detach (self ) -> torch .nn .Module :
425
427
"""
426
- Detaches the model from sparse data dist (SDD) pipeline.
428
+ Detaches the model from sparse data dist (SDD) pipeline. A user might want to get
429
+ the original model back after training. The original model.forward was previously
430
+ modified by the train pipeline. for more please see:
431
+ https://github.com/pytorch/torchrec/pull/2076
432
+
427
433
To use the pipeline after detaching the model, pipeline.attach(model)
428
434
needs to be called.
429
435
Inflight batches are kept so pipeline.progress(data_iter) can be resumed normally.
@@ -445,6 +451,11 @@ def detach(self) -> torch.nn.Module:
445
451
def attach (
446
452
self , model : Optional [torch .nn .Module ] = None , sparse_dist : bool = True
447
453
) -> None :
454
+ """
455
+ should be used with detach function. these functions should only be used from user code,
456
+ when user want to switch the train pipeline. for more please see:
457
+ https://github.com/pytorch/torchrec/pull/2076
458
+ """
448
459
if model :
449
460
self ._model = model
450
461
@@ -463,6 +474,12 @@ def attach(
463
474
self ._pipelined_postprocs = []
464
475
465
476
def _set_module_context (self , context : TrainPipelineContext ) -> None :
477
+ """
478
+ pipelined modules are the TorchRec's sparse modules like shardedEBC, shardedEC, etc.
479
+ the forward function is swapped with a PipelinedForward in the _rewrite_model call.
480
+ The PipelinedForward needs a context to correctly perform the forward behavior.
481
+ please check PipelinedForward for details.
482
+ """
466
483
for module in self ._pipelined_modules :
467
484
module .forward .set_context (context )
468
485
@@ -471,6 +488,10 @@ def _set_module_context(self, context: TrainPipelineContext) -> None:
471
488
postproc_module .set_context (context )
472
489
473
490
def enqueue_batch (self , dataloader_iter : Iterator [In ]) -> bool :
491
+ """
492
+ load a data batch from dataloader, and copy it from cpu to gpu
493
+ also create the context for this batch.
494
+ """
474
495
batch , context = self .copy_batch_to_gpu (dataloader_iter )
475
496
if batch is None :
476
497
return False
@@ -481,30 +502,50 @@ def enqueue_batch(self, dataloader_iter: Iterator[In]) -> bool:
481
502
return True
482
503
483
504
def dequeue_batch (self ) -> None :
505
+ """
506
+ remove a processed batch from the batch queue, also set the module context if applicable
507
+ """
484
508
self .batches .popleft ()
485
509
self .contexts .popleft ()
486
- # update PipelineForwards context to match next forward pass
510
+
511
+ # update PipelinedForward context to match next forward pass
487
512
if len (self .batches ) >= 1 :
488
513
self ._set_module_context (self .contexts [0 ])
489
514
490
515
def fill_pipeline (self , dataloader_iter : Iterator [In ]) -> None :
491
- # pipeline is already filled
516
+ """
517
+ This function is called in self.progress (one of the main APIs for running train pipeline)
518
+ Here we assume the max pipelined len(batches) == 2 (capacity), which will be the most common
519
+ scenario during the full training job, when this function is effectively doing nothing.
520
+ There would only be two other scenarios:
521
+ len(batches) == 0:
522
+ initialize the pipeline, fill in two batches, start input_dist for the first batch.
523
+ len(batches) == 1:
524
+ dataloader_iter stops, the last batch, do nothing
525
+ """
526
+
527
+ # pipeline is already filled with max capacity (2)
492
528
if len (self .batches ) >= 2 :
493
529
return
494
- # executes last batch in pipeline
530
+
531
+ # executes last batch in pipeline, when there is only one batch in the pipeline
532
+ # TODO: this _execute_all_batches doesn't really work here D43546239. it will
533
+ # just throw an exception at copy_to_gpu when the dataloader is exhausted
495
534
if self .batches and self ._execute_all_batches :
496
535
return
497
536
498
- # batch i
537
+ # batch i, data (batch) and context
499
538
if not self .enqueue_batch (dataloader_iter ):
500
539
return
501
540
541
+ # modify the (sharded) sparse module forward, and invoke the first part of input_dist
502
542
self ._init_pipelined_modules (
503
543
# pyre-ignore [6]
504
544
self .batches [0 ],
505
545
self .contexts [0 ],
506
546
PipelinedForward ,
507
547
)
548
+ # doing the second part of input_dist, the first part is invoked in _init_pipelined_modules
508
549
self .wait_sparse_data_dist (self .contexts [0 ])
509
550
510
551
# batch i+1
@@ -520,10 +561,22 @@ def _backward(self, losses: torch.Tensor) -> None:
520
561
torch .sum (losses , dim = 0 ).backward ()
521
562
522
563
def progress (self , dataloader_iter : Iterator [In ]) -> Out :
564
+ """
565
+ For TrainPipelineSparseDist, we assume the max pipelined batches == 3 (capacity):
566
+ batches[0]: current batch, for emb_lookup, output_dist, and fwd/bwd/opt (expecting input_dist)
567
+ batches[1]: next batch, for input_dist (expecting copied to device)
568
+ batches[2]: i+2 batch, for copy_batch_to_gpu (expecting non-exhausted dataloader iter)
569
+ """
570
+
571
+ # attach the model just in case the user forgets to call it, especially when the user
572
+ # pauses the pipeline.progress and detach the model for other purpose.
523
573
if not self ._model_attached :
524
574
self .attach (self ._model )
525
575
576
+ # fill the pipeline is only needed for the beginning when the pipeline (batches) is empty
526
577
self .fill_pipeline (dataloader_iter )
578
+
579
+ # here is the expected stop after exhausting all batches
527
580
if not self .batches :
528
581
raise StopIteration
529
582
@@ -534,19 +587,23 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
534
587
with record_function ("## zero_grad ##" ):
535
588
self ._optimizer .zero_grad ()
536
589
590
+ # wait for batches[0] being available on device, this should always be completed since
591
+ # the input_dist of batches[0] has be invoked in previous iter. TODO: fact check
537
592
self ._wait_for_batch ()
538
593
539
594
if len (self .batches ) >= 2 :
595
+ # invoke splits all_to_all comms (first part of input_dist)
540
596
self .start_sparse_data_dist (self .batches [1 ], self .contexts [1 ])
541
597
542
- # batch i+2
598
+ # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
543
599
self .enqueue_batch (dataloader_iter )
544
600
545
601
# forward
546
602
with record_function ("## forward ##" ):
547
603
losses , output = self ._model_fwd (self .batches [0 ])
548
604
549
605
if len (self .batches ) >= 2 :
606
+ # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
550
607
self .wait_sparse_data_dist (self .contexts [1 ])
551
608
552
609
if self ._model .training :
@@ -768,6 +825,9 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
768
825
training. If False, will update dense optimizer as soon as gradients available (naive "Semi-Sync)
769
826
"""
770
827
828
+ # The PipelinedForward class that is used in _rewrite_model
829
+ _pipelined_forward_type = EmbeddingPipelinedForward # pyre-ignore
830
+
771
831
def __init__ (
772
832
self ,
773
833
model : torch .nn .Module ,
@@ -793,7 +853,6 @@ def __init__(
793
853
pipeline_postproc = pipeline_postproc ,
794
854
custom_model_fwd = custom_model_fwd ,
795
855
)
796
- self ._pipelined_forward_type = EmbeddingPipelinedForward
797
856
self ._start_batch = start_batch
798
857
self ._stash_gradients = stash_gradients
799
858
logger .debug (f"Starting semi-sync run at batch: { self ._start_batch } " )
@@ -865,6 +924,8 @@ def _mlp_optimizer_step(self, current_batch: int) -> None:
865
924
self ._optimizer .step ()
866
925
867
926
def progress (self , dataloader_iter : Iterator [In ]) -> Out :
927
+ # attach the model just in case the user forgets to call it, especially when the user
928
+ # pauses the pipeline.progress and detach the model for other purpose.
868
929
if not self ._model_attached :
869
930
self .attach (self ._model )
870
931
@@ -1653,6 +1714,8 @@ def get_compiled_autograd_ctx(
1653
1714
)
1654
1715
1655
1716
def progress (self , dataloader_iter : Iterator [In ]) -> Out :
1717
+ # attach the model just in case the user forgets to call it, especially when the user
1718
+ # pauses the pipeline.progress and detach the model for other purpose.
1656
1719
if not self ._model_attached :
1657
1720
self .attach (self ._model )
1658
1721
0 commit comments