From 8d0c58d024c117239a5eebd522a39ef4649775a2 Mon Sep 17 00:00:00 2001 From: Dipannita Shaw Date: Mon, 2 Dec 2024 17:19:58 +0000 Subject: [PATCH] Remove recorder calls from trainer for now --- axlearn/common/trainer.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index 4ffcb810..a6056076 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -238,7 +238,6 @@ def __init__( utils.validate_float_dtype(cfg.train_dtype) # Create the device mesh. - self._maybe_record_event(measurement.Event.START_ACCELERATOR_INIT) if devices is None: self._step_log( "Devices: global=%s local=%s %s", @@ -325,7 +324,6 @@ def __init__( model=self.model, model_param_partition_specs=model_param_partition_specs, ) - self._maybe_record_event(measurement.Event.END_ACCELERATOR_INIT) @property def step(self): @@ -830,7 +828,6 @@ def _prepare_training(self, prng_key: Tensor) -> bool: # Attempt to restore the latest checkpoint, which may contain a saved `_input_iter`. self.restore_checkpoint(restore_step=None) - self._maybe_record_event(measurement.Event.START_TRAINING_PREPARATION) if self.step is None: # If we didn't restore from checkpoint, attempt to build initial state according # to `cfg.init_state_builder` and initialize the remaining parameters. @@ -850,7 +847,6 @@ def _prepare_training(self, prng_key: Tensor) -> bool: with fs.open(os.path.join(cfg.dir, "model_analysis.txt"), "w") as f: f.write(model_analysis) - self._maybe_record_event(measurement.Event.END_TRAINING_PREPARATION) # Log config. self.summary_writer.log_config(cfg, step=self.step) @@ -887,7 +883,6 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int restore_input_iter = cfg.save_input_iterator try: # Try to restore with `input_iter`. - self._maybe_record_event(measurement.Event.START_DATA_LOADING) step, ckpt_state = self.checkpointer.restore( step=restore_step, state=( @@ -901,7 +896,6 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int step, restore_input_iter, ) - self._maybe_record_event(measurement.Event.END_DATA_LOADING) except ValueError as e: logging.warning( "Attempt to restore checkpoint with restore_input_iter=%s failed: %s", @@ -909,7 +903,6 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int e, ) # Restore with a different restore_input_iter setting. - self._maybe_record_event(measurement.Event.START_DATA_LOADING) restore_input_iter = not restore_input_iter step, ckpt_state = self.checkpointer.restore( step=restore_step, @@ -924,7 +917,6 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int step, restore_input_iter, ) - self._maybe_record_event(measurement.Event.END_DATA_LOADING) if step is not None: self._step = step self._trainer_state = TrainerState(