Skip to content

Commit

Permalink
chore: refactor max length out of pytorch trials (#10091)
Browse files Browse the repository at this point in the history
  • Loading branch information
azhou-determined authored and rb-determined-ai committed Oct 23, 2024
1 parent ee11e57 commit 0210b1c
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,7 @@ To run Trainer API solely on-cluster, the code is much simpler:
trial_inst = model.MNistTrial(train_context)
trainer = det.pytorch.Trainer(trial_inst, train_context)
trainer.fit(
max_length=pytorch.Epoch(11),
checkpoint_period=pytorch.Batch(100),
validation_period=pytorch.Batch(100),
latest_checkpoint=det.get_cluster_info().latest_checkpoint,
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorials/mnist_pytorch/adaptive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ searcher:
metric: validation_loss
smaller_is_better: true
max_trials: 16
time_metric: batch
time_metric: batches
max_time: 937 # 60,000 training images with batch size 64
entrypoint: python3 train.py --epochs 1
7 changes: 5 additions & 2 deletions examples/tutorials/mnist_pytorch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def evaluate_batch(self, batch: pytorch.TorchData, batch_idx: int) -> Dict[str,
return {
"validation_loss": validation_loss,
"accuracy": accuracy,
"batch": self.context.current_train_batch(),
}


Expand Down Expand Up @@ -115,7 +114,11 @@ def run(max_length, local: bool = False):
with pytorch.init() as train_context:
trial = MNistTrial(train_context, hparams=hparams)
trainer = pytorch.Trainer(trial, train_context)
trainer.fit(max_length=max_length, latest_checkpoint=latest_checkpoint)
trainer.fit(
max_length=max_length,
latest_checkpoint=latest_checkpoint,
validation_period=pytorch.Batch(100),
)


if __name__ == "__main__":
Expand Down
176 changes: 54 additions & 122 deletions harness/determined/pytorch/_pytorch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ def _from_searcher_unit(
else:
raise ValueError(f"unrecognized searcher unit {unit}")

def _to_searcher_unit(self) -> core.Unit:
if isinstance(self, Batch):
return core.Unit.BATCHES
return core.Unit.EPOCHS

@staticmethod
def _from_values(
batches: Optional[int] = None,
Expand Down Expand Up @@ -124,7 +119,8 @@ class Epoch(TrainUnit):
Epoch step type (e.g. Epoch(1) defines 1 epoch)
"""

pass
def __str__(self) -> str:
return f"Epoch({self.value})"


class Batch(TrainUnit):
Expand All @@ -136,6 +132,9 @@ class Batch(TrainUnit):
def _from_records(records: int, global_batch_size: int) -> "Batch":
return Batch(max(records // global_batch_size, 1))

def __str__(self) -> str:
return f"Batch({self.value})"


class _TrainBoundaryType(enum.Enum):
CHECKPOINT = "CHECKPOINT"
Expand Down Expand Up @@ -195,7 +194,7 @@ def __init__(
searcher_metric_name: Optional[str],
checkpoint_policy: str,
step_zero_validation: bool,
max_length: Optional[TrainUnit],
max_length: TrainUnit,
global_batch_size: Optional[int],
profiling_enabled: Optional[bool],
) -> None:
Expand All @@ -219,18 +218,7 @@ def __init__(
self.reporting_period = reporting_period

# Training loop state
if local_training:
self.trial_id = 0
assert self.max_length, "max_length must be specified for local-training mode."
self.searcher_unit = self.max_length._to_searcher_unit()
else:
self.trial_id = self.core_context.train._trial_id
configured_units = self.core_context.searcher.get_configured_units()
if configured_units is None:
raise ValueError(
"Searcher units must be configured for training with PyTorchTrial."
)
self.searcher_unit = configured_units
self.trial_id = 0 if local_training else self.core_context.train._trial_id

# Don't initialize the state here because it will be invalid until we load a checkpoint.
self.state = None # type: Optional[_TrialState]
Expand All @@ -247,10 +235,6 @@ def __init__(
self.global_batch_size = global_batch_size
self.profiling_enabled = profiling_enabled

if self.searcher_unit == core.Unit.RECORDS:
if self.global_batch_size is None:
raise ValueError("global_batch_size required for searcher unit RECORDS.")

self.callbacks = self.trial.build_callbacks()
for callback in self.callbacks.values():
if util.is_overridden(callback.on_checkpoint_end, pytorch.PyTorchCallback):
Expand Down Expand Up @@ -513,17 +497,18 @@ def _stop_requested(self) -> None:
if self.context.get_stop_requested():
raise ShouldExit()

def _report_searcher_progress(
self, op: core.SearcherOperation, unit: Optional[core.Unit]
) -> None:
def _report_training_progress(self) -> None:
assert self.state
if unit == core.Unit.BATCHES:
op.report_progress(self.state.batches_trained)
elif unit == core.Unit.RECORDS:
assert self.global_batch_size, "global_batch_size must be specified for RECORDS"
op.report_progress(self.global_batch_size * self.state.batches_trained)
elif unit == core.Unit.EPOCHS:
op.report_progress(self.state.epochs_trained)
assert isinstance(self.max_length.value, int)

if isinstance(self.max_length, Batch):
progress = self.state.batches_trained / self.max_length.value
elif isinstance(self.max_length, Epoch):
progress = self.state.epochs_trained / self.max_length.value
else:
raise ValueError(f"unexpected train unit type {type(self.max_length)}")

self.core_context.train.report_progress(progress=progress)

def _checkpoint_is_current(self) -> bool:
assert self.state
Expand Down Expand Up @@ -615,7 +600,6 @@ def cleanup_iterator() -> None:
self._run()

def _run(self) -> None:
ops: Iterator[det.core.SearcherOperation]
assert self.state

try:
Expand All @@ -626,47 +610,24 @@ def _run(self) -> None:
):
self._validate()

if self.local_training:
assert self.max_length and isinstance(self.max_length.value, int)
ops = iter(
[
det.core.DummySearcherOperation(
length=self.max_length.value, is_chief=self.is_chief
)
]
)
else:
ops = self.core_context.searcher.operations()

for op in ops:
if self.local_training:
train_unit = self.max_length
else:
train_unit = TrainUnit._from_searcher_unit(
op.length, self.searcher_unit, self.global_batch_size
)
assert train_unit

self._train_for_op(
op=op,
train_boundaries=[
_TrainBoundary(
step_type=_TrainBoundaryType.TRAIN,
unit=train_unit,
),
_TrainBoundary(
step_type=_TrainBoundaryType.VALIDATE, unit=self.validation_period
),
_TrainBoundary(
step_type=_TrainBoundaryType.CHECKPOINT,
unit=self.checkpoint_period,
),
# Scheduling unit is always configured in batches
_TrainBoundary(
step_type=_TrainBoundaryType.REPORT, unit=self.reporting_period
),
],
)
self._train(
length=Batch(1) if self.test_mode else self.max_length,
train_boundaries=[
_TrainBoundary(
step_type=_TrainBoundaryType.TRAIN,
unit=self.max_length,
),
_TrainBoundary(
step_type=_TrainBoundaryType.VALIDATE, unit=self.validation_period
),
_TrainBoundary(
step_type=_TrainBoundaryType.CHECKPOINT,
unit=self.checkpoint_period,
),
# Scheduling unit is always configured in batches
_TrainBoundary(step_type=_TrainBoundaryType.REPORT, unit=self.reporting_period),
],
)
except ShouldExit as e:
# Checkpoint unsaved work and exit.
if not e.skip_exit_checkpoint and not self._checkpoint_is_current():
Expand Down Expand Up @@ -733,20 +694,8 @@ def _train_with_boundaries(
# True epoch end
return train_boundaries, training_metrics

def _train_for_op(
self, op: core.SearcherOperation, train_boundaries: List[_TrainBoundary]
) -> None:
if self.test_mode:
train_length = Batch(1)
elif self.local_training:
train_length = self.max_length # type: ignore
else:
train_length = TrainUnit._from_searcher_unit(
op.length, self.searcher_unit, self.global_batch_size
) # type: ignore
assert train_length

while self._steps_until_complete(train_length) > 0:
def _train(self, length: TrainUnit, train_boundaries: List[_TrainBoundary]) -> None:
while self._steps_until_complete(length) > 0:
train_boundaries, training_metrics = self._train_with_boundaries(
self.training_enumerator, train_boundaries
)
Expand All @@ -767,16 +716,16 @@ def _train_for_op(

# Train step limits reached, proceed accordingly.
if train_boundary.step_type == _TrainBoundaryType.TRAIN:
if not op._completed and self.is_chief and not step_reported:
self._report_searcher_progress(op, self.searcher_unit)
if self.is_chief and not step_reported:
self._report_training_progress()
step_reported = True
elif train_boundary.step_type == _TrainBoundaryType.REPORT:
if not op._completed and self.is_chief and not step_reported:
self._report_searcher_progress(op, self.searcher_unit)
if self.is_chief and not step_reported:
self._report_training_progress()
step_reported = True
elif train_boundary.step_type == _TrainBoundaryType.VALIDATE:
if not self._validation_is_current():
self._validate(op)
self._validate()
elif train_boundary.step_type == _TrainBoundaryType.CHECKPOINT:
if not self._checkpoint_is_current():
self._checkpoint(already_exiting=False)
Expand All @@ -789,20 +738,12 @@ def _train_for_op(
self._upload_tb_files()
self._stop_requested()

# Finished training for op. Perform final checkpoint/validation if necessary.
# Finished training. Perform final checkpoint/validation if necessary.
if not self._validation_is_current():
self._validate(op)
self._validate()
if not self._checkpoint_is_current():
self._checkpoint(already_exiting=False)

# Test mode will break after one batch despite not completing op.
if self.is_chief and not self.test_mode:
# The only case where op isn't reported as completed is if we restarted but
# op.length was already trained for and validated on; in that case just raise
# ShouldExit; we have nothing to do.
if not op._completed:
raise ShouldExit(skip_exit_checkpoint=True)

def _check_searcher_metric(self, val_metrics: Dict) -> Any:
if self.searcher_metric_name not in val_metrics:
raise RuntimeError(
Expand Down Expand Up @@ -913,7 +854,7 @@ def _train_batch(
return training_metrics

@torch.no_grad()
def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dict[str, Any]:
def _validate(self) -> Dict[str, Any]:
# Report a validation step is starting.
if self.is_chief:
self.core_context.train.set_status("validating")
Expand Down Expand Up @@ -1049,24 +990,14 @@ def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dic
# Get best validation before reporting metrics.
best_validation_before = self.core_context.train.get_experiment_best_validation()

self.core_context.train.report_validation_metrics(self.state.batches_trained, metrics)
# We report "batch" and "epoch" only if these keys are not already reported in user
# metrics.
metrics["batches"] = metrics.get("batches", self.state.batches_trained)
metrics["epochs"] = metrics.get("epochs", self.state.epochs_trained)

searcher_metric = None

# Report searcher status.
if self.is_chief and searcher_op:
if self.local_training:
searcher_length = self.max_length
else:
searcher_length = TrainUnit._from_searcher_unit(
searcher_op.length, self.searcher_unit, self.global_batch_size
)
if self.searcher_metric_name:
searcher_metric = self._check_searcher_metric(metrics)

assert searcher_length
if self._steps_until_complete(searcher_length) < 1 and not searcher_op._completed:
searcher_op.report_completed(searcher_metric)
self.core_context.train.report_validation_metrics(
steps_completed=self.state.batches_trained, metrics=metrics
)

should_checkpoint = False

Expand All @@ -1079,6 +1010,7 @@ def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dic
assert (
self.searcher_metric_name
), "checkpoint policy 'best' but searcher metric name not defined"
searcher_metric = self._check_searcher_metric(metrics)
assert searcher_metric is not None

if self._is_best_validation(now=searcher_metric, before=best_validation_before):
Expand Down
Loading

0 comments on commit 0210b1c

Please sign in to comment.