Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add link for EMR Steps Sensor logs #28180

Merged
merged 5 commits into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def parse_s3_url(s3url: str) -> tuple[str, str]:
:return: the parsed bucket name and key
"""
format = s3url.split("//")
if format[0].lower() == "s3:":
if re.match(r"s3[na]?:", format[0], re.IGNORECASE):
parsed_url = urlsplit(s3url)
if not parsed_url.netloc:
raise AirflowException(f'Please provide a bucket name using a valid format: "{s3url}"')
Expand Down
8 changes: 8 additions & 0 deletions airflow/providers/amazon/aws/links/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,11 @@ class EmrClusterLink(BaseAwsLink):
format_str = (
BASE_AWS_CONSOLE_LINK + "/elasticmapreduce/home?region={region_name}#cluster-details:{job_flow_id}"
)


class EmrLogsLink(BaseAwsLink):
"""Helper class for constructing AWS EMR Logs Link"""

name = "EMR Cluster Logs"
key = "emr_logs"
format_str = BASE_AWS_CONSOLE_LINK + "/s3/buckets/{log_uri}?region={region_name}&prefix={job_flow_id}/"
26 changes: 19 additions & 7 deletions airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.links.emr import EmrLogsLink
from airflow.sensors.base import BaseSensorOperator, poke_mode_only

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,7 +63,7 @@ def get_hook(self) -> EmrHook:
return self.hook

def poke(self, context: Context):
response = self.get_emr_response()
response = self.get_emr_response(context=context)

if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
self.log.info("Bad HTTP response: %s", response)
Expand All @@ -78,7 +80,7 @@ def poke(self, context: Context):

return False

def get_emr_response(self) -> dict[str, Any]:
def get_emr_response(self, context: Context) -> dict[str, Any]:
"""
Make an API call with boto3 and get response.

Expand Down Expand Up @@ -329,7 +331,7 @@ def __init__(
self.target_states = target_states or self.COMPLETED_STATES
self.failed_states = failed_states or self.FAILURE_STATES

def get_emr_response(self) -> dict[str, Any]:
def get_emr_response(self, context: Context) -> dict[str, Any]:
emr_client = self.get_hook().get_conn()
self.log.info("Poking notebook %s", self.notebook_execution_id)

Expand Down Expand Up @@ -382,6 +384,7 @@ class EmrJobFlowSensor(EmrBaseSensor):

template_fields: Sequence[str] = ("job_flow_id", "target_states", "failed_states")
template_ext: Sequence[str] = ()
operator_extra_links = (EmrLogsLink(),)

def __init__(
self,
Expand All @@ -396,7 +399,7 @@ def __init__(
self.target_states = target_states or ["TERMINATED"]
self.failed_states = failed_states or ["TERMINATED_WITH_ERRORS"]

def get_emr_response(self) -> dict[str, Any]:
def get_emr_response(self, context: Context) -> dict[str, Any]:
"""
Make an API call with boto3 and get cluster-level details.

Expand All @@ -406,9 +409,18 @@ def get_emr_response(self) -> dict[str, Any]:
:return: response
"""
emr_client = self.get_hook().get_conn()

self.log.info("Poking cluster %s", self.job_flow_id)
return emr_client.describe_cluster(ClusterId=self.job_flow_id)
response = emr_client.describe_cluster(ClusterId=self.job_flow_id)
log_uri = S3Hook.parse_s3_url(response["Cluster"]["LogUri"])
EmrLogsLink.persist(
context=context,
operator=self,
region_name=self.get_hook().conn_region_name,
aws_partition=self.get_hook().conn_partition,
job_flow_id=self.job_flow_id,
log_uri="/".join(log_uri),
)
return response

@staticmethod
def state_from_response(response: dict[str, Any]) -> str:
Expand Down Expand Up @@ -476,7 +488,7 @@ def __init__(
self.target_states = target_states or ["COMPLETED"]
self.failed_states = failed_states or ["CANCELLED", "FAILED", "INTERRUPTED"]

def get_emr_response(self) -> dict[str, Any]:
def get_emr_response(self, context: Context) -> dict[str, Any]:
"""
Make an API call with boto3 and get details about the cluster step.

Expand Down
1 change: 1 addition & 0 deletions airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ extra-links:
- airflow.providers.amazon.aws.links.batch.BatchJobDetailsLink
- airflow.providers.amazon.aws.links.batch.BatchJobQueueLink
- airflow.providers.amazon.aws.links.emr.EmrClusterLink
- airflow.providers.amazon.aws.links.emr.EmrLogsLink
- airflow.providers.amazon.aws.links.logs.CloudWatchEventsLink

connection-types:
Expand Down
8 changes: 8 additions & 0 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ def test_parse_s3_url(self):
parsed = S3Hook.parse_s3_url("s3://test/this/is/not/a-real-key.txt")
assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url"

def test_parse_s3_url_s3a_style(self):
parsed = S3Hook.parse_s3_url("s3a://test/this/is/not/a-real-key.txt")
assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url"

def test_parse_s3_url_s3n_style(self):
parsed = S3Hook.parse_s3_url("s3n://test/this/is/not/a-real-key.txt")
assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url"

def test_parse_s3_url_path_style(self):
parsed = S3Hook.parse_s3_url("https://s3.us-west-2.amazonaws.com/DOC-EXAMPLE-BUCKET1/test.jpg")
assert parsed == ("DOC-EXAMPLE-BUCKET1", "test.jpg"), "Incorrect parsing of the s3 url"
Expand Down
3 changes: 2 additions & 1 deletion tests/providers/amazon/aws/sensors/test_emr_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.sensors.emr import EmrBaseSensor
from airflow.utils.context import Context

TARGET_STATE = "TARGET_STATE"
FAILED_STATE = "FAILED_STATE"
Expand All @@ -40,7 +41,7 @@ def __init__(self, *args, **kwargs):
self.failed_states = [FAILED_STATE]
self.response = {} # will be set in tests

def get_emr_response(self):
def get_emr_response(self, context: Context):
return self.response

@staticmethod
Expand Down
9 changes: 6 additions & 3 deletions tests/providers/amazon/aws/sensors/test_emr_job_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ def setup_method(self):
# Mock out the emr_client creator
self.boto3_session_mock = MagicMock(return_value=mock_emr_session)

# Mock context used in execute function
self.mock_ctx = MagicMock()

def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self):
self.mock_emr_client.describe_cluster.side_effect = [
DESCRIBE_CLUSTER_STARTING_RETURN,
Expand All @@ -210,7 +213,7 @@ def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self
task_id="test_task", poke_interval=0, job_flow_id="j-8989898989", aws_conn_id="aws_default"
)

operator.execute(None)
operator.execute(self.mock_ctx)

# make sure we called twice
assert self.mock_emr_client.describe_cluster.call_count == 3
Expand All @@ -230,7 +233,7 @@ def test_execute_calls_with_the_job_flow_id_until_it_reaches_failed_state_with_e
)

with pytest.raises(AirflowException):
operator.execute(None)
operator.execute(self.mock_ctx)

# make sure we called twice
assert self.mock_emr_client.describe_cluster.call_count == 2
Expand All @@ -256,7 +259,7 @@ def test_different_target_states(self):
target_states=["RUNNING", "WAITING"],
)

operator.execute(None)
operator.execute(self.mock_ctx)

# make sure we called twice
assert self.mock_emr_client.describe_cluster.call_count == 3
Expand Down