diff --git a/python/ray/tests/test_task_events.py b/python/ray/tests/test_task_events.py index 28e49933b6aa..80d6bf0053a6 100644 --- a/python/ray/tests/test_task_events.py +++ b/python/ray/tests/test_task_events.py @@ -1,9 +1,12 @@ from collections import defaultdict from typing import Dict +import pytest +import time import ray from ray._private.test_utils import ( raw_metrics, + run_string_as_driver_nonblocking, wait_for_condition, ) from ray.experimental.state.api import list_tasks @@ -14,6 +17,7 @@ "task_events_report_interval_ms": 100, "metrics_report_interval_ms": 200, "enable_timeline": False, + "gcs_mark_task_failed_on_job_done_delay_ms": 1000, } @@ -77,7 +81,6 @@ def verify(): def test_fault_tolerance_parent_failed(shutdown_only): ray.init(num_cpus=4, _system_config=_SYSTEM_CONFIG) - import time # Each parent task spins off 2 child task, where each child spins off # 1 grand_child task. @@ -118,3 +121,298 @@ def verify(): timeout=10, retry_interval_ms=500, ) + + +def test_fault_tolerance_job_failed(shutdown_only): + ray.init(num_cpus=8, _system_config=_SYSTEM_CONFIG) + script = """ +import ray +import time + +ray.init("auto") +NUM_CHILD = 2 + +@ray.remote +def grandchild(): + time.sleep(999) + +@ray.remote +def child(): + ray.get(grandchild.remote()) + +@ray.remote +def finished_child(): + ray.put(1) + return + +@ray.remote +def parent(): + children = [child.remote() for _ in range(NUM_CHILD)] + finished_children = ray.get([finished_child.remote() for _ in range(NUM_CHILD)]) + ray.get(children) + +ray.get(parent.remote()) + +""" + proc = run_string_as_driver_nonblocking(script) + + def verify(): + tasks = list_tasks() + print(tasks) + assert len(tasks) == 7, ( + "Incorrect number of tasks are reported. " + "Expected length: 1 parent + 2 finished child + 2 failed child + " + "2 failed grandchild tasks" + ) + return True + + wait_for_condition( + verify, + timeout=10, + retry_interval_ms=500, + ) + + proc.kill() + + def verify(): + tasks = list_tasks() + assert len(tasks) == 7, ( + "Incorrect number of tasks are reported. " + "Expected length: 1 parent + 2 finished child + 2 failed child + " + "2 failed grandchild tasks" + ) + for task in tasks: + if "finished" in task["func_or_class_name"]: + assert ( + task["scheduling_state"] == "FINISHED" + ), f"task {task['func_or_class_name']} has wrong state" + else: + assert ( + task["scheduling_state"] == "FAILED" + ), f"task {task['func_or_class_name']} has wrong state" + + return True + + wait_for_condition( + verify, + timeout=10, + retry_interval_ms=500, + ) + + +@ray.remote +def task_finish_child(): + pass + + +@ray.remote +def task_sleep_child(): + time.sleep(999) + + +@ray.remote +class ChildActor: + def children(self): + ray.get(task_finish_child.remote()) + ray.get(task_sleep_child.remote()) + + +@ray.remote +class Actor: + def fail_parent(self): + task_finish_child.remote() + task_sleep_child.remote() + raise ValueError("expected to fail.") + + def child_actor(self): + a = ChildActor.remote() + try: + ray.get(a.children.remote(), timeout=2) + except ray.exceptions.GetTimeoutError: + pass + raise ValueError("expected to fail.") + + +def test_fault_tolerance_actor_tasks_failed(shutdown_only): + ray.init(_system_config=_SYSTEM_CONFIG) + # Test actor tasks + with pytest.raises(ray.exceptions.RayTaskError): + a = Actor.remote() + ray.get(a.fail_parent.remote()) + + def verify(): + tasks = list_tasks() + assert ( + len(tasks) == 4 + ), "1 creation task + 1 actor tasks + 2 normal tasks run by the actor tasks" + for task in tasks: + if "finish" in task["name"] or "__init__" in task["name"]: + assert task["scheduling_state"] == "FINISHED", task + else: + assert task["scheduling_state"] == "FAILED", task + + return True + + wait_for_condition( + verify, + timeout=10, + retry_interval_ms=500, + ) + + +def test_fault_tolerance_nested_actors_failed(shutdown_only): + ray.init(_system_config=_SYSTEM_CONFIG) + + # Test nested actor tasks + with pytest.raises(ray.exceptions.RayTaskError): + a = Actor.remote() + ray.get(a.child_actor.remote()) + + def verify(): + tasks = list_tasks() + assert len(tasks) == 6, ( + "2 creation task + 1 parent actor task + 1 child actor task " + " + 2 normal tasks run by child actor" + ) + for task in tasks: + if "finish" in task["name"] or "__init__" in task["name"]: + assert task["scheduling_state"] == "FINISHED", task + else: + assert task["scheduling_state"] == "FAILED", task + + return True + + wait_for_condition( + verify, + timeout=10, + retry_interval_ms=500, + ) + + +@pytest.mark.parametrize("death_list", [["A"], ["Abb", "C"], ["Abb", "Ca", "A"]]) +def test_fault_tolerance_advanced_tree(shutdown_only, death_list): + import asyncio + + # Some constants + NORMAL_TASK = 0 + ACTOR_TASK = 1 + + # Root should always be finish + execution_graph = { + "root": [ + (NORMAL_TASK, "A"), + (ACTOR_TASK, "B"), + (NORMAL_TASK, "C"), + (ACTOR_TASK, "D"), + ], + "A": [(ACTOR_TASK, "Aa"), (NORMAL_TASK, "Ab")], + "C": [(ACTOR_TASK, "Ca"), (NORMAL_TASK, "Cb")], + "D": [ + (NORMAL_TASK, "Da"), + (NORMAL_TASK, "Db"), + (ACTOR_TASK, "Dc"), + (ACTOR_TASK, "Dd"), + ], + "Aa": [], + "Ab": [(ACTOR_TASK, "Aba"), (NORMAL_TASK, "Abb"), (NORMAL_TASK, "Abc")], + "Ca": [(ACTOR_TASK, "Caa"), (NORMAL_TASK, "Cab")], + "Abb": [(NORMAL_TASK, "Abba")], + "Abc": [], + "Abba": [(NORMAL_TASK, "Abbaa"), (ACTOR_TASK, "Abbab")], + "Abbaa": [(NORMAL_TASK, "Abbaaa"), (ACTOR_TASK, "Abbaab")], + } + + ray.init(_system_config=_SYSTEM_CONFIG) + + @ray.remote + class Killer: + def __init__(self, death_list, wait_time): + self.idx_ = 0 + self.death_list_ = death_list + self.wait_time_ = wait_time + self.start_ = time.time() + + async def next_to_kill(self): + now = time.time() + if now - self.start_ < self.wait_time_: + # Sleep until killing starts... + time.sleep(self.wait_time_ - (now - self.start_)) + + # if no more tasks to kill - simply sleep to keep all running tasks blocked. + while self.idx_ >= len(self.death_list_): + await asyncio.sleep(999) + + to_kill = self.death_list_[self.idx_] + print(f"{to_kill} to be killed") + return to_kill + + async def advance_next(self): + self.idx_ += 1 + + def run_children(my_name, killer, execution_graph): + children = execution_graph.get(my_name, []) + for task_type, child_name in children: + if task_type == NORMAL_TASK: + task.options(name=child_name).remote( + child_name, killer, execution_graph + ) + else: + a = Actor.remote() + a.actor_task.options(name=child_name).remote( + child_name, killer, execution_graph + ) + + # Block until killed + while True: + to_fail = ray.get(killer.next_to_kill.remote()) + if to_fail == my_name: + ray.get(killer.advance_next.remote()) + raise ValueError(f"{my_name} expected to fail") + + @ray.remote + class Actor: + def actor_task(self, my_name, killer, execution_graph): + run_children(my_name, killer, execution_graph) + + @ray.remote + def task(my_name, killer, execution_graph): + run_children(my_name, killer, execution_graph) + + killer = Killer.remote(death_list, 5) + + task.options(name="root").remote("root", killer, execution_graph) + + def verify(): + tasks = list_tasks() + target_tasks = filter( + lambda task: "__init__" not in task["name"] + and "Killer" not in task["name"], + tasks, + ) + + # Calculate tasks that should have failed + dead_tasks = set() + + def add_death_tasks_recur(task, execution_graph, dead_tasks): + children = execution_graph.get(task, []) + dead_tasks.add(task) + + for _, child in children: + add_death_tasks_recur(child, execution_graph, dead_tasks) + + for task in death_list: + add_death_tasks_recur(task, execution_graph, dead_tasks) + + for task in target_tasks: + if task["name"] in dead_tasks: + assert task["scheduling_state"] == "FAILED", task["name"] + else: + assert task["scheduling_state"] == "RUNNING", task["name"] + + return True + + wait_for_condition( + verify, + timeout=15, + retry_interval_ms=500, + ) diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index a50b84b41281..24ac74ec9889 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -463,6 +463,11 @@ RAY_CONFIG(int64_t, task_events_max_num_task_events_in_buffer, 10000) /// Setting the value to -1 allows unlimited profile events to be sent. RAY_CONFIG(int64_t, task_events_max_num_profile_events_for_task, 100) +/// The delay in ms that GCS should mark any running tasks from a job as failed. +/// Setting this value too smaller might result in some finished tasks marked as failed by +/// GCS. +RAY_CONFIG(uint64_t, gcs_mark_task_failed_on_job_done_delay_ms, /* 15 secs */ 1000 * 15) + /// Whether or not we enable metrics collection. RAY_CONFIG(bool, enable_metrics_collection, true) diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.cc b/src/ray/gcs/gcs_server/gcs_job_manager.cc index 938de847176a..bed2b0298fd0 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_job_manager.cc @@ -82,7 +82,7 @@ void GcsJobManager::MarkJobAsFinished(rpc::JobTableData job_table_data, } else { RAY_CHECK_OK(gcs_publisher_->PublishJob(job_id, job_table_data, nullptr)); runtime_env_manager_.RemoveURIReference(job_id.Hex()); - ClearJobInfos(job_id); + ClearJobInfos(job_table_data); RAY_LOG(INFO) << "Finished marking job state, job id = " << job_id; } function_manager_.RemoveJobReference(job_id); @@ -121,10 +121,10 @@ void GcsJobManager::HandleMarkJobFinished(rpc::MarkJobFinishedRequest request, } } -void GcsJobManager::ClearJobInfos(const JobID &job_id) { +void GcsJobManager::ClearJobInfos(const rpc::JobTableData &job_data) { // Notify all listeners. for (auto &listener : job_finished_listeners_) { - listener(std::make_shared(job_id)); + listener(job_data); } // Clear cache. // TODO(qwang): This line will cause `test_actor_advanced.py::test_detached_actor` @@ -137,8 +137,7 @@ void GcsJobManager::ClearJobInfos(const JobID &job_id) { /// Add listener to monitor the add action of nodes. /// /// \param listener The handler which process the add of nodes. -void GcsJobManager::AddJobFinishedListener( - std::function)> listener) { +void GcsJobManager::AddJobFinishedListener(JobFinishListenerCallback listener) { RAY_CHECK(listener); job_finished_listeners_.emplace_back(std::move(listener)); } diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.h b/src/ray/gcs/gcs_server/gcs_job_manager.h index b3ac1c15055e..a7c0c25ec997 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.h +++ b/src/ray/gcs/gcs_server/gcs_job_manager.h @@ -24,6 +24,8 @@ namespace ray { namespace gcs { +using JobFinishListenerCallback = rpc::JobInfoHandler::JobFinishListenerCallback; + /// This implementation class of `JobInfoHandler`. class GcsJobManager : public rpc::JobInfoHandler { public: @@ -58,8 +60,7 @@ class GcsJobManager : public rpc::JobInfoHandler { rpc::GetNextJobIDReply *reply, rpc::SendReplyCallback send_reply_callback) override; - void AddJobFinishedListener( - std::function)> listener) override; + void AddJobFinishedListener(JobFinishListenerCallback listener) override; std::shared_ptr GetJobConfig(const JobID &job_id) const; @@ -68,14 +69,14 @@ class GcsJobManager : public rpc::JobInfoHandler { std::shared_ptr gcs_publisher_; /// Listeners which monitors the finish of jobs. - std::vector)>> job_finished_listeners_; + std::vector job_finished_listeners_; /// A cached mapping from job id to job config. absl::flat_hash_map> cached_job_configs_; ray::RuntimeEnvManager &runtime_env_manager_; GcsFunctionManager &function_manager_; - void ClearJobInfos(const JobID &job_id); + void ClearJobInfos(const rpc::JobTableData &job_data); void MarkJobAsFinished(rpc::JobTableData job_table_data, std::function done_callback); diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 1343ec0f5e99..31e7a90ccec0 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -667,9 +667,11 @@ void GcsServer::InstallEventListeners() { }); // Install job event listeners. - gcs_job_manager_->AddJobFinishedListener([this](std::shared_ptr job_id) { - gcs_actor_manager_->OnJobFinished(*job_id); - gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenJobDead(*job_id); + gcs_job_manager_->AddJobFinishedListener([this](const rpc::JobTableData &job_data) { + const auto job_id = JobID::FromBinary(job_data.job_id()); + gcs_actor_manager_->OnJobFinished(job_id); + gcs_task_manager_->OnJobFinished(job_id, job_data.end_time()); + gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenJobDead(job_id); }); // Install scheduling event listeners. diff --git a/src/ray/gcs/gcs_server/gcs_task_manager.cc b/src/ray/gcs/gcs_server/gcs_task_manager.cc index c066645e0150..68dc813758cc 100644 --- a/src/ray/gcs/gcs_server/gcs_task_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_task_manager.cc @@ -111,6 +111,22 @@ const rpc::TaskEvents &GcsTaskManager::GcsTaskManagerStorage::GetTaskEvent( return task_events_.at(idx_itr->second); } +void GcsTaskManager::GcsTaskManagerStorage::MarkTaskAttemptFailed( + const TaskAttempt &task_attempt, int64_t failed_ts) { + auto &task_event = GetTaskEvent(task_attempt); + if (!task_event.has_state_updates()) { + return; + } + task_event.mutable_state_updates()->set_failed_ts(failed_ts); +} + +bool GcsTaskManager::GcsTaskManagerStorage::IsTaskTerminated( + const TaskID &task_id) const { + auto failed_ts = GetTaskStatusUpdateTime(task_id, rpc::TaskStatus::FAILED); + auto finished_ts = GetTaskStatusUpdateTime(task_id, rpc::TaskStatus::FINISHED); + return failed_ts.has_value() || finished_ts.has_value(); +} + absl::optional GcsTaskManager::GcsTaskManagerStorage::GetTaskStatusUpdateTime( const TaskID &task_id, const rpc::TaskStatus &task_status) const { auto latest_task_attempt = GetLatestTaskAttempt(task_id); @@ -124,15 +140,29 @@ absl::optional GcsTaskManager::GcsTaskManagerStorage::GetTaskStatusUpda : absl::nullopt; } +void GcsTaskManager::GcsTaskManagerStorage::MarkTasksFailed(const JobID &job_id, + int64_t job_finish_time_ns) { + auto task_attempts_itr = job_to_task_attempt_index_.find(job_id); + if (task_attempts_itr == job_to_task_attempt_index_.end()) { + // No tasks in the job. + return; + } + + // Iterate all task attempts from the job. + for (const auto &task_attempt : task_attempts_itr->second) { + if (!IsTaskTerminated(task_attempt.first)) { + MarkTaskAttemptFailed(task_attempt, job_finish_time_ns); + } + } +} + void GcsTaskManager::GcsTaskManagerStorage::MarkTaskFailed(const TaskID &task_id, int64_t failed_ts) { auto latest_task_attempt = GetLatestTaskAttempt(task_id); if (!latest_task_attempt.has_value()) { return; } - auto &task_event = GetTaskEvent(*latest_task_attempt); - task_event.mutable_state_updates()->set_failed_ts(failed_ts); - task_event.mutable_state_updates()->clear_finished_ts(); + MarkTaskAttemptFailed(*latest_task_attempt, failed_ts); } void GcsTaskManager::GcsTaskManagerStorage::MarkTaskTreeFailedIfNeeded( @@ -161,11 +191,8 @@ void GcsTaskManager::GcsTaskManagerStorage::MarkTaskTreeFailedIfNeeded( continue; } for (const auto &child_task_id : children_tasks_itr->second) { - // Mark any non-terminated child as failed with parent's (or ancestor's) failure - // timestamp. - if (!(GetTaskStatusUpdateTime(child_task_id, rpc::TaskStatus::FAILED).has_value() || - GetTaskStatusUpdateTime(child_task_id, rpc::TaskStatus::FINISHED) - .has_value())) { + // Mark any non-terminated child as failed with parent's failure timestamp. + if (!IsTaskTerminated(child_task_id)) { MarkTaskFailed(child_task_id, task_failed_ts.value()); failed_tasks.push_back(child_task_id); } @@ -349,17 +376,14 @@ void GcsTaskManager::HandleAddTaskEventData(rpc::AddTaskEventDataRequest request rpc::AddTaskEventDataReply *reply, rpc::SendReplyCallback send_reply_callback) { absl::MutexLock lock(&mutex_); - RAY_LOG(DEBUG) << "Adding task state event:" << request.data().ShortDebugString(); // Dispatch to the handler auto data = std::move(request.data()); - size_t num_to_process = data.events_by_task_size(); // Update counters. total_num_profile_task_events_dropped_ += data.num_profile_task_events_dropped(); total_num_status_task_events_dropped_ += data.num_status_task_events_dropped(); for (auto events_by_task : *data.mutable_events_by_task()) { total_num_task_events_reported_++; - auto task_id = TaskID::FromBinary(events_by_task.task_id()); // TODO(rickyx): add logic to handle too many profile events for a single task // attempt. https://github.com/ray-project/ray/issues/31279 @@ -378,11 +402,9 @@ void GcsTaskManager::HandleAddTaskEventData(rpc::AddTaskEventDataRequest request replaced_task_events->profile_events().events_size(); } } - RAY_LOG(DEBUG) << "Processed a task event. [task_id=" << task_id.Hex() << "]"; } // Processed all the task events - RAY_LOG(DEBUG) << "Processed all " << num_to_process << " task events"; GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); } @@ -418,5 +440,22 @@ void GcsTaskManager::RecordMetrics() { task_event_storage_->GetTaskEventsBytes()); } +void GcsTaskManager::OnJobFinished(const JobID &job_id, int64_t job_finish_time_ms) { + RAY_LOG(DEBUG) << "Marking all running tasks of job " << job_id.Hex() << " as failed."; + timer_.expires_from_now(boost::posix_time::milliseconds( + RayConfig::instance().gcs_mark_task_failed_on_job_done_delay_ms())); + timer_.async_wait( + [this, job_id, job_finish_time_ms](const boost::system::error_code &error) { + if (error == boost::asio::error::operation_aborted) { + // timer canceled or aborted. + return; + } + absl::MutexLock lock(&mutex_); + // If there are any non-terminated tasks from the job, mark them failed since all + // workers associated with the job will be killed. + task_event_storage_->MarkTasksFailed(job_id, job_finish_time_ms * 1000); + }); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_server/gcs_task_manager.h b/src/ray/gcs/gcs_server/gcs_task_manager.h index 43ca174b9098..afd49d9604eb 100644 --- a/src/ray/gcs/gcs_server/gcs_task_manager.h +++ b/src/ray/gcs/gcs_server/gcs_task_manager.h @@ -48,7 +48,8 @@ class GcsTaskManager : public rpc::TaskInfoHandler { // Keep io_service_ alive. boost::asio::io_service::work io_service_work_(io_service_); io_service_.run(); - })) {} + })), + timer_(io_service_) {} /// Handles a AddTaskEventData request. /// @@ -76,6 +77,13 @@ class GcsTaskManager : public rpc::TaskInfoHandler { /// This function returns when the io thread is joined. void Stop() LOCKS_EXCLUDED(mutex_); + /// Handler to be called when a job finishes. This marks all non-terminated tasks + /// of the job as failed. + /// + /// \param job_id Job Id + /// \param job_finish_time_ms Job finish time in ms. + void OnJobFinished(const JobID &job_id, int64_t job_finish_time_ms); + /// Returns the io_service. /// /// \return Reference to its io_service. @@ -146,6 +154,13 @@ class GcsTaskManager : public rpc::TaskInfoHandler { std::vector GetTaskEvents( const absl::flat_hash_set &task_attempts) const; + /// Mark tasks from a job as failed. + /// + /// \param job_id Job ID + /// \param job_finish_time_ns job finished time in nanoseconds, which will be the task + /// failed time. + void MarkTasksFailed(const JobID &job_id, int64_t job_finish_time_ns); + private: /// Mark the task tree containing this task attempt as failure if necessary. /// @@ -192,6 +207,12 @@ class GcsTaskManager : public rpc::TaskInfoHandler { absl::optional GetTaskStatusUpdateTime( const TaskID &task_id, const rpc::TaskStatus &task_status) const; + /// Return if task has terminated. + /// + /// \param task_id Task id + /// \return True if the task has finished or failed timestamp sets, false otherwise. + bool IsTaskTerminated(const TaskID &task_id) const; + /// Mark the task as failure with the failed timestamp. /// /// This also overwrites the finished state of the task if the task has finished by @@ -202,6 +223,12 @@ class GcsTaskManager : public rpc::TaskInfoHandler { /// timestamp. void MarkTaskFailed(const TaskID &task_id, int64_t failed_ts); + /// Mark a task attempt as failed. + /// + /// \param task_attempt Task attempt. + /// \param failed_ts The failure timestamp. + void MarkTaskAttemptFailed(const TaskAttempt &task_attempt, int64_t failed_ts); + /// Get the latest task attempt for the task. /// /// If there is no such task or data loss due to task events dropped at the worker, @@ -279,10 +306,14 @@ class GcsTaskManager : public rpc::TaskInfoHandler { /// Its own IO thread from the main thread. std::unique_ptr io_service_thread_; + /// Timer for delay functions. + boost::asio::deadline_timer timer_; + FRIEND_TEST(GcsTaskManagerTest, TestHandleAddTaskEventBasic); FRIEND_TEST(GcsTaskManagerTest, TestMergeTaskEventsSameTaskAttempt); FRIEND_TEST(GcsTaskManagerMemoryLimitedTest, TestLimitTaskEvents); FRIEND_TEST(GcsTaskManagerMemoryLimitedTest, TestIndexNoLeak); + FRIEND_TEST(GcsTaskManagerTest, TestJobFinishesFailAllRunningTasks); }; } // namespace gcs diff --git a/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc index 9809b0a4feaf..d106f25c7c1d 100644 --- a/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc @@ -30,7 +30,8 @@ class GcsTaskManagerTest : public ::testing::Test { RayConfig::instance().initialize( R"( { - "task_events_max_num_task_in_gcs": 1000 + "task_events_max_num_task_in_gcs": 1000, + "gcs_mark_task_failed_on_job_done_delay_ms": 100 } )"); } @@ -61,6 +62,21 @@ class GcsTaskManagerTest : public ::testing::Test { } } + void SyncAddTaskEvent( + const std::vector &tasks, + const std::vector> &status_timestamps, + const TaskID &parent_task_id = TaskID::Nil(), + int job_id = 0) { + auto events = GenTaskEvents(tasks, + /* attempt_number */ 0, + /* job_id */ job_id, + /* profile event */ absl::nullopt, + GenStateUpdate(status_timestamps), + GenTaskInfo(JobID::FromInt(job_id), parent_task_id)); + auto events_data = Mocker::GenTaskEventsData(events); + SyncAddTaskEventData(events_data); + } + rpc::AddTaskEventDataReply SyncAddTaskEventData(const rpc::TaskEventData &events_data) { rpc::AddTaskEventDataRequest request; rpc::AddTaskEventDataReply reply; @@ -449,38 +465,13 @@ TEST_F(GcsTaskManagerTest, TestFailingParentFailChildren) { auto child2 = task_ids[2]; // Parent task running - { - auto events = GenTaskEvents({parent}, - /* attempt_number */ 0, - /* job_id */ 0, - /* profile event */ absl::nullopt, - GenStateUpdate({{rpc::TaskStatus::RUNNING, 1}})); - auto events_data = Mocker::GenTaskEventsData(events); - SyncAddTaskEventData(events_data); - } + SyncAddTaskEvent({parent}, {{rpc::TaskStatus::RUNNING, 1}}); // Child tasks running - { - auto events = GenTaskEvents({child1, child2}, - /* attempt_number */ 0, - /* job_id */ 0, - /* profile event */ absl::nullopt, - GenStateUpdate({{rpc::TaskStatus::RUNNING, 2}}), - GenTaskInfo(/* job_id */ JobID::FromInt(0), parent)); - auto events_data = Mocker::GenTaskEventsData(events); - SyncAddTaskEventData(events_data); - } + SyncAddTaskEvent({child1, child2}, {{rpc::TaskStatus::RUNNING, 2}}, parent); // Parent task failed - { - auto events = GenTaskEvents({parent}, - /* attempt_number */ 0, - /* job_id */ 0, - /* profile event */ absl::nullopt, - GenStateUpdate({{rpc::TaskStatus::FAILED, 3}})); - auto events_data = Mocker::GenTaskEventsData(events); - SyncAddTaskEventData(events_data); - } + SyncAddTaskEvent({parent}, {{rpc::TaskStatus::FAILED, 3}}); // Get all children task events should be failed { @@ -502,38 +493,13 @@ TEST_F(GcsTaskManagerTest, TestFailedParentShouldFailGrandChildren) { auto grand_child2 = task_ids[3]; // Parent task running - { - auto events = GenTaskEvents({parent}, - /* attempt_number */ 0, - /* job_id */ 0, - /* profile event */ absl::nullopt, - GenStateUpdate({{rpc::TaskStatus::RUNNING, 1}})); - auto events_data = Mocker::GenTaskEventsData(events); - SyncAddTaskEventData(events_data); - } + SyncAddTaskEvent({parent}, {{rpc::TaskStatus::RUNNING, 1}}); // Grandchild tasks running - { - auto events = GenTaskEvents({grand_child1, grand_child2}, - /* attempt_number */ 0, - /* job_id */ 0, - /* profile event */ absl::nullopt, - GenStateUpdate({{rpc::TaskStatus::RUNNING, 3}}), - GenTaskInfo(/* job_id */ JobID::FromInt(0), child)); - auto events_data = Mocker::GenTaskEventsData(events); - SyncAddTaskEventData(events_data); - } + SyncAddTaskEvent({grand_child1, grand_child2}, {{rpc::TaskStatus::RUNNING, 3}}, child); // Parent task failed - { - auto events = GenTaskEvents({parent}, - /* attempt_number */ 0, - /* job_id */ 0, - /* profile event */ absl::nullopt, - GenStateUpdate({{rpc::TaskStatus::FAILED, 4}})); - auto events_data = Mocker::GenTaskEventsData(events); - SyncAddTaskEventData(events_data); - } + SyncAddTaskEvent({parent}, {{rpc::TaskStatus::FAILED, 4}}); // Get grand child should still be running since the parent-grand-child relationship is // not recorded yet. @@ -546,16 +512,7 @@ TEST_F(GcsTaskManagerTest, TestFailedParentShouldFailGrandChildren) { } // Child task reported running. - { - auto events = GenTaskEvents({child}, - /* attempt_number */ 0, - /* job_id */ 0, - /* profile event */ absl::nullopt, - GenStateUpdate({{rpc::TaskStatus::RUNNING, 2}}), - GenTaskInfo(/* job_id */ JobID::FromInt(0), parent)); - auto events_data = Mocker::GenTaskEventsData(events); - SyncAddTaskEventData(events_data); - } + SyncAddTaskEvent({child}, {{rpc::TaskStatus::RUNNING, 2}}, parent); // Both child and grand-child should report failure since their ancestor fail. // i.e. Child task should mark grandchildren failed. @@ -568,6 +525,76 @@ TEST_F(GcsTaskManagerTest, TestFailedParentShouldFailGrandChildren) { } } +TEST_F(GcsTaskManagerTest, TestJobFinishesFailAllRunningTasks) { + auto tasks_running_job1 = GenTaskIDs(10); + auto tasks_finished_job1 = GenTaskIDs(10); + auto tasks_failed_job1 = GenTaskIDs(10); + + auto tasks_running_job2 = GenTaskIDs(5); + + SyncAddTaskEvent(tasks_running_job1, {{rpc::TaskStatus::RUNNING, 1}}, TaskID::Nil(), 1); + SyncAddTaskEvent( + tasks_finished_job1, {{rpc::TaskStatus::FINISHED, 2}}, TaskID::Nil(), 1); + SyncAddTaskEvent(tasks_failed_job1, {{rpc::TaskStatus::FAILED, 3}}, TaskID::Nil(), 1); + + SyncAddTaskEvent(tasks_running_job2, {{rpc::TaskStatus::RUNNING, 4}}, TaskID::Nil(), 2); + + task_manager->OnJobFinished(JobID::FromInt(1), 5); // in ms + + // Wait for longer than the default timer + boost::asio::io_service io; + boost::asio::deadline_timer timer( + io, + boost::posix_time::milliseconds( + 2 * RayConfig::instance().gcs_mark_task_failed_on_job_done_delay_ms())); + timer.wait(); + + // Running tasks from job1 failed at 5 + { + absl::flat_hash_set tasks(tasks_running_job1.begin(), + tasks_running_job1.end()); + auto reply = SyncGetTaskEvents(tasks); + EXPECT_EQ(reply.events_by_task_size(), 10); + for (const auto &task_event : reply.events_by_task()) { + EXPECT_EQ(task_event.state_updates().failed_ts(), 5000); + } + } + + // Finished tasks from job1 remain finished + { + absl::flat_hash_set tasks(tasks_finished_job1.begin(), + tasks_finished_job1.end()); + auto reply = SyncGetTaskEvents(tasks); + EXPECT_EQ(reply.events_by_task_size(), 10); + for (const auto &task_event : reply.events_by_task()) { + EXPECT_EQ(task_event.state_updates().finished_ts(), 2); + EXPECT_FALSE(task_event.state_updates().has_failed_ts()); + } + } + + // Failed tasks from job1 failed timestamp not overriden + { + absl::flat_hash_set tasks(tasks_failed_job1.begin(), tasks_failed_job1.end()); + auto reply = SyncGetTaskEvents(tasks); + EXPECT_EQ(reply.events_by_task_size(), 10); + for (const auto &task_event : reply.events_by_task()) { + EXPECT_EQ(task_event.state_updates().failed_ts(), 3); + } + } + + // Tasks from job2 should not be affected. + { + absl::flat_hash_set tasks(tasks_running_job2.begin(), + tasks_running_job2.end()); + auto reply = SyncGetTaskEvents(tasks); + EXPECT_EQ(reply.events_by_task_size(), 5); + for (const auto &task_event : reply.events_by_task()) { + EXPECT_FALSE(task_event.state_updates().has_failed_ts()); + EXPECT_FALSE(task_event.state_updates().has_finished_ts()); + } + } +} + TEST_F(GcsTaskManagerMemoryLimitedTest, TestIndexNoLeak) { size_t num_limit = 100; // synced with test config size_t num_total = 1000; diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index b1cc41bd3c06..c2453c88a1df 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -78,6 +78,8 @@ namespace rpc { class JobInfoGcsServiceHandler { public: + using JobFinishListenerCallback = std::function; + virtual ~JobInfoGcsServiceHandler() = default; virtual void HandleAddJob(AddJobRequest request, @@ -92,8 +94,7 @@ class JobInfoGcsServiceHandler { GetAllJobInfoReply *reply, SendReplyCallback send_reply_callback) = 0; - virtual void AddJobFinishedListener( - std::function)> listener) = 0; + virtual void AddJobFinishedListener(JobFinishListenerCallback listener) = 0; virtual void HandleReportJobError(ReportJobErrorRequest request, ReportJobErrorReply *reply,