From 065b0074d4bf80a292b8abbfe56a8f4a3dc2d778 Mon Sep 17 00:00:00 2001 From: Surbhi Jain Date: Sat, 14 Jun 2025 00:04:09 +0000 Subject: [PATCH] Refactor profiler and fix profiling activation/deactivation only when enabled --- MaxText/elastic_train.py | 11 ++--------- MaxText/experimental/rl/grpo_trainer.py | 11 ++--------- MaxText/profiler.py | 21 +++++++++++++++++++++ MaxText/sft_trainer.py | 11 ++--------- MaxText/train.py | 11 ++--------- 5 files changed, 29 insertions(+), 36 deletions(-) diff --git a/MaxText/elastic_train.py b/MaxText/elastic_train.py index 0956a5cb6..672c72db0 100644 --- a/MaxText/elastic_train.py +++ b/MaxText/elastic_train.py @@ -236,10 +236,6 @@ def train_loop(config, elastic_manager, recorder, state=None): start_step = get_first_step(state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) - first_profiling_step = prof.start_initial_profile_step - if config.profiler != "" and first_profiling_step >= config.steps: - raise ValueError("Profiling requested but initial profiling step set past training final step") - last_profiling_step = prof.finished_initial_profile_step example_batch = None last_step_completion = datetime.datetime.now() @@ -271,9 +267,7 @@ def train_loop(config, elastic_manager, recorder, state=None): # the step is restored back to the latest snapshot when a slice is lost while step < config.steps: try: - if step == first_profiling_step or prof.should_activate_periodic_profile(step): - optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else "" - prof.activate(blocking_object=state, optional_postfix=optional_postfix) + prof.maybe_activate_profiler(step, state) max_logging.log(f"{step=} {elastic_manager.elastic_down_event_count=} {elastic_manager.good_slice_count=}") with mesh, nn_partitioning.axis_rules(config.logical_axis_rules), jax.default_device(elastic_manager.default_device): @@ -310,8 +304,7 @@ def train_loop(config, elastic_manager, recorder, state=None): metric_logger.write_metrics(running_gcs_metrics, metrics, step) - if step == last_profiling_step or prof.should_deactivate_periodic_profile(step): - prof.deactivate(blocking_object=state) + prof.maybe_deactivate_profiler(step, state) elastic_manager.maybe_snapshot( step=step, diff --git a/MaxText/experimental/rl/grpo_trainer.py b/MaxText/experimental/rl/grpo_trainer.py index 38517ee61..77a8d6267 100644 --- a/MaxText/experimental/rl/grpo_trainer.py +++ b/MaxText/experimental/rl/grpo_trainer.py @@ -779,10 +779,6 @@ def train_loop(config, config_inference, recorder, state=None): start_step = get_first_step(state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) - first_profiling_step = prof.start_initial_profile_step - if config.profiler != "" and first_profiling_step >= config.steps: - raise ValueError("Profiling requested but initial profiling step set past training final step") - last_profiling_step = prof.finished_initial_profile_step example_batch = None last_step_completion = datetime.datetime.now() @@ -799,9 +795,7 @@ def train_loop(config, config_inference, recorder, state=None): metric_logger = MetricLogger(writer, config) input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh) for step in np.arange(start_step, config.steps): - if step == first_profiling_step or prof.should_activate_periodic_profile(step): - optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else "" - prof.activate(blocking_object=state, optional_postfix=optional_postfix) + prof.maybe_activate_profiler(step, state) with jax.profiler.StepTraceAnnotation("train", step_num=step): with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING): @@ -894,8 +888,7 @@ def train_loop(config, config_inference, recorder, state=None): prof.deactivate() break - if step == last_profiling_step or prof.should_deactivate_periodic_profile(step): - prof.deactivate(blocking_object=state) + prof.maybe_deactivate_profiler(step, state) if step == start_step: max_utils.print_mem_stats("After params initialized") diff --git a/MaxText/profiler.py b/MaxText/profiler.py index 39667c02c..856f2783c 100644 --- a/MaxText/profiler.py +++ b/MaxText/profiler.py @@ -40,6 +40,18 @@ def __init__(self, config, offset_step=0): self.profile_period = config.profile_periodically_period self.start_initial_profile_step = self._set_first_profiler_step(config.skip_first_n_steps_for_profiler, offset_step) self.finished_initial_profile_step = self._set_last_profiler_step(config.profiler_steps, config.steps) + if config.profiler != "" and self.start_initial_profile_step >= config.steps: + raise ValueError("Profiling requested but initial profiling step set past training final step") + + def maybe_activate_profiler(self, step, state): + """Conditionally activates the profiler based on the current step. + This method checks if the current training step matches the step designated + for starting an initial profile, or if it meets the criteria for + activating a new periodic profile. + """ + if self.mode != "" and (step == self.start_initial_profile_step or self.should_activate_periodic_profile(step)): + optional_postfix = f"step_{step}" if self.profile_period > 0 else "" + self.activate(blocking_object=state, optional_postfix=optional_postfix) def activate(self, blocking_object=None, optional_postfix=""): """Start the profiler. @@ -60,6 +72,15 @@ def activate(self, blocking_object=None, optional_postfix=""): elif self.mode == "xplane": jax.profiler.start_trace(self.output_path) + def maybe_deactivate_profiler(self, step, state): + """Conditionally deactivates the profiler based on the current step. + This method checks if the current training step matches the step designated + for finishing the initial profile, or if it meets the criteria for + deactivating a periodic profile. + """ + if self.mode != "" and (step == self.finished_initial_profile_step or self.should_deactivate_periodic_profile(step)): + self.deactivate(blocking_object=state) + def deactivate(self, blocking_object=None): """End the profiler. The result is uploaded to the output bucket.""" diff --git a/MaxText/sft_trainer.py b/MaxText/sft_trainer.py index 5219fcfae..21f6460aa 100644 --- a/MaxText/sft_trainer.py +++ b/MaxText/sft_trainer.py @@ -137,10 +137,6 @@ def train_loop(config, recorder, state=None): start_step = get_first_step(state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) - first_profiling_step = prof.start_initial_profile_step - if config.profiler != "" and first_profiling_step >= config.steps: - raise ValueError("Profiling requested but initial profiling step set past training final step") - last_profiling_step = prof.finished_initial_profile_step example_batch = None last_step_completion = datetime.datetime.now() @@ -157,9 +153,7 @@ def train_loop(config, recorder, state=None): metric_logger = MetricLogger(writer, config) input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh) for step in np.arange(start_step, config.steps): - if step == first_profiling_step or prof.should_activate_periodic_profile(step): - optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else "" - prof.activate(blocking_object=state, optional_postfix=optional_postfix) + prof.maybe_activate_profiler(step, state) with jax.profiler.StepTraceAnnotation("train", step_num=step): with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING): @@ -243,8 +237,7 @@ def train_loop(config, recorder, state=None): prof.deactivate() break - if step == last_profiling_step or prof.should_deactivate_periodic_profile(step): - prof.deactivate(blocking_object=state) + prof.maybe_deactivate_profiler(step, state) if step == start_step: max_utils.print_mem_stats("After params initialized") diff --git a/MaxText/train.py b/MaxText/train.py index e6a3cb9ab..f95456ff1 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -818,10 +818,6 @@ def train_loop(config, recorder, state=None): start_step = get_first_step(state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) - first_profiling_step = prof.start_initial_profile_step - if config.profiler != "" and first_profiling_step >= config.steps: - raise ValueError("Profiling requested but initial profiling step set past training final step") - last_profiling_step = prof.finished_initial_profile_step example_batch = None last_step_completion = datetime.datetime.now() @@ -838,9 +834,7 @@ def train_loop(config, recorder, state=None): metric_logger = MetricLogger(writer, config) input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh) for step in np.arange(start_step, config.steps): - if step == first_profiling_step or prof.should_activate_periodic_profile(step): - optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else "" - prof.activate(blocking_object=state, optional_postfix=optional_postfix) + prof.maybe_activate_profiler(step, state) with jax.profiler.StepTraceAnnotation("train", step_num=step): with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING): @@ -930,8 +924,7 @@ def train_loop(config, recorder, state=None): prof.deactivate() break - if step == last_profiling_step or prof.should_deactivate_periodic_profile(step): - prof.deactivate(blocking_object=state) + prof.maybe_deactivate_profiler(step, state) if step == start_step: max_utils.print_mem_stats("After params initialized")