Skip to content

Refactor profiler in trainers #1833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions MaxText/elastic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,6 @@ def train_loop(config, elastic_manager, recorder, state=None):

start_step = get_first_step(state) # this is the start_step for training
prof = profiler.Profiler(config, offset_step=start_step)
first_profiling_step = prof.start_initial_profile_step
if config.profiler != "" and first_profiling_step >= config.steps:
raise ValueError("Profiling requested but initial profiling step set past training final step")
last_profiling_step = prof.finished_initial_profile_step

example_batch = None
last_step_completion = datetime.datetime.now()
Expand Down Expand Up @@ -271,9 +267,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
# the step is restored back to the latest snapshot when a slice is lost
while step < config.steps:
try:
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
prof.activate(blocking_object=state, optional_postfix=optional_postfix)
prof.maybe_activate_profiler(step, state)

max_logging.log(f"{step=} {elastic_manager.elastic_down_event_count=} {elastic_manager.good_slice_count=}")
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules), jax.default_device(elastic_manager.default_device):
Expand Down Expand Up @@ -310,8 +304,7 @@ def train_loop(config, elastic_manager, recorder, state=None):

metric_logger.write_metrics(running_gcs_metrics, metrics, step)

if step == last_profiling_step or prof.should_deactivate_periodic_profile(step):
prof.deactivate(blocking_object=state)
prof.maybe_deactivate_profiler(step, state)

elastic_manager.maybe_snapshot(
step=step,
Expand Down
11 changes: 2 additions & 9 deletions MaxText/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,10 +779,6 @@ def train_loop(config, config_inference, recorder, state=None):

start_step = get_first_step(state) # this is the start_step for training
prof = profiler.Profiler(config, offset_step=start_step)
first_profiling_step = prof.start_initial_profile_step
if config.profiler != "" and first_profiling_step >= config.steps:
raise ValueError("Profiling requested but initial profiling step set past training final step")
last_profiling_step = prof.finished_initial_profile_step

example_batch = None
last_step_completion = datetime.datetime.now()
Expand All @@ -799,9 +795,7 @@ def train_loop(config, config_inference, recorder, state=None):
metric_logger = MetricLogger(writer, config)
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
for step in np.arange(start_step, config.steps):
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
prof.activate(blocking_object=state, optional_postfix=optional_postfix)
prof.maybe_activate_profiler(step, state)

with jax.profiler.StepTraceAnnotation("train", step_num=step):
with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING):
Expand Down Expand Up @@ -894,8 +888,7 @@ def train_loop(config, config_inference, recorder, state=None):
prof.deactivate()
break

if step == last_profiling_step or prof.should_deactivate_periodic_profile(step):
prof.deactivate(blocking_object=state)
prof.maybe_deactivate_profiler(step, state)

if step == start_step:
max_utils.print_mem_stats("After params initialized")
Expand Down
21 changes: 21 additions & 0 deletions MaxText/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ def __init__(self, config, offset_step=0):
self.profile_period = config.profile_periodically_period
self.start_initial_profile_step = self._set_first_profiler_step(config.skip_first_n_steps_for_profiler, offset_step)
self.finished_initial_profile_step = self._set_last_profiler_step(config.profiler_steps, config.steps)
if config.profiler != "" and self.start_initial_profile_step >= config.steps:
raise ValueError("Profiling requested but initial profiling step set past training final step")

def maybe_activate_profiler(self, step, state):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed that maybe_activate_profiler and maybe_deactivate_profiler should also have a config.profiler!="" check to ensure we only call activate() and deactivate() when profiling is enabled. Is that right understanding @bvandermoon ?

"""Conditionally activates the profiler based on the current step.
This method checks if the current training step matches the step designated
for starting an initial profile, or if it meets the criteria for
activating a new periodic profile.
"""
if step == self.start_initial_profile_step or self.should_activate_periodic_profile(step):
optional_postfix = f"step_{step}" if self.profile_period > 0 else ""
self.activate(blocking_object=state, optional_postfix=optional_postfix)

def activate(self, blocking_object=None, optional_postfix=""):
"""Start the profiler.
Expand All @@ -60,6 +72,15 @@ def activate(self, blocking_object=None, optional_postfix=""):
elif self.mode == "xplane":
jax.profiler.start_trace(self.output_path)

def maybe_deactivate_profiler(self, step, state):
"""Conditionally deactivates the profiler based on the current step.
This method checks if the current training step matches the step designated
for finishing the initial profile, or if it meets the criteria for
deactivating a periodic profile.
"""
if step == self.finished_initial_profile_step or self.should_deactivate_periodic_profile(step):
self.deactivate(blocking_object=state)

def deactivate(self, blocking_object=None):
"""End the profiler.
The result is uploaded to the output bucket."""
Expand Down
11 changes: 2 additions & 9 deletions MaxText/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,6 @@ def train_loop(config, recorder, state=None):

start_step = get_first_step(state) # this is the start_step for training
prof = profiler.Profiler(config, offset_step=start_step)
first_profiling_step = prof.start_initial_profile_step
if config.profiler != "" and first_profiling_step >= config.steps:
raise ValueError("Profiling requested but initial profiling step set past training final step")
last_profiling_step = prof.finished_initial_profile_step

example_batch = None
last_step_completion = datetime.datetime.now()
Expand All @@ -157,9 +153,7 @@ def train_loop(config, recorder, state=None):
metric_logger = MetricLogger(writer, config)
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
for step in np.arange(start_step, config.steps):
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
prof.activate(blocking_object=state, optional_postfix=optional_postfix)
prof.maybe_activate_profiler(step, state)

with jax.profiler.StepTraceAnnotation("train", step_num=step):
with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING):
Expand Down Expand Up @@ -243,8 +237,7 @@ def train_loop(config, recorder, state=None):
prof.deactivate()
break

if step == last_profiling_step or prof.should_deactivate_periodic_profile(step):
prof.deactivate(blocking_object=state)
prof.maybe_deactivate_profiler(step, state)

if step == start_step:
max_utils.print_mem_stats("After params initialized")
Expand Down
11 changes: 2 additions & 9 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,10 +818,6 @@ def train_loop(config, recorder, state=None):

start_step = get_first_step(state) # this is the start_step for training
prof = profiler.Profiler(config, offset_step=start_step)
first_profiling_step = prof.start_initial_profile_step
if config.profiler != "" and first_profiling_step >= config.steps:
raise ValueError("Profiling requested but initial profiling step set past training final step")
last_profiling_step = prof.finished_initial_profile_step

example_batch = None
last_step_completion = datetime.datetime.now()
Expand All @@ -838,9 +834,7 @@ def train_loop(config, recorder, state=None):
metric_logger = MetricLogger(writer, config)
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
for step in np.arange(start_step, config.steps):
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
prof.activate(blocking_object=state, optional_postfix=optional_postfix)
prof.maybe_activate_profiler(step, state)

with jax.profiler.StepTraceAnnotation("train", step_num=step):
with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING):
Expand Down Expand Up @@ -930,8 +924,7 @@ def train_loop(config, recorder, state=None):
prof.deactivate()
break

if step == last_profiling_step or prof.should_deactivate_periodic_profile(step):
prof.deactivate(blocking_object=state)
prof.maybe_deactivate_profiler(step, state)

if step == start_step:
max_utils.print_mem_stats("After params initialized")
Expand Down
Loading