Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Jobs] Add option to specify max_restarts_on_errors #4169

Merged
merged 35 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8eba87b
Add option to specify `max_retry_on_failure`
Michaelvll Oct 24, 2024
7294204
fix recover counts
Michaelvll Oct 24, 2024
3ab9619
fix log streaming
Michaelvll Oct 25, 2024
7145842
fix docs
Michaelvll Oct 25, 2024
8bfd59a
fix
Michaelvll Oct 25, 2024
e459271
fix
Michaelvll Oct 25, 2024
de78310
fix
Michaelvll Oct 25, 2024
23345c0
fix
Michaelvll Oct 25, 2024
3709cd6
fix default value
Michaelvll Oct 25, 2024
92e7c35
Fix spinner
Michaelvll Oct 25, 2024
935491e
Add unit test for default strategy
Michaelvll Oct 25, 2024
90f95b1
fix test
Michaelvll Oct 25, 2024
ceff8cd
format
Michaelvll Oct 25, 2024
a20fa5c
Update docs/source/examples/managed-jobs.rst
Michaelvll Oct 25, 2024
b5b35f4
rename to restarts
Michaelvll Oct 25, 2024
149c9fd
Merge branch 'jobs-max-retry-on-failure' of github.com:skypilot-org/s…
Michaelvll Oct 25, 2024
1947605
Merge branch 'master' of github.com:skypilot-org/skypilot into jobs-m…
Michaelvll Oct 25, 2024
7cf2b17
Update docs/source/examples/managed-jobs.rst
Michaelvll Oct 25, 2024
44882fe
update docs
Michaelvll Oct 25, 2024
599a838
Merge branch 'master' of github.com:skypilot-org/skypilot into jobs-m…
Michaelvll Oct 25, 2024
a7d266b
warning instead of error out
Michaelvll Oct 28, 2024
087414b
Update docs/source/examples/managed-jobs.rst
Michaelvll Oct 28, 2024
3ffadb1
rename
Michaelvll Oct 28, 2024
da26fc1
add comment
Michaelvll Oct 28, 2024
bea7fe0
Merge branch 'jobs-max-retry-on-failure' of github.com:skypilot-org/s…
Michaelvll Oct 28, 2024
987df3d
fix
Michaelvll Oct 28, 2024
c3e88a2
rename
Michaelvll Oct 28, 2024
92acfe3
Update sky/execution.py
Michaelvll Oct 28, 2024
71e9518
Update sky/execution.py
Michaelvll Oct 28, 2024
acd96ab
address comments
Michaelvll Oct 28, 2024
02b0b19
Merge branch 'jobs-max-retry-on-failure' of github.com:skypilot-org/s…
Michaelvll Oct 28, 2024
86d0d64
format
Michaelvll Oct 28, 2024
0573c33
commit changes for docs
Michaelvll Oct 28, 2024
df62ee4
Format
Michaelvll Oct 29, 2024
d79cd2f
Merge branch 'master' of github.com:skypilot-org/skypilot into jobs-m…
Michaelvll Oct 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/examples/managed-jobs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,24 @@ candidate resources for a job. See documentation :ref:`here
In this example, SkyPilot will perform cost optimizations to select the resource to use, which almost certainly
will be spot instances. If spot instances are not available, SkyPilot will fall back to launch on-demand instances.


Jobs Restarts on User Code Failure
-----------------------------------

By default, SkyPilot will try to recover a job when its underlying cluster is preempted or failed. Any user code failures (non-zero exit codes) are not auto-recovered.

In some cases, you may want a job to automatically restart on its own failures, e.g., when a training job crashes due to a Nvidia driver issue or NCCL timeouts. To specify this, you
can set :code:`max_restarts_on_errors` in :code:`resources.job_recovery` in the job YAML file.

.. code-block:: yaml

resources:
accelerators: A100:8
job_recovery:
# Restart the job up to 3 times on user code errors.
max_restarts_on_errors: 3


More advanced policies for resource selection, such as the `Can't Be Late
<https://www.usenix.org/conference/nsdi24/presentation/wu-zhanghao>`__ (NSDI'24)
paper, may be supported in the future.
Expand Down
4 changes: 4 additions & 0 deletions docs/source/reference/yaml-spec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ Available fields:
#
# default: EAGER_NEXT_REGION
job_recovery: none
# Or, to allow up to 3 restarts (default: 0) on user code errors:
# job_recovery:
# strategy: EAGER_NEXT_REGION
# max_restarts_on_errors: 3

# Disk size in GB to allocate for OS (mounted at /). Increase this if you
# have a large working directory or tasks that write out large outputs.
Expand Down
9 changes: 5 additions & 4 deletions sky/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,11 @@ def _execute(
task = dag.tasks[0]

if any(r.job_recovery is not None for r in task.resources):
with ux_utils.print_exception_no_traceback():
raise ValueError(
'Job recovery is specified in the task. To launch a '
'managed job, please use: sky jobs launch')
logger.warning(
f'{colorama.Style.DIM}The task has `job_recovery` specified, '
'but is launched as an unmanaged job. It will be ignored.'
'To enable job recovery, use managed jobs: sky jobs launch.'
f'{colorama.Style.RESET_ALL}')

cluster_exists = False
if cluster_name is not None:
Expand Down
60 changes: 38 additions & 22 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,22 +160,26 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
if task_id == 0:
submitted_at = backend_utils.get_timestamp_from_run_timestamp(
self._backend.run_timestamp)
assert task.name is not None, task
cluster_name = managed_job_utils.generate_managed_job_cluster_name(
task.name, self._job_id)
self._strategy_executor = recovery_strategy.StrategyExecutor.make(
cluster_name, self._backend, task, self._retry_until_up)
managed_job_state.set_submitted(
self._job_id,
task_id,
self._backend.run_timestamp,
submitted_at,
resources_str=backend_utils.get_task_resources_str(
task, is_managed_job=True),
specs={
'max_restarts_on_errors':
self._strategy_executor.max_restarts_on_errors
},
callback_func=callback_func)
logger.info(
f'Submitted managed job {self._job_id} (task: {task_id}, name: '
f'{task.name!r}); {constants.TASK_ID_ENV_VAR}: {task_id_env_var}')
assert task.name is not None, task
cluster_name = managed_job_utils.generate_managed_job_cluster_name(
task.name, self._job_id)
self._strategy_executor = recovery_strategy.StrategyExecutor.make(
cluster_name, self._backend, task, self._retry_until_up)

logger.info('Started monitoring.')
managed_job_state.set_starting(job_id=self._job_id,
Expand Down Expand Up @@ -283,23 +287,35 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
failure_reason = (
'To see the details, run: '
f'sky jobs logs --controller {self._job_id}')

managed_job_state.set_failed(
self._job_id,
task_id,
failure_type=managed_job_status,
failure_reason=failure_reason,
end_time=end_time,
callback_func=callback_func)
return False
# Although the cluster is healthy, we fail to access the
# job status. Try to recover the job (will not restart the
# cluster, if the cluster is healthy).
assert job_status is None, job_status
logger.info('Failed to fetch the job status while the '
'cluster is healthy. Try to recover the job '
'(the cluster will not be restarted).')

should_restart_on_failure = (
self._strategy_executor.should_restart_on_failure())
if should_restart_on_failure:
max_restarts = (
self._strategy_executor.max_restarts_on_errors)
logger.info(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of our users mentioned backoff between restarts - any thoughts on adding it here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have backoff between launches if the resources are not available across all regions/clouds. I feel adding additional backoff between job restarts is not that clean.

f'User program crashed '
f'({managed_job_status.value}). '
f'Retry the job as max_restarts_on_errors is '
f'set to {max_restarts}. '
f'[{self._strategy_executor.restart_cnt_on_failure}'
f'/{max_restarts}]')
else:
managed_job_state.set_failed(
self._job_id,
task_id,
failure_type=managed_job_status,
failure_reason=failure_reason,
end_time=end_time,
callback_func=callback_func)
return False
else:
# Although the cluster is healthy, we fail to access the
# job status. Try to recover the job (will not restart the
# cluster, if the cluster is healthy).
assert job_status is None, job_status
logger.info('Failed to fetch the job status while the '
'cluster is healthy. Try to recover the job '
'(the cluster will not be restarted).')
# When the handle is None, the cluster should be cleaned up already.
if handle is not None:
resources = handle.launched_resources
Expand Down
35 changes: 30 additions & 5 deletions sky/jobs/recovery_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ class StrategyExecutor:
RETRY_INIT_GAP_SECONDS = 60

def __init__(self, cluster_name: str, backend: 'backends.Backend',
task: 'task_lib.Task', retry_until_up: bool) -> None:
task: 'task_lib.Task', retry_until_up: bool,
max_restarts_on_errors: int) -> None:
"""Initialize the strategy executor.

Args:
Expand All @@ -82,6 +83,8 @@ def __init__(self, cluster_name: str, backend: 'backends.Backend',
self.cluster_name = cluster_name
self.backend = backend
self.retry_until_up = retry_until_up
self.max_restarts_on_errors = max_restarts_on_errors
self.restart_cnt_on_failure = 0

def __init_subclass__(cls, name: str, default: bool = False):
RECOVERY_STRATEGIES[name] = cls
Expand Down Expand Up @@ -109,8 +112,17 @@ def make(cls, cluster_name: str, backend: 'backends.Backend',
# set the new_task_resources to be the same type (list or set) as the
# original task.resources
task.set_resources(type(task.resources)(new_resources_list))
return RECOVERY_STRATEGIES[job_recovery](cluster_name, backend, task,
retry_until_up)
if isinstance(job_recovery, dict):
job_recovery_name = job_recovery.pop('strategy',
DEFAULT_RECOVERY_STRATEGY)
max_restarts_on_errors = job_recovery.pop('max_restarts_on_errors',
0)
else:
job_recovery_name = job_recovery
max_restarts_on_errors = 0
return RECOVERY_STRATEGIES[job_recovery_name](cluster_name, backend,
task, retry_until_up,
max_restarts_on_errors)

def launch(self) -> float:
"""Launch the cluster for the first time.
Expand Down Expand Up @@ -368,6 +380,17 @@ def _launch(self,
f'{gap_seconds:.1f} seconds.')
time.sleep(gap_seconds)

def should_restart_on_failure(self) -> bool:
"""Increments counter & checks if job should be restarted on a failure.

Returns:
True if the job should be restarted, otherwise False.
"""
self.restart_cnt_on_failure += 1
if self.restart_cnt_on_failure > self.max_restarts_on_errors:
return False
return True


class FailoverStrategyExecutor(StrategyExecutor, name='FAILOVER',
default=False):
Expand All @@ -376,8 +399,10 @@ class FailoverStrategyExecutor(StrategyExecutor, name='FAILOVER',
_MAX_RETRY_CNT = 240 # Retry for 4 hours.

def __init__(self, cluster_name: str, backend: 'backends.Backend',
task: 'task_lib.Task', retry_until_up: bool) -> None:
super().__init__(cluster_name, backend, task, retry_until_up)
task: 'task_lib.Task', retry_until_up: bool,
max_restarts_on_errors: int) -> None:
super().__init__(cluster_name, backend, task, retry_until_up,
max_restarts_on_errors)
# Note down the cloud/region of the launched cluster, so that we can
# first retry in the same cloud/region. (Inside recover() we may not
# rely on cluster handle, as it can be None if the cluster is
Expand Down
38 changes: 33 additions & 5 deletions sky/jobs/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# TODO(zhwu): maybe use file based status instead of database, so
# that we can easily switch to a s3-based storage.
import enum
import json
import pathlib
import sqlite3
import time
Expand Down Expand Up @@ -65,7 +66,8 @@ def _get_db_path() -> str:
failure_reason TEXT,
spot_job_id INTEGER,
task_id INTEGER DEFAULT 0,
task_name TEXT)""")
task_name TEXT,
specs TEXT)""")
_CONN.commit()

db_utils.add_column_to_table(_CURSOR, _CONN, 'spot', 'failure_reason', 'TEXT')
Expand All @@ -92,6 +94,17 @@ def _get_db_path() -> str:
'TEXT',
copy_from='job_name')

# Specs is some useful information about the task, e.g., the
# max_restarts_on_errors value. It is stored in JSON format.
db_utils.add_column_to_table(_CURSOR,
_CONN,
'spot',
'specs',
'TEXT',
value_to_replace_existing_entries=json.dumps({
'max_restarts_on_errors': 0,
}))

# `job_info` contains the mapping from job_id to the job_name.
# In the future, it may contain more information about each job.
_CURSOR.execute("""\
Expand Down Expand Up @@ -130,7 +143,8 @@ def _get_db_path() -> str:
'task_name',
# columns from the job_info table
'_job_info_job_id', # This should be the same as job_id
'job_name'
'job_name',
'specs',
]


Expand Down Expand Up @@ -283,7 +297,8 @@ def set_pending(job_id: int, task_id: int, task_name: str, resources_str: str):

def set_submitted(job_id: int, task_id: int, run_timestamp: str,
submit_time: float, resources_str: str,
callback_func: CallbackType):
specs: Dict[str, Union[str,
int]], callback_func: CallbackType):
"""Set the task to submitted.

Args:
Expand All @@ -293,6 +308,8 @@ def set_submitted(job_id: int, task_id: int, run_timestamp: str,
determine the log directory of the managed task.
submit_time: The time when the managed task is submitted.
resources_str: The resources string of the managed task.
specs: The specs of the managed task.
callback_func: The callback function.
"""
# Use the timestamp in the `run_timestamp` ('sky-2022-10...'), to make
# the log directory and submission time align with each other, so as to
Expand All @@ -306,11 +323,12 @@ def set_submitted(job_id: int, task_id: int, run_timestamp: str,
resources=(?),
submitted_at=(?),
status=(?),
run_timestamp=(?)
run_timestamp=(?),
specs=(?)
WHERE spot_job_id=(?) AND
task_id=(?)""",
(resources_str, submit_time, ManagedJobStatus.SUBMITTED.value,
run_timestamp, job_id, task_id))
run_timestamp, json.dumps(specs), job_id, task_id))
callback_func('SUBMITTED')


Expand Down Expand Up @@ -619,3 +637,13 @@ def get_latest_job_id() -> Optional[int]:
for (job_id,) in rows:
return job_id
return None


def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]:
with db_utils.safe_cursor(_DB_PATH) as cursor:
task_specs = cursor.execute(
"""\
SELECT specs FROM spot
WHERE spot_job_id=(?) AND task_id=(?)""",
(job_id, task_id)).fetchone()
return json.loads(task_specs[0])
32 changes: 28 additions & 4 deletions sky/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
# state, after the job finished. This is a safeguard to avoid the case where
# the managed job status fails to be updated and keep the `sky jobs logs`
# blocking for a long time.
_FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS = 20
_FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS = 25


class UserSignal(enum.Enum):
Expand Down Expand Up @@ -392,8 +392,12 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
f'INFO: Log for the current task ({task_id}) '
'is finished. Waiting for the next task\'s log '
'to be started.')
status_display.update('Waiting for the next task: '
f'{task_id + 1}.')
# Add a newline to avoid the status display below
# removing the last line of the task output.
print()
romilbhardwaj marked this conversation as resolved.
Show resolved Hide resolved
status_display.update(
ux_utils.spinner_message(
f'Waiting for the next task: {task_id + 1}'))
status_display.start()
original_task_id = task_id
while True:
Expand All @@ -405,7 +409,27 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue
else:
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems the retry logic is added in the wrong branch. It's currently in the else branch of if task_id < num_tasks - 1 and follow, which means it only triggers when we want to terminate. The retry check should be in the outer else branch where we handle cluster failures.

task_specs = managed_job_state.get_task_specs(
job_id, task_id)
if task_specs.get('max_restarts_on_errors', 0) == 0:
# We don't need to wait for the managed job status
# update, as the job is guaranteed to be in terminal
# state afterwards.
break
print()
status_display.update(
ux_utils.spinner_message(
'Waiting for next restart for the failed task'))
status_display.start()
while True:
_, managed_job_status = (
managed_job_state.get_latest_task_id_status(
job_id))
if (managed_job_status !=
managed_job_state.ManagedJobStatus.RUNNING):
break
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue
# The job can be cancelled by the user or the controller (when
# the cluster is partially preempted).
logger.debug(
Expand Down
Loading
Loading