Skip to content

Commit

Permalink
[core][state] Proper report of failure when job finishes and for fini…
Browse files Browse the repository at this point in the history
…shed tasks (#31761)

This PR handles 2 edges when marking tasks as fail:

When a job finishes but tasks still running should be marked as failed.
Don't override a task's finished or failed timestamp when an ancestor failed.
For 1:

It adds a handler function OnJobFinished as a job finish listener in the GcsJobManager, so when a job is marked as finished, the OnJobFinished will be called to mark any non-terminated tasks as failed
For 2:

It adds an ancestor_failed_ts to keep track of ancestor failure time in the task tree.
This extra bit of info is necessary to keep since we should not be overriding any already failed or finished child tasks's timestamps. But we will also need to know if any task subtree has been traversed (and all non-terminated children marked as failed) w/o traversing the task tree.
When adding a new task event, If the task fails or its ancestor failed, its failed_ts and ancestor_failed_ts will be set, and we will traverse into the child task tree.
During the tree traversal, when a task has its failed_ts or ancestor_failed_ts set, this means its children should have been traversed when its failed_ts or ancestor_failed_ts was set.
  • Loading branch information
rickyyx authored Jan 23, 2023
1 parent 0c69020 commit 86bd6c6
Show file tree
Hide file tree
Showing 9 changed files with 499 additions and 96 deletions.
300 changes: 299 additions & 1 deletion python/ray/tests/test_task_events.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from collections import defaultdict
from typing import Dict
import pytest
import time

import ray
from ray._private.test_utils import (
raw_metrics,
run_string_as_driver_nonblocking,
wait_for_condition,
)
from ray.experimental.state.api import list_tasks
Expand All @@ -14,6 +17,7 @@
"task_events_report_interval_ms": 100,
"metrics_report_interval_ms": 200,
"enable_timeline": False,
"gcs_mark_task_failed_on_job_done_delay_ms": 1000,
}


Expand Down Expand Up @@ -77,7 +81,6 @@ def verify():

def test_fault_tolerance_parent_failed(shutdown_only):
ray.init(num_cpus=4, _system_config=_SYSTEM_CONFIG)
import time

# Each parent task spins off 2 child task, where each child spins off
# 1 grand_child task.
Expand Down Expand Up @@ -118,3 +121,298 @@ def verify():
timeout=10,
retry_interval_ms=500,
)


def test_fault_tolerance_job_failed(shutdown_only):
ray.init(num_cpus=8, _system_config=_SYSTEM_CONFIG)
script = """
import ray
import time
ray.init("auto")
NUM_CHILD = 2
@ray.remote
def grandchild():
time.sleep(999)
@ray.remote
def child():
ray.get(grandchild.remote())
@ray.remote
def finished_child():
ray.put(1)
return
@ray.remote
def parent():
children = [child.remote() for _ in range(NUM_CHILD)]
finished_children = ray.get([finished_child.remote() for _ in range(NUM_CHILD)])
ray.get(children)
ray.get(parent.remote())
"""
proc = run_string_as_driver_nonblocking(script)

def verify():
tasks = list_tasks()
print(tasks)
assert len(tasks) == 7, (
"Incorrect number of tasks are reported. "
"Expected length: 1 parent + 2 finished child + 2 failed child + "
"2 failed grandchild tasks"
)
return True

wait_for_condition(
verify,
timeout=10,
retry_interval_ms=500,
)

proc.kill()

def verify():
tasks = list_tasks()
assert len(tasks) == 7, (
"Incorrect number of tasks are reported. "
"Expected length: 1 parent + 2 finished child + 2 failed child + "
"2 failed grandchild tasks"
)
for task in tasks:
if "finished" in task["func_or_class_name"]:
assert (
task["scheduling_state"] == "FINISHED"
), f"task {task['func_or_class_name']} has wrong state"
else:
assert (
task["scheduling_state"] == "FAILED"
), f"task {task['func_or_class_name']} has wrong state"

return True

wait_for_condition(
verify,
timeout=10,
retry_interval_ms=500,
)


@ray.remote
def task_finish_child():
pass


@ray.remote
def task_sleep_child():
time.sleep(999)


@ray.remote
class ChildActor:
def children(self):
ray.get(task_finish_child.remote())
ray.get(task_sleep_child.remote())


@ray.remote
class Actor:
def fail_parent(self):
task_finish_child.remote()
task_sleep_child.remote()
raise ValueError("expected to fail.")

def child_actor(self):
a = ChildActor.remote()
try:
ray.get(a.children.remote(), timeout=2)
except ray.exceptions.GetTimeoutError:
pass
raise ValueError("expected to fail.")


def test_fault_tolerance_actor_tasks_failed(shutdown_only):
ray.init(_system_config=_SYSTEM_CONFIG)
# Test actor tasks
with pytest.raises(ray.exceptions.RayTaskError):
a = Actor.remote()
ray.get(a.fail_parent.remote())

def verify():
tasks = list_tasks()
assert (
len(tasks) == 4
), "1 creation task + 1 actor tasks + 2 normal tasks run by the actor tasks"
for task in tasks:
if "finish" in task["name"] or "__init__" in task["name"]:
assert task["scheduling_state"] == "FINISHED", task
else:
assert task["scheduling_state"] == "FAILED", task

return True

wait_for_condition(
verify,
timeout=10,
retry_interval_ms=500,
)


def test_fault_tolerance_nested_actors_failed(shutdown_only):
ray.init(_system_config=_SYSTEM_CONFIG)

# Test nested actor tasks
with pytest.raises(ray.exceptions.RayTaskError):
a = Actor.remote()
ray.get(a.child_actor.remote())

def verify():
tasks = list_tasks()
assert len(tasks) == 6, (
"2 creation task + 1 parent actor task + 1 child actor task "
" + 2 normal tasks run by child actor"
)
for task in tasks:
if "finish" in task["name"] or "__init__" in task["name"]:
assert task["scheduling_state"] == "FINISHED", task
else:
assert task["scheduling_state"] == "FAILED", task

return True

wait_for_condition(
verify,
timeout=10,
retry_interval_ms=500,
)


@pytest.mark.parametrize("death_list", [["A"], ["Abb", "C"], ["Abb", "Ca", "A"]])
def test_fault_tolerance_advanced_tree(shutdown_only, death_list):
import asyncio

# Some constants
NORMAL_TASK = 0
ACTOR_TASK = 1

# Root should always be finish
execution_graph = {
"root": [
(NORMAL_TASK, "A"),
(ACTOR_TASK, "B"),
(NORMAL_TASK, "C"),
(ACTOR_TASK, "D"),
],
"A": [(ACTOR_TASK, "Aa"), (NORMAL_TASK, "Ab")],
"C": [(ACTOR_TASK, "Ca"), (NORMAL_TASK, "Cb")],
"D": [
(NORMAL_TASK, "Da"),
(NORMAL_TASK, "Db"),
(ACTOR_TASK, "Dc"),
(ACTOR_TASK, "Dd"),
],
"Aa": [],
"Ab": [(ACTOR_TASK, "Aba"), (NORMAL_TASK, "Abb"), (NORMAL_TASK, "Abc")],
"Ca": [(ACTOR_TASK, "Caa"), (NORMAL_TASK, "Cab")],
"Abb": [(NORMAL_TASK, "Abba")],
"Abc": [],
"Abba": [(NORMAL_TASK, "Abbaa"), (ACTOR_TASK, "Abbab")],
"Abbaa": [(NORMAL_TASK, "Abbaaa"), (ACTOR_TASK, "Abbaab")],
}

ray.init(_system_config=_SYSTEM_CONFIG)

@ray.remote
class Killer:
def __init__(self, death_list, wait_time):
self.idx_ = 0
self.death_list_ = death_list
self.wait_time_ = wait_time
self.start_ = time.time()

async def next_to_kill(self):
now = time.time()
if now - self.start_ < self.wait_time_:
# Sleep until killing starts...
time.sleep(self.wait_time_ - (now - self.start_))

# if no more tasks to kill - simply sleep to keep all running tasks blocked.
while self.idx_ >= len(self.death_list_):
await asyncio.sleep(999)

to_kill = self.death_list_[self.idx_]
print(f"{to_kill} to be killed")
return to_kill

async def advance_next(self):
self.idx_ += 1

def run_children(my_name, killer, execution_graph):
children = execution_graph.get(my_name, [])
for task_type, child_name in children:
if task_type == NORMAL_TASK:
task.options(name=child_name).remote(
child_name, killer, execution_graph
)
else:
a = Actor.remote()
a.actor_task.options(name=child_name).remote(
child_name, killer, execution_graph
)

# Block until killed
while True:
to_fail = ray.get(killer.next_to_kill.remote())
if to_fail == my_name:
ray.get(killer.advance_next.remote())
raise ValueError(f"{my_name} expected to fail")

@ray.remote
class Actor:
def actor_task(self, my_name, killer, execution_graph):
run_children(my_name, killer, execution_graph)

@ray.remote
def task(my_name, killer, execution_graph):
run_children(my_name, killer, execution_graph)

killer = Killer.remote(death_list, 5)

task.options(name="root").remote("root", killer, execution_graph)

def verify():
tasks = list_tasks()
target_tasks = filter(
lambda task: "__init__" not in task["name"]
and "Killer" not in task["name"],
tasks,
)

# Calculate tasks that should have failed
dead_tasks = set()

def add_death_tasks_recur(task, execution_graph, dead_tasks):
children = execution_graph.get(task, [])
dead_tasks.add(task)

for _, child in children:
add_death_tasks_recur(child, execution_graph, dead_tasks)

for task in death_list:
add_death_tasks_recur(task, execution_graph, dead_tasks)

for task in target_tasks:
if task["name"] in dead_tasks:
assert task["scheduling_state"] == "FAILED", task["name"]
else:
assert task["scheduling_state"] == "RUNNING", task["name"]

return True

wait_for_condition(
verify,
timeout=15,
retry_interval_ms=500,
)
5 changes: 5 additions & 0 deletions src/ray/common/ray_config_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,11 @@ RAY_CONFIG(int64_t, task_events_max_num_task_events_in_buffer, 10000)
/// Setting the value to -1 allows unlimited profile events to be sent.
RAY_CONFIG(int64_t, task_events_max_num_profile_events_for_task, 100)

/// The delay in ms that GCS should mark any running tasks from a job as failed.
/// Setting this value too smaller might result in some finished tasks marked as failed by
/// GCS.
RAY_CONFIG(uint64_t, gcs_mark_task_failed_on_job_done_delay_ms, /* 15 secs */ 1000 * 15)

/// Whether or not we enable metrics collection.
RAY_CONFIG(bool, enable_metrics_collection, true)

Expand Down
9 changes: 4 additions & 5 deletions src/ray/gcs/gcs_server/gcs_job_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void GcsJobManager::MarkJobAsFinished(rpc::JobTableData job_table_data,
} else {
RAY_CHECK_OK(gcs_publisher_->PublishJob(job_id, job_table_data, nullptr));
runtime_env_manager_.RemoveURIReference(job_id.Hex());
ClearJobInfos(job_id);
ClearJobInfos(job_table_data);
RAY_LOG(INFO) << "Finished marking job state, job id = " << job_id;
}
function_manager_.RemoveJobReference(job_id);
Expand Down Expand Up @@ -121,10 +121,10 @@ void GcsJobManager::HandleMarkJobFinished(rpc::MarkJobFinishedRequest request,
}
}

void GcsJobManager::ClearJobInfos(const JobID &job_id) {
void GcsJobManager::ClearJobInfos(const rpc::JobTableData &job_data) {
// Notify all listeners.
for (auto &listener : job_finished_listeners_) {
listener(std::make_shared<JobID>(job_id));
listener(job_data);
}
// Clear cache.
// TODO(qwang): This line will cause `test_actor_advanced.py::test_detached_actor`
Expand All @@ -137,8 +137,7 @@ void GcsJobManager::ClearJobInfos(const JobID &job_id) {
/// Add listener to monitor the add action of nodes.
///
/// \param listener The handler which process the add of nodes.
void GcsJobManager::AddJobFinishedListener(
std::function<void(std::shared_ptr<JobID>)> listener) {
void GcsJobManager::AddJobFinishedListener(JobFinishListenerCallback listener) {
RAY_CHECK(listener);
job_finished_listeners_.emplace_back(std::move(listener));
}
Expand Down
Loading

0 comments on commit 86bd6c6

Please sign in to comment.