diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index e5b1bad5804f2..1e2235f24af87 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -152,7 +152,7 @@ def parse_s3_url(s3url: str) -> tuple[str, str]: :return: the parsed bucket name and key """ format = s3url.split("//") - if format[0].lower() == "s3:": + if re.match(r"s3[na]?:", format[0], re.IGNORECASE): parsed_url = urlsplit(s3url) if not parsed_url.netloc: raise AirflowException(f'Please provide a bucket name using a valid format: "{s3url}"') diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py index aa739567fb919..83e190f663a8c 100644 --- a/airflow/providers/amazon/aws/links/emr.py +++ b/airflow/providers/amazon/aws/links/emr.py @@ -27,3 +27,11 @@ class EmrClusterLink(BaseAwsLink): format_str = ( BASE_AWS_CONSOLE_LINK + "/elasticmapreduce/home?region={region_name}#cluster-details:{job_flow_id}" ) + + +class EmrLogsLink(BaseAwsLink): + """Helper class for constructing AWS EMR Logs Link""" + + name = "EMR Cluster Logs" + key = "emr_logs" + format_str = BASE_AWS_CONSOLE_LINK + "/s3/buckets/{log_uri}?region={region_name}&prefix={job_flow_id}/" diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index 04e28977169bc..d1cd0949e0591 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -21,6 +21,8 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.links.emr import EmrLogsLink from airflow.sensors.base import BaseSensorOperator, poke_mode_only if TYPE_CHECKING: @@ -61,7 +63,7 @@ def get_hook(self) -> EmrHook: return self.hook def poke(self, context: Context): - response = self.get_emr_response() + response = self.get_emr_response(context=context) if response["ResponseMetadata"]["HTTPStatusCode"] != 200: self.log.info("Bad HTTP response: %s", response) @@ -78,7 +80,7 @@ def poke(self, context: Context): return False - def get_emr_response(self) -> dict[str, Any]: + def get_emr_response(self, context: Context) -> dict[str, Any]: """ Make an API call with boto3 and get response. @@ -329,7 +331,7 @@ def __init__( self.target_states = target_states or self.COMPLETED_STATES self.failed_states = failed_states or self.FAILURE_STATES - def get_emr_response(self) -> dict[str, Any]: + def get_emr_response(self, context: Context) -> dict[str, Any]: emr_client = self.get_hook().get_conn() self.log.info("Poking notebook %s", self.notebook_execution_id) @@ -382,6 +384,7 @@ class EmrJobFlowSensor(EmrBaseSensor): template_fields: Sequence[str] = ("job_flow_id", "target_states", "failed_states") template_ext: Sequence[str] = () + operator_extra_links = (EmrLogsLink(),) def __init__( self, @@ -396,7 +399,7 @@ def __init__( self.target_states = target_states or ["TERMINATED"] self.failed_states = failed_states or ["TERMINATED_WITH_ERRORS"] - def get_emr_response(self) -> dict[str, Any]: + def get_emr_response(self, context: Context) -> dict[str, Any]: """ Make an API call with boto3 and get cluster-level details. @@ -406,9 +409,18 @@ def get_emr_response(self) -> dict[str, Any]: :return: response """ emr_client = self.get_hook().get_conn() - self.log.info("Poking cluster %s", self.job_flow_id) - return emr_client.describe_cluster(ClusterId=self.job_flow_id) + response = emr_client.describe_cluster(ClusterId=self.job_flow_id) + log_uri = S3Hook.parse_s3_url(response["Cluster"]["LogUri"]) + EmrLogsLink.persist( + context=context, + operator=self, + region_name=self.get_hook().conn_region_name, + aws_partition=self.get_hook().conn_partition, + job_flow_id=self.job_flow_id, + log_uri="/".join(log_uri), + ) + return response @staticmethod def state_from_response(response: dict[str, Any]) -> str: @@ -476,7 +488,7 @@ def __init__( self.target_states = target_states or ["COMPLETED"] self.failed_states = failed_states or ["CANCELLED", "FAILED", "INTERRUPTED"] - def get_emr_response(self) -> dict[str, Any]: + def get_emr_response(self, context: Context) -> dict[str, Any]: """ Make an API call with boto3 and get details about the cluster step. diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 3f7b793a1c2a9..230bc135ae1f5 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -558,6 +558,7 @@ extra-links: - airflow.providers.amazon.aws.links.batch.BatchJobDetailsLink - airflow.providers.amazon.aws.links.batch.BatchJobQueueLink - airflow.providers.amazon.aws.links.emr.EmrClusterLink + - airflow.providers.amazon.aws.links.emr.EmrLogsLink - airflow.providers.amazon.aws.links.logs.CloudWatchEventsLink connection-types: diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py index 48cfb9912fc2b..ff231482da050 100644 --- a/tests/providers/amazon/aws/hooks/test_s3.py +++ b/tests/providers/amazon/aws/hooks/test_s3.py @@ -71,6 +71,14 @@ def test_parse_s3_url(self): parsed = S3Hook.parse_s3_url("s3://test/this/is/not/a-real-key.txt") assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url" + def test_parse_s3_url_s3a_style(self): + parsed = S3Hook.parse_s3_url("s3a://test/this/is/not/a-real-key.txt") + assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url" + + def test_parse_s3_url_s3n_style(self): + parsed = S3Hook.parse_s3_url("s3n://test/this/is/not/a-real-key.txt") + assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url" + def test_parse_s3_url_path_style(self): parsed = S3Hook.parse_s3_url("https://s3.us-west-2.amazonaws.com/DOC-EXAMPLE-BUCKET1/test.jpg") assert parsed == ("DOC-EXAMPLE-BUCKET1", "test.jpg"), "Incorrect parsing of the s3 url" diff --git a/tests/providers/amazon/aws/sensors/test_emr_base.py b/tests/providers/amazon/aws/sensors/test_emr_base.py index b0dfd66233801..f6b4351833989 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_base.py +++ b/tests/providers/amazon/aws/sensors/test_emr_base.py @@ -21,6 +21,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.sensors.emr import EmrBaseSensor +from airflow.utils.context import Context TARGET_STATE = "TARGET_STATE" FAILED_STATE = "FAILED_STATE" @@ -40,7 +41,7 @@ def __init__(self, *args, **kwargs): self.failed_states = [FAILED_STATE] self.response = {} # will be set in tests - def get_emr_response(self): + def get_emr_response(self, context: Context): return self.response @staticmethod diff --git a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py index 87a80d6a012ad..a10e68abb8483 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py +++ b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py @@ -199,6 +199,9 @@ def setup_method(self): # Mock out the emr_client creator self.boto3_session_mock = MagicMock(return_value=mock_emr_session) + # Mock context used in execute function + self.mock_ctx = MagicMock() + def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self): self.mock_emr_client.describe_cluster.side_effect = [ DESCRIBE_CLUSTER_STARTING_RETURN, @@ -210,7 +213,7 @@ def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self task_id="test_task", poke_interval=0, job_flow_id="j-8989898989", aws_conn_id="aws_default" ) - operator.execute(None) + operator.execute(self.mock_ctx) # make sure we called twice assert self.mock_emr_client.describe_cluster.call_count == 3 @@ -230,7 +233,7 @@ def test_execute_calls_with_the_job_flow_id_until_it_reaches_failed_state_with_e ) with pytest.raises(AirflowException): - operator.execute(None) + operator.execute(self.mock_ctx) # make sure we called twice assert self.mock_emr_client.describe_cluster.call_count == 2 @@ -256,7 +259,7 @@ def test_different_target_states(self): target_states=["RUNNING", "WAITING"], ) - operator.execute(None) + operator.execute(self.mock_ctx) # make sure we called twice assert self.mock_emr_client.describe_cluster.call_count == 3