diff --git a/docs/model-dev-guide/api-guides/apis-howto/api-pytorch-ug.rst b/docs/model-dev-guide/api-guides/apis-howto/api-pytorch-ug.rst index f420e19eea8..da4dd49159a 100644 --- a/docs/model-dev-guide/api-guides/apis-howto/api-pytorch-ug.rst +++ b/docs/model-dev-guide/api-guides/apis-howto/api-pytorch-ug.rst @@ -879,6 +879,7 @@ To run Trainer API solely on-cluster, the code is much simpler: trial_inst = model.MNistTrial(train_context) trainer = det.pytorch.Trainer(trial_inst, train_context) trainer.fit( + max_length=pytorch.Epoch(11), checkpoint_period=pytorch.Batch(100), validation_period=pytorch.Batch(100), latest_checkpoint=det.get_cluster_info().latest_checkpoint, diff --git a/examples/tutorials/mnist_pytorch/adaptive.yaml b/examples/tutorials/mnist_pytorch/adaptive.yaml index d7612374ea7..5953cdad5d6 100644 --- a/examples/tutorials/mnist_pytorch/adaptive.yaml +++ b/examples/tutorials/mnist_pytorch/adaptive.yaml @@ -25,6 +25,6 @@ searcher: metric: validation_loss smaller_is_better: true max_trials: 16 - time_metric: batch + time_metric: batches max_time: 937 # 60,000 training images with batch size 64 entrypoint: python3 train.py --epochs 1 diff --git a/examples/tutorials/mnist_pytorch/train.py b/examples/tutorials/mnist_pytorch/train.py index f7c9a7e61f6..78d7eadd53c 100644 --- a/examples/tutorials/mnist_pytorch/train.py +++ b/examples/tutorials/mnist_pytorch/train.py @@ -80,7 +80,6 @@ def evaluate_batch(self, batch: pytorch.TorchData, batch_idx: int) -> Dict[str, return { "validation_loss": validation_loss, "accuracy": accuracy, - "batch": self.context.current_train_batch(), } @@ -115,7 +114,11 @@ def run(max_length, local: bool = False): with pytorch.init() as train_context: trial = MNistTrial(train_context, hparams=hparams) trainer = pytorch.Trainer(trial, train_context) - trainer.fit(max_length=max_length, latest_checkpoint=latest_checkpoint) + trainer.fit( + max_length=max_length, + latest_checkpoint=latest_checkpoint, + validation_period=pytorch.Batch(100), + ) if __name__ == "__main__": diff --git a/harness/determined/pytorch/_pytorch_trial.py b/harness/determined/pytorch/_pytorch_trial.py index 1363f321483..0e65b4f7cda 100644 --- a/harness/determined/pytorch/_pytorch_trial.py +++ b/harness/determined/pytorch/_pytorch_trial.py @@ -70,11 +70,6 @@ def _from_searcher_unit( else: raise ValueError(f"unrecognized searcher unit {unit}") - def _to_searcher_unit(self) -> core.Unit: - if isinstance(self, Batch): - return core.Unit.BATCHES - return core.Unit.EPOCHS - @staticmethod def _from_values( batches: Optional[int] = None, @@ -124,7 +119,8 @@ class Epoch(TrainUnit): Epoch step type (e.g. Epoch(1) defines 1 epoch) """ - pass + def __str__(self) -> str: + return f"Epoch({self.value})" class Batch(TrainUnit): @@ -136,6 +132,9 @@ class Batch(TrainUnit): def _from_records(records: int, global_batch_size: int) -> "Batch": return Batch(max(records // global_batch_size, 1)) + def __str__(self) -> str: + return f"Batch({self.value})" + class _TrainBoundaryType(enum.Enum): CHECKPOINT = "CHECKPOINT" @@ -195,7 +194,7 @@ def __init__( searcher_metric_name: Optional[str], checkpoint_policy: str, step_zero_validation: bool, - max_length: Optional[TrainUnit], + max_length: TrainUnit, global_batch_size: Optional[int], profiling_enabled: Optional[bool], ) -> None: @@ -219,18 +218,7 @@ def __init__( self.reporting_period = reporting_period # Training loop state - if local_training: - self.trial_id = 0 - assert self.max_length, "max_length must be specified for local-training mode." - self.searcher_unit = self.max_length._to_searcher_unit() - else: - self.trial_id = self.core_context.train._trial_id - configured_units = self.core_context.searcher.get_configured_units() - if configured_units is None: - raise ValueError( - "Searcher units must be configured for training with PyTorchTrial." - ) - self.searcher_unit = configured_units + self.trial_id = 0 if local_training else self.core_context.train._trial_id # Don't initialize the state here because it will be invalid until we load a checkpoint. self.state = None # type: Optional[_TrialState] @@ -247,10 +235,6 @@ def __init__( self.global_batch_size = global_batch_size self.profiling_enabled = profiling_enabled - if self.searcher_unit == core.Unit.RECORDS: - if self.global_batch_size is None: - raise ValueError("global_batch_size required for searcher unit RECORDS.") - self.callbacks = self.trial.build_callbacks() for callback in self.callbacks.values(): if util.is_overridden(callback.on_checkpoint_end, pytorch.PyTorchCallback): @@ -513,17 +497,18 @@ def _stop_requested(self) -> None: if self.context.get_stop_requested(): raise ShouldExit() - def _report_searcher_progress( - self, op: core.SearcherOperation, unit: Optional[core.Unit] - ) -> None: + def _report_training_progress(self) -> None: assert self.state - if unit == core.Unit.BATCHES: - op.report_progress(self.state.batches_trained) - elif unit == core.Unit.RECORDS: - assert self.global_batch_size, "global_batch_size must be specified for RECORDS" - op.report_progress(self.global_batch_size * self.state.batches_trained) - elif unit == core.Unit.EPOCHS: - op.report_progress(self.state.epochs_trained) + assert isinstance(self.max_length.value, int) + + if isinstance(self.max_length, Batch): + progress = self.state.batches_trained / self.max_length.value + elif isinstance(self.max_length, Epoch): + progress = self.state.epochs_trained / self.max_length.value + else: + raise ValueError(f"unexpected train unit type {type(self.max_length)}") + + self.core_context.train.report_progress(progress=progress) def _checkpoint_is_current(self) -> bool: assert self.state @@ -615,7 +600,6 @@ def cleanup_iterator() -> None: self._run() def _run(self) -> None: - ops: Iterator[det.core.SearcherOperation] assert self.state try: @@ -626,47 +610,24 @@ def _run(self) -> None: ): self._validate() - if self.local_training: - assert self.max_length and isinstance(self.max_length.value, int) - ops = iter( - [ - det.core.DummySearcherOperation( - length=self.max_length.value, is_chief=self.is_chief - ) - ] - ) - else: - ops = self.core_context.searcher.operations() - - for op in ops: - if self.local_training: - train_unit = self.max_length - else: - train_unit = TrainUnit._from_searcher_unit( - op.length, self.searcher_unit, self.global_batch_size - ) - assert train_unit - - self._train_for_op( - op=op, - train_boundaries=[ - _TrainBoundary( - step_type=_TrainBoundaryType.TRAIN, - unit=train_unit, - ), - _TrainBoundary( - step_type=_TrainBoundaryType.VALIDATE, unit=self.validation_period - ), - _TrainBoundary( - step_type=_TrainBoundaryType.CHECKPOINT, - unit=self.checkpoint_period, - ), - # Scheduling unit is always configured in batches - _TrainBoundary( - step_type=_TrainBoundaryType.REPORT, unit=self.reporting_period - ), - ], - ) + self._train( + length=Batch(1) if self.test_mode else self.max_length, + train_boundaries=[ + _TrainBoundary( + step_type=_TrainBoundaryType.TRAIN, + unit=self.max_length, + ), + _TrainBoundary( + step_type=_TrainBoundaryType.VALIDATE, unit=self.validation_period + ), + _TrainBoundary( + step_type=_TrainBoundaryType.CHECKPOINT, + unit=self.checkpoint_period, + ), + # Scheduling unit is always configured in batches + _TrainBoundary(step_type=_TrainBoundaryType.REPORT, unit=self.reporting_period), + ], + ) except ShouldExit as e: # Checkpoint unsaved work and exit. if not e.skip_exit_checkpoint and not self._checkpoint_is_current(): @@ -733,20 +694,8 @@ def _train_with_boundaries( # True epoch end return train_boundaries, training_metrics - def _train_for_op( - self, op: core.SearcherOperation, train_boundaries: List[_TrainBoundary] - ) -> None: - if self.test_mode: - train_length = Batch(1) - elif self.local_training: - train_length = self.max_length # type: ignore - else: - train_length = TrainUnit._from_searcher_unit( - op.length, self.searcher_unit, self.global_batch_size - ) # type: ignore - assert train_length - - while self._steps_until_complete(train_length) > 0: + def _train(self, length: TrainUnit, train_boundaries: List[_TrainBoundary]) -> None: + while self._steps_until_complete(length) > 0: train_boundaries, training_metrics = self._train_with_boundaries( self.training_enumerator, train_boundaries ) @@ -767,16 +716,16 @@ def _train_for_op( # Train step limits reached, proceed accordingly. if train_boundary.step_type == _TrainBoundaryType.TRAIN: - if not op._completed and self.is_chief and not step_reported: - self._report_searcher_progress(op, self.searcher_unit) + if self.is_chief and not step_reported: + self._report_training_progress() step_reported = True elif train_boundary.step_type == _TrainBoundaryType.REPORT: - if not op._completed and self.is_chief and not step_reported: - self._report_searcher_progress(op, self.searcher_unit) + if self.is_chief and not step_reported: + self._report_training_progress() step_reported = True elif train_boundary.step_type == _TrainBoundaryType.VALIDATE: if not self._validation_is_current(): - self._validate(op) + self._validate() elif train_boundary.step_type == _TrainBoundaryType.CHECKPOINT: if not self._checkpoint_is_current(): self._checkpoint(already_exiting=False) @@ -789,20 +738,12 @@ def _train_for_op( self._upload_tb_files() self._stop_requested() - # Finished training for op. Perform final checkpoint/validation if necessary. + # Finished training. Perform final checkpoint/validation if necessary. if not self._validation_is_current(): - self._validate(op) + self._validate() if not self._checkpoint_is_current(): self._checkpoint(already_exiting=False) - # Test mode will break after one batch despite not completing op. - if self.is_chief and not self.test_mode: - # The only case where op isn't reported as completed is if we restarted but - # op.length was already trained for and validated on; in that case just raise - # ShouldExit; we have nothing to do. - if not op._completed: - raise ShouldExit(skip_exit_checkpoint=True) - def _check_searcher_metric(self, val_metrics: Dict) -> Any: if self.searcher_metric_name not in val_metrics: raise RuntimeError( @@ -913,7 +854,7 @@ def _train_batch( return training_metrics @torch.no_grad() - def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dict[str, Any]: + def _validate(self) -> Dict[str, Any]: # Report a validation step is starting. if self.is_chief: self.core_context.train.set_status("validating") @@ -1049,24 +990,14 @@ def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dic # Get best validation before reporting metrics. best_validation_before = self.core_context.train.get_experiment_best_validation() - self.core_context.train.report_validation_metrics(self.state.batches_trained, metrics) + # We report "batch" and "epoch" only if these keys are not already reported in user + # metrics. + metrics["batches"] = metrics.get("batches", self.state.batches_trained) + metrics["epochs"] = metrics.get("epochs", self.state.epochs_trained) - searcher_metric = None - - # Report searcher status. - if self.is_chief and searcher_op: - if self.local_training: - searcher_length = self.max_length - else: - searcher_length = TrainUnit._from_searcher_unit( - searcher_op.length, self.searcher_unit, self.global_batch_size - ) - if self.searcher_metric_name: - searcher_metric = self._check_searcher_metric(metrics) - - assert searcher_length - if self._steps_until_complete(searcher_length) < 1 and not searcher_op._completed: - searcher_op.report_completed(searcher_metric) + self.core_context.train.report_validation_metrics( + steps_completed=self.state.batches_trained, metrics=metrics + ) should_checkpoint = False @@ -1079,6 +1010,7 @@ def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dic assert ( self.searcher_metric_name ), "checkpoint policy 'best' but searcher metric name not defined" + searcher_metric = self._check_searcher_metric(metrics) assert searcher_metric is not None if self._is_best_validation(now=searcher_metric, before=best_validation_before): diff --git a/harness/determined/pytorch/_trainer.py b/harness/determined/pytorch/_trainer.py index 180b4ca214a..bba0f21bd00 100644 --- a/harness/determined/pytorch/_trainer.py +++ b/harness/determined/pytorch/_trainer.py @@ -2,6 +2,7 @@ import logging import random import sys +import warnings from typing import Any, Dict, Iterator, Optional import numpy as np @@ -94,11 +95,16 @@ def fit( of ``collections.abc.Container`` (list, tuple, etc.). For example, ``Batch(100)`` would validate every 100 batches, while ``Batch([5, 30, 45])`` would validate after every 5th, 30th, and 45th batch. - max_length: The maximum number of steps to train for. This value is required and - only applicable in local training mode. For on-cluster training, this value will - be ignored; the searcher’s ``max_length`` must be configured from the experiment - configuration. This is a ``TrainUnit`` type (``Batch`` or ``Epoch``) which takes an - ``int``. For example, ``Epoch(1)`` would train for a maximum length of one epoch. + max_length: The maximum number of steps to train for. This is a ``TrainUnit`` type + (``Batch`` or ``Epoch``) which takes an ``int``. For example, ``Epoch(1)`` would + train for a maximum length of one epoch. + + .. note:: + + If using an ASHA searcher, this value should match the searcher config values in + the experiment config (i.e. ``Epoch(1)`` = `max_time: 1` and `time_metric: + "epochs"`). + reporting_period: The number of steps to train for before reporting metrics and searcher progress. For local training mode, metrics are printed to stdout. This is a ``TrainUnit`` type (``Batch`` or ``Epoch``) which can take an ``int`` or @@ -146,8 +152,13 @@ def fit( if max_length is None: raise ValueError("max_length must be defined in local training mode.") - if not isinstance(max_length.value, int): - raise TypeError("max_length must be configured in TrainUnit(int) types.") + if not isinstance(max_length, (pytorch.Batch, pytorch.Epoch)) or not isinstance( + max_length.value, int + ): + raise TypeError( + "max_length must either be a det.pytorch.Batch(int) or det.pytorch.Epoch(int) " + "type" + ) if profiling_enabled: logger.warning("Profiling is not supported in local training mode.") @@ -160,12 +171,6 @@ def fit( if test_mode: raise ValueError("test_mode is only supported in local training mode.") - if max_length is not None: - logger.warning( - "max_length is ignored when training on-cluster. Please configure the " - "searcher instead." - ) - assert self._info, "Unable to detect cluster info." if latest_checkpoint is None and self._info.latest_checkpoint is not None: logger.warning( @@ -176,11 +181,45 @@ def fit( smaller_is_better = bool(self._info.trial._config["searcher"]["smaller_is_better"]) searcher_metric_name = self._info.trial._config["searcher"]["metric"] + steps_completed = int(self._info.trial._steps_completed) global_batch_size = self._info.trial.hparams.get("global_batch_size", None) if global_batch_size: global_batch_size = int(global_batch_size) + # Backwards compatibility: try to parse legacy `searcher.max_length` if `max_length` + # isn't passed in. + if max_length is None: + max_length_val = core._parse_searcher_max_length(self._info.trial._config) + if max_length_val: + warnings.warn( + "Configuring `max_length` from the `searcher.max_length` experiment " + "config, which was deprecated in XXYYZZ and will be removed in a future " + "release. Please set `fit(max_length=X)` with your desired training length " + "directly.", + FutureWarning, + stacklevel=2, + ) + max_length_unit = core._parse_searcher_units(self._info.trial._config) + max_length = pytorch.TrainUnit._from_searcher_unit( + max_length_val, max_length_unit, global_batch_size + ) + + # If we couldn't parse the legacy `searcher.max_length`, raise an error. + if not max_length: + raise ValueError( + "`fit(max_length=X)` must be set with your desired training length." + ) + if not isinstance(max_length, (pytorch.Batch, pytorch.Epoch)) or not isinstance( + max_length.value, int + ): + raise TypeError( + "max_length must either be a det.pytorch.Batch(int) or det.pytorch.Epoch(int) " + "type." + ) + + _check_searcher_length(exp_conf=self._info.trial._config, max_length=max_length) + trial_controller = pytorch._PyTorchTrialController( trial_inst=self._trial, context=self._context, @@ -203,6 +242,43 @@ def fit( trial_controller.run() +def _check_searcher_length( + exp_conf: Dict[str, Any], + max_length: pytorch.TrainUnit, +) -> None: + """ + Certain searchers (ASHA and Adaptive ASHA) require configuring the maximum training length in + the experiment config. We check that the `max_length` passed to `fit()` matches the experiment + config and log warnings if it doesn't. + """ + time_metric = exp_conf["searcher"].get("time_metric") + if time_metric is not None: + max_time = exp_conf["searcher"].get("max_time") + assert max_time, "`searcher.max_time` not configured" + if time_metric == "batches": + if not isinstance(max_length, pytorch.Batch) or max_length.value != max_time: + logger.warning( + f"`max_length` passed into `fit()` method ({max_length}) does not match " + f"`searcher.max_time` and `searcher.time_metric` from the experiment config " + f"(Batch(value={max_time})). This may result in unexpected hyperparameter " + f"search behavior." + ) + elif time_metric == "epochs": + if not isinstance(max_length, pytorch.Epoch) or max_length.value != max_time: + logger.warning( + f"`max_length` passed into `fit()` method ({max_length}) does not match " + f"`searcher.max_time` and `searcher.time_metric` from the experiment config " + f"(Epoch(value={max_time})). This may result in unexpected hyperparameter " + f"search behavior." + ) + else: + logger.warning( + "`searcher.time_metric` must be either 'batches' or 'epochs' " + f"for training with PyTorchTrials, but got {time_metric}. " + f"Training will proceed with {max_length} but may result in unexpected behavior." + ) + + def _initialize_distributed_backend() -> Optional[core.DistributedContext]: info = det.get_cluster_info() diff --git a/harness/tests/experiment/pytorch/test_pytorch_trial.py b/harness/tests/experiment/pytorch/test_pytorch_trial.py index 4af913a4333..5c591a9f9e9 100644 --- a/harness/tests/experiment/pytorch/test_pytorch_trial.py +++ b/harness/tests/experiment/pytorch/test_pytorch_trial.py @@ -1228,14 +1228,14 @@ def test_trial_validation_checkpointing(self, tmp_path: pathlib.Path): return_value=checkpoint_condition["best_validation"] ) controller._checkpoint = mock.MagicMock() - controller._validate(det.core.DummySearcherOperation(length=100, is_chief=True)) + controller._validate() controller.core_context.train.get_experiment_best_validation.assert_called_once() if checkpoint_condition["checkpoint"]: controller._checkpoint.assert_called_once() controller.core_context.train.get_experiment_best_validation.reset_mock() controller._checkpoint.reset_mock() - @mock.patch.object(det.core.DummySearcherOperation, "report_progress") + @mock.patch.object(det.core.DummyTrainContext, "report_progress") def test_searcher_progress_reporting(self, mock_report_progress: mock.MagicMock): trial, controller = pytorch_utils.create_trial_and_trial_controller( trial_class=pytorch_onevar_model.OneVarTrial, @@ -1246,8 +1246,8 @@ def test_searcher_progress_reporting(self, mock_report_progress: mock.MagicMock) ) controller.run() - exp_prog = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100] - got_prog = [x.args[0] for x in mock_report_progress.call_args_list] + exp_prog = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + got_prog = [x.kwargs["progress"] for x in mock_report_progress.call_args_list] assert exp_prog == got_prog def test_test_mode(self):