Skip to content

Commit

Permalink
Move ECS Executor to its own file (#35418)
Browse files Browse the repository at this point in the history
* Move ECS Executor from __init__.py to its own file. This improves the logging because logs record the filename, and __init__.py was not a helpful name.

* Fix failing tests
  • Loading branch information
syedahsn authored Nov 3, 2023
1 parent 0988074 commit 92d1e8c
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 331 deletions.
326 changes: 0 additions & 326 deletions airflow/providers/amazon/aws/executors/ecs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,329 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
AWS ECS Executor.
Each Airflow task gets delegated out to an Amazon ECS Task.
"""

from __future__ import annotations

import time
from collections import defaultdict, deque
from copy import deepcopy
from typing import TYPE_CHECKING

from airflow.configuration import conf
from airflow.executors.base_executor import BaseExecutor
from airflow.providers.amazon.aws.executors.ecs.boto_schema import BotoDescribeTasksSchema, BotoRunTaskSchema
from airflow.providers.amazon.aws.executors.ecs.utils import (
CONFIG_DEFAULTS,
CONFIG_GROUP_NAME,
AllEcsConfigKeys,
EcsExecutorException,
EcsQueuedTask,
EcsTaskCollection,
)
from airflow.utils.state import State

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.providers.amazon.aws.executors.ecs.utils import (
CommandType,
ExecutorConfigType,
)


class AwsEcsExecutor(BaseExecutor):
"""
Executes the provided Airflow command on an ECS instance.
The Airflow Scheduler creates a shell command, and passes it to the executor. This ECS Executor
runs said Airflow command on a remote Amazon ECS Cluster with a task-definition configured to
launch the same containers as the Scheduler. It then periodically checks in with the launched
tasks (via task ARNs) to determine the status.
This allows individual tasks to specify CPU, memory, GPU, env variables, etc. When initializing a task,
there's an option for "executor config" which should be a dictionary with keys that match the
``ContainerOverride`` definition per AWS documentation (see link below).
Prerequisite: proper configuration of Boto3 library
.. seealso:: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html for
authentication and access-key management. You can store an environmental variable, setup aws config from
console, or use IAM roles.
.. seealso:: https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_ContainerOverride.html for an
Airflow TaskInstance's executor_config.
"""

# Maximum number of retries to run an ECS task.
MAX_RUN_TASK_ATTEMPTS = conf.get(
CONFIG_GROUP_NAME,
AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS,
fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS],
)

# AWS limits the maximum number of ARNs in the describe_tasks function.
DESCRIBE_TASKS_BATCH_SIZE = 99

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.active_workers: EcsTaskCollection = EcsTaskCollection()
self.pending_tasks: deque = deque()

self.cluster = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CLUSTER)
self.container_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME)
aws_conn_id = conf.get(
CONFIG_GROUP_NAME,
AllEcsConfigKeys.AWS_CONN_ID,
fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.AWS_CONN_ID],
)
region_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME)
from airflow.providers.amazon.aws.hooks.ecs import EcsHook

self.ecs = EcsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
self.run_task_kwargs = self._load_run_kwargs()

def sync(self):
try:
self.sync_running_tasks()
self.attempt_task_runs()
except Exception:
# We catch any and all exceptions because otherwise they would bubble
# up and kill the scheduler process
self.log.exception("Failed to sync %s", self.__class__.__name__)

def sync_running_tasks(self):
"""Checks and update state on all running tasks."""
all_task_arns = self.active_workers.get_all_arns()
if not all_task_arns:
self.log.debug("No active Airflow tasks, skipping sync.")
return

describe_tasks_response = self.__describe_tasks(all_task_arns)
self.log.debug("Active Workers: %s", describe_tasks_response)

if describe_tasks_response["failures"]:
for failure in describe_tasks_response["failures"]:
self.__handle_failed_task(failure["arn"], failure["reason"])

updated_tasks = describe_tasks_response["tasks"]
for task in updated_tasks:
self.__update_running_task(task)

def __update_running_task(self, task):
self.active_workers.update_task(task)
# Get state of current task.
task_state = task.get_task_state()
task_key = self.active_workers.arn_to_key[task.task_arn]
# Mark finished tasks as either a success/failure.
if task_state == State.FAILED:
self.fail(task_key)
elif task_state == State.SUCCESS:
self.success(task_key)
elif task_state == State.REMOVED:
self.__handle_failed_task(task.task_arn, task.stopped_reason)
if task_state in (State.FAILED, State.SUCCESS):
self.log.debug(
"Airflow task %s marked as %s after running on %s",
task_key,
task_state,
task.task_arn,
)
self.active_workers.pop_by_key(task_key)

def __describe_tasks(self, task_arns):
all_task_descriptions = {"tasks": [], "failures": []}
for i in range(0, len(task_arns), self.DESCRIBE_TASKS_BATCH_SIZE):
batched_task_arns = task_arns[i : i + self.DESCRIBE_TASKS_BATCH_SIZE]
if not batched_task_arns:
continue
boto_describe_tasks = self.ecs.describe_tasks(tasks=batched_task_arns, cluster=self.cluster)
describe_tasks_response = BotoDescribeTasksSchema().load(boto_describe_tasks)

all_task_descriptions["tasks"].extend(describe_tasks_response["tasks"])
all_task_descriptions["failures"].extend(describe_tasks_response["failures"])
return all_task_descriptions

def __handle_failed_task(self, task_arn: str, reason: str):
"""If an API failure occurs, the task is rescheduled."""
task_key = self.active_workers.arn_to_key[task_arn]
task_info = self.active_workers.info_by_key(task_key)
task_cmd = task_info.cmd
queue = task_info.queue
exec_info = task_info.config
failure_count = self.active_workers.failure_count_by_key(task_key)
if int(failure_count) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
self.log.warning(
"Airflow task %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.",
task_key,
reason,
failure_count,
self.__class__.MAX_RUN_TASK_ATTEMPTS,
task_arn,
)
self.active_workers.increment_failure_count(task_key)
self.pending_tasks.appendleft(
EcsQueuedTask(task_key, task_cmd, queue, exec_info, failure_count + 1)
)
else:
self.log.error(
"Airflow task %s has failed a maximum of %s times. Marking as failed",
task_key,
failure_count,
)
self.active_workers.pop_by_key(task_key)
self.fail(task_key)

def attempt_task_runs(self):
"""
Takes tasks from the pending_tasks queue, and attempts to find an instance to run it on.
If the launch type is EC2, this will attempt to place tasks on empty EC2 instances. If
there are no EC2 instances available, no task is placed and this function will be
called again in the next heart-beat.
If the launch type is FARGATE, this will run the tasks on new AWS Fargate instances.
"""
queue_len = len(self.pending_tasks)
failure_reasons = defaultdict(int)
for _ in range(queue_len):
ecs_task = self.pending_tasks.popleft()
task_key = ecs_task.key
cmd = ecs_task.command
queue = ecs_task.queue
exec_config = ecs_task.executor_config
attempt_number = ecs_task.attempt_number
_failure_reasons = []
try:
run_task_response = self._run_task(task_key, cmd, queue, exec_config)
except Exception as e:
# Failed to even get a response back from the Boto3 API or something else went
# wrong. For any possible failure we want to add the exception reasons to the
# failure list so that it is logged to the user and most importantly the task is
# added back to the pending list to be retried later.
_failure_reasons.append(str(e))
else:
# We got a response back, check if there were failures. If so, add them to the
# failures list so that it is logged to the user and most importantly the task
# is added back to the pending list to be retried later.
if run_task_response["failures"]:
_failure_reasons.extend([f["reason"] for f in run_task_response["failures"]])

if _failure_reasons:
for reason in _failure_reasons:
failure_reasons[reason] += 1
# Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS
if int(attempt_number) <= int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
ecs_task.attempt_number += 1
self.pending_tasks.appendleft(ecs_task)
else:
self.log.error(
"ECS task %s has failed a maximum of %s times. Marking as failed",
task_key,
attempt_number,
)
self.fail(task_key)
elif not run_task_response["tasks"]:
self.log.error("ECS RunTask Response: %s", run_task_response)
raise EcsExecutorException(
"No failures and no ECS tasks provided in response. This should never happen."
)
else:
task = run_task_response["tasks"][0]
self.active_workers.add_task(task, task_key, queue, cmd, exec_config, attempt_number)
if failure_reasons:
self.log.error(
"Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.",
dict(failure_reasons),
)

def _run_task(
self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
):
"""
Run a queued-up Airflow task.
Not to be confused with execute_async() which inserts tasks into the queue.
The command and executor config will be placed in the container-override
section of the JSON request before calling Boto3's "run_task" function.
"""
run_task_api = self._run_task_kwargs(task_id, cmd, queue, exec_config)
boto_run_task = self.ecs.run_task(**run_task_api)
run_task_response = BotoRunTaskSchema().load(boto_run_task)
return run_task_response

def _run_task_kwargs(
self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
) -> dict:
"""
Overrides the Airflow command to update the container overrides so kwargs are specific to this task.
One last chance to modify Boto3's "run_task" kwarg params before it gets passed into the Boto3 client.
"""
run_task_api = deepcopy(self.run_task_kwargs)
container_override = self.get_container(run_task_api["overrides"]["containerOverrides"])
container_override["command"] = cmd
container_override.update(exec_config)

# Inject the env variable to configure logging for containerized execution environment
if "environment" not in container_override:
container_override["environment"] = []
container_override["environment"].append({"name": "AIRFLOW_IS_EXECUTOR_CONTAINER", "value": "true"})

return run_task_api

def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None):
"""Save the task to be executed in the next sync by inserting the commands into a queue."""
if executor_config and ("name" in executor_config or "command" in executor_config):
raise ValueError('Executor Config should never override "name" or "command"')
self.pending_tasks.append(EcsQueuedTask(key, command, queue, executor_config or {}, 1))

def end(self, heartbeat_interval=10):
"""Waits for all currently running tasks to end, and doesn't launch any tasks."""
try:
while True:
self.sync()
if not self.active_workers:
break
time.sleep(heartbeat_interval)
except Exception:
# We catch any and all exceptions because otherwise they would bubble
# up and kill the scheduler process.
self.log.exception("Failed to end %s", self.__class__.__name__)

def terminate(self):
"""Kill all ECS processes by calling Boto3's StopTask API."""
try:
for arn in self.active_workers.get_all_arns():
self.ecs.stop_task(
cluster=self.cluster, task=arn, reason="Airflow Executor received a SIGTERM"
)
self.end()
except Exception:
# We catch any and all exceptions because otherwise they would bubble
# up and kill the scheduler process.
self.log.exception("Failed to terminate %s", self.__class__.__name__)

def _load_run_kwargs(self) -> dict:
from airflow.providers.amazon.aws.executors.ecs.ecs_executor_config import build_task_kwargs

ecs_executor_run_task_kwargs = build_task_kwargs()

try:
self.get_container(ecs_executor_run_task_kwargs["overrides"]["containerOverrides"])["command"]
except KeyError:
raise KeyError(
"Rendered JSON template does not contain key "
'"overrides[containerOverrides][containers][x][command]"'
)
return ecs_executor_run_task_kwargs

def get_container(self, container_list):
"""Searches task list for core Airflow container."""
for container in container_list:
if container["name"] == self.container_name:
return container
raise KeyError(f"No such container found by container name: {self.container_name}")
Loading

0 comments on commit 92d1e8c

Please sign in to comment.