Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add some important log in aws athena hook #27917

Merged
merged 4 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 50 additions & 15 deletions airflow/providers/amazon/aws/hooks/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class AthenaHook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`

:param sleep_time: Time (in seconds) to wait between two consecutive calls to check query status on Athena
:param log_query: Whether to log athena query and other execution params when it's executed.
Defaults to *True*.
"""

INTERMEDIATE_STATES = (
Expand All @@ -61,9 +63,10 @@ class AthenaHook(AwsBaseHook):
"CANCELLED",
)

def __init__(self, *args: Any, sleep_time: int = 30, **kwargs: Any) -> None:
def __init__(self, *args: Any, sleep_time: int = 30, log_query: bool = True, **kwargs: Any) -> None:
super().__init__(client_type="athena", *args, **kwargs) # type: ignore
self.sleep_time = sleep_time
self.log_query = log_query

def run_query(
self,
Expand Down Expand Up @@ -91,8 +94,12 @@ def run_query(
}
if client_request_token:
params["ClientRequestToken"] = client_request_token
if self.log_query:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason why we want to use a flag (log_query) to configure whether we log or not? In other words, why dont we log regardless ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops! I missed that. Thanks for pointing it out. LGTM

self.log.info("Running Query with params: %s", params)
uranusjr marked this conversation as resolved.
Show resolved Hide resolved
response = self.get_conn().start_query_execution(**params)
return response["QueryExecutionId"]
query_execution_id = response["QueryExecutionId"]
self.log.info("Query execution id: %s", query_execution_id)
return query_execution_id

def check_query_status(self, query_execution_id: str) -> str | None:
"""
Expand All @@ -105,8 +112,10 @@ def check_query_status(self, query_execution_id: str) -> str | None:
state = None
try:
state = response["QueryExecution"]["Status"]["State"]
except Exception as ex:
self.log.error("Exception while getting query state %s", ex)
except Exception:
self.log.exception(
"Exception while getting query state. Query execution id: %s", query_execution_id
)
finally:
# The error is being absorbed here and is being handled by the caller.
# The error is being absorbed to implement retries.
Expand All @@ -123,8 +132,11 @@ def get_state_change_reason(self, query_execution_id: str) -> str | None:
reason = None
try:
reason = response["QueryExecution"]["Status"]["StateChangeReason"]
except Exception as ex:
self.log.error("Exception while getting query state change reason: %s", ex)
except Exception:
self.log.exception(
"Exception while getting query state change reason. Query execution id: %s",
query_execution_id,
)
finally:
# The error is being absorbed here and is being handled by the caller.
# The error is being absorbed to implement retries.
Expand All @@ -144,10 +156,14 @@ def get_query_results(
"""
query_state = self.check_query_status(query_execution_id)
if query_state is None:
self.log.error("Invalid Query state")
self.log.error("Invalid Query state. Query execution id: %s", query_execution_id)
return None
elif query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES:
self.log.error('Query is in "%s" state. Cannot fetch results', query_state)
self.log.error(
'Query is in "%s" state. Cannot fetch results. Query execution id: %s',
query_state,
query_execution_id,
)
return None
result_params = {"QueryExecutionId": query_execution_id, "MaxResults": max_results}
if next_token_id:
Expand All @@ -174,10 +190,14 @@ def get_query_results_paginator(
"""
query_state = self.check_query_status(query_execution_id)
if query_state is None:
self.log.error("Invalid Query state (null)")
self.log.error("Invalid Query state (null). Query execution id: %s", query_execution_id)
return None
if query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES:
self.log.error('Query is in "%s" state. Cannot fetch results', query_state)
self.log.error(
'Query is in "%s" state. Cannot fetch results, Query execution id: %s',
query_state,
query_execution_id,
)
return None
result_params = {
"QueryExecutionId": query_execution_id,
Expand Down Expand Up @@ -222,15 +242,27 @@ def poll_query_status(
while True:
query_state = self.check_query_status(query_execution_id)
if query_state is None:
self.log.info("Trial %s: Invalid query state. Retrying again", try_number)
self.log.info(
"Query execution id: %s, trial %s: Invalid query state. Retrying again",
query_execution_id,
try_number,
)
elif query_state in self.TERMINAL_STATES:
self.log.info(
"Trial %s: Query execution completed. Final state is %s}", try_number, query_state
"Query execution id: %s, trial %s: Query execution completed. Final state is %s",
query_execution_id,
try_number,
query_state,
)
final_query_state = query_state
break
else:
self.log.info("Trial %s: Query is still in non-terminal state - %s", try_number, query_state)
self.log.info(
"Query execution id: %s, trial %s: Query is still in non-terminal state - %s",
query_execution_id,
try_number,
query_state,
)
if (
max_polling_attempts and try_number >= max_polling_attempts
): # Break loop if max_polling_attempts reached
Expand All @@ -256,12 +288,14 @@ def get_output_location(self, query_execution_id: str) -> str:
try:
output_location = response["QueryExecution"]["ResultConfiguration"]["OutputLocation"]
except KeyError:
self.log.error("Error retrieving OutputLocation")
self.log.error(
"Error retrieving OutputLocation. Query execution id: %s", query_execution_id
)
raise
else:
raise
else:
raise ValueError("Invalid Query execution id")
raise ValueError("Invalid Query execution id. Query execution id: %s", query_execution_id)

return output_location

Expand All @@ -272,4 +306,5 @@ def stop_query(self, query_execution_id: str) -> dict:
:param query_execution_id: Id of submitted athena query
:return: dict
"""
self.log.info("Stopping Query with executionId - %s", query_execution_id)
return self.get_conn().stop_query_execution(QueryExecutionId=query_execution_id)
13 changes: 9 additions & 4 deletions airflow/providers/amazon/aws/operators/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class AthenaOperator(BaseOperator):
:param max_tries: Deprecated - use max_polling_attempts instead.
:param max_polling_attempts: Number of times to poll for query state before function exits
To limit task execution time, use execution_timeout.
:param log_query: Whether to log athena query and other execution params when it's executed.
Defaults to *True*.
"""

ui_color = "#44b5e2"
Expand All @@ -69,6 +71,7 @@ def __init__(
sleep_time: int = 30,
max_tries: int | None = None,
max_polling_attempts: int | None = None,
log_query: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -83,6 +86,7 @@ def __init__(
self.sleep_time = sleep_time
self.max_polling_attempts = max_polling_attempts
self.query_execution_id: str | None = None
self.log_query: bool = log_query

if max_tries:
warnings.warn(
Expand All @@ -99,7 +103,7 @@ def __init__(
@cached_property
def hook(self) -> AthenaHook:
"""Create and return an AthenaHook."""
return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time)
return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time, log_query=self.log_query)

def execute(self, context: Context) -> str | None:
"""Run Presto Query on Athena"""
Expand Down Expand Up @@ -135,13 +139,14 @@ def on_kill(self) -> None:
"""Cancel the submitted athena query"""
if self.query_execution_id:
self.log.info("Received a kill signal.")
self.log.info("Stopping Query with executionId - %s", self.query_execution_id)
response = self.hook.stop_query(self.query_execution_id)
http_status_code = None
try:
http_status_code = response["ResponseMetadata"]["HTTPStatusCode"]
except Exception as ex:
self.log.error("Exception while cancelling query: %s", ex)
except Exception:
self.log.exception(
"Exception while cancelling query. Query execution id: %s", self.query_execution_id
)
finally:
if http_status_code is None or http_status_code != 200:
self.log.error("Unable to request query cancel on athena. Exiting")
Expand Down
21 changes: 21 additions & 0 deletions tests/providers/amazon/aws/hooks/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,27 @@ def test_hook_run_query_with_token(self, mock_conn):
mock_conn.return_value.start_query_execution.assert_called_with(**expected_call_params)
assert result == MOCK_DATA["query_execution_id"]

@mock.patch.object(AthenaHook, "log")
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_run_query_log_query(self, mock_conn, log):
self.athena.run_query(
query=MOCK_DATA["query"],
query_context=mock_query_context,
result_configuration=mock_result_configuration,
)
assert self.athena.log.info.call_count == 2

@mock.patch.object(AthenaHook, "log")
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_run_query_no_log_query(self, mock_conn, log):
athena_hook_no_log_query = AthenaHook(sleep_time=0, log_query=False)
athena_hook_no_log_query.run_query(
query=MOCK_DATA["query"],
query_context=mock_query_context,
result_configuration=mock_result_configuration,
)
assert athena_hook_no_log_query.log.info.call_count == 1

@mock.patch.object(AthenaHook, "get_conn")
def test_hook_get_query_results_with_non_succeeded_query(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
Expand Down