From 89c4fa1791653d51f752d20e51d5d0850bf95b3a Mon Sep 17 00:00:00 2001 From: AnatoliProM Date: Thu, 16 Sep 2021 14:44:07 -0700 Subject: [PATCH] Fixes a few small issues in the databricks trigger / hook: (#12) - Enables trigger logging - Updates HTTPException to ClientResponseError in databricks hook - Sets databricks connection asynchronously in hook --- astronomer_operators/databricks.py | 1 - astronomer_operators/hooks/databricks.py | 42 +++++++++++++++++-- .../hooks/test_databricks.py | 31 +++++--------- tests/astronomer_operators/test_databricks.py | 10 +++-- .../test_external_task.py | 8 ++++ 5 files changed, 63 insertions(+), 29 deletions(-) diff --git a/astronomer_operators/databricks.py b/astronomer_operators/databricks.py index ac4f2f1..4ad3dad 100644 --- a/astronomer_operators/databricks.py +++ b/astronomer_operators/databricks.py @@ -157,7 +157,6 @@ async def run(self): ) raise AirflowException(error_message) else: - # TODO: Figure out logging for trigger self.log.info("%s in run state: %s", self.task_id, run_state) self.log.info("Sleeping for %s seconds.", self.polling_period_seconds) await asyncio.sleep(self.polling_period_seconds) diff --git a/astronomer_operators/hooks/databricks.py b/astronomer_operators/hooks/databricks.py index 00ccc0b..0c0b289 100644 --- a/astronomer_operators/hooks/databricks.py +++ b/astronomer_operators/hooks/databricks.py @@ -2,7 +2,7 @@ import base64 import aiohttp -from aiohttp.web_exceptions import HTTPException +from aiohttp import ClientResponseError from airflow.exceptions import AirflowException from airflow.providers.databricks.hooks.databricks import ( GET_RUN_ENDPOINT, @@ -10,9 +10,27 @@ DatabricksHook, RunState, ) +from asgiref.sync import sync_to_async + +DEFAULT_CONN_NAME = "databricks_default" class DatabricksHookAsync(DatabricksHook): + def __init__( + self, + databricks_conn_id: str = DEFAULT_CONN_NAME, + timeout_seconds: int = 180, + retry_limit: int = 3, + retry_delay: float = 1.0, + ) -> None: + self.databricks_conn_id = databricks_conn_id + self.databricks_conn = None # To be set asynchronously in create_hook() + self.timeout_seconds = timeout_seconds + if retry_limit < 1: + raise ValueError("Retry limit must be greater than equal to 1") + self.retry_limit = retry_limit + self.retry_delay = retry_delay + async def get_run_state_async(self, run_id: str) -> RunState: """ Retrieves run state of the run using an asyncronous api call. @@ -88,12 +106,12 @@ async def _do_api_call_async(self, endpoint_info, json): ) response.raise_for_status() return await response.json() - except HTTPException as e: + except ClientResponseError as e: if not self._retryable_error_async(e): # In this case, the user probably made a mistake. # Don't retry. raise AirflowException( - f"Response: {e}, Status Code: {e.status_code}" + f"Response: {e.message}, Status Code: {e.status}" ) self._log_request_error(attempt_num, e) @@ -119,5 +137,21 @@ def _retryable_error_async(self, exception) -> bool: - anything with a status code >= 500 Most retryable errors are covered by status code >= 500. + :return: if the status is retryable + :rtype: bool """ - return exception.status_code >= 500 + return exception.status >= 500 + + +async def create_hook(): + """ + Initializes a new DatabricksHookAsync then sets its databricks_conn + field asynchronously. + :return: a new async Databricks hook + :rtype: DataBricksHookAsync() + """ + self = DatabricksHookAsync() + self.databricks_conn = await sync_to_async(self.get_connection)( + self.databricks_conn_id + ) + return self diff --git a/tests/astronomer_operators/hooks/test_databricks.py b/tests/astronomer_operators/hooks/test_databricks.py index 9a32202..e5eac06 100644 --- a/tests/astronomer_operators/hooks/test_databricks.py +++ b/tests/astronomer_operators/hooks/test_databricks.py @@ -3,14 +3,13 @@ from unittest import mock import pytest -from aiohttp.web_exceptions import HTTPBadRequest, HTTPInternalServerError from airflow.exceptions import AirflowException from airflow.providers.databricks.hooks.databricks import ( GET_RUN_ENDPOINT, SUBMIT_RUN_ENDPOINT, ) -from astronomer_operators.hooks.databricks import DatabricksHookAsync +from astronomer_operators.hooks.databricks import create_hook TASK_ID = "databricks_check" CONN_ID = "unit_test_conn_id" @@ -30,7 +29,7 @@ async def test_databricks_hook_get_run_state(mocked_response): is in a PENDING state (i.e. "RUNNING") and after it reaches a TERMINATED state (i.e. "SUCCESS"). """ - hook = DatabricksHookAsync() + hook = await create_hook() # Mock response while job is running mocked_response.return_value = { "state": { @@ -72,7 +71,7 @@ async def test_do_api_call_async_get_basic_auth(caplog, aioresponse): and basic auth. """ caplog.set_level(logging.INFO) - hook = DatabricksHookAsync() + hook = await create_hook() hook.databricks_conn.login = LOGIN hook.databricks_conn.password = PASSWORD params = {"run_id": RUN_ID} @@ -98,7 +97,7 @@ async def test_do_api_call_async_get_auth_token(caplog, aioresponse): and basic auth. """ caplog.set_level(logging.INFO) - hook = DatabricksHookAsync() + hook = await create_hook() hook.databricks_conn.extra = json.dumps({"token": "test_token"}) params = {"run_id": RUN_ID} @@ -119,7 +118,7 @@ async def test_do_api_call_async_non_retryable_error(aioresponse): Asserts that the Databricks hook will throw an exception when a non-retryable error is returned by the API. """ - hook = DatabricksHookAsync() + hook = await create_hook() hook.databricks_conn.login = LOGIN hook.databricks_conn.password = PASSWORD @@ -127,7 +126,7 @@ async def test_do_api_call_async_non_retryable_error(aioresponse): aioresponse.get( "https://localhost/api/2.0/jobs/runs/get?run_id=unit_test_run_id", - exception=HTTPBadRequest(), + status=400, ) with pytest.raises(AirflowException) as exc: @@ -142,14 +141,14 @@ async def test_do_api_call_async_retryable_error(aioresponse): Asserts that the Databricks hook will attempt another API call as many times as the retry_limit when a retryable error is returned by the API. """ - hook = DatabricksHookAsync() + hook = await create_hook() hook.databricks_conn.login = LOGIN hook.databricks_conn.password = PASSWORD params = {"run_id": RUN_ID} aioresponse.get( "https://localhost/api/2.0/jobs/runs/get?run_id=unit_test_run_id", - exception=HTTPInternalServerError(), + status=500, repeat=True, ) @@ -167,7 +166,7 @@ async def test_do_api_call_async_post(aioresponse): """ Asserts that the Databricks hook makes a POST call as expected. """ - hook = DatabricksHookAsync() + hook = await create_hook() hook.databricks_conn.login = LOGIN hook.databricks_conn.password = PASSWORD json = { @@ -192,7 +191,7 @@ async def test_do_api_call_async_unknown_method(): Asserts that the Databricks hook throws an exception when it attempts to make an API call using a non-existent method. """ - hook = DatabricksHookAsync() + hook = await create_hook() hook.databricks_conn.login = LOGIN hook.databricks_conn.password = PASSWORD json = { @@ -205,13 +204,3 @@ async def test_do_api_call_async_unknown_method(): await hook._do_api_call_async(("NOPE", "api/2.0/jobs/runs/submit"), json) assert str(exc.value) == "Unexpected HTTP Method: NOPE" - - -def test_retryable_error_async(): - """ - Asserts that HTTP errors are properly identified as retryable - when the status code >= 500. - """ - hook = DatabricksHookAsync() - assert hook._retryable_error_async(HTTPInternalServerError) - assert not hook._retryable_error_async(HTTPBadRequest) diff --git a/tests/astronomer_operators/test_databricks.py b/tests/astronomer_operators/test_databricks.py index 5ee3ed2..40ed497 100644 --- a/tests/astronomer_operators/test_databricks.py +++ b/tests/astronomer_operators/test_databricks.py @@ -193,12 +193,16 @@ async def test_databricks_trigger_running(run_state, caplog): ) task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + # TriggerEvent was not returned assert task.done() is False - # # TODO: uncomment these after logging is implemented in Trigger - # assert f"{TASK_ID} in run state: RUNNING" in caplog.text - # assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text + assert ( + f"{TASK_ID} in run state: {{'life_cycle_state': 'RUNNING', 'result_state': '', 'state_message': 'In run'}}" + in caplog.text + ) + assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text # Prevents error when task is destroyed while in "pending" state asyncio.get_event_loop().stop() diff --git a/tests/astronomer_operators/test_external_task.py b/tests/astronomer_operators/test_external_task.py index febaba3..217da04 100644 --- a/tests/astronomer_operators/test_external_task.py +++ b/tests/astronomer_operators/test_external_task.py @@ -19,6 +19,7 @@ DEFAULT_DATE = datetime(2015, 1, 1) TEST_DAG_ID = "unit_test_dag" TEST_TASK_ID = "external_task_sensor_check" +TEST_RUN_ID = "unit_test_dag_run_id" TEST_EXT_DAG_ID = "wait_for_me_dag" # DAG the external task sensor is waiting on TEST_EXT_TASK_ID = "wait_for_me_task" # Task the external task sensor is waiting on TEST_STATES = ["success", "fail"] @@ -224,6 +225,13 @@ async def test_task_state_trigger(session, dag): Asserts that the TaskStateTrigger only goes off on or after a TaskInstance reaches an allowed state (i.e. SUCCESS). """ + dag_run = DagRun( + dag.dag_id, run_type="manual", execution_date=DEFAULT_DATE, run_id=TEST_RUN_ID + ) + + session.add(dag_run) + session.commit() + external_task = DummyOperator(task_id=TEST_TASK_ID, dag=dag) instance = TaskInstance(external_task, DEFAULT_DATE) session.add(instance)