Skip to content

Commit

Permalink
[Tune][Fix]Remove the clear_checkpoint function during Trial restor…
Browse files Browse the repository at this point in the history
…ation error handling. (ray-project#48532)

This PR removes the `clear_checkpoint` function,
so that Tune doesn't try to "restart trials from scratch.
`clear_checkpoint` solved for a legacy use case that doesn't
apply anymore, and "restoration failures" are also now an
edge case for function Trainables and Ray Train usage.

---------

Signed-off-by: Hongpeng Guo <[email protected]>
  • Loading branch information
hongpeng-guo authored and JP-sDEV committed Nov 14, 2024
1 parent 979911e commit f26be00
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
10 changes: 2 additions & 8 deletions python/ray/tune/experiment/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,11 +793,11 @@ def get_error(self) -> Optional[TuneError]:
return None

def _handle_restore_error(self, exc: Exception):
# For Restoration errors, we only increment the restore failure count
# if the number of failures exceeds the restore retry limit.
if self.temporary_state.num_restore_failures >= int(
os.environ.get("TUNE_RESTORE_RETRY_NUM", 0)
):
# Restore was unsuccessful, try again without checkpoint.
self.clear_checkpoint()
self.run_metadata.num_failures += 1
else:
self.temporary_state.num_restore_failures += 1
Expand Down Expand Up @@ -883,12 +883,6 @@ def should_checkpoint(self):
def has_checkpoint(self) -> bool:
return self.checkpoint is not None

def clear_checkpoint(self):
if self.latest_checkpoint_result:
self.latest_checkpoint_result.checkpoint = None
self.temporary_state.restoring_from = None
self.run_metadata.invalidate_cache()

def on_checkpoint(self, checkpoint_result: _TrainingResult):
"""Hook for handling checkpoints taken by the Trainable.
Expand Down
20 changes: 17 additions & 3 deletions python/ray/tune/tests/test_tuner_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,21 @@ def test_tuner_restore_latest_available_checkpoint(

@pytest.mark.parametrize("retry_num", [0, 2])
def test_restore_retry(ray_start_2_cpus, tmpdir, retry_num):
"""Test retrying restore on a trial level by setting `TUNE_RESTORE_RETRY_NUM`."""
"""
Test retrying restore on a trial level by setting `TUNE_RESTORE_RETRY_NUM`.
This unit test holds the following hyperparameters:
- `retry_num`: Maximum number of retry attempts for restoring a trial.
This value is assigned to the environment variable `TUNE_RESTORE_RETRY_NUM`.
If the restoration fails after retry_num attempts, the trial increments its
counter of total number of failures by 1.
- `retry_num_to_fail`: Number of restore attempts to fail. In this test,
retry_num_to_fail is set to 2, causing the first two restore attempts to fail.
- `max_failures`: Maximum allowable failures during training. Here, max_failures is
set to 2, meaning the training process will terminate after two total failures.
"""

class MockTrainable(Trainable):
"""A trainable that can generate one failure during training and
Expand All @@ -546,7 +560,7 @@ class MockTrainable(Trainable):
def setup(self, config):
self.idx = 0
self.tag_file_path = config["tag_file_path"]
self.retry_num_to_fail = config.get("retry_num_to_fail", 2)
self.retry_num_to_fail = 2
self._is_restored = False

def step(self):
Expand Down Expand Up @@ -592,7 +606,7 @@ def load_checkpoint(self, checkpoint_dir):
name="tryout_restore",
stop={"training_iteration": 5},
storage_path=str(tmpdir),
failure_config=FailureConfig(max_failures=1),
failure_config=FailureConfig(max_failures=2),
checkpoint_config=CheckpointConfig(checkpoint_frequency=1),
),
param_space={"tag_file_path": tag_file},
Expand Down

0 comments on commit f26be00

Please sign in to comment.