Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update doc and sample dag for EMR Containers #24087

Merged
merged 4 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
from airflow.providers.amazon.aws.operators.emr import (
EmrAddStepsOperator,
EmrCreateJobFlowOperator,
EmrModifyClusterOperator,
EmrTerminateJobFlowOperator,
)
from airflow.providers.amazon.aws.sensors.emr import EmrStepSensor
from airflow.providers.amazon.aws.sensors.emr import EmrJobFlowSensor, EmrStepSensor

JOB_FLOW_ROLE = os.getenv('EMR_JOB_FLOW_ROLE', 'EMR_EC2_DefaultRole')
SERVICE_ROLE = os.getenv('EMR_SERVICE_ROLE', 'EMR_DefaultRole')

# [START howto_operator_emr_steps_config]
SPARK_STEPS = [
{
'Name': 'calculate_pi',
Expand Down Expand Up @@ -58,48 +60,66 @@
'KeepJobFlowAliveWhenNoSteps': False,
'TerminationProtected': False,
},
'Steps': SPARK_STEPS,
'JobFlowRole': JOB_FLOW_ROLE,
'ServiceRole': SERVICE_ROLE,
}

# [END howto_operator_emr_steps_config]

with DAG(
dag_id='example_emr_job_flow_manual_steps',
dag_id='example_emr',
schedule_interval=None,
start_date=datetime(2021, 1, 1),
tags=['example'],
catchup=False,
) as dag:

cluster_creator = EmrCreateJobFlowOperator(
# [START howto_operator_emr_create_job_flow]
job_flow_creator = EmrCreateJobFlowOperator(
task_id='create_job_flow',
job_flow_overrides=JOB_FLOW_OVERRIDES,
)
# [END howto_operator_emr_create_job_flow]

# [START howto_sensor_emr_job_flow]
job_sensor = EmrJobFlowSensor(
task_id='check_job_flow',
job_flow_id=job_flow_creator.output,
)
# [END howto_sensor_emr_job_flow]

# [START howto_operator_emr_modify_cluster]
cluster_modifier = EmrModifyClusterOperator(
task_id='modify_cluster', cluster_id=job_flow_creator.output, step_concurrency_level=1
)
# [END howto_operator_emr_modify_cluster]

# [START howto_operator_emr_add_steps]
step_adder = EmrAddStepsOperator(
task_id='add_steps',
job_flow_id=cluster_creator.output,
job_flow_id=job_flow_creator.output,
steps=SPARK_STEPS,
)
# [END howto_operator_emr_add_steps]

# [START howto_sensor_emr_step_sensor]
# [START howto_sensor_emr_step]
step_checker = EmrStepSensor(
task_id='watch_step',
job_flow_id=cluster_creator.output,
job_flow_id=job_flow_creator.output,
step_id="{{ task_instance.xcom_pull(task_ids='add_steps', key='return_value')[0] }}",
)
# [END howto_sensor_emr_step_sensor]
# [END howto_sensor_emr_step]

# [START howto_operator_emr_terminate_job_flow]
cluster_remover = EmrTerminateJobFlowOperator(
task_id='remove_cluster',
job_flow_id=cluster_creator.output,
job_flow_id=job_flow_creator.output,
)
# [END howto_operator_emr_terminate_job_flow]

chain(
job_flow_creator,
job_sensor,
cluster_modifier,
step_adder,
step_checker,
cluster_remover,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from datetime import datetime

from airflow import DAG
from airflow.models.baseoperator import chain
from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator
from airflow.providers.amazon.aws.sensors.emr import EmrContainerSensor

VIRTUAL_CLUSTER_ID = os.getenv("VIRTUAL_CLUSTER_ID", "test-cluster")
JOB_ROLE_ARN = os.getenv("JOB_ROLE_ARN", "arn:aws:iam::012345678912:role/emr_eks_default_role")
Expand Down Expand Up @@ -51,18 +53,13 @@
# [END howto_operator_emr_eks_config]

with DAG(
dag_id='example_emr_eks_job',
dag_id='example_emr_eks',
schedule_interval=None,
start_date=datetime(2021, 1, 1),
tags=['example'],
catchup=False,
) as dag:

# An example of how to get the cluster id and arn from an Airflow connection
# VIRTUAL_CLUSTER_ID = '{{ conn.emr_eks.extra_dejson["virtual_cluster_id"] }}'
# JOB_ROLE_ARN = '{{ conn.emr_eks.extra_dejson["job_role_arn"] }}'

# [START howto_operator_emr_eks_job]
# [START howto_operator_emr_container]
job_starter = EmrContainerOperator(
task_id="start_job",
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
Expand All @@ -71,5 +68,14 @@
job_driver=JOB_DRIVER_ARG,
configuration_overrides=CONFIGURATION_OVERRIDES_ARG,
name="pi.py",
wait_for_completion=False,
)
# [END howto_operator_emr_eks_job]
# [END howto_operator_emr_container]

# [START howto_sensor_emr_container]
job_waiter = EmrContainerSensor(
task_id="job_waiter", virtual_cluster_id=VIRTUAL_CLUSTER_ID, job_id=str(job_starter.output)
)
# [END howto_sensor_emr_container]

chain(job_starter, job_waiter)

This file was deleted.

34 changes: 21 additions & 13 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ class EmrContainerOperator(BaseOperator):
"""
An operator that submits jobs to EMR on EKS virtual clusters.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:EmrContainerOperator`

:param name: The name of the job run.
:param virtual_cluster_id: The EMR on EKS virtual cluster ID
:param execution_role_arn: The IAM role ARN associated with the job run.
Expand All @@ -133,6 +137,7 @@ class EmrContainerOperator(BaseOperator):
Use this if you want to specify a unique ID to prevent two jobs from getting started.
If no token is provided, a UUIDv4 token will be generated for you.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param wait_for_completion: Whether or not to wait in the operator for the job to complete.
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check query status on EMR
:param max_tries: Maximum number of times to wait for the job run to finish.
Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state.
Expand Down Expand Up @@ -160,6 +165,7 @@ def __init__(
configuration_overrides: Optional[dict] = None,
client_request_token: Optional[str] = None,
aws_conn_id: str = "aws_default",
wait_for_completion: bool = True,
poll_interval: int = 30,
max_tries: Optional[int] = None,
tags: Optional[dict] = None,
Expand All @@ -174,6 +180,7 @@ def __init__(
self.configuration_overrides = configuration_overrides or {}
self.aws_conn_id = aws_conn_id
self.client_request_token = client_request_token or str(uuid4())
self.wait_for_completion = wait_for_completion
self.poll_interval = poll_interval
self.max_tries = max_tries
self.tags = tags
Expand All @@ -198,19 +205,20 @@ def execute(self, context: 'Context') -> Optional[str]:
self.client_request_token,
self.tags,
)
query_status = self.hook.poll_query_status(self.job_id, self.max_tries, self.poll_interval)

if query_status in EmrContainerHook.FAILURE_STATES:
error_message = self.hook.get_job_failure_reason(self.job_id)
raise AirflowException(
f"EMR Containers job failed. Final state is {query_status}. "
f"query_execution_id is {self.job_id}. Error: {error_message}"
)
elif not query_status or query_status in EmrContainerHook.INTERMEDIATE_STATES:
raise AirflowException(
f"Final state of EMR Containers job is {query_status}. "
f"Max tries of poll status exceeded, query_execution_id is {self.job_id}."
)
if self.wait_for_completion:
query_status = self.hook.poll_query_status(self.job_id, self.max_tries, self.poll_interval)

if query_status in EmrContainerHook.FAILURE_STATES:
error_message = self.hook.get_job_failure_reason(self.job_id)
raise AirflowException(
f"EMR Containers job failed. Final state is {query_status}. "
f"query_execution_id is {self.job_id}. Error: {error_message}"
)
elif not query_status or query_status in EmrContainerHook.INTERMEDIATE_STATES:
raise AirflowException(
f"Final state of EMR Containers job is {query_status}. "
f"Max tries of poll status exceeded, query_execution_id is {self.job_id}."
)

return self.job_id

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/sensors/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class AthenaSensor(BaseSensorOperator):
If the query fails, the task will fail.

.. seealso::
For more information on how to use this operator, take a look at the guide:
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:AthenaSensor`


Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/sensors/cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class CloudFormationCreateStackSensor(BaseSensorOperator):
Waits for a stack to be created successfully on AWS CloudFormation.

.. seealso::
For more information on how to use this operator, take a look at the guide:
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:CloudFormationCreateStackSensor`


Expand Down Expand Up @@ -74,7 +74,7 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator):
Waits for a stack to be deleted successfully on AWS CloudFormation.

.. seealso::
For more information on how to use this operator, take a look at the guide:
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:CloudFormationDeleteStackSensor`

:param stack_name: The name of the stack to wait for (templated)
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/sensors/dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class DmsTaskCompletedSensor(DmsTaskBaseSensor):
Pokes DMS task until it is completed.

.. seealso::
For more information on how to use this operator, take a look at the guide:
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:DmsTaskCompletedSensor`

:param replication_task_arn: AWS DMS replication task ARN
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/sensors/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class EksFargateProfileStateSensor(BaseSensorOperator):
Check the state of an AWS Fargate profile until it reaches the target state or another terminal state.

.. seealso::
For more information on how to use this operator, take a look at the guide:
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:EksFargateProfileStateSensor`

:param cluster_name: The name of the Cluster which the AWS Fargate profile is attached to. (templated)
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/sensors/glacier.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class GlacierJobOperationSensor(BaseSensorOperator):
Glacier sensor for checking job state. This operator runs only in reschedule mode.

.. seealso::
For more information on how to use this operator, take a look at the guide:
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:GlacierJobOperationSensor`

:param aws_conn_id: The reference to the AWS connection details
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/sensors/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class RdsSnapshotExistenceSensor(RdsBaseSensor):
Waits for RDS snapshot with a specific status.

.. seealso::
For more information on how to use this operator, take a look at the guide:
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:RdsSnapshotExistenceSensor`

:param db_type: Type of the DB - either "instance" or "cluster"
Expand Down Expand Up @@ -112,7 +112,7 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor):
Waits for RDS export task with a specific status.

.. seealso::
For more information on how to use this operator, take a look at the guide:
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:RdsExportTaskExistenceSensor`

:param export_task_identifier: A unique identifier for the snapshot export task.
Expand Down
Loading