Skip to content

Commit

Permalink
Add link for EMR Steps Sensor logs (#28180)
Browse files Browse the repository at this point in the history
  • Loading branch information
syedahsn authored Dec 20, 2022
1 parent 9eacf60 commit fefcb1d
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 12 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def parse_s3_url(s3url: str) -> tuple[str, str]:
:return: the parsed bucket name and key
"""
format = s3url.split("//")
if format[0].lower() == "s3:":
if re.match(r"s3[na]?:", format[0], re.IGNORECASE):
parsed_url = urlsplit(s3url)
if not parsed_url.netloc:
raise AirflowException(f'Please provide a bucket name using a valid format: "{s3url}"')
Expand Down
8 changes: 8 additions & 0 deletions airflow/providers/amazon/aws/links/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,11 @@ class EmrClusterLink(BaseAwsLink):
format_str = (
BASE_AWS_CONSOLE_LINK + "/elasticmapreduce/home?region={region_name}#cluster-details:{job_flow_id}"
)


class EmrLogsLink(BaseAwsLink):
"""Helper class for constructing AWS EMR Logs Link"""

name = "EMR Cluster Logs"
key = "emr_logs"
format_str = BASE_AWS_CONSOLE_LINK + "/s3/buckets/{log_uri}?region={region_name}&prefix={job_flow_id}/"
26 changes: 19 additions & 7 deletions airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.links.emr import EmrLogsLink
from airflow.sensors.base import BaseSensorOperator, poke_mode_only

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,7 +63,7 @@ def get_hook(self) -> EmrHook:
return self.hook

def poke(self, context: Context):
response = self.get_emr_response()
response = self.get_emr_response(context=context)

if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
self.log.info("Bad HTTP response: %s", response)
Expand All @@ -78,7 +80,7 @@ def poke(self, context: Context):

return False

def get_emr_response(self) -> dict[str, Any]:
def get_emr_response(self, context: Context) -> dict[str, Any]:
"""
Make an API call with boto3 and get response.
Expand Down Expand Up @@ -329,7 +331,7 @@ def __init__(
self.target_states = target_states or self.COMPLETED_STATES
self.failed_states = failed_states or self.FAILURE_STATES

def get_emr_response(self) -> dict[str, Any]:
def get_emr_response(self, context: Context) -> dict[str, Any]:
emr_client = self.get_hook().get_conn()
self.log.info("Poking notebook %s", self.notebook_execution_id)

Expand Down Expand Up @@ -382,6 +384,7 @@ class EmrJobFlowSensor(EmrBaseSensor):

template_fields: Sequence[str] = ("job_flow_id", "target_states", "failed_states")
template_ext: Sequence[str] = ()
operator_extra_links = (EmrLogsLink(),)

def __init__(
self,
Expand All @@ -396,7 +399,7 @@ def __init__(
self.target_states = target_states or ["TERMINATED"]
self.failed_states = failed_states or ["TERMINATED_WITH_ERRORS"]

def get_emr_response(self) -> dict[str, Any]:
def get_emr_response(self, context: Context) -> dict[str, Any]:
"""
Make an API call with boto3 and get cluster-level details.
Expand All @@ -406,9 +409,18 @@ def get_emr_response(self) -> dict[str, Any]:
:return: response
"""
emr_client = self.get_hook().get_conn()

self.log.info("Poking cluster %s", self.job_flow_id)
return emr_client.describe_cluster(ClusterId=self.job_flow_id)
response = emr_client.describe_cluster(ClusterId=self.job_flow_id)
log_uri = S3Hook.parse_s3_url(response["Cluster"]["LogUri"])
EmrLogsLink.persist(
context=context,
operator=self,
region_name=self.get_hook().conn_region_name,
aws_partition=self.get_hook().conn_partition,
job_flow_id=self.job_flow_id,
log_uri="/".join(log_uri),
)
return response

@staticmethod
def state_from_response(response: dict[str, Any]) -> str:
Expand Down Expand Up @@ -476,7 +488,7 @@ def __init__(
self.target_states = target_states or ["COMPLETED"]
self.failed_states = failed_states or ["CANCELLED", "FAILED", "INTERRUPTED"]

def get_emr_response(self) -> dict[str, Any]:
def get_emr_response(self, context: Context) -> dict[str, Any]:
"""
Make an API call with boto3 and get details about the cluster step.
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ extra-links:
- airflow.providers.amazon.aws.links.batch.BatchJobDetailsLink
- airflow.providers.amazon.aws.links.batch.BatchJobQueueLink
- airflow.providers.amazon.aws.links.emr.EmrClusterLink
- airflow.providers.amazon.aws.links.emr.EmrLogsLink
- airflow.providers.amazon.aws.links.logs.CloudWatchEventsLink

connection-types:
Expand Down
8 changes: 8 additions & 0 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ def test_parse_s3_url(self):
parsed = S3Hook.parse_s3_url("s3://test/this/is/not/a-real-key.txt")
assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url"

def test_parse_s3_url_s3a_style(self):
parsed = S3Hook.parse_s3_url("s3a://test/this/is/not/a-real-key.txt")
assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url"

def test_parse_s3_url_s3n_style(self):
parsed = S3Hook.parse_s3_url("s3n://test/this/is/not/a-real-key.txt")
assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url"

def test_parse_s3_url_path_style(self):
parsed = S3Hook.parse_s3_url("https://s3.us-west-2.amazonaws.com/DOC-EXAMPLE-BUCKET1/test.jpg")
assert parsed == ("DOC-EXAMPLE-BUCKET1", "test.jpg"), "Incorrect parsing of the s3 url"
Expand Down
3 changes: 2 additions & 1 deletion tests/providers/amazon/aws/sensors/test_emr_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.sensors.emr import EmrBaseSensor
from airflow.utils.context import Context

TARGET_STATE = "TARGET_STATE"
FAILED_STATE = "FAILED_STATE"
Expand All @@ -40,7 +41,7 @@ def __init__(self, *args, **kwargs):
self.failed_states = [FAILED_STATE]
self.response = {} # will be set in tests

def get_emr_response(self):
def get_emr_response(self, context: Context):
return self.response

@staticmethod
Expand Down
9 changes: 6 additions & 3 deletions tests/providers/amazon/aws/sensors/test_emr_job_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ def setup_method(self):
# Mock out the emr_client creator
self.boto3_session_mock = MagicMock(return_value=mock_emr_session)

# Mock context used in execute function
self.mock_ctx = MagicMock()

def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self):
self.mock_emr_client.describe_cluster.side_effect = [
DESCRIBE_CLUSTER_STARTING_RETURN,
Expand All @@ -210,7 +213,7 @@ def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self
task_id="test_task", poke_interval=0, job_flow_id="j-8989898989", aws_conn_id="aws_default"
)

operator.execute(None)
operator.execute(self.mock_ctx)

# make sure we called twice
assert self.mock_emr_client.describe_cluster.call_count == 3
Expand All @@ -230,7 +233,7 @@ def test_execute_calls_with_the_job_flow_id_until_it_reaches_failed_state_with_e
)

with pytest.raises(AirflowException):
operator.execute(None)
operator.execute(self.mock_ctx)

# make sure we called twice
assert self.mock_emr_client.describe_cluster.call_count == 2
Expand All @@ -256,7 +259,7 @@ def test_different_target_states(self):
target_states=["RUNNING", "WAITING"],
)

operator.execute(None)
operator.execute(self.mock_ctx)

# make sure we called twice
assert self.mock_emr_client.describe_cluster.call_count == 3
Expand Down

0 comments on commit fefcb1d

Please sign in to comment.