diff --git a/astronomer_operators/hooks/databricks.py b/astronomer_operators/hooks/databricks.py index 0c0b28914..ae7a67805 100644 --- a/astronomer_operators/hooks/databricks.py +++ b/astronomer_operators/hooks/databricks.py @@ -16,21 +16,6 @@ class DatabricksHookAsync(DatabricksHook): - def __init__( - self, - databricks_conn_id: str = DEFAULT_CONN_NAME, - timeout_seconds: int = 180, - retry_limit: int = 3, - retry_delay: float = 1.0, - ) -> None: - self.databricks_conn_id = databricks_conn_id - self.databricks_conn = None # To be set asynchronously in create_hook() - self.timeout_seconds = timeout_seconds - if retry_limit < 1: - raise ValueError("Retry limit must be greater than equal to 1") - self.retry_limit = retry_limit - self.retry_delay = retry_delay - async def get_run_state_async(self, run_id: str) -> RunState: """ Retrieves run state of the run using an asyncronous api call. @@ -64,6 +49,11 @@ async def _do_api_call_async(self, endpoint_info, json): headers = USER_AGENT_HEADER attempt_num = 1 + if not self.databricks_conn: + self.databricks_conn = await sync_to_async(self.get_connection)( + self.databricks_conn_id + ) + if "token" in self.databricks_conn.extra_dejson: self.log.info("Using token auth. ") auth = self.databricks_conn.extra_dejson["token"] @@ -141,17 +131,3 @@ def _retryable_error_async(self, exception) -> bool: :rtype: bool """ return exception.status >= 500 - - -async def create_hook(): - """ - Initializes a new DatabricksHookAsync then sets its databricks_conn - field asynchronously. - :return: a new async Databricks hook - :rtype: DataBricksHookAsync() - """ - self = DatabricksHookAsync() - self.databricks_conn = await sync_to_async(self.get_connection)( - self.databricks_conn_id - ) - return self diff --git a/tests/astronomer_operators/hooks/test_databricks.py b/tests/astronomer_operators/hooks/test_databricks.py index e5eac0618..79e42285d 100644 --- a/tests/astronomer_operators/hooks/test_databricks.py +++ b/tests/astronomer_operators/hooks/test_databricks.py @@ -8,8 +8,9 @@ GET_RUN_ENDPOINT, SUBMIT_RUN_ENDPOINT, ) +from asgiref.sync import sync_to_async -from astronomer_operators.hooks.databricks import create_hook +from astronomer_operators.hooks.databricks import DatabricksHookAsync TASK_ID = "databricks_check" CONN_ID = "unit_test_conn_id" @@ -29,7 +30,7 @@ async def test_databricks_hook_get_run_state(mocked_response): is in a PENDING state (i.e. "RUNNING") and after it reaches a TERMINATED state (i.e. "SUCCESS"). """ - hook = await create_hook() + hook = DatabricksHookAsync() # Mock response while job is running mocked_response.return_value = { "state": { @@ -71,7 +72,10 @@ async def test_do_api_call_async_get_basic_auth(caplog, aioresponse): and basic auth. """ caplog.set_level(logging.INFO) - hook = await create_hook() + hook = DatabricksHookAsync() + hook.databricks_conn = await sync_to_async(hook.get_connection)( + hook.databricks_conn_id + ) hook.databricks_conn.login = LOGIN hook.databricks_conn.password = PASSWORD params = {"run_id": RUN_ID} @@ -97,7 +101,10 @@ async def test_do_api_call_async_get_auth_token(caplog, aioresponse): and basic auth. """ caplog.set_level(logging.INFO) - hook = await create_hook() + hook = DatabricksHookAsync() + hook.databricks_conn = await sync_to_async(hook.get_connection)( + hook.databricks_conn_id + ) hook.databricks_conn.extra = json.dumps({"token": "test_token"}) params = {"run_id": RUN_ID} @@ -118,8 +125,10 @@ async def test_do_api_call_async_non_retryable_error(aioresponse): Asserts that the Databricks hook will throw an exception when a non-retryable error is returned by the API. """ - hook = await create_hook() - + hook = DatabricksHookAsync() + hook.databricks_conn = await sync_to_async(hook.get_connection)( + hook.databricks_conn_id + ) hook.databricks_conn.login = LOGIN hook.databricks_conn.password = PASSWORD params = {"run_id": RUN_ID} @@ -141,7 +150,10 @@ async def test_do_api_call_async_retryable_error(aioresponse): Asserts that the Databricks hook will attempt another API call as many times as the retry_limit when a retryable error is returned by the API. """ - hook = await create_hook() + hook = DatabricksHookAsync() + hook.databricks_conn = await sync_to_async(hook.get_connection)( + hook.databricks_conn_id + ) hook.databricks_conn.login = LOGIN hook.databricks_conn.password = PASSWORD params = {"run_id": RUN_ID} @@ -166,7 +178,10 @@ async def test_do_api_call_async_post(aioresponse): """ Asserts that the Databricks hook makes a POST call as expected. """ - hook = await create_hook() + hook = DatabricksHookAsync() + hook.databricks_conn = await sync_to_async(hook.get_connection)( + hook.databricks_conn_id + ) hook.databricks_conn.login = LOGIN hook.databricks_conn.password = PASSWORD json = { @@ -191,7 +206,10 @@ async def test_do_api_call_async_unknown_method(): Asserts that the Databricks hook throws an exception when it attempts to make an API call using a non-existent method. """ - hook = await create_hook() + hook = DatabricksHookAsync() + hook.databricks_conn = await sync_to_async(hook.get_connection)( + hook.databricks_conn_id + ) hook.databricks_conn.login = LOGIN hook.databricks_conn.password = PASSWORD json = { diff --git a/tests/astronomer_operators/test_external_task.py b/tests/astronomer_operators/test_external_task.py index 217da04bb..42aeb6e38 100644 --- a/tests/astronomer_operators/test_external_task.py +++ b/tests/astronomer_operators/test_external_task.py @@ -272,9 +272,7 @@ async def test_dag_state_trigger(session, dag): reaches an allowed state (i.e. SUCCESS). """ dag_run = DagRun( - dag.dag_id, - run_type="manual", - execution_date=DEFAULT_DATE, + dag.dag_id, run_type="manual", execution_date=DEFAULT_DATE, run_id=TEST_RUN_ID ) session.add(dag_run)