Skip to content

Commit

Permalink
Remove recorder calls from trainer for now
Browse files Browse the repository at this point in the history
  • Loading branch information
dipannita08 committed Dec 2, 2024
1 parent 0221426 commit 8d0c58d
Showing 1 changed file with 0 additions and 8 deletions.
8 changes: 0 additions & 8 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def __init__(
utils.validate_float_dtype(cfg.train_dtype)

# Create the device mesh.
self._maybe_record_event(measurement.Event.START_ACCELERATOR_INIT)
if devices is None:
self._step_log(
"Devices: global=%s local=%s %s",
Expand Down Expand Up @@ -325,7 +324,6 @@ def __init__(
model=self.model,
model_param_partition_specs=model_param_partition_specs,
)
self._maybe_record_event(measurement.Event.END_ACCELERATOR_INIT)

@property
def step(self):
Expand Down Expand Up @@ -830,7 +828,6 @@ def _prepare_training(self, prng_key: Tensor) -> bool:
# Attempt to restore the latest checkpoint, which may contain a saved `_input_iter`.
self.restore_checkpoint(restore_step=None)

self._maybe_record_event(measurement.Event.START_TRAINING_PREPARATION)
if self.step is None:
# If we didn't restore from checkpoint, attempt to build initial state according
# to `cfg.init_state_builder` and initialize the remaining parameters.
Expand All @@ -850,7 +847,6 @@ def _prepare_training(self, prng_key: Tensor) -> bool:
with fs.open(os.path.join(cfg.dir, "model_analysis.txt"), "w") as f:
f.write(model_analysis)

self._maybe_record_event(measurement.Event.END_TRAINING_PREPARATION)
# Log config.
self.summary_writer.log_config(cfg, step=self.step)

Expand Down Expand Up @@ -887,7 +883,6 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int
restore_input_iter = cfg.save_input_iterator
try:
# Try to restore with `input_iter`.
self._maybe_record_event(measurement.Event.START_DATA_LOADING)
step, ckpt_state = self.checkpointer.restore(
step=restore_step,
state=(
Expand All @@ -901,15 +896,13 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int
step,
restore_input_iter,
)
self._maybe_record_event(measurement.Event.END_DATA_LOADING)
except ValueError as e:
logging.warning(
"Attempt to restore checkpoint with restore_input_iter=%s failed: %s",
restore_input_iter,
e,
)
# Restore with a different restore_input_iter setting.
self._maybe_record_event(measurement.Event.START_DATA_LOADING)
restore_input_iter = not restore_input_iter
step, ckpt_state = self.checkpointer.restore(
step=restore_step,
Expand All @@ -924,7 +917,6 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int
step,
restore_input_iter,
)
self._maybe_record_event(measurement.Event.END_DATA_LOADING)
if step is not None:
self._step = step
self._trainer_state = TrainerState(
Expand Down

0 comments on commit 8d0c58d

Please sign in to comment.